feat: complete ddpm & training process
This commit is contained in:
parent
816f6d1f56
commit
dcfeac845e
56
ddpm.py
56
ddpm.py
@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from unet import Unet
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
class DDPM(nn.Module):
|
||||
@ -13,31 +12,28 @@ class DDPM(nn.Module):
|
||||
batch_size (int): batch_size, for generate time_seq, etc.
|
||||
iteration (int): max time_seq
|
||||
beta_min, beta_max (float): for beta scheduling
|
||||
time_emb_dim (int): for Unet's PositionEncode layer
|
||||
device (nn.Device)
|
||||
'''
|
||||
def __init__(self, batch_size, iteration, beta_min, beta_max, time_emb_dim, device):
|
||||
def __init__(self, batch_size, iteration, beta_min, beta_max, device):
|
||||
super(DDPM, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.iteration = iteration
|
||||
self.device = device
|
||||
self.unet = Unet(time_emb_dim, device)
|
||||
self.time_emb_dim = time_emb_dim
|
||||
|
||||
self.beta = torch.linspace(beta_min, beta_max, steps=iteration) # (iteration)
|
||||
self.alpha = 1 - self.beta # (iteration)
|
||||
self.beta = torch.linspace(beta_min, beta_max, steps=iteration).to(self.device) # (iteration)
|
||||
self.alpha = (1 - self.beta).to(self.device) # (iteration)
|
||||
self.overline_alpha = torch.cumprod(self.alpha, dim=0)
|
||||
|
||||
def get_time_seq(self):
|
||||
def get_time_seq(self, length):
|
||||
'''
|
||||
Get random time sequence for each picture in the batch
|
||||
|
||||
Inputs:
|
||||
None
|
||||
length (int): size of sequence
|
||||
Outputs:
|
||||
time_seq: rand int from 0 to ITERATION
|
||||
'''
|
||||
return torch.randint(0, self.iteration, (self.batch_size,) )
|
||||
return torch.randint(0, self.iteration, (length,) ).to(self.device)
|
||||
|
||||
def get_x_t(self, x_0, time_seq):
|
||||
'''
|
||||
@ -50,10 +46,40 @@ class DDPM(nn.Module):
|
||||
x_t: noised pictures (b, c, w, h)
|
||||
'''
|
||||
b, c, w, h = x_0.shape
|
||||
mu = torch.sqrt(self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h)
|
||||
mu = mu * x_0
|
||||
mu = torch.sqrt(self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h) # (b, c, w, h)
|
||||
mu = mu * x_0 # (b, c, w, h)
|
||||
sigma = torch.sqrt(1-self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h) # (b, c, w, h)
|
||||
epsilon = torch.randn_like(x_0).to(self.device) # (b, c, w, h)
|
||||
|
||||
sigma = torch.sqrt(1-self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h)
|
||||
epsilon = torch.randn_like(x_0)
|
||||
return mu + sigma * epsilon, epsilon # (b, c, w, h)
|
||||
|
||||
return mu + sigma * epsilon
|
||||
def sample(self, model, n):
|
||||
'''
|
||||
Inputs:
|
||||
model (nn.Module): Unet instance
|
||||
n (int): want to sample n pictures
|
||||
Outputs:
|
||||
x_0 (nn.Tensor): (n, c, h, w)
|
||||
'''
|
||||
c, h, w = 1, 28, 28
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w)
|
||||
for i in reversed(range(self.iteration)):
|
||||
time_seq = (torch.ones(n) * i).long().to(self.device) # (n, )
|
||||
predict_noise = model(x_t, time_seq) # (n, c, h, w)
|
||||
|
||||
first_term = 1/(torch.sqrt(self.alpha[time_seq])) # (n, )
|
||||
second_term = (1-self.alpha[time_seq]) / (torch.sqrt(1-self.overline_alpha[time_seq]))
|
||||
|
||||
first_term = first_term[:, None, None, None].repeat(1, c, h, w)
|
||||
second_term = second_term[:, None, None, None].repeat(1, c, h, w)
|
||||
|
||||
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w)
|
||||
|
||||
z = torch.randn((n, c, h, w))
|
||||
|
||||
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
|
||||
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
|
||||
x = x * 255
|
||||
return x_t
|
||||
57
train.py
Normal file
57
train.py
Normal file
@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.datasets import MNIST
|
||||
import torchvision
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
from ddpm import DDPM
|
||||
from unet import Unet
|
||||
|
||||
BATCH_SIZE = 512
|
||||
ITERATION = 1500
|
||||
TIME_EMB_DIM = 128
|
||||
DEVICE = torch.device('cuda')
|
||||
EPOCH_NUM = 3000
|
||||
LEARNING_RATE = 1e-3
|
||||
|
||||
def getMnistLoader():
|
||||
|
||||
transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor()
|
||||
])
|
||||
|
||||
data = MNIST("./data", train=True, download=True, transform=transform)
|
||||
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
|
||||
return loader
|
||||
|
||||
def train(loader, device, epoch_num, lr):
|
||||
model = Unet(TIME_EMB_DIM, DEVICE).to(device)
|
||||
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, device)
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
optimzer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||
|
||||
for epoch in range(epoch_num):
|
||||
loss_sum = 0
|
||||
# progress = tqdm(total=len(loader))
|
||||
for x, y in loader:
|
||||
optimzer.zero_grad()
|
||||
|
||||
x = x.to(DEVICE)
|
||||
time_seq = ddpm.get_time_seq(x.shape[0])
|
||||
x_t, noise = ddpm.get_x_t(x, time_seq)
|
||||
|
||||
predict_noise = model(x_t, time_seq)
|
||||
loss = criterion(predict_noise, noise)
|
||||
|
||||
loss_sum += loss.item()
|
||||
|
||||
loss.backward()
|
||||
optimzer.step()
|
||||
# progress.update(1)
|
||||
torch.save(model.state_dict(), 'unet.pth')
|
||||
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, loss_sum/len(loader)))
|
||||
|
||||
loader = getMnistLoader()
|
||||
train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE)
|
||||
Loading…
Reference in New Issue
Block a user