From 37adcc7e9702dc61da618da70892bd070e0c12d7 Mon Sep 17 00:00:00 2001 From: snsd0805 Date: Tue, 14 Mar 2023 22:46:53 +0800 Subject: [PATCH] feat: check whether generating seq pic --- ddpm.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/ddpm.py b/ddpm.py index 24ddad1..6a430a3 100644 --- a/ddpm.py +++ b/ddpm.py @@ -53,7 +53,7 @@ class DDPM(nn.Module): return mu + sigma * epsilon, epsilon # (b, c, w, h) - def sample(self, model, n): + def sample(self, model, generate_iteration_pic=False, n=self.batch_size): ''' Inputs: model (nn.Module): Unet instance @@ -84,13 +84,15 @@ class DDPM(nn.Module): x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta - if i % (self.iteration/10) == 0: - p = x_t[0].cpu() - p = ( p.clamp(-1, 1) + 1 ) / 2 - p = p * 255 - p = p.permute(1, 2, 0) - plt.imshow(p, vmin=0, vmax=255, cmap='gray') - plt.savefig("output/iter_{}.png".format(i)) + # generate 10 pic on the different denoising times + if generate_iteration_pic: + if i % (self.iteration/10) == 0: + p = x_t[0].cpu() + p = ( p.clamp(-1, 1) + 1 ) / 2 + p = p * 255 + p = p.permute(1, 2, 0) + plt.imshow(p, vmin=0, vmax=255, cmap='gray') + plt.savefig("output/iter_{}.png".format(i)) x_t = ( x_t.clamp(-1, 1) + 1 ) / 2 x_t = x_t * 255