fix: dont add noise when t=0 && save iteration figure in DDPM sample
This commit is contained in:
parent
25e6a5ff62
commit
79e8f37f5b
14
ddpm.py
14
ddpm.py
@ -77,9 +77,21 @@ class DDPM(nn.Module):
|
|||||||
|
|
||||||
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w)
|
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w)
|
||||||
|
|
||||||
z = torch.randn((n, c, h, w)).to(self.device)
|
if i!= 0:
|
||||||
|
z = torch.randn((n, c, h, w)).to(self.device)
|
||||||
|
else:
|
||||||
|
z = torch.zeros((n, c, h, w)).to(self.device)
|
||||||
|
|
||||||
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
|
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
|
||||||
|
|
||||||
|
if i % (self.iteration/10) == 0:
|
||||||
|
p = x_t[0].cpu()
|
||||||
|
p = ( p.clamp(-1, 1) + 1 ) / 2
|
||||||
|
p = p * 255
|
||||||
|
p = p.permute(1, 2, 0)
|
||||||
|
plt.imshow(p, vmin=0, vmax=255, cmap='gray')
|
||||||
|
plt.savefig("output/iter_{}.png".format(i))
|
||||||
|
|
||||||
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
|
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
|
||||||
x_t = x_t * 255
|
x_t = x_t * 255
|
||||||
return x_t
|
return x_t
|
||||||
Loading…
Reference in New Issue
Block a user