feat: classifier guidence diffusion model

This commit is contained in:
snsd0805 2023-03-22 16:15:09 +08:00
parent 687a994bea
commit 9852bf3530
Signed by: snsd0805
GPG Key ID: 569349933C77A854
8 changed files with 213 additions and 63 deletions

49
classifier.py Normal file
View File

@ -0,0 +1,49 @@
import torch
import torch.nn as nn
from unet import DownSampling, DoubleConv
from positionEncoding import PositionEncode
class Classfier(nn.Module):
'''
Args:
time_emb_dim (int): dimention when position encoding
device (nn.device): for PositionEncode layer, means the data puts on what device
Inputs:
x: feature maps, (b, c, h, w)
time_seq: A longtensor means this x is x_t, In this module, it will transform time_seq to position embedding. (b, )
Outputs:
p: probability (b, 10)
'''
def __init__(self, time_emb_dim, device):
super(Classfier, self).__init__()
self.conv1 = DoubleConv(1, 32, nn.ReLU())
self.conv2 = DownSampling(32, 64, time_emb_dim)
self.conv3 = DownSampling(64, 128, time_emb_dim)
self.conv4 = DownSampling(128, 256, time_emb_dim)
self.dense1 = nn.Linear(256*3*3, 512)
self.dense2 = nn.Linear(512, 128)
self.dense3 = nn.Linear(128, 10)
self.pooling = nn.AvgPool2d(2, stride=2)
self.relu = nn .ReLU()
self.dropout = nn.Dropout(0.3)
self.time_embedding = PositionEncode(time_emb_dim, device)
def forward(self, x, time_seq):
time_emb = self.time_embedding(time_seq) # (b, time_emb_dim)
x = self.conv1(x,) # b, 32, 28, 28
x = self.conv2(x, time_emb) # b, 64, 14, 14
x = self.dropout(x)
x = self.conv3(x, time_emb) # b, 128, 7, 7
x = self.dropout(x)
x = self.conv4(x, time_emb) # b, 256, 3, 3
x = self.dropout(x)
x = x.reshape((x.shape[0], -1)) # b, 2304
x = self.relu(self.dense1(x)) # b, 512
x = self.dropout(x)
x = self.relu(self.dense2(x)) # b, 128
x = self.dropout(x)
x = self.dense3(x) # b, 10
return x

39
ddpm.py
View File

@ -51,12 +51,15 @@ class DDPM(nn.Module):
return mu + sigma * epsilon, epsilon # (b, c, w, h) return mu + sigma * epsilon, epsilon # (b, c, w, h)
def sample(self, model, generate_iteration_pic=False, n=None): def sample(self, model, generate_iteration_pic=False, n=None, target=None, classifier=None, classifier_scale=0.5):
''' '''
Inputs: Inputs:
model (nn.Module): Unet instance model (nn.Module): Unet instance
generate_iteration_pic (bool): whether generate 10 pic on different denoising time generate_iteration_pic (bool): whether generate 10 pic on different denoising time
n (int, default=self.batch_size): want to sample n pictures n (int, default=self.batch_size): want to sample n pictures
target (int, default=None): conditional target
classifier (int, default=None): for conditional diffusion model
classifier_scale (float): scaling classifier's control
Outputs: Outputs:
x_0 (nn.Tensor): (n, c, h, w) x_0 (nn.Tensor): (n, c, h, w)
''' '''
@ -65,25 +68,43 @@ class DDPM(nn.Module):
c, h, w = 1, 28, 28 c, h, w = 1, 28, 28
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w) x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w)
x_t += 0.2
for i in reversed(range(self.iteration)): for i in reversed(range(self.iteration)):
time_seq = (torch.ones(n) * i).long().to(self.device) # (n, ) time_seq = (torch.ones(n) * i).long().to(self.device) # (n, )
predict_noise = model(x_t, time_seq) # (n, c, h, w) predict_noise = model(x_t, time_seq) # (n, c, h, w)
first_term = 1/(torch.sqrt(self.alpha[time_seq])) # (n, ) first_term = 1/(torch.sqrt(self.alpha[time_seq]))
second_term = (1-self.alpha[time_seq]) / (torch.sqrt(1-self.overline_alpha[time_seq])) second_term = (1-self.alpha[time_seq]) / (torch.sqrt(1-self.overline_alpha[time_seq]))
first_term = first_term[:, None, None, None].repeat(1, c, h, w) first_term = first_term[:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
second_term = second_term[:, None, None, None].repeat(1, c, h, w) second_term = second_term[:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w) beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
mu = first_term * (x_t-(second_term * predict_noise)) # origin mu (n, c, h, w)
# if conditional => get classifier gradient
if target != None and classifier != None:
with torch.enable_grad():
# ref: https://github.com/clalanliu/IntroductionDiffusionModels/blob/main/control_diffusion.ipynb
x = x_t.detach()
x.requires_grad_(True)
logits = classifier(x, time_seq) # (b, 10)
log_probs = nn.LogSoftmax(dim=1)(logits) # (b, 10)
selected = log_probs[:, target] # (b)
grad = torch.autograd.grad(selected.sum(), x)[0] #
mu = mu + beta * classifier_scale * grad
# mask
if i!= 0: if i!= 0:
z = torch.randn((n, c, h, w)).to(self.device) z = torch.randn((n, c, h, w)).to(self.device)
else: else:
z = torch.zeros((n, c, h, w)).to(self.device) z = torch.zeros((n, c, h, w)).to(self.device)
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta x_t = mu - z * beta # origin mu
# generate 10 pic on the different denoising times # generate 10 pic on the different denoising times
if generate_iteration_pic: if generate_iteration_pic:

22
loader.py Normal file
View File

@ -0,0 +1,22 @@
import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, Dataset
def getMnistLoader(config):
'''
Get MNIST dataset's loader
Inputs:
config (configparser.ConfigParser)
Outputs:
loader (nn.utils.data.DataLoader)
'''
BATCH_SIZE = int(config['unet']['batch_size'])
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
data = MNIST("./data", train=True, download=True, transform=transform)
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
return loader

31
positionEncoding.py Normal file
View File

@ -0,0 +1,31 @@
import torch
import torch.nn as nn
import math
class PositionEncode(nn.Module):
'''
Input a LongTensor time sequence, return position embedding
Input:
time_seq: shape of LongTensor (b, )
Output:
dim: shape of tensor (b, time_emb_dim)
Args:
time_emb_dim (int): output's dimension
device (nn.device): data on what device
'''
def __init__(self, time_emb_dim, device):
super(PositionEncode, self).__init__()
self.time_emb_dim = time_emb_dim
self.base = torch.Tensor([ 1/math.pow(10000, (i//2)/self.time_emb_dim) for i in range(self.time_emb_dim) ]) # (d)
self.base = self.base.to(device)
def forward(self, time_seq):
seq_len = len(time_seq)
dim = self.base.reshape(1, self.time_emb_dim).repeat(seq_len, 1) # (b, time_emb_dim)
time_seq = time_seq[:, None].repeat(1, self.time_emb_dim) # (b, time_emb_dim)
ans = dim * time_seq # (b, time_emb_dim)
ans[:, 0::2] = torch.sin(ans[:, 0::2])
ans[:, 1::2] = torch.cos(ans[:, 1::2])
return ans

View File

@ -5,11 +5,16 @@ from unet import Unet
import sys import sys
import os import os
import configparser import configparser
from classifier import Classfier
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) != 2: if len(sys.argv) < 2:
print("Usage: python sample.py [pic_num]") print("Usage: python sample.py [pic_num]")
exit() exit()
elif len(sys.argv) == 3:
target = int( sys.argv[2] )
print("Target: {}".format(target))
# read config file # read config file
config = configparser.ConfigParser() config = configparser.ConfigParser()
@ -23,10 +28,12 @@ if __name__ == '__main__':
# start sampling # start sampling
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE) model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE) ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
classifier = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
model.load_state_dict(torch.load('unet.pth')) model.load_state_dict(torch.load('unet.pth'))
classifier.load_state_dict(torch.load('classifier.pth'))
x_t = ddpm.sample(model) x_t = ddpm.sample(model, target=target, classifier=classifier, classifier_scale=0.5)
if not os.path.isdir('./output'): if not os.path.isdir('./output'):
os.mkdir('./output') os.mkdir('./output')

View File

@ -1,32 +1,11 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from tqdm import tqdm from tqdm import tqdm
from ddpm import DDPM from ddpm import DDPM
from unet import Unet from unet import Unet
import configparser import configparser
from loader import getMnistLoader
def getMnistLoader(config):
'''
Get MNIST dataset's loader
Inputs:
config (configparser.ConfigParser)
Outputs:
loader (nn.utils.data.DataLoader)
'''
BATCH_SIZE = int(config['unet']['batch_size'])
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
data = MNIST("./data", train=True, download=True, transform=transform)
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
return loader
def train(config): def train(config):
''' '''

68
train_classifier.py Normal file
View File

@ -0,0 +1,68 @@
import torch
import torch.nn as nn
import configparser
from loader import getMnistLoader
from classifier import Classfier
from ddpm import DDPM
def train(config):
'''
Start Classier Training
Inputs:
config (configparser.ConfigParser)
Outputs:
None
'''
BATCH_SIZE = int(config['classifier']['batch_size'])
ITERATION = int(config['ddpm']['iteration'])
TIME_EMB_DIM = int(config['classifier']['time_emb_dim'])
DEVICE = torch.device(config['classifier']['device'])
EPOCH_NUM = int(config['classifier']['epoch_num'])
LEARNING_RATE = float(config['classifier']['learning_rate'])
# training
model = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
criterion = nn.CrossEntropyLoss()
optimzer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
min_loss = 99
for epoch in range(EPOCH_NUM):
loss_sum = 0
acc_sum = 0
data_count = 0
for x, y in loader:
optimzer.zero_grad()
x = x.to(DEVICE)
y = y.to(DEVICE)
time_seq = ddpm.get_time_seq(x.shape[0])
x_t, noise = ddpm.get_x_t(x, time_seq)
p = model(x_t, time_seq)
loss = criterion(p, y)
loss_sum += loss.item()
data_count += len(x)
acc_sum += (p.argmax(1)==y).sum()
loss.backward()
optimzer.step()
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}, acc: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader), acc_sum/data_count))
if loss_sum/len(loader) < min_loss:
min_loss = loss_sum/len(loader)
print("save model: the best loss is {}".format(min_loss))
torch.save(model.state_dict(), 'classifier.pth')
if __name__ == '__main__':
# read config file
config = configparser.ConfigParser()
config.read('training.ini')
# start training
loader = getMnistLoader(config)
train(config)

29
unet.py
View File

@ -2,6 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torchinfo import summary from torchinfo import summary
import math import math
from positionEncoding import PositionEncode
class DoubleConv(nn.Module): class DoubleConv(nn.Module):
''' '''
@ -90,34 +91,6 @@ class UpSampling(nn.Module):
time_emb = time_emb[:, :, None, None].repeat(1, 1, h, w) # (b, out_dim, h*2, w*2) time_emb = time_emb[:, :, None, None].repeat(1, 1, h, w) # (b, out_dim, h*2, w*2)
return x + time_emb return x + time_emb
class PositionEncode(nn.Module):
'''
Input a LongTensor time sequence, return position embedding
Input:
time_seq: shape of LongTensor (b, )
Output:
dim: shape of tensor (b, time_emb_dim)
Args:
time_emb_dim (int): output's dimension
device (nn.device): data on what device
'''
def __init__(self, time_emb_dim, device):
super(PositionEncode, self).__init__()
self.time_emb_dim = time_emb_dim
self.base = torch.Tensor([ 1/math.pow(10000, (i//2)/self.time_emb_dim) for i in range(self.time_emb_dim) ]) # (d)
self.base = self.base.to(device)
def forward(self, time_seq):
seq_len = len(time_seq)
dim = self.base.reshape(1, self.time_emb_dim).repeat(seq_len, 1) # (b, time_emb_dim)
time_seq = time_seq[:, None].repeat(1, self.time_emb_dim) # (b, time_emb_dim)
ans = dim * time_seq # (b, time_emb_dim)
ans[:, 0::2] = torch.sin(ans[:, 0::2])
ans[:, 1::2] = torch.cos(ans[:, 1::2])
return ans
class Unet(nn.Module): class Unet(nn.Module):
''' '''
Unet module, predict the noise Unet module, predict the noise