Compare commits
3 Commits
dcfeac845e
...
25e6a5ff62
| Author | SHA1 | Date | |
|---|---|---|---|
| 25e6a5ff62 | |||
| 7a0b5a2a8a | |||
| fab0bc4af1 |
4
ddpm.py
4
ddpm.py
@ -77,9 +77,9 @@ class DDPM(nn.Module):
|
||||
|
||||
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w)
|
||||
|
||||
z = torch.randn((n, c, h, w))
|
||||
z = torch.randn((n, c, h, w)).to(self.device)
|
||||
|
||||
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
|
||||
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
|
||||
x = x * 255
|
||||
x_t = x_t * 255
|
||||
return x_t
|
||||
20
sample.py
Normal file
20
sample.py
Normal file
@ -0,0 +1,20 @@
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from ddpm import DDPM
|
||||
from unet import Unet
|
||||
|
||||
BATCH_SIZE = 512
|
||||
ITERATION = 1500
|
||||
TIME_EMB_DIM = 128
|
||||
DEVICE = torch.device('cuda')
|
||||
|
||||
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
|
||||
|
||||
model.load_state_dict(torch.load('unet.pth'))
|
||||
|
||||
x_t = ddpm.sample(model, 32)
|
||||
for index, pic in enumerate(x_t):
|
||||
p = pic.to('cpu').permute(1, 2, 0)
|
||||
plt.imshow(p)
|
||||
plt.savefig("output/{}.png".format(index))
|
||||
8
unet.py
8
unet.py
@ -141,7 +141,10 @@ class Unet(nn.Module):
|
||||
self.latent3 = DoubleConv(256, 128, nn.ReLU())
|
||||
self.up1 = UpSampling(128, 64, time_emb_dim)
|
||||
self.up2 = UpSampling(64, 32, time_emb_dim)
|
||||
self.out = DoubleConv(32, 1, nn.Tanh())
|
||||
# self.out = DoubleConv(32, 1, nn.Tanh())
|
||||
self.out1 = nn.Conv2d(32, 32, 3, padding=1)
|
||||
self.out2 = nn.Conv2d(32, 1, 3, padding=1)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.time_embedding = PositionEncode(time_emb_dim, device)
|
||||
|
||||
@ -158,5 +161,6 @@ class Unet(nn.Module):
|
||||
|
||||
l4 = self.up1(latent, l2, time_emb) # (b, 64, 14, 14)
|
||||
l5 = self.up2(l4, l1, time_emb) # (b, 32, 28, 28)
|
||||
out = self.out(l5) # (b, 1, 28, 28)
|
||||
out = self.relu(self.out1(l5)) # (b, 1, 28, 28)
|
||||
out = self.out2(out) # (b, 1, 28, 28)
|
||||
return out
|
||||
|
||||
Loading…
Reference in New Issue
Block a user