diff --git a/ddpm.py b/ddpm.py index 49bd3fd..10cc889 100644 --- a/ddpm.py +++ b/ddpm.py @@ -77,9 +77,9 @@ class DDPM(nn.Module): beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w) - z = torch.randn((n, c, h, w)) + z = torch.randn((n, c, h, w)).to(self.device) x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta x_t = ( x_t.clamp(-1, 1) + 1 ) / 2 - x = x * 255 + x_t = x_t * 255 return x_t \ No newline at end of file