feat: classifier guidence diffusion model

This commit is contained in:
snsd0805 2023-03-22 16:15:09 +08:00
parent 687a994bea
commit 9852bf3530
Signed by: snsd0805
GPG Key ID: 569349933C77A854
8 changed files with 213 additions and 63 deletions

49
classifier.py Normal file
View File

@ -0,0 +1,49 @@
import torch
import torch.nn as nn
from unet import DownSampling, DoubleConv
from positionEncoding import PositionEncode
class Classfier(nn.Module):
'''
Args:
time_emb_dim (int): dimention when position encoding
device (nn.device): for PositionEncode layer, means the data puts on what device
Inputs:
x: feature maps, (b, c, h, w)
time_seq: A longtensor means this x is x_t, In this module, it will transform time_seq to position embedding. (b, )
Outputs:
p: probability (b, 10)
'''
def __init__(self, time_emb_dim, device):
super(Classfier, self).__init__()
self.conv1 = DoubleConv(1, 32, nn.ReLU())
self.conv2 = DownSampling(32, 64, time_emb_dim)
self.conv3 = DownSampling(64, 128, time_emb_dim)
self.conv4 = DownSampling(128, 256, time_emb_dim)
self.dense1 = nn.Linear(256*3*3, 512)
self.dense2 = nn.Linear(512, 128)
self.dense3 = nn.Linear(128, 10)
self.pooling = nn.AvgPool2d(2, stride=2)
self.relu = nn .ReLU()
self.dropout = nn.Dropout(0.3)
self.time_embedding = PositionEncode(time_emb_dim, device)
def forward(self, x, time_seq):
time_emb = self.time_embedding(time_seq) # (b, time_emb_dim)
x = self.conv1(x,) # b, 32, 28, 28
x = self.conv2(x, time_emb) # b, 64, 14, 14
x = self.dropout(x)
x = self.conv3(x, time_emb) # b, 128, 7, 7
x = self.dropout(x)
x = self.conv4(x, time_emb) # b, 256, 3, 3
x = self.dropout(x)
x = x.reshape((x.shape[0], -1)) # b, 2304
x = self.relu(self.dense1(x)) # b, 512
x = self.dropout(x)
x = self.relu(self.dense2(x)) # b, 128
x = self.dropout(x)
x = self.dense3(x) # b, 10
return x

33
ddpm.py
View File

@ -51,12 +51,15 @@ class DDPM(nn.Module):
return mu + sigma * epsilon, epsilon # (b, c, w, h)
def sample(self, model, generate_iteration_pic=False, n=None):
def sample(self, model, generate_iteration_pic=False, n=None, target=None, classifier=None, classifier_scale=0.5):
'''
Inputs:
model (nn.Module): Unet instance
generate_iteration_pic (bool): whether generate 10 pic on different denoising time
n (int, default=self.batch_size): want to sample n pictures
target (int, default=None): conditional target
classifier (int, default=None): for conditional diffusion model
classifier_scale (float): scaling classifier's control
Outputs:
x_0 (nn.Tensor): (n, c, h, w)
'''
@ -66,24 +69,42 @@ class DDPM(nn.Module):
model.eval()
with torch.no_grad():
x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w)
x_t += 0.2
for i in reversed(range(self.iteration)):
time_seq = (torch.ones(n) * i).long().to(self.device) # (n, )
predict_noise = model(x_t, time_seq) # (n, c, h, w)
first_term = 1/(torch.sqrt(self.alpha[time_seq])) # (n, )
first_term = 1/(torch.sqrt(self.alpha[time_seq]))
second_term = (1-self.alpha[time_seq]) / (torch.sqrt(1-self.overline_alpha[time_seq]))
first_term = first_term[:, None, None, None].repeat(1, c, h, w)
second_term = second_term[:, None, None, None].repeat(1, c, h, w)
first_term = first_term[:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
second_term = second_term[:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w)
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
mu = first_term * (x_t-(second_term * predict_noise)) # origin mu (n, c, h, w)
# if conditional => get classifier gradient
if target != None and classifier != None:
with torch.enable_grad():
# ref: https://github.com/clalanliu/IntroductionDiffusionModels/blob/main/control_diffusion.ipynb
x = x_t.detach()
x.requires_grad_(True)
logits = classifier(x, time_seq) # (b, 10)
log_probs = nn.LogSoftmax(dim=1)(logits) # (b, 10)
selected = log_probs[:, target] # (b)
grad = torch.autograd.grad(selected.sum(), x)[0] #
mu = mu + beta * classifier_scale * grad
# mask
if i!= 0:
z = torch.randn((n, c, h, w)).to(self.device)
else:
z = torch.zeros((n, c, h, w)).to(self.device)
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
x_t = mu - z * beta # origin mu
# generate 10 pic on the different denoising times
if generate_iteration_pic:

22
loader.py Normal file
View File

@ -0,0 +1,22 @@
import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, Dataset
def getMnistLoader(config):
'''
Get MNIST dataset's loader
Inputs:
config (configparser.ConfigParser)
Outputs:
loader (nn.utils.data.DataLoader)
'''
BATCH_SIZE = int(config['unet']['batch_size'])
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
data = MNIST("./data", train=True, download=True, transform=transform)
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
return loader

31
positionEncoding.py Normal file
View File

@ -0,0 +1,31 @@
import torch
import torch.nn as nn
import math
class PositionEncode(nn.Module):
'''
Input a LongTensor time sequence, return position embedding
Input:
time_seq: shape of LongTensor (b, )
Output:
dim: shape of tensor (b, time_emb_dim)
Args:
time_emb_dim (int): output's dimension
device (nn.device): data on what device
'''
def __init__(self, time_emb_dim, device):
super(PositionEncode, self).__init__()
self.time_emb_dim = time_emb_dim
self.base = torch.Tensor([ 1/math.pow(10000, (i//2)/self.time_emb_dim) for i in range(self.time_emb_dim) ]) # (d)
self.base = self.base.to(device)
def forward(self, time_seq):
seq_len = len(time_seq)
dim = self.base.reshape(1, self.time_emb_dim).repeat(seq_len, 1) # (b, time_emb_dim)
time_seq = time_seq[:, None].repeat(1, self.time_emb_dim) # (b, time_emb_dim)
ans = dim * time_seq # (b, time_emb_dim)
ans[:, 0::2] = torch.sin(ans[:, 0::2])
ans[:, 1::2] = torch.cos(ans[:, 1::2])
return ans

View File

@ -5,11 +5,16 @@ from unet import Unet
import sys
import os
import configparser
from classifier import Classfier
if __name__ == '__main__':
if len(sys.argv) != 2:
if len(sys.argv) < 2:
print("Usage: python sample.py [pic_num]")
exit()
elif len(sys.argv) == 3:
target = int( sys.argv[2] )
print("Target: {}".format(target))
# read config file
config = configparser.ConfigParser()
@ -23,10 +28,12 @@ if __name__ == '__main__':
# start sampling
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
classifier = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
model.load_state_dict(torch.load('unet.pth'))
classifier.load_state_dict(torch.load('classifier.pth'))
x_t = ddpm.sample(model)
x_t = ddpm.sample(model, target=target, classifier=classifier, classifier_scale=0.5)
if not os.path.isdir('./output'):
os.mkdir('./output')

View File

@ -1,32 +1,11 @@
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from tqdm import tqdm
from ddpm import DDPM
from unet import Unet
import configparser
def getMnistLoader(config):
'''
Get MNIST dataset's loader
Inputs:
config (configparser.ConfigParser)
Outputs:
loader (nn.utils.data.DataLoader)
'''
BATCH_SIZE = int(config['unet']['batch_size'])
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
data = MNIST("./data", train=True, download=True, transform=transform)
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
return loader
from loader import getMnistLoader
def train(config):
'''

68
train_classifier.py Normal file
View File

@ -0,0 +1,68 @@
import torch
import torch.nn as nn
import configparser
from loader import getMnistLoader
from classifier import Classfier
from ddpm import DDPM
def train(config):
'''
Start Classier Training
Inputs:
config (configparser.ConfigParser)
Outputs:
None
'''
BATCH_SIZE = int(config['classifier']['batch_size'])
ITERATION = int(config['ddpm']['iteration'])
TIME_EMB_DIM = int(config['classifier']['time_emb_dim'])
DEVICE = torch.device(config['classifier']['device'])
EPOCH_NUM = int(config['classifier']['epoch_num'])
LEARNING_RATE = float(config['classifier']['learning_rate'])
# training
model = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
criterion = nn.CrossEntropyLoss()
optimzer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
min_loss = 99
for epoch in range(EPOCH_NUM):
loss_sum = 0
acc_sum = 0
data_count = 0
for x, y in loader:
optimzer.zero_grad()
x = x.to(DEVICE)
y = y.to(DEVICE)
time_seq = ddpm.get_time_seq(x.shape[0])
x_t, noise = ddpm.get_x_t(x, time_seq)
p = model(x_t, time_seq)
loss = criterion(p, y)
loss_sum += loss.item()
data_count += len(x)
acc_sum += (p.argmax(1)==y).sum()
loss.backward()
optimzer.step()
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}, acc: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader), acc_sum/data_count))
if loss_sum/len(loader) < min_loss:
min_loss = loss_sum/len(loader)
print("save model: the best loss is {}".format(min_loss))
torch.save(model.state_dict(), 'classifier.pth')
if __name__ == '__main__':
# read config file
config = configparser.ConfigParser()
config.read('training.ini')
# start training
loader = getMnistLoader(config)
train(config)

29
unet.py
View File

@ -2,6 +2,7 @@ import torch
import torch.nn as nn
from torchinfo import summary
import math
from positionEncoding import PositionEncode
class DoubleConv(nn.Module):
'''
@ -90,34 +91,6 @@ class UpSampling(nn.Module):
time_emb = time_emb[:, :, None, None].repeat(1, 1, h, w) # (b, out_dim, h*2, w*2)
return x + time_emb
class PositionEncode(nn.Module):
'''
Input a LongTensor time sequence, return position embedding
Input:
time_seq: shape of LongTensor (b, )
Output:
dim: shape of tensor (b, time_emb_dim)
Args:
time_emb_dim (int): output's dimension
device (nn.device): data on what device
'''
def __init__(self, time_emb_dim, device):
super(PositionEncode, self).__init__()
self.time_emb_dim = time_emb_dim
self.base = torch.Tensor([ 1/math.pow(10000, (i//2)/self.time_emb_dim) for i in range(self.time_emb_dim) ]) # (d)
self.base = self.base.to(device)
def forward(self, time_seq):
seq_len = len(time_seq)
dim = self.base.reshape(1, self.time_emb_dim).repeat(seq_len, 1) # (b, time_emb_dim)
time_seq = time_seq[:, None].repeat(1, self.time_emb_dim) # (b, time_emb_dim)
ans = dim * time_seq # (b, time_emb_dim)
ans[:, 0::2] = torch.sin(ans[:, 0::2])
ans[:, 1::2] = torch.cos(ans[:, 1::2])
return ans
class Unet(nn.Module):
'''
Unet module, predict the noise