fix: remove nn.Tanh() on the last layer

This commit is contained in:
snsd0805 2023-03-14 02:14:19 +08:00
parent fab0bc4af1
commit 7a0b5a2a8a
Signed by: snsd0805
GPG Key ID: 569349933C77A854

View File

@ -141,7 +141,10 @@ class Unet(nn.Module):
self.latent3 = DoubleConv(256, 128, nn.ReLU())
self.up1 = UpSampling(128, 64, time_emb_dim)
self.up2 = UpSampling(64, 32, time_emb_dim)
self.out = DoubleConv(32, 1, nn.Tanh())
# self.out = DoubleConv(32, 1, nn.Tanh())
self.out1 = nn.Conv2d(32, 32, 3, padding=1)
self.out2 = nn.Conv2d(32, 1, 3, padding=1)
self.relu = nn.ReLU()
self.time_embedding = PositionEncode(time_emb_dim, device)
@ -158,5 +161,6 @@ class Unet(nn.Module):
l4 = self.up1(latent, l2, time_emb) # (b, 64, 14, 14)
l5 = self.up2(l4, l1, time_emb) # (b, 32, 28, 28)
out = self.out(l5) # (b, 1, 28, 28)
out = self.relu(self.out1(l5)) # (b, 1, 28, 28)
out = self.out2(out) # (b, 1, 28, 28)
return out