fix: add dropout layer into unet model

This commit is contained in:
snsd0805 2023-03-14 18:18:13 +08:00
parent 79e8f37f5b
commit 3530d91aaf
Signed by: snsd0805
GPG Key ID: 569349933C77A854

11
unet.py
View File

@ -146,21 +146,32 @@ class Unet(nn.Module):
self.out2 = nn.Conv2d(32, 1, 3, padding=1)
self.relu = nn.ReLU()
self.dropout05 = nn.Dropout(0.2)
self.time_embedding = PositionEncode(time_emb_dim, device)
def forward(self, x, time_seq):
time_emb = self.time_embedding(time_seq) # (b, time_emb_dim)
l1 = self.in1(x) # (b, 32, 28, 28)
l1 = self.dropout05(l1)
l2 = self.down1(l1, time_emb) # (b, 64, 14, 14)
l2 = self.dropout05(l2)
l3 = self.down2(l2, time_emb) # (b,128, 7, 7)
l3 = self.dropout05(l3)
latent = self.latent1(l3) # (b, 256, 7, 7)
latent = self.dropout05(latent)
latent = self.latent2(latent) # (b, 256, 7, 7)
latent = self.dropout05(latent)
latent = self.latent3(latent) # (b, 128, 7, 7)
latent = self.dropout05(latent)
l4 = self.up1(latent, l2, time_emb) # (b, 64, 14, 14)
l4 = self.dropout05(l4)
l5 = self.up2(l4, l1, time_emb) # (b, 32, 28, 28)
l5 = self.dropout05(l5)
out = self.relu(self.out1(l5)) # (b, 1, 28, 28)
out = self.dropout05(out)
out = self.out2(out) # (b, 1, 28, 28)
return out