fix: remove nn.Tanh() on the last layer
This commit is contained in:
parent
fab0bc4af1
commit
7a0b5a2a8a
8
unet.py
8
unet.py
@ -141,7 +141,10 @@ class Unet(nn.Module):
|
||||
self.latent3 = DoubleConv(256, 128, nn.ReLU())
|
||||
self.up1 = UpSampling(128, 64, time_emb_dim)
|
||||
self.up2 = UpSampling(64, 32, time_emb_dim)
|
||||
self.out = DoubleConv(32, 1, nn.Tanh())
|
||||
# self.out = DoubleConv(32, 1, nn.Tanh())
|
||||
self.out1 = nn.Conv2d(32, 32, 3, padding=1)
|
||||
self.out2 = nn.Conv2d(32, 1, 3, padding=1)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.time_embedding = PositionEncode(time_emb_dim, device)
|
||||
|
||||
@ -158,5 +161,6 @@ class Unet(nn.Module):
|
||||
|
||||
l4 = self.up1(latent, l2, time_emb) # (b, 64, 14, 14)
|
||||
l5 = self.up2(l4, l1, time_emb) # (b, 32, 28, 28)
|
||||
out = self.out(l5) # (b, 1, 28, 28)
|
||||
out = self.relu(self.out1(l5)) # (b, 1, 28, 28)
|
||||
out = self.out2(out) # (b, 1, 28, 28)
|
||||
return out
|
||||
|
||||
Loading…
Reference in New Issue
Block a user