DDPM_Mnist/ddpm.py

122 lines
5.5 KiB
Python

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
class DDPM(nn.Module):
'''
Denoising Diffussion Probabilistic Model
Args:
batch_size (int): batch_size, for generate time_seq, etc.
iteration (int): max time_seq
beta_min, beta_max (float): for beta scheduling
device (nn.Device)
'''
def __init__(self, batch_size, iteration, beta_min, beta_max, device):
super(DDPM, self).__init__()
self.batch_size = batch_size
self.iteration = iteration
self.device = device
self.beta = torch.linspace(beta_min, beta_max, steps=iteration).to(self.device) # (iteration)
self.alpha = (1 - self.beta).to(self.device) # (iteration)
self.overline_alpha = torch.cumprod(self.alpha, dim=0)
def get_time_seq(self, length):
'''
Get random time sequence for each picture in the batch
Inputs:
length (int): size of sequence
Outputs:
time_seq: rand int from 0 to ITERATION
'''
return torch.randint(0, self.iteration, (length,) ).to(self.device)
def get_x_t(self, x_0, time_seq):
'''
Input pictures then return noised pictures
Inputs:
x_0: pictures (b, c, w, h)
time_seq: times apply on each pictures (b, )
Outputs:
x_t: noised pictures (b, c, w, h)
'''
b, c, w, h = x_0.shape
mu = torch.sqrt(self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h) # (b, c, w, h)
mu = mu * x_0 # (b, c, w, h)
sigma = torch.sqrt(1-self.overline_alpha[time_seq])[:, None, None, None].repeat(1, c, w, h) # (b, c, w, h)
epsilon = torch.randn_like(x_0).to(self.device) # (b, c, w, h)
return mu + sigma * epsilon, epsilon # (b, c, w, h)
def sample(self, model, generate_iteration_pic=False, n=None, target=None, classifier=None, classifier_scale=0.5):
'''
Inputs:
model (nn.Module): Unet instance
generate_iteration_pic (bool): whether generate 10 pic on different denoising time
n (int, default=self.batch_size): want to sample n pictures
target (int, default=None): conditional target
classifier (int, default=None): for conditional diffusion model
classifier_scale (float): scaling classifier's control
Outputs:
x_0 (nn.Tensor): (n, c, h, w)
'''
if n == None:
n = self.batch_size
c, h, w = 1, 28, 28
model.eval()
with torch.no_grad():
x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w)
x_t += 0.2
for i in reversed(range(self.iteration)):
time_seq = (torch.ones(n) * i).long().to(self.device) # (n, )
predict_noise = model(x_t, time_seq) # (n, c, h, w)
first_term = 1/(torch.sqrt(self.alpha[time_seq]))
second_term = (1-self.alpha[time_seq]) / (torch.sqrt(1-self.overline_alpha[time_seq]))
first_term = first_term[:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
second_term = second_term[:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
mu = first_term * (x_t-(second_term * predict_noise)) # origin mu (n, c, h, w)
# if conditional => get classifier gradient
if target != None and classifier != None:
with torch.enable_grad():
# ref: https://github.com/clalanliu/IntroductionDiffusionModels/blob/main/control_diffusion.ipynb
x = x_t.detach()
x.requires_grad_(True)
logits = classifier(x, time_seq) # (b, 10)
log_probs = nn.LogSoftmax(dim=1)(logits) # (b, 10)
selected = log_probs[:, target] # (b)
grad = torch.autograd.grad(selected.sum(), x)[0] #
mu = mu + beta * classifier_scale * grad
# mask
if i!= 0:
z = torch.randn((n, c, h, w)).to(self.device)
else:
z = torch.zeros((n, c, h, w)).to(self.device)
x_t = mu - z * beta # origin mu
# generate 10 pic on the different denoising times
if generate_iteration_pic:
if i % (self.iteration/10) == 0:
p = x_t[0].cpu()
p = ( p.clamp(-1, 1) + 1 ) / 2
p = p * 255
p = p.permute(1, 2, 0)
plt.imshow(p, vmin=0, vmax=255, cmap='gray')
plt.savefig("output/iter_{}.png".format(i))
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
x_t = x_t * 255
return x_t