diff --git a/unet.py b/unet.py index 6d1a57a..5034544 100644 --- a/unet.py +++ b/unet.py @@ -146,21 +146,32 @@ class Unet(nn.Module): self.out2 = nn.Conv2d(32, 1, 3, padding=1) self.relu = nn.ReLU() + self.dropout05 = nn.Dropout(0.2) + self.time_embedding = PositionEncode(time_emb_dim, device) def forward(self, x, time_seq): time_emb = self.time_embedding(time_seq) # (b, time_emb_dim) l1 = self.in1(x) # (b, 32, 28, 28) + l1 = self.dropout05(l1) l2 = self.down1(l1, time_emb) # (b, 64, 14, 14) + l2 = self.dropout05(l2) l3 = self.down2(l2, time_emb) # (b,128, 7, 7) + l3 = self.dropout05(l3) latent = self.latent1(l3) # (b, 256, 7, 7) + latent = self.dropout05(latent) latent = self.latent2(latent) # (b, 256, 7, 7) + latent = self.dropout05(latent) latent = self.latent3(latent) # (b, 128, 7, 7) + latent = self.dropout05(latent) l4 = self.up1(latent, l2, time_emb) # (b, 64, 14, 14) + l4 = self.dropout05(l4) l5 = self.up2(l4, l1, time_emb) # (b, 32, 28, 28) - out = self.relu(self.out1(l5)) # (b, 1, 28, 28) - out = self.out2(out) # (b, 1, 28, 28) + l5 = self.dropout05(l5) + out = self.relu(self.out1(l5)) # (b, 1, 28, 28) + out = self.dropout05(out) + out = self.out2(out) # (b, 1, 28, 28) return out