From 816f6d1f56c7b8c5a4b97a08c7c9da8d281d06a6 Mon Sep 17 00:00:00 2001 From: snsd0805 Date: Mon, 13 Mar 2023 22:33:04 +0800 Subject: [PATCH] feat: DDPM add noise & get time sequence --- ddpm.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 ddpm.py diff --git a/ddpm.py b/ddpm.py new file mode 100644 index 0000000..4930ed5 --- /dev/null +++ b/ddpm.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +from unet import Unet +import matplotlib.pyplot as plt + +class DDPM(nn.Module): + ''' + Denoising Diffussion Probabilistic Model + + Inputs: + + Args: + batch_size (int): batch_size, for generate time_seq, etc. + iteration (int): max time_seq + beta_min, beta_max (float): for beta scheduling + time_emb_dim (int): for Unet's PositionEncode layer + device (nn.Device) + ''' + def __init__(self, batch_size, iteration, beta_min, beta_max, time_emb_dim, device): + super(DDPM, self).__init__() + self.batch_size = batch_size + self.iteration = iteration + self.device = device + self.unet = Unet(time_emb_dim, device) + self.time_emb_dim = time_emb_dim + + self.beta = torch.linspace(beta_min, beta_max, steps=iteration) # (iteration) + self.alpha = 1 - self.beta # (iteration) + self.overline_alpha = torch.cumprod(self.alpha, dim=0) + + def get_time_seq(self): + ''' + Get random time sequence for each picture in the batch + + Inputs: + None + Outputs: + time_seq: rand int from 0 to ITERATION + ''' + return torch.randint(0, self.iteration, (self.batch_size,) ) + + def get_x_t(self, x_0, time_seq): + ''' + Input pictures then return noised pictures + + Inputs: + x_0: pictures (b, c, w, h) + time_seq: times apply on each pictures (b, ) + Outputs: + x_t: noised pictures (b, c, w, h) + ''' + b, c, w, h = x_0.shape + mu = torch.sqrt(self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h) + mu = mu * x_0 + + sigma = torch.sqrt(1-self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h) + epsilon = torch.randn_like(x_0) + + return mu + sigma * epsilon