Compare commits

..

No commits in common. "3530d91aafe0ac90edeec6eff7b34f2f98596d3d" and "25e6a5ff62b2a68e5e1bdfc6acdc59edb8d4304b" have entirely different histories.

2 changed files with 3 additions and 26 deletions

14
ddpm.py
View File

@ -77,21 +77,9 @@ class DDPM(nn.Module):
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w) beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w)
if i!= 0: z = torch.randn((n, c, h, w)).to(self.device)
z = torch.randn((n, c, h, w)).to(self.device)
else:
z = torch.zeros((n, c, h, w)).to(self.device)
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
if i % (self.iteration/10) == 0:
p = x_t[0].cpu()
p = ( p.clamp(-1, 1) + 1 ) / 2
p = p * 255
p = p.permute(1, 2, 0)
plt.imshow(p, vmin=0, vmax=255, cmap='gray')
plt.savefig("output/iter_{}.png".format(i))
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2 x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
x_t = x_t * 255 x_t = x_t * 255
return x_t return x_t

15
unet.py
View File

@ -146,32 +146,21 @@ class Unet(nn.Module):
self.out2 = nn.Conv2d(32, 1, 3, padding=1) self.out2 = nn.Conv2d(32, 1, 3, padding=1)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.dropout05 = nn.Dropout(0.2)
self.time_embedding = PositionEncode(time_emb_dim, device) self.time_embedding = PositionEncode(time_emb_dim, device)
def forward(self, x, time_seq): def forward(self, x, time_seq):
time_emb = self.time_embedding(time_seq) # (b, time_emb_dim) time_emb = self.time_embedding(time_seq) # (b, time_emb_dim)
l1 = self.in1(x) # (b, 32, 28, 28) l1 = self.in1(x) # (b, 32, 28, 28)
l1 = self.dropout05(l1)
l2 = self.down1(l1, time_emb) # (b, 64, 14, 14) l2 = self.down1(l1, time_emb) # (b, 64, 14, 14)
l2 = self.dropout05(l2)
l3 = self.down2(l2, time_emb) # (b,128, 7, 7) l3 = self.down2(l2, time_emb) # (b,128, 7, 7)
l3 = self.dropout05(l3)
latent = self.latent1(l3) # (b, 256, 7, 7) latent = self.latent1(l3) # (b, 256, 7, 7)
latent = self.dropout05(latent)
latent = self.latent2(latent) # (b, 256, 7, 7) latent = self.latent2(latent) # (b, 256, 7, 7)
latent = self.dropout05(latent)
latent = self.latent3(latent) # (b, 128, 7, 7) latent = self.latent3(latent) # (b, 128, 7, 7)
latent = self.dropout05(latent)
l4 = self.up1(latent, l2, time_emb) # (b, 64, 14, 14) l4 = self.up1(latent, l2, time_emb) # (b, 64, 14, 14)
l4 = self.dropout05(l4)
l5 = self.up2(l4, l1, time_emb) # (b, 32, 28, 28) l5 = self.up2(l4, l1, time_emb) # (b, 32, 28, 28)
l5 = self.dropout05(l5) out = self.relu(self.out1(l5)) # (b, 1, 28, 28)
out = self.relu(self.out1(l5)) # (b, 1, 28, 28) out = self.out2(out) # (b, 1, 28, 28)
out = self.dropout05(out)
out = self.out2(out) # (b, 1, 28, 28)
return out return out