feat: check whether generating seq pic

This commit is contained in:
snsd0805 2023-03-14 22:46:53 +08:00
parent d3eff8e425
commit 37adcc7e97
Signed by: snsd0805
GPG Key ID: 569349933C77A854

18
ddpm.py
View File

@ -53,7 +53,7 @@ class DDPM(nn.Module):
return mu + sigma * epsilon, epsilon # (b, c, w, h) return mu + sigma * epsilon, epsilon # (b, c, w, h)
def sample(self, model, n): def sample(self, model, generate_iteration_pic=False, n=self.batch_size):
''' '''
Inputs: Inputs:
model (nn.Module): Unet instance model (nn.Module): Unet instance
@ -84,13 +84,15 @@ class DDPM(nn.Module):
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
if i % (self.iteration/10) == 0: # generate 10 pic on the different denoising times
p = x_t[0].cpu() if generate_iteration_pic:
p = ( p.clamp(-1, 1) + 1 ) / 2 if i % (self.iteration/10) == 0:
p = p * 255 p = x_t[0].cpu()
p = p.permute(1, 2, 0) p = ( p.clamp(-1, 1) + 1 ) / 2
plt.imshow(p, vmin=0, vmax=255, cmap='gray') p = p * 255
plt.savefig("output/iter_{}.png".format(i)) p = p.permute(1, 2, 0)
plt.imshow(p, vmin=0, vmax=255, cmap='gray')
plt.savefig("output/iter_{}.png".format(i))
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2 x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
x_t = x_t * 255 x_t = x_t * 255