Compare commits

..

5 Commits

4 changed files with 91 additions and 40 deletions

25
ddpm.py
View File

@ -6,8 +6,6 @@ class DDPM(nn.Module):
''' '''
Denoising Diffussion Probabilistic Model Denoising Diffussion Probabilistic Model
Inputs:
Args: Args:
batch_size (int): batch_size, for generate time_seq, etc. batch_size (int): batch_size, for generate time_seq, etc.
iteration (int): max time_seq iteration (int): max time_seq
@ -53,14 +51,17 @@ class DDPM(nn.Module):
return mu + sigma * epsilon, epsilon # (b, c, w, h) return mu + sigma * epsilon, epsilon # (b, c, w, h)
def sample(self, model, n): def sample(self, model, generate_iteration_pic=False, n=None):
''' '''
Inputs: Inputs:
model (nn.Module): Unet instance model (nn.Module): Unet instance
n (int): want to sample n pictures generate_iteration_pic (bool): whether generate 10 pic on different denoising time
n (int, default=self.batch_size): want to sample n pictures
Outputs: Outputs:
x_0 (nn.Tensor): (n, c, h, w) x_0 (nn.Tensor): (n, c, h, w)
''' '''
if n == None:
n = self.batch_size
c, h, w = 1, 28, 28 c, h, w = 1, 28, 28
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
@ -84,13 +85,15 @@ class DDPM(nn.Module):
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
if i % (self.iteration/10) == 0: # generate 10 pic on the different denoising times
p = x_t[0].cpu() if generate_iteration_pic:
p = ( p.clamp(-1, 1) + 1 ) / 2 if i % (self.iteration/10) == 0:
p = p * 255 p = x_t[0].cpu()
p = p.permute(1, 2, 0) p = ( p.clamp(-1, 1) + 1 ) / 2
plt.imshow(p, vmin=0, vmax=255, cmap='gray') p = p * 255
plt.savefig("output/iter_{}.png".format(i)) p = p.permute(1, 2, 0)
plt.imshow(p, vmin=0, vmax=255, cmap='gray')
plt.savefig("output/iter_{}.png".format(i))
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2 x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
x_t = x_t * 255 x_t = x_t * 255

View File

@ -2,21 +2,36 @@ import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from ddpm import DDPM from ddpm import DDPM
from unet import Unet from unet import Unet
import sys
import os
import configparser
BATCH_SIZE = 256 if __name__ == '__main__':
ITERATION = 500 if len(sys.argv) != 2:
TIME_EMB_DIM = 128 print("Usage: python sample.py [pic_num]")
DEVICE = torch.device('cuda') exit()
# read config file
config = configparser.ConfigParser()
config.read('training.ini')
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE) BATCH_SIZE = int(config['unet']['batch_size'])
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE) ITERATION = int(config['ddpm']['iteration'])
TIME_EMB_DIM = int(config['unet']['time_emb_dim'])
DEVICE = torch.device(config['unet']['device'])
model.load_state_dict(torch.load('unet.pth')) # start sampling
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
x_t = ddpm.sample(model, 256) model.load_state_dict(torch.load('unet.pth'))
for index, pic in enumerate(x_t):
p = pic.to('cpu').permute(1, 2, 0)
plt.imshow(p, cmap='gray', vmin=0, vmax=255)
plt.savefig("output/{}.png".format(index))
x_t = ddpm.sample(model)
if not os.path.isdir('./output'):
os.mkdir('./output')
for index, pic in enumerate(x_t):
p = pic.to('cpu').permute(1, 2, 0)
plt.imshow(p, cmap='gray', vmin=0, vmax=255)
plt.savefig("output/{}.png".format(index))

View File

@ -7,15 +7,18 @@ import matplotlib.pyplot as plt
from tqdm import tqdm from tqdm import tqdm
from ddpm import DDPM from ddpm import DDPM
from unet import Unet from unet import Unet
import configparser
BATCH_SIZE = 256 def getMnistLoader(config):
ITERATION = 500 '''
TIME_EMB_DIM = 128 Get MNIST dataset's loader
DEVICE = torch.device('cuda')
EPOCH_NUM = 500
LEARNING_RATE = 1e-4
def getMnistLoader(): Inputs:
config (configparser.ConfigParser)
Outputs:
loader (nn.utils.data.DataLoader)
'''
BATCH_SIZE = int(config['unet']['batch_size'])
transform = torchvision.transforms.Compose([ transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor() torchvision.transforms.ToTensor()
@ -25,18 +28,33 @@ def getMnistLoader():
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True) loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
return loader return loader
def train(loader, device, epoch_num, lr): def train(config):
model = Unet(TIME_EMB_DIM, DEVICE).to(device) '''
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, device) Start Unet Training
Inputs:
config (configparser.ConfigParser)
Outputs:
None
'''
BATCH_SIZE = int(config['unet']['batch_size'])
ITERATION = int(config['ddpm']['iteration'])
TIME_EMB_DIM = int(config['unet']['time_emb_dim'])
DEVICE = torch.device(config['unet']['device'])
EPOCH_NUM = int(config['unet']['epoch_num'])
LEARNING_RATE = float(config['unet']['learning_rate'])
# training
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
criterion = nn.MSELoss() criterion = nn.MSELoss()
optimzer = torch.optim.Adam(model.parameters(), lr=lr) optimzer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
min_loss = 99 min_loss = 99
for epoch in range(epoch_num): for epoch in range(EPOCH_NUM):
loss_sum = 0 loss_sum = 0
# progress = tqdm(total=len(loader))
for x, y in loader: for x, y in loader:
optimzer.zero_grad() optimzer.zero_grad()
@ -51,12 +69,18 @@ def train(loader, device, epoch_num, lr):
loss.backward() loss.backward()
optimzer.step() optimzer.step()
# progress.update(1)
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader))) print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader)))
if loss_sum/len(loader) < min_loss: if loss_sum/len(loader) < min_loss:
min_loss = loss_sum/len(loader) min_loss = loss_sum/len(loader)
print("save model: the best loss is {}".format(min_loss)) print("save model: the best loss is {}".format(min_loss))
torch.save(model.state_dict(), 'unet.pth') torch.save(model.state_dict(), 'unet.pth')
loader = getMnistLoader() if __name__ == '__main__':
train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE) # read config file
config = configparser.ConfigParser()
config.read('training.ini')
# start training
loader = getMnistLoader(config)
train(config)

9
training.ini Normal file
View File

@ -0,0 +1,9 @@
[unet]
batch_size = 256
time_emb_dim = 128
device = cuda
epoch_num = 500
learning_rate = 1e-4
[ddpm]
iteration = 500