diff --git a/ddpm.py b/ddpm.py index 4930ed5..49bd3fd 100644 --- a/ddpm.py +++ b/ddpm.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -from unet import Unet import matplotlib.pyplot as plt class DDPM(nn.Module): @@ -13,31 +12,28 @@ class DDPM(nn.Module): batch_size (int): batch_size, for generate time_seq, etc. iteration (int): max time_seq beta_min, beta_max (float): for beta scheduling - time_emb_dim (int): for Unet's PositionEncode layer device (nn.Device) ''' - def __init__(self, batch_size, iteration, beta_min, beta_max, time_emb_dim, device): + def __init__(self, batch_size, iteration, beta_min, beta_max, device): super(DDPM, self).__init__() self.batch_size = batch_size self.iteration = iteration self.device = device - self.unet = Unet(time_emb_dim, device) - self.time_emb_dim = time_emb_dim - self.beta = torch.linspace(beta_min, beta_max, steps=iteration) # (iteration) - self.alpha = 1 - self.beta # (iteration) + self.beta = torch.linspace(beta_min, beta_max, steps=iteration).to(self.device) # (iteration) + self.alpha = (1 - self.beta).to(self.device) # (iteration) self.overline_alpha = torch.cumprod(self.alpha, dim=0) - def get_time_seq(self): + def get_time_seq(self, length): ''' Get random time sequence for each picture in the batch Inputs: - None + length (int): size of sequence Outputs: time_seq: rand int from 0 to ITERATION ''' - return torch.randint(0, self.iteration, (self.batch_size,) ) + return torch.randint(0, self.iteration, (length,) ).to(self.device) def get_x_t(self, x_0, time_seq): ''' @@ -50,10 +46,40 @@ class DDPM(nn.Module): x_t: noised pictures (b, c, w, h) ''' b, c, w, h = x_0.shape - mu = torch.sqrt(self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h) - mu = mu * x_0 + mu = torch.sqrt(self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h) # (b, c, w, h) + mu = mu * x_0 # (b, c, w, h) + sigma = torch.sqrt(1-self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h) # (b, c, w, h) + epsilon = torch.randn_like(x_0).to(self.device) # (b, c, w, h) - sigma = torch.sqrt(1-self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h) - epsilon = torch.randn_like(x_0) + return mu + sigma * epsilon, epsilon # (b, c, w, h) + + def sample(self, model, n): + ''' + Inputs: + model (nn.Module): Unet instance + n (int): want to sample n pictures + Outputs: + x_0 (nn.Tensor): (n, c, h, w) + ''' + c, h, w = 1, 28, 28 + model.eval() + with torch.no_grad(): + x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w) + for i in reversed(range(self.iteration)): + time_seq = (torch.ones(n) * i).long().to(self.device) # (n, ) + predict_noise = model(x_t, time_seq) # (n, c, h, w) - return mu + sigma * epsilon + first_term = 1/(torch.sqrt(self.alpha[time_seq])) # (n, ) + second_term = (1-self.alpha[time_seq]) / (torch.sqrt(1-self.overline_alpha[time_seq])) + + first_term = first_term[:, None, None, None].repeat(1, c, h, w) + second_term = second_term[:, None, None, None].repeat(1, c, h, w) + + beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w) + + z = torch.randn((n, c, h, w)) + + x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta + x_t = ( x_t.clamp(-1, 1) + 1 ) / 2 + x = x * 255 + return x_t \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..ef0a6c8 --- /dev/null +++ b/train.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn +from torchvision.datasets import MNIST +import torchvision +from torch.utils.data import DataLoader, Dataset +import matplotlib.pyplot as plt +from tqdm import tqdm +from ddpm import DDPM +from unet import Unet + +BATCH_SIZE = 512 +ITERATION = 1500 +TIME_EMB_DIM = 128 +DEVICE = torch.device('cuda') +EPOCH_NUM = 3000 +LEARNING_RATE = 1e-3 + +def getMnistLoader(): + + transform = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor() + ]) + + data = MNIST("./data", train=True, download=True, transform=transform) + loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True) + return loader + +def train(loader, device, epoch_num, lr): + model = Unet(TIME_EMB_DIM, DEVICE).to(device) + ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, device) + + criterion = nn.MSELoss() + optimzer = torch.optim.Adam(model.parameters(), lr=lr) + + for epoch in range(epoch_num): + loss_sum = 0 + # progress = tqdm(total=len(loader)) + for x, y in loader: + optimzer.zero_grad() + + x = x.to(DEVICE) + time_seq = ddpm.get_time_seq(x.shape[0]) + x_t, noise = ddpm.get_x_t(x, time_seq) + + predict_noise = model(x_t, time_seq) + loss = criterion(predict_noise, noise) + + loss_sum += loss.item() + + loss.backward() + optimzer.step() + # progress.update(1) + torch.save(model.state_dict(), 'unet.pth') + print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, loss_sum/len(loader))) + +loader = getMnistLoader() +train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE) \ No newline at end of file