feat: complete ddpm & training process

This commit is contained in:
snsd0805 2023-03-14 00:08:54 +08:00
parent 816f6d1f56
commit dcfeac845e
Signed by: snsd0805
GPG Key ID: 569349933C77A854
2 changed files with 98 additions and 15 deletions

56
ddpm.py
View File

@ -1,6 +1,5 @@
import torch
import torch.nn as nn
from unet import Unet
import matplotlib.pyplot as plt
class DDPM(nn.Module):
@ -13,31 +12,28 @@ class DDPM(nn.Module):
batch_size (int): batch_size, for generate time_seq, etc.
iteration (int): max time_seq
beta_min, beta_max (float): for beta scheduling
time_emb_dim (int): for Unet's PositionEncode layer
device (nn.Device)
'''
def __init__(self, batch_size, iteration, beta_min, beta_max, time_emb_dim, device):
def __init__(self, batch_size, iteration, beta_min, beta_max, device):
super(DDPM, self).__init__()
self.batch_size = batch_size
self.iteration = iteration
self.device = device
self.unet = Unet(time_emb_dim, device)
self.time_emb_dim = time_emb_dim
self.beta = torch.linspace(beta_min, beta_max, steps=iteration) # (iteration)
self.alpha = 1 - self.beta # (iteration)
self.beta = torch.linspace(beta_min, beta_max, steps=iteration).to(self.device) # (iteration)
self.alpha = (1 - self.beta).to(self.device) # (iteration)
self.overline_alpha = torch.cumprod(self.alpha, dim=0)
def get_time_seq(self):
def get_time_seq(self, length):
'''
Get random time sequence for each picture in the batch
Inputs:
None
length (int): size of sequence
Outputs:
time_seq: rand int from 0 to ITERATION
'''
return torch.randint(0, self.iteration, (self.batch_size,) )
return torch.randint(0, self.iteration, (length,) ).to(self.device)
def get_x_t(self, x_0, time_seq):
'''
@ -50,10 +46,40 @@ class DDPM(nn.Module):
x_t: noised pictures (b, c, w, h)
'''
b, c, w, h = x_0.shape
mu = torch.sqrt(self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h)
mu = mu * x_0
mu = torch.sqrt(self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h) # (b, c, w, h)
mu = mu * x_0 # (b, c, w, h)
sigma = torch.sqrt(1-self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h) # (b, c, w, h)
epsilon = torch.randn_like(x_0).to(self.device) # (b, c, w, h)
sigma = torch.sqrt(1-self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h)
epsilon = torch.randn_like(x_0)
return mu + sigma * epsilon, epsilon # (b, c, w, h)
def sample(self, model, n):
'''
Inputs:
model (nn.Module): Unet instance
n (int): want to sample n pictures
Outputs:
x_0 (nn.Tensor): (n, c, h, w)
'''
c, h, w = 1, 28, 28
model.eval()
with torch.no_grad():
x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w)
for i in reversed(range(self.iteration)):
time_seq = (torch.ones(n) * i).long().to(self.device) # (n, )
predict_noise = model(x_t, time_seq) # (n, c, h, w)
return mu + sigma * epsilon
first_term = 1/(torch.sqrt(self.alpha[time_seq])) # (n, )
second_term = (1-self.alpha[time_seq]) / (torch.sqrt(1-self.overline_alpha[time_seq]))
first_term = first_term[:, None, None, None].repeat(1, c, h, w)
second_term = second_term[:, None, None, None].repeat(1, c, h, w)
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w)
z = torch.randn((n, c, h, w))
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
x = x * 255
return x_t

57
train.py Normal file
View File

@ -0,0 +1,57 @@
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from tqdm import tqdm
from ddpm import DDPM
from unet import Unet
BATCH_SIZE = 512
ITERATION = 1500
TIME_EMB_DIM = 128
DEVICE = torch.device('cuda')
EPOCH_NUM = 3000
LEARNING_RATE = 1e-3
def getMnistLoader():
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
data = MNIST("./data", train=True, download=True, transform=transform)
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
return loader
def train(loader, device, epoch_num, lr):
model = Unet(TIME_EMB_DIM, DEVICE).to(device)
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, device)
criterion = nn.MSELoss()
optimzer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(epoch_num):
loss_sum = 0
# progress = tqdm(total=len(loader))
for x, y in loader:
optimzer.zero_grad()
x = x.to(DEVICE)
time_seq = ddpm.get_time_seq(x.shape[0])
x_t, noise = ddpm.get_x_t(x, time_seq)
predict_noise = model(x_t, time_seq)
loss = criterion(predict_noise, noise)
loss_sum += loss.item()
loss.backward()
optimzer.step()
# progress.update(1)
torch.save(model.state_dict(), 'unet.pth')
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, loss_sum/len(loader)))
loader = getMnistLoader()
train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE)