feat: classifier guidence diffusion model
This commit is contained in:
parent
687a994bea
commit
9852bf3530
49
classifier.py
Normal file
49
classifier.py
Normal file
@ -0,0 +1,49 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from unet import DownSampling, DoubleConv
|
||||
from positionEncoding import PositionEncode
|
||||
|
||||
class Classfier(nn.Module):
|
||||
'''
|
||||
Args:
|
||||
time_emb_dim (int): dimention when position encoding
|
||||
device (nn.device): for PositionEncode layer, means the data puts on what device
|
||||
Inputs:
|
||||
x: feature maps, (b, c, h, w)
|
||||
time_seq: A longtensor means this x is x_t, In this module, it will transform time_seq to position embedding. (b, )
|
||||
Outputs:
|
||||
p: probability (b, 10)
|
||||
'''
|
||||
def __init__(self, time_emb_dim, device):
|
||||
super(Classfier, self).__init__()
|
||||
self.conv1 = DoubleConv(1, 32, nn.ReLU())
|
||||
self.conv2 = DownSampling(32, 64, time_emb_dim)
|
||||
self.conv3 = DownSampling(64, 128, time_emb_dim)
|
||||
self.conv4 = DownSampling(128, 256, time_emb_dim)
|
||||
self.dense1 = nn.Linear(256*3*3, 512)
|
||||
self.dense2 = nn.Linear(512, 128)
|
||||
self.dense3 = nn.Linear(128, 10)
|
||||
self.pooling = nn.AvgPool2d(2, stride=2)
|
||||
self.relu = nn .ReLU()
|
||||
self.dropout = nn.Dropout(0.3)
|
||||
|
||||
self.time_embedding = PositionEncode(time_emb_dim, device)
|
||||
|
||||
def forward(self, x, time_seq):
|
||||
time_emb = self.time_embedding(time_seq) # (b, time_emb_dim)
|
||||
|
||||
x = self.conv1(x,) # b, 32, 28, 28
|
||||
x = self.conv2(x, time_emb) # b, 64, 14, 14
|
||||
x = self.dropout(x)
|
||||
x = self.conv3(x, time_emb) # b, 128, 7, 7
|
||||
x = self.dropout(x)
|
||||
x = self.conv4(x, time_emb) # b, 256, 3, 3
|
||||
x = self.dropout(x)
|
||||
|
||||
x = x.reshape((x.shape[0], -1)) # b, 2304
|
||||
x = self.relu(self.dense1(x)) # b, 512
|
||||
x = self.dropout(x)
|
||||
x = self.relu(self.dense2(x)) # b, 128
|
||||
x = self.dropout(x)
|
||||
x = self.dense3(x) # b, 10
|
||||
return x
|
||||
41
ddpm.py
41
ddpm.py
@ -51,12 +51,15 @@ class DDPM(nn.Module):
|
||||
|
||||
return mu + sigma * epsilon, epsilon # (b, c, w, h)
|
||||
|
||||
def sample(self, model, generate_iteration_pic=False, n=None):
|
||||
def sample(self, model, generate_iteration_pic=False, n=None, target=None, classifier=None, classifier_scale=0.5):
|
||||
'''
|
||||
Inputs:
|
||||
model (nn.Module): Unet instance
|
||||
generate_iteration_pic (bool): whether generate 10 pic on different denoising time
|
||||
n (int, default=self.batch_size): want to sample n pictures
|
||||
target (int, default=None): conditional target
|
||||
classifier (int, default=None): for conditional diffusion model
|
||||
classifier_scale (float): scaling classifier's control
|
||||
Outputs:
|
||||
x_0 (nn.Tensor): (n, c, h, w)
|
||||
'''
|
||||
@ -65,25 +68,43 @@ class DDPM(nn.Module):
|
||||
c, h, w = 1, 28, 28
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w)
|
||||
x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w)
|
||||
x_t += 0.2
|
||||
for i in reversed(range(self.iteration)):
|
||||
time_seq = (torch.ones(n) * i).long().to(self.device) # (n, )
|
||||
predict_noise = model(x_t, time_seq) # (n, c, h, w)
|
||||
time_seq = (torch.ones(n) * i).long().to(self.device) # (n, )
|
||||
predict_noise = model(x_t, time_seq) # (n, c, h, w)
|
||||
|
||||
first_term = 1/(torch.sqrt(self.alpha[time_seq])) # (n, )
|
||||
first_term = 1/(torch.sqrt(self.alpha[time_seq]))
|
||||
second_term = (1-self.alpha[time_seq]) / (torch.sqrt(1-self.overline_alpha[time_seq]))
|
||||
|
||||
first_term = first_term[:, None, None, None].repeat(1, c, h, w)
|
||||
second_term = second_term[:, None, None, None].repeat(1, c, h, w)
|
||||
first_term = first_term[:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
|
||||
second_term = second_term[:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
|
||||
|
||||
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w)
|
||||
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w) # (n, c, h, w)
|
||||
|
||||
|
||||
mu = first_term * (x_t-(second_term * predict_noise)) # origin mu (n, c, h, w)
|
||||
|
||||
# if conditional => get classifier gradient
|
||||
if target != None and classifier != None:
|
||||
with torch.enable_grad():
|
||||
# ref: https://github.com/clalanliu/IntroductionDiffusionModels/blob/main/control_diffusion.ipynb
|
||||
x = x_t.detach()
|
||||
x.requires_grad_(True)
|
||||
|
||||
logits = classifier(x, time_seq) # (b, 10)
|
||||
log_probs = nn.LogSoftmax(dim=1)(logits) # (b, 10)
|
||||
selected = log_probs[:, target] # (b)
|
||||
grad = torch.autograd.grad(selected.sum(), x)[0] #
|
||||
mu = mu + beta * classifier_scale * grad
|
||||
|
||||
# mask
|
||||
if i!= 0:
|
||||
z = torch.randn((n, c, h, w)).to(self.device)
|
||||
else:
|
||||
z = torch.zeros((n, c, h, w)).to(self.device)
|
||||
|
||||
x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
|
||||
x_t = mu - z * beta # origin mu
|
||||
|
||||
# generate 10 pic on the different denoising times
|
||||
if generate_iteration_pic:
|
||||
@ -97,4 +118,4 @@ class DDPM(nn.Module):
|
||||
|
||||
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
|
||||
x_t = x_t * 255
|
||||
return x_t
|
||||
return x_t
|
||||
|
||||
22
loader.py
Normal file
22
loader.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torchvision
|
||||
from torchvision.datasets import MNIST
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
def getMnistLoader(config):
|
||||
'''
|
||||
Get MNIST dataset's loader
|
||||
|
||||
Inputs:
|
||||
config (configparser.ConfigParser)
|
||||
Outputs:
|
||||
loader (nn.utils.data.DataLoader)
|
||||
'''
|
||||
BATCH_SIZE = int(config['unet']['batch_size'])
|
||||
|
||||
transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor()
|
||||
])
|
||||
|
||||
data = MNIST("./data", train=True, download=True, transform=transform)
|
||||
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
|
||||
return loader
|
||||
31
positionEncoding.py
Normal file
31
positionEncoding.py
Normal file
@ -0,0 +1,31 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
class PositionEncode(nn.Module):
|
||||
'''
|
||||
Input a LongTensor time sequence, return position embedding
|
||||
|
||||
Input:
|
||||
time_seq: shape of LongTensor (b, )
|
||||
Output:
|
||||
dim: shape of tensor (b, time_emb_dim)
|
||||
Args:
|
||||
time_emb_dim (int): output's dimension
|
||||
device (nn.device): data on what device
|
||||
'''
|
||||
def __init__(self, time_emb_dim, device):
|
||||
super(PositionEncode, self).__init__()
|
||||
self.time_emb_dim = time_emb_dim
|
||||
|
||||
self.base = torch.Tensor([ 1/math.pow(10000, (i//2)/self.time_emb_dim) for i in range(self.time_emb_dim) ]) # (d)
|
||||
self.base = self.base.to(device)
|
||||
|
||||
def forward(self, time_seq):
|
||||
seq_len = len(time_seq)
|
||||
dim = self.base.reshape(1, self.time_emb_dim).repeat(seq_len, 1) # (b, time_emb_dim)
|
||||
time_seq = time_seq[:, None].repeat(1, self.time_emb_dim) # (b, time_emb_dim)
|
||||
ans = dim * time_seq # (b, time_emb_dim)
|
||||
ans[:, 0::2] = torch.sin(ans[:, 0::2])
|
||||
ans[:, 1::2] = torch.cos(ans[:, 1::2])
|
||||
return ans
|
||||
13
sample.py
13
sample.py
@ -5,11 +5,16 @@ from unet import Unet
|
||||
import sys
|
||||
import os
|
||||
import configparser
|
||||
from classifier import Classfier
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) != 2:
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python sample.py [pic_num]")
|
||||
exit()
|
||||
elif len(sys.argv) == 3:
|
||||
target = int( sys.argv[2] )
|
||||
print("Target: {}".format(target))
|
||||
|
||||
|
||||
# read config file
|
||||
config = configparser.ConfigParser()
|
||||
@ -23,10 +28,12 @@ if __name__ == '__main__':
|
||||
# start sampling
|
||||
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
|
||||
classifier = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||
|
||||
model.load_state_dict(torch.load('unet.pth'))
|
||||
classifier.load_state_dict(torch.load('classifier.pth'))
|
||||
|
||||
x_t = ddpm.sample(model)
|
||||
x_t = ddpm.sample(model, target=target, classifier=classifier, classifier_scale=0.5)
|
||||
|
||||
if not os.path.isdir('./output'):
|
||||
os.mkdir('./output')
|
||||
@ -34,4 +41,4 @@ if __name__ == '__main__':
|
||||
for index, pic in enumerate(x_t):
|
||||
p = pic.to('cpu').permute(1, 2, 0)
|
||||
plt.imshow(p, cmap='gray', vmin=0, vmax=255)
|
||||
plt.savefig("output/{}.png".format(index))
|
||||
plt.savefig("output/{}.png".format(index))
|
||||
|
||||
23
train.py
23
train.py
@ -1,32 +1,11 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.datasets import MNIST
|
||||
import torchvision
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
from ddpm import DDPM
|
||||
from unet import Unet
|
||||
import configparser
|
||||
|
||||
def getMnistLoader(config):
|
||||
'''
|
||||
Get MNIST dataset's loader
|
||||
|
||||
Inputs:
|
||||
config (configparser.ConfigParser)
|
||||
Outputs:
|
||||
loader (nn.utils.data.DataLoader)
|
||||
'''
|
||||
BATCH_SIZE = int(config['unet']['batch_size'])
|
||||
|
||||
transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor()
|
||||
])
|
||||
|
||||
data = MNIST("./data", train=True, download=True, transform=transform)
|
||||
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
|
||||
return loader
|
||||
from loader import getMnistLoader
|
||||
|
||||
def train(config):
|
||||
'''
|
||||
|
||||
68
train_classifier.py
Normal file
68
train_classifier.py
Normal file
@ -0,0 +1,68 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import configparser
|
||||
from loader import getMnistLoader
|
||||
from classifier import Classfier
|
||||
from ddpm import DDPM
|
||||
|
||||
def train(config):
|
||||
'''
|
||||
Start Classier Training
|
||||
|
||||
Inputs:
|
||||
config (configparser.ConfigParser)
|
||||
Outputs:
|
||||
None
|
||||
'''
|
||||
BATCH_SIZE = int(config['classifier']['batch_size'])
|
||||
ITERATION = int(config['ddpm']['iteration'])
|
||||
TIME_EMB_DIM = int(config['classifier']['time_emb_dim'])
|
||||
DEVICE = torch.device(config['classifier']['device'])
|
||||
EPOCH_NUM = int(config['classifier']['epoch_num'])
|
||||
LEARNING_RATE = float(config['classifier']['learning_rate'])
|
||||
|
||||
# training
|
||||
model = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimzer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
min_loss = 99
|
||||
|
||||
for epoch in range(EPOCH_NUM):
|
||||
loss_sum = 0
|
||||
acc_sum = 0
|
||||
data_count = 0
|
||||
for x, y in loader:
|
||||
optimzer.zero_grad()
|
||||
|
||||
x = x.to(DEVICE)
|
||||
y = y.to(DEVICE)
|
||||
time_seq = ddpm.get_time_seq(x.shape[0])
|
||||
x_t, noise = ddpm.get_x_t(x, time_seq)
|
||||
|
||||
p = model(x_t, time_seq)
|
||||
loss = criterion(p, y)
|
||||
|
||||
loss_sum += loss.item()
|
||||
data_count += len(x)
|
||||
acc_sum += (p.argmax(1)==y).sum()
|
||||
|
||||
loss.backward()
|
||||
optimzer.step()
|
||||
|
||||
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}, acc: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader), acc_sum/data_count))
|
||||
if loss_sum/len(loader) < min_loss:
|
||||
min_loss = loss_sum/len(loader)
|
||||
print("save model: the best loss is {}".format(min_loss))
|
||||
torch.save(model.state_dict(), 'classifier.pth')
|
||||
|
||||
if __name__ == '__main__':
|
||||
# read config file
|
||||
config = configparser.ConfigParser()
|
||||
config.read('training.ini')
|
||||
|
||||
# start training
|
||||
loader = getMnistLoader(config)
|
||||
train(config)
|
||||
29
unet.py
29
unet.py
@ -2,6 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torchinfo import summary
|
||||
import math
|
||||
from positionEncoding import PositionEncode
|
||||
|
||||
class DoubleConv(nn.Module):
|
||||
'''
|
||||
@ -90,34 +91,6 @@ class UpSampling(nn.Module):
|
||||
time_emb = time_emb[:, :, None, None].repeat(1, 1, h, w) # (b, out_dim, h*2, w*2)
|
||||
return x + time_emb
|
||||
|
||||
class PositionEncode(nn.Module):
|
||||
'''
|
||||
Input a LongTensor time sequence, return position embedding
|
||||
|
||||
Input:
|
||||
time_seq: shape of LongTensor (b, )
|
||||
Output:
|
||||
dim: shape of tensor (b, time_emb_dim)
|
||||
Args:
|
||||
time_emb_dim (int): output's dimension
|
||||
device (nn.device): data on what device
|
||||
'''
|
||||
def __init__(self, time_emb_dim, device):
|
||||
super(PositionEncode, self).__init__()
|
||||
self.time_emb_dim = time_emb_dim
|
||||
|
||||
self.base = torch.Tensor([ 1/math.pow(10000, (i//2)/self.time_emb_dim) for i in range(self.time_emb_dim) ]) # (d)
|
||||
self.base = self.base.to(device)
|
||||
|
||||
def forward(self, time_seq):
|
||||
seq_len = len(time_seq)
|
||||
dim = self.base.reshape(1, self.time_emb_dim).repeat(seq_len, 1) # (b, time_emb_dim)
|
||||
time_seq = time_seq[:, None].repeat(1, self.time_emb_dim) # (b, time_emb_dim)
|
||||
ans = dim * time_seq # (b, time_emb_dim)
|
||||
ans[:, 0::2] = torch.sin(ans[:, 0::2])
|
||||
ans[:, 1::2] = torch.cos(ans[:, 1::2])
|
||||
return ans
|
||||
|
||||
class Unet(nn.Module):
|
||||
'''
|
||||
Unet module, predict the noise
|
||||
|
||||
Loading…
Reference in New Issue
Block a user