feat: check whether generating seq pic
This commit is contained in:
parent
d3eff8e425
commit
37adcc7e97
18
ddpm.py
18
ddpm.py
@ -53,7 +53,7 @@ class DDPM(nn.Module):
|
|||||||
|
|
||||||
return mu + sigma * epsilon, epsilon # (b, c, w, h)
|
return mu + sigma * epsilon, epsilon # (b, c, w, h)
|
||||||
|
|
||||||
def sample(self, model, n):
|
def sample(self, model, generate_iteration_pic=False, n=self.batch_size):
|
||||||
'''
|
'''
|
||||||
Inputs:
|
Inputs:
|
||||||
model (nn.Module): Unet instance
|
model (nn.Module): Unet instance
|
||||||
@ -84,13 +84,15 @@ class DDPM(nn.Module):
|
|||||||
|
|
||||||
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
|
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
|
||||||
|
|
||||||
if i % (self.iteration/10) == 0:
|
# generate 10 pic on the different denoising times
|
||||||
p = x_t[0].cpu()
|
if generate_iteration_pic:
|
||||||
p = ( p.clamp(-1, 1) + 1 ) / 2
|
if i % (self.iteration/10) == 0:
|
||||||
p = p * 255
|
p = x_t[0].cpu()
|
||||||
p = p.permute(1, 2, 0)
|
p = ( p.clamp(-1, 1) + 1 ) / 2
|
||||||
plt.imshow(p, vmin=0, vmax=255, cmap='gray')
|
p = p * 255
|
||||||
plt.savefig("output/iter_{}.png".format(i))
|
p = p.permute(1, 2, 0)
|
||||||
|
plt.imshow(p, vmin=0, vmax=255, cmap='gray')
|
||||||
|
plt.savefig("output/iter_{}.png".format(i))
|
||||||
|
|
||||||
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
|
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
|
||||||
x_t = x_t * 255
|
x_t = x_t * 255
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user