fix: fix typo & move tensor to cuda device

This commit is contained in:
snsd0805 2023-03-14 02:12:58 +08:00
parent dcfeac845e
commit fab0bc4af1
Signed by: snsd0805
GPG Key ID: 569349933C77A854

View File

@ -77,9 +77,9 @@ class DDPM(nn.Module):
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w)
z = torch.randn((n, c, h, w))
z = torch.randn((n, c, h, w)).to(self.device)
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
x = x * 255
x_t = x_t * 255
return x_t