feat: classifier guidence diffusion model
This commit is contained in:
parent
687a994bea
commit
9852bf3530
49
classifier.py
Normal file
49
classifier.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from unet import DownSampling, DoubleConv
|
||||||
|
from positionEncoding import PositionEncode
|
||||||
|
|
||||||
|
class Classfier(nn.Module):
|
||||||
|
'''
|
||||||
|
Args:
|
||||||
|
time_emb_dim (int): dimention when position encoding
|
||||||
|
device (nn.device): for PositionEncode layer, means the data puts on what device
|
||||||
|
Inputs:
|
||||||
|
x: feature maps, (b, c, h, w)
|
||||||
|
time_seq: A longtensor means this x is x_t, In this module, it will transform time_seq to position embedding. (b, )
|
||||||
|
Outputs:
|
||||||
|
p: probability (b, 10)
|
||||||
|
'''
|
||||||
|
def __init__(self, time_emb_dim, device):
|
||||||
|
super(Classfier, self).__init__()
|
||||||
|
self.conv1 = DoubleConv(1, 32, nn.ReLU())
|
||||||
|
self.conv2 = DownSampling(32, 64, time_emb_dim)
|
||||||
|
self.conv3 = DownSampling(64, 128, time_emb_dim)
|
||||||
|
self.conv4 = DownSampling(128, 256, time_emb_dim)
|
||||||
|
self.dense1 = nn.Linear(256*3*3, 512)
|
||||||
|
self.dense2 = nn.Linear(512, 128)
|
||||||
|
self.dense3 = nn.Linear(128, 10)
|
||||||
|
self.pooling = nn.AvgPool2d(2, stride=2)
|
||||||
|
self.relu = nn .ReLU()
|
||||||
|
self.dropout = nn.Dropout(0.3)
|
||||||
|
|
||||||
|
self.time_embedding = PositionEncode(time_emb_dim, device)
|
||||||
|
|
||||||
|
def forward(self, x, time_seq):
|
||||||
|
time_emb = self.time_embedding(time_seq) # (b, time_emb_dim)
|
||||||
|
|
||||||
|
x = self.conv1(x,) # b, 32, 28, 28
|
||||||
|
x = self.conv2(x, time_emb) # b, 64, 14, 14
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.conv3(x, time_emb) # b, 128, 7, 7
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.conv4(x, time_emb) # b, 256, 3, 3
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
x = x.reshape((x.shape[0], -1)) # b, 2304
|
||||||
|
x = self.relu(self.dense1(x)) # b, 512
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.relu(self.dense2(x)) # b, 128
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.dense3(x) # b, 10
|
||||||
|
return x
|
||||||
39
ddpm.py
39
ddpm.py
@ -51,12 +51,15 @@ class DDPM(nn.Module):
|
|||||||
|
|
||||||
return mu + sigma * epsilon, epsilon # (b, c, w, h)
|
return mu + sigma * epsilon, epsilon # (b, c, w, h)
|
||||||
|
|
||||||
def sample(self, model, generate_iteration_pic=False, n=None):
|
def sample(self, model, generate_iteration_pic=False, n=None, target=None, classifier=None, classifier_scale=0.5):
|
||||||
'''
|
'''
|
||||||
Inputs:
|
Inputs:
|
||||||
model (nn.Module): Unet instance
|
model (nn.Module): Unet instance
|
||||||
generate_iteration_pic (bool): whether generate 10 pic on different denoising time
|
generate_iteration_pic (bool): whether generate 10 pic on different denoising time
|
||||||
n (int, default=self.batch_size): want to sample n pictures
|
n (int, default=self.batch_size): want to sample n pictures
|
||||||
|
target (int, default=None): conditional target
|
||||||
|
classifier (int, default=None): for conditional diffusion model
|
||||||
|
classifier_scale (float): scaling classifier's control
|
||||||
Outputs:
|
Outputs:
|
||||||
x_0 (nn.Tensor): (n, c, h, w)
|
x_0 (nn.Tensor): (n, c, h, w)
|
||||||
'''
|
'''
|
||||||
@ -65,25 +68,43 @@ class DDPM(nn.Module):
|
|||||||
c, h, w = 1, 28, 28
|
c, h, w = 1, 28, 28
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w)
|
x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w)
|
||||||
|
x_t += 0.2
|
||||||
for i in reversed(range(self.iteration)):
|
for i in reversed(range(self.iteration)):
|
||||||
time_seq = (torch.ones(n) * i).long().to(self.device) # (n, )
|
time_seq = (torch.ones(n) * i).long().to(self.device) # (n, )
|
||||||
predict_noise = model(x_t, time_seq) # (n, c, h, w)
|
predict_noise = model(x_t, time_seq) # (n, c, h, w)
|
||||||
|
|
||||||
first_term = 1/(torch.sqrt(self.alpha[time_seq])) # (n, )
|
first_term = 1/(torch.sqrt(self.alpha[time_seq]))
|
||||||
second_term = (1-self.alpha[time_seq]) / (torch.sqrt(1-self.overline_alpha[time_seq]))
|
second_term = (1-self.alpha[time_seq]) / (torch.sqrt(1-self.overline_alpha[time_seq]))
|
||||||
|
|
||||||
first_term = first_term[:, None, None, None].repeat(1, c, h, w)
|
first_term = first_term[:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
|
||||||
second_term = second_term[:, None, None, None].repeat(1, c, h, w)
|
second_term = second_term[:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
|
||||||
|
|
||||||
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w)
|
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
|
||||||
|
|
||||||
|
|
||||||
|
mu = first_term * (x_t-(second_term * predict_noise)) # origin mu (n, c, h, w)
|
||||||
|
|
||||||
|
# if conditional => get classifier gradient
|
||||||
|
if target != None and classifier != None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
# ref: https://github.com/clalanliu/IntroductionDiffusionModels/blob/main/control_diffusion.ipynb
|
||||||
|
x = x_t.detach()
|
||||||
|
x.requires_grad_(True)
|
||||||
|
|
||||||
|
logits = classifier(x, time_seq) # (b, 10)
|
||||||
|
log_probs = nn.LogSoftmax(dim=1)(logits) # (b, 10)
|
||||||
|
selected = log_probs[:, target] # (b)
|
||||||
|
grad = torch.autograd.grad(selected.sum(), x)[0] #
|
||||||
|
mu = mu + beta * classifier_scale * grad
|
||||||
|
|
||||||
|
# mask
|
||||||
if i!= 0:
|
if i!= 0:
|
||||||
z = torch.randn((n, c, h, w)).to(self.device)
|
z = torch.randn((n, c, h, w)).to(self.device)
|
||||||
else:
|
else:
|
||||||
z = torch.zeros((n, c, h, w)).to(self.device)
|
z = torch.zeros((n, c, h, w)).to(self.device)
|
||||||
|
|
||||||
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
|
x_t = mu - z * beta # origin mu
|
||||||
|
|
||||||
# generate 10 pic on the different denoising times
|
# generate 10 pic on the different denoising times
|
||||||
if generate_iteration_pic:
|
if generate_iteration_pic:
|
||||||
|
|||||||
22
loader.py
Normal file
22
loader.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import torchvision
|
||||||
|
from torchvision.datasets import MNIST
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
def getMnistLoader(config):
|
||||||
|
'''
|
||||||
|
Get MNIST dataset's loader
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
config (configparser.ConfigParser)
|
||||||
|
Outputs:
|
||||||
|
loader (nn.utils.data.DataLoader)
|
||||||
|
'''
|
||||||
|
BATCH_SIZE = int(config['unet']['batch_size'])
|
||||||
|
|
||||||
|
transform = torchvision.transforms.Compose([
|
||||||
|
torchvision.transforms.ToTensor()
|
||||||
|
])
|
||||||
|
|
||||||
|
data = MNIST("./data", train=True, download=True, transform=transform)
|
||||||
|
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
return loader
|
||||||
31
positionEncoding.py
Normal file
31
positionEncoding.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
class PositionEncode(nn.Module):
|
||||||
|
'''
|
||||||
|
Input a LongTensor time sequence, return position embedding
|
||||||
|
|
||||||
|
Input:
|
||||||
|
time_seq: shape of LongTensor (b, )
|
||||||
|
Output:
|
||||||
|
dim: shape of tensor (b, time_emb_dim)
|
||||||
|
Args:
|
||||||
|
time_emb_dim (int): output's dimension
|
||||||
|
device (nn.device): data on what device
|
||||||
|
'''
|
||||||
|
def __init__(self, time_emb_dim, device):
|
||||||
|
super(PositionEncode, self).__init__()
|
||||||
|
self.time_emb_dim = time_emb_dim
|
||||||
|
|
||||||
|
self.base = torch.Tensor([ 1/math.pow(10000, (i//2)/self.time_emb_dim) for i in range(self.time_emb_dim) ]) # (d)
|
||||||
|
self.base = self.base.to(device)
|
||||||
|
|
||||||
|
def forward(self, time_seq):
|
||||||
|
seq_len = len(time_seq)
|
||||||
|
dim = self.base.reshape(1, self.time_emb_dim).repeat(seq_len, 1) # (b, time_emb_dim)
|
||||||
|
time_seq = time_seq[:, None].repeat(1, self.time_emb_dim) # (b, time_emb_dim)
|
||||||
|
ans = dim * time_seq # (b, time_emb_dim)
|
||||||
|
ans[:, 0::2] = torch.sin(ans[:, 0::2])
|
||||||
|
ans[:, 1::2] = torch.cos(ans[:, 1::2])
|
||||||
|
return ans
|
||||||
11
sample.py
11
sample.py
@ -5,11 +5,16 @@ from unet import Unet
|
|||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import configparser
|
import configparser
|
||||||
|
from classifier import Classfier
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if len(sys.argv) != 2:
|
if len(sys.argv) < 2:
|
||||||
print("Usage: python sample.py [pic_num]")
|
print("Usage: python sample.py [pic_num]")
|
||||||
exit()
|
exit()
|
||||||
|
elif len(sys.argv) == 3:
|
||||||
|
target = int( sys.argv[2] )
|
||||||
|
print("Target: {}".format(target))
|
||||||
|
|
||||||
|
|
||||||
# read config file
|
# read config file
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@ -23,10 +28,12 @@ if __name__ == '__main__':
|
|||||||
# start sampling
|
# start sampling
|
||||||
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||||
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
|
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
|
||||||
|
classifier = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||||
|
|
||||||
model.load_state_dict(torch.load('unet.pth'))
|
model.load_state_dict(torch.load('unet.pth'))
|
||||||
|
classifier.load_state_dict(torch.load('classifier.pth'))
|
||||||
|
|
||||||
x_t = ddpm.sample(model)
|
x_t = ddpm.sample(model, target=target, classifier=classifier, classifier_scale=0.5)
|
||||||
|
|
||||||
if not os.path.isdir('./output'):
|
if not os.path.isdir('./output'):
|
||||||
os.mkdir('./output')
|
os.mkdir('./output')
|
||||||
|
|||||||
23
train.py
23
train.py
@ -1,32 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchvision.datasets import MNIST
|
|
||||||
import torchvision
|
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from ddpm import DDPM
|
from ddpm import DDPM
|
||||||
from unet import Unet
|
from unet import Unet
|
||||||
import configparser
|
import configparser
|
||||||
|
from loader import getMnistLoader
|
||||||
def getMnistLoader(config):
|
|
||||||
'''
|
|
||||||
Get MNIST dataset's loader
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
config (configparser.ConfigParser)
|
|
||||||
Outputs:
|
|
||||||
loader (nn.utils.data.DataLoader)
|
|
||||||
'''
|
|
||||||
BATCH_SIZE = int(config['unet']['batch_size'])
|
|
||||||
|
|
||||||
transform = torchvision.transforms.Compose([
|
|
||||||
torchvision.transforms.ToTensor()
|
|
||||||
])
|
|
||||||
|
|
||||||
data = MNIST("./data", train=True, download=True, transform=transform)
|
|
||||||
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
|
|
||||||
return loader
|
|
||||||
|
|
||||||
def train(config):
|
def train(config):
|
||||||
'''
|
'''
|
||||||
|
|||||||
68
train_classifier.py
Normal file
68
train_classifier.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import configparser
|
||||||
|
from loader import getMnistLoader
|
||||||
|
from classifier import Classfier
|
||||||
|
from ddpm import DDPM
|
||||||
|
|
||||||
|
def train(config):
|
||||||
|
'''
|
||||||
|
Start Classier Training
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
config (configparser.ConfigParser)
|
||||||
|
Outputs:
|
||||||
|
None
|
||||||
|
'''
|
||||||
|
BATCH_SIZE = int(config['classifier']['batch_size'])
|
||||||
|
ITERATION = int(config['ddpm']['iteration'])
|
||||||
|
TIME_EMB_DIM = int(config['classifier']['time_emb_dim'])
|
||||||
|
DEVICE = torch.device(config['classifier']['device'])
|
||||||
|
EPOCH_NUM = int(config['classifier']['epoch_num'])
|
||||||
|
LEARNING_RATE = float(config['classifier']['learning_rate'])
|
||||||
|
|
||||||
|
# training
|
||||||
|
model = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||||
|
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimzer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
||||||
|
|
||||||
|
min_loss = 99
|
||||||
|
|
||||||
|
for epoch in range(EPOCH_NUM):
|
||||||
|
loss_sum = 0
|
||||||
|
acc_sum = 0
|
||||||
|
data_count = 0
|
||||||
|
for x, y in loader:
|
||||||
|
optimzer.zero_grad()
|
||||||
|
|
||||||
|
x = x.to(DEVICE)
|
||||||
|
y = y.to(DEVICE)
|
||||||
|
time_seq = ddpm.get_time_seq(x.shape[0])
|
||||||
|
x_t, noise = ddpm.get_x_t(x, time_seq)
|
||||||
|
|
||||||
|
p = model(x_t, time_seq)
|
||||||
|
loss = criterion(p, y)
|
||||||
|
|
||||||
|
loss_sum += loss.item()
|
||||||
|
data_count += len(x)
|
||||||
|
acc_sum += (p.argmax(1)==y).sum()
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimzer.step()
|
||||||
|
|
||||||
|
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}, acc: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader), acc_sum/data_count))
|
||||||
|
if loss_sum/len(loader) < min_loss:
|
||||||
|
min_loss = loss_sum/len(loader)
|
||||||
|
print("save model: the best loss is {}".format(min_loss))
|
||||||
|
torch.save(model.state_dict(), 'classifier.pth')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# read config file
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read('training.ini')
|
||||||
|
|
||||||
|
# start training
|
||||||
|
loader = getMnistLoader(config)
|
||||||
|
train(config)
|
||||||
29
unet.py
29
unet.py
@ -2,6 +2,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchinfo import summary
|
from torchinfo import summary
|
||||||
import math
|
import math
|
||||||
|
from positionEncoding import PositionEncode
|
||||||
|
|
||||||
class DoubleConv(nn.Module):
|
class DoubleConv(nn.Module):
|
||||||
'''
|
'''
|
||||||
@ -90,34 +91,6 @@ class UpSampling(nn.Module):
|
|||||||
time_emb = time_emb[:, :, None, None].repeat(1, 1, h, w) # (b, out_dim, h*2, w*2)
|
time_emb = time_emb[:, :, None, None].repeat(1, 1, h, w) # (b, out_dim, h*2, w*2)
|
||||||
return x + time_emb
|
return x + time_emb
|
||||||
|
|
||||||
class PositionEncode(nn.Module):
|
|
||||||
'''
|
|
||||||
Input a LongTensor time sequence, return position embedding
|
|
||||||
|
|
||||||
Input:
|
|
||||||
time_seq: shape of LongTensor (b, )
|
|
||||||
Output:
|
|
||||||
dim: shape of tensor (b, time_emb_dim)
|
|
||||||
Args:
|
|
||||||
time_emb_dim (int): output's dimension
|
|
||||||
device (nn.device): data on what device
|
|
||||||
'''
|
|
||||||
def __init__(self, time_emb_dim, device):
|
|
||||||
super(PositionEncode, self).__init__()
|
|
||||||
self.time_emb_dim = time_emb_dim
|
|
||||||
|
|
||||||
self.base = torch.Tensor([ 1/math.pow(10000, (i//2)/self.time_emb_dim) for i in range(self.time_emb_dim) ]) # (d)
|
|
||||||
self.base = self.base.to(device)
|
|
||||||
|
|
||||||
def forward(self, time_seq):
|
|
||||||
seq_len = len(time_seq)
|
|
||||||
dim = self.base.reshape(1, self.time_emb_dim).repeat(seq_len, 1) # (b, time_emb_dim)
|
|
||||||
time_seq = time_seq[:, None].repeat(1, self.time_emb_dim) # (b, time_emb_dim)
|
|
||||||
ans = dim * time_seq # (b, time_emb_dim)
|
|
||||||
ans[:, 0::2] = torch.sin(ans[:, 0::2])
|
|
||||||
ans[:, 1::2] = torch.cos(ans[:, 1::2])
|
|
||||||
return ans
|
|
||||||
|
|
||||||
class Unet(nn.Module):
|
class Unet(nn.Module):
|
||||||
'''
|
'''
|
||||||
Unet module, predict the noise
|
Unet module, predict the noise
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user