diff --git a/ddpm.py b/ddpm.py index 10cc889..24ddad1 100644 --- a/ddpm.py +++ b/ddpm.py @@ -77,9 +77,21 @@ class DDPM(nn.Module): beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w) - z = torch.randn((n, c, h, w)).to(self.device) + if i!= 0: + z = torch.randn((n, c, h, w)).to(self.device) + else: + z = torch.zeros((n, c, h, w)).to(self.device) x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta + + if i % (self.iteration/10) == 0: + p = x_t[0].cpu() + p = ( p.clamp(-1, 1) + 1 ) / 2 + p = p * 255 + p = p.permute(1, 2, 0) + plt.imshow(p, vmin=0, vmax=255, cmap='gray') + plt.savefig("output/iter_{}.png".format(i)) + x_t = ( x_t.clamp(-1, 1) + 1 ) / 2 x_t = x_t * 255 return x_t \ No newline at end of file