Compare commits
5 Commits
d3eff8e425
...
4c4297f2bd
| Author | SHA1 | Date | |
|---|---|---|---|
| 4c4297f2bd | |||
| 81779cf5e2 | |||
| 667346c561 | |||
| a330666a17 | |||
| 37adcc7e97 |
25
ddpm.py
25
ddpm.py
@ -6,8 +6,6 @@ class DDPM(nn.Module):
|
||||
'''
|
||||
Denoising Diffussion Probabilistic Model
|
||||
|
||||
Inputs:
|
||||
|
||||
Args:
|
||||
batch_size (int): batch_size, for generate time_seq, etc.
|
||||
iteration (int): max time_seq
|
||||
@ -53,14 +51,17 @@ class DDPM(nn.Module):
|
||||
|
||||
return mu + sigma * epsilon, epsilon # (b, c, w, h)
|
||||
|
||||
def sample(self, model, n):
|
||||
def sample(self, model, generate_iteration_pic=False, n=None):
|
||||
'''
|
||||
Inputs:
|
||||
model (nn.Module): Unet instance
|
||||
n (int): want to sample n pictures
|
||||
generate_iteration_pic (bool): whether generate 10 pic on different denoising time
|
||||
n (int, default=self.batch_size): want to sample n pictures
|
||||
Outputs:
|
||||
x_0 (nn.Tensor): (n, c, h, w)
|
||||
'''
|
||||
if n == None:
|
||||
n = self.batch_size
|
||||
c, h, w = 1, 28, 28
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
@ -84,13 +85,15 @@ class DDPM(nn.Module):
|
||||
|
||||
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
|
||||
|
||||
if i % (self.iteration/10) == 0:
|
||||
p = x_t[0].cpu()
|
||||
p = ( p.clamp(-1, 1) + 1 ) / 2
|
||||
p = p * 255
|
||||
p = p.permute(1, 2, 0)
|
||||
plt.imshow(p, vmin=0, vmax=255, cmap='gray')
|
||||
plt.savefig("output/iter_{}.png".format(i))
|
||||
# generate 10 pic on the different denoising times
|
||||
if generate_iteration_pic:
|
||||
if i % (self.iteration/10) == 0:
|
||||
p = x_t[0].cpu()
|
||||
p = ( p.clamp(-1, 1) + 1 ) / 2
|
||||
p = p * 255
|
||||
p = p.permute(1, 2, 0)
|
||||
plt.imshow(p, vmin=0, vmax=255, cmap='gray')
|
||||
plt.savefig("output/iter_{}.png".format(i))
|
||||
|
||||
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
|
||||
x_t = x_t * 255
|
||||
|
||||
41
sample.py
41
sample.py
@ -2,21 +2,36 @@ import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from ddpm import DDPM
|
||||
from unet import Unet
|
||||
import sys
|
||||
import os
|
||||
import configparser
|
||||
|
||||
BATCH_SIZE = 256
|
||||
ITERATION = 500
|
||||
TIME_EMB_DIM = 128
|
||||
DEVICE = torch.device('cuda')
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python sample.py [pic_num]")
|
||||
exit()
|
||||
|
||||
# read config file
|
||||
config = configparser.ConfigParser()
|
||||
config.read('training.ini')
|
||||
|
||||
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
|
||||
BATCH_SIZE = int(config['unet']['batch_size'])
|
||||
ITERATION = int(config['ddpm']['iteration'])
|
||||
TIME_EMB_DIM = int(config['unet']['time_emb_dim'])
|
||||
DEVICE = torch.device(config['unet']['device'])
|
||||
|
||||
model.load_state_dict(torch.load('unet.pth'))
|
||||
# start sampling
|
||||
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
|
||||
|
||||
x_t = ddpm.sample(model, 256)
|
||||
for index, pic in enumerate(x_t):
|
||||
p = pic.to('cpu').permute(1, 2, 0)
|
||||
plt.imshow(p, cmap='gray', vmin=0, vmax=255)
|
||||
plt.savefig("output/{}.png".format(index))
|
||||
model.load_state_dict(torch.load('unet.pth'))
|
||||
|
||||
|
||||
x_t = ddpm.sample(model)
|
||||
|
||||
if not os.path.isdir('./output'):
|
||||
os.mkdir('./output')
|
||||
|
||||
for index, pic in enumerate(x_t):
|
||||
p = pic.to('cpu').permute(1, 2, 0)
|
||||
plt.imshow(p, cmap='gray', vmin=0, vmax=255)
|
||||
plt.savefig("output/{}.png".format(index))
|
||||
56
train.py
56
train.py
@ -7,15 +7,18 @@ import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
from ddpm import DDPM
|
||||
from unet import Unet
|
||||
import configparser
|
||||
|
||||
BATCH_SIZE = 256
|
||||
ITERATION = 500
|
||||
TIME_EMB_DIM = 128
|
||||
DEVICE = torch.device('cuda')
|
||||
EPOCH_NUM = 500
|
||||
LEARNING_RATE = 1e-4
|
||||
def getMnistLoader(config):
|
||||
'''
|
||||
Get MNIST dataset's loader
|
||||
|
||||
def getMnistLoader():
|
||||
Inputs:
|
||||
config (configparser.ConfigParser)
|
||||
Outputs:
|
||||
loader (nn.utils.data.DataLoader)
|
||||
'''
|
||||
BATCH_SIZE = int(config['unet']['batch_size'])
|
||||
|
||||
transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor()
|
||||
@ -25,18 +28,33 @@ def getMnistLoader():
|
||||
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
|
||||
return loader
|
||||
|
||||
def train(loader, device, epoch_num, lr):
|
||||
model = Unet(TIME_EMB_DIM, DEVICE).to(device)
|
||||
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, device)
|
||||
def train(config):
|
||||
'''
|
||||
Start Unet Training
|
||||
|
||||
Inputs:
|
||||
config (configparser.ConfigParser)
|
||||
Outputs:
|
||||
None
|
||||
'''
|
||||
BATCH_SIZE = int(config['unet']['batch_size'])
|
||||
ITERATION = int(config['ddpm']['iteration'])
|
||||
TIME_EMB_DIM = int(config['unet']['time_emb_dim'])
|
||||
DEVICE = torch.device(config['unet']['device'])
|
||||
EPOCH_NUM = int(config['unet']['epoch_num'])
|
||||
LEARNING_RATE = float(config['unet']['learning_rate'])
|
||||
|
||||
# training
|
||||
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
optimzer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||
optimzer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
min_loss = 99
|
||||
|
||||
for epoch in range(epoch_num):
|
||||
for epoch in range(EPOCH_NUM):
|
||||
loss_sum = 0
|
||||
# progress = tqdm(total=len(loader))
|
||||
for x, y in loader:
|
||||
optimzer.zero_grad()
|
||||
|
||||
@ -51,12 +69,18 @@ def train(loader, device, epoch_num, lr):
|
||||
|
||||
loss.backward()
|
||||
optimzer.step()
|
||||
# progress.update(1)
|
||||
|
||||
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader)))
|
||||
if loss_sum/len(loader) < min_loss:
|
||||
min_loss = loss_sum/len(loader)
|
||||
print("save model: the best loss is {}".format(min_loss))
|
||||
torch.save(model.state_dict(), 'unet.pth')
|
||||
|
||||
loader = getMnistLoader()
|
||||
train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE)
|
||||
if __name__ == '__main__':
|
||||
# read config file
|
||||
config = configparser.ConfigParser()
|
||||
config.read('training.ini')
|
||||
|
||||
# start training
|
||||
loader = getMnistLoader(config)
|
||||
train(config)
|
||||
9
training.ini
Normal file
9
training.ini
Normal file
@ -0,0 +1,9 @@
|
||||
[unet]
|
||||
batch_size = 256
|
||||
time_emb_dim = 128
|
||||
device = cuda
|
||||
epoch_num = 500
|
||||
learning_rate = 1e-4
|
||||
|
||||
[ddpm]
|
||||
iteration = 500
|
||||
Loading…
Reference in New Issue
Block a user