Compare commits

..

No commits in common. "4c4297f2bd9b3cbf2f5da63b2827db5368f1e3ef" and "d3eff8e42517f37959731ded3cc65109898cc262" have entirely different histories.

4 changed files with 40 additions and 91 deletions

11
ddpm.py
View File

@ -6,6 +6,8 @@ class DDPM(nn.Module):
''' '''
Denoising Diffussion Probabilistic Model Denoising Diffussion Probabilistic Model
Inputs:
Args: Args:
batch_size (int): batch_size, for generate time_seq, etc. batch_size (int): batch_size, for generate time_seq, etc.
iteration (int): max time_seq iteration (int): max time_seq
@ -51,17 +53,14 @@ class DDPM(nn.Module):
return mu + sigma * epsilon, epsilon # (b, c, w, h) return mu + sigma * epsilon, epsilon # (b, c, w, h)
def sample(self, model, generate_iteration_pic=False, n=None): def sample(self, model, n):
''' '''
Inputs: Inputs:
model (nn.Module): Unet instance model (nn.Module): Unet instance
generate_iteration_pic (bool): whether generate 10 pic on different denoising time n (int): want to sample n pictures
n (int, default=self.batch_size): want to sample n pictures
Outputs: Outputs:
x_0 (nn.Tensor): (n, c, h, w) x_0 (nn.Tensor): (n, c, h, w)
''' '''
if n == None:
n = self.batch_size
c, h, w = 1, 28, 28 c, h, w = 1, 28, 28
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
@ -85,8 +84,6 @@ class DDPM(nn.Module):
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
# generate 10 pic on the different denoising times
if generate_iteration_pic:
if i % (self.iteration/10) == 0: if i % (self.iteration/10) == 0:
p = x_t[0].cpu() p = x_t[0].cpu()
p = ( p.clamp(-1, 1) + 1 ) / 2 p = ( p.clamp(-1, 1) + 1 ) / 2

View File

@ -2,36 +2,21 @@ import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from ddpm import DDPM from ddpm import DDPM
from unet import Unet from unet import Unet
import sys
import os
import configparser
if __name__ == '__main__': BATCH_SIZE = 256
if len(sys.argv) != 2: ITERATION = 500
print("Usage: python sample.py [pic_num]") TIME_EMB_DIM = 128
exit() DEVICE = torch.device('cuda')
# read config file model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
config = configparser.ConfigParser() ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
config.read('training.ini')
BATCH_SIZE = int(config['unet']['batch_size']) model.load_state_dict(torch.load('unet.pth'))
ITERATION = int(config['ddpm']['iteration'])
TIME_EMB_DIM = int(config['unet']['time_emb_dim'])
DEVICE = torch.device(config['unet']['device'])
# start sampling x_t = ddpm.sample(model, 256)
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE) for index, pic in enumerate(x_t):
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
model.load_state_dict(torch.load('unet.pth'))
x_t = ddpm.sample(model)
if not os.path.isdir('./output'):
os.mkdir('./output')
for index, pic in enumerate(x_t):
p = pic.to('cpu').permute(1, 2, 0) p = pic.to('cpu').permute(1, 2, 0)
plt.imshow(p, cmap='gray', vmin=0, vmax=255) plt.imshow(p, cmap='gray', vmin=0, vmax=255)
plt.savefig("output/{}.png".format(index)) plt.savefig("output/{}.png".format(index))

View File

@ -7,18 +7,15 @@ import matplotlib.pyplot as plt
from tqdm import tqdm from tqdm import tqdm
from ddpm import DDPM from ddpm import DDPM
from unet import Unet from unet import Unet
import configparser
def getMnistLoader(config): BATCH_SIZE = 256
''' ITERATION = 500
Get MNIST dataset's loader TIME_EMB_DIM = 128
DEVICE = torch.device('cuda')
EPOCH_NUM = 500
LEARNING_RATE = 1e-4
Inputs: def getMnistLoader():
config (configparser.ConfigParser)
Outputs:
loader (nn.utils.data.DataLoader)
'''
BATCH_SIZE = int(config['unet']['batch_size'])
transform = torchvision.transforms.Compose([ transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor() torchvision.transforms.ToTensor()
@ -28,33 +25,18 @@ def getMnistLoader(config):
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True) loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
return loader return loader
def train(config): def train(loader, device, epoch_num, lr):
''' model = Unet(TIME_EMB_DIM, DEVICE).to(device)
Start Unet Training ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, device)
Inputs:
config (configparser.ConfigParser)
Outputs:
None
'''
BATCH_SIZE = int(config['unet']['batch_size'])
ITERATION = int(config['ddpm']['iteration'])
TIME_EMB_DIM = int(config['unet']['time_emb_dim'])
DEVICE = torch.device(config['unet']['device'])
EPOCH_NUM = int(config['unet']['epoch_num'])
LEARNING_RATE = float(config['unet']['learning_rate'])
# training
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
criterion = nn.MSELoss() criterion = nn.MSELoss()
optimzer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) optimzer = torch.optim.Adam(model.parameters(), lr=lr)
min_loss = 99 min_loss = 99
for epoch in range(EPOCH_NUM): for epoch in range(epoch_num):
loss_sum = 0 loss_sum = 0
# progress = tqdm(total=len(loader))
for x, y in loader: for x, y in loader:
optimzer.zero_grad() optimzer.zero_grad()
@ -69,18 +51,12 @@ def train(config):
loss.backward() loss.backward()
optimzer.step() optimzer.step()
# progress.update(1)
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader))) print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader)))
if loss_sum/len(loader) < min_loss: if loss_sum/len(loader) < min_loss:
min_loss = loss_sum/len(loader) min_loss = loss_sum/len(loader)
print("save model: the best loss is {}".format(min_loss)) print("save model: the best loss is {}".format(min_loss))
torch.save(model.state_dict(), 'unet.pth') torch.save(model.state_dict(), 'unet.pth')
if __name__ == '__main__': loader = getMnistLoader()
# read config file train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE)
config = configparser.ConfigParser()
config.read('training.ini')
# start training
loader = getMnistLoader(config)
train(config)

View File

@ -1,9 +0,0 @@
[unet]
batch_size = 256
time_emb_dim = 128
device = cuda
epoch_num = 500
learning_rate = 1e-4
[ddpm]
iteration = 500