Compare commits

..

No commits in common. "7f62ee3f311f27c3ede803d3e982f9f56137517b" and "687a994beaf22e655e0cbc3c8cd82f2c1c0f4a30" have entirely different histories.

11 changed files with 65 additions and 228 deletions

View File

@ -7,13 +7,6 @@ It's just for fun, the Unet model does not include attention, normalization, etc
![](./images/Screenshot_20230314_225320.png) ![](./images/Screenshot_20230314_225320.png)
![](./images/ezgif.com-gif-maker.gif) ![](./images/ezgif.com-gif-maker.gif)
## Classifier Guidence DDPM
ref: Diffusion Models Beat GANs on Image Synthesis (https://arxiv.org/abs/2105.05233)
- generate "7"
- ![](./images/Screenshot_20230322_161942.png)
- generate "2"
- ![](./images/Screenshot_20230322_162114.png)
## Traning ## Traning
Before training, please set up the config.ini file: Before training, please set up the config.ini file:
@ -43,6 +36,5 @@ To generate 16 pictures, run the following command:
The pictures will be output to the `./output` directory. The pictures will be output to the `./output` directory.
``` ```
$ python sample 16 # unconditional $ python sample 16
$ python sample 16 7 # condiditional, want to generate "7" pictures
``` ```

View File

@ -1,49 +0,0 @@
import torch
import torch.nn as nn
from unet import DownSampling, DoubleConv
from positionEncoding import PositionEncode
class Classfier(nn.Module):
'''
Args:
time_emb_dim (int): dimention when position encoding
device (nn.device): for PositionEncode layer, means the data puts on what device
Inputs:
x: feature maps, (b, c, h, w)
time_seq: A longtensor means this x is x_t, In this module, it will transform time_seq to position embedding. (b, )
Outputs:
p: probability (b, 10)
'''
def __init__(self, time_emb_dim, device):
super(Classfier, self).__init__()
self.conv1 = DoubleConv(1, 32, nn.ReLU())
self.conv2 = DownSampling(32, 64, time_emb_dim)
self.conv3 = DownSampling(64, 128, time_emb_dim)
self.conv4 = DownSampling(128, 256, time_emb_dim)
self.dense1 = nn.Linear(256*3*3, 512)
self.dense2 = nn.Linear(512, 128)
self.dense3 = nn.Linear(128, 10)
self.pooling = nn.AvgPool2d(2, stride=2)
self.relu = nn .ReLU()
self.dropout = nn.Dropout(0.3)
self.time_embedding = PositionEncode(time_emb_dim, device)
def forward(self, x, time_seq):
time_emb = self.time_embedding(time_seq) # (b, time_emb_dim)
x = self.conv1(x,) # b, 32, 28, 28
x = self.conv2(x, time_emb) # b, 64, 14, 14
x = self.dropout(x)
x = self.conv3(x, time_emb) # b, 128, 7, 7
x = self.dropout(x)
x = self.conv4(x, time_emb) # b, 256, 3, 3
x = self.dropout(x)
x = x.reshape((x.shape[0], -1)) # b, 2304
x = self.relu(self.dense1(x)) # b, 512
x = self.dropout(x)
x = self.relu(self.dense2(x)) # b, 128
x = self.dropout(x)
x = self.dense3(x) # b, 10
return x

41
ddpm.py
View File

@ -51,15 +51,12 @@ class DDPM(nn.Module):
return mu + sigma * epsilon, epsilon # (b, c, w, h) return mu + sigma * epsilon, epsilon # (b, c, w, h)
def sample(self, model, generate_iteration_pic=False, n=None, target=None, classifier=None, classifier_scale=0.5): def sample(self, model, generate_iteration_pic=False, n=None):
''' '''
Inputs: Inputs:
model (nn.Module): Unet instance model (nn.Module): Unet instance
generate_iteration_pic (bool): whether generate 10 pic on different denoising time generate_iteration_pic (bool): whether generate 10 pic on different denoising time
n (int, default=self.batch_size): want to sample n pictures n (int, default=self.batch_size): want to sample n pictures
target (int, default=None): conditional target
classifier (int, default=None): for conditional diffusion model
classifier_scale (float): scaling classifier's control
Outputs: Outputs:
x_0 (nn.Tensor): (n, c, h, w) x_0 (nn.Tensor): (n, c, h, w)
''' '''
@ -68,43 +65,25 @@ class DDPM(nn.Module):
c, h, w = 1, 28, 28 c, h, w = 1, 28, 28
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w) x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w)
x_t += 0.2
for i in reversed(range(self.iteration)): for i in reversed(range(self.iteration)):
time_seq = (torch.ones(n) * i).long().to(self.device) # (n, ) time_seq = (torch.ones(n) * i).long().to(self.device) # (n, )
predict_noise = model(x_t, time_seq) # (n, c, h, w) predict_noise = model(x_t, time_seq) # (n, c, h, w)
first_term = 1/(torch.sqrt(self.alpha[time_seq])) first_term = 1/(torch.sqrt(self.alpha[time_seq])) # (n, )
second_term = (1-self.alpha[time_seq]) / (torch.sqrt(1-self.overline_alpha[time_seq])) second_term = (1-self.alpha[time_seq]) / (torch.sqrt(1-self.overline_alpha[time_seq]))
first_term = first_term[:, None, None, None].repeat(1, c, h, w) # (n, c, h, w) first_term = first_term[:, None, None, None].repeat(1, c, h, w)
second_term = second_term[:, None, None, None].repeat(1, c, h, w) # (n, c, h, w) second_term = second_term[:, None, None, None].repeat(1, c, h, w)
beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w) # (n, c, h, w) beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w)
mu = first_term * (x_t-(second_term * predict_noise)) # origin mu (n, c, h, w)
# if conditional => get classifier gradient
if target != None and classifier != None:
with torch.enable_grad():
# ref: https://github.com/clalanliu/IntroductionDiffusionModels/blob/main/control_diffusion.ipynb
x = x_t.detach()
x.requires_grad_(True)
logits = classifier(x, time_seq) # (b, 10)
log_probs = nn.LogSoftmax(dim=1)(logits) # (b, 10)
selected = log_probs[:, target] # (b)
grad = torch.autograd.grad(selected.sum(), x)[0] #
mu = mu + beta * classifier_scale * grad
# mask
if i!= 0: if i!= 0:
z = torch.randn((n, c, h, w)).to(self.device) z = torch.randn((n, c, h, w)).to(self.device)
else: else:
z = torch.zeros((n, c, h, w)).to(self.device) z = torch.zeros((n, c, h, w)).to(self.device)
x_t = mu - z * beta # origin mu x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta
# generate 10 pic on the different denoising times # generate 10 pic on the different denoising times
if generate_iteration_pic: if generate_iteration_pic:
@ -118,4 +97,4 @@ class DDPM(nn.Module):
x_t = ( x_t.clamp(-1, 1) + 1 ) / 2 x_t = ( x_t.clamp(-1, 1) + 1 ) / 2
x_t = x_t * 255 x_t = x_t * 255
return x_t return x_t

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 28 KiB

View File

@ -1,22 +0,0 @@
import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, Dataset
def getMnistLoader(config):
'''
Get MNIST dataset's loader
Inputs:
config (configparser.ConfigParser)
Outputs:
loader (nn.utils.data.DataLoader)
'''
BATCH_SIZE = int(config['unet']['batch_size'])
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
data = MNIST("./data", train=True, download=True, transform=transform)
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
return loader

View File

@ -1,31 +0,0 @@
import torch
import torch.nn as nn
import math
class PositionEncode(nn.Module):
'''
Input a LongTensor time sequence, return position embedding
Input:
time_seq: shape of LongTensor (b, )
Output:
dim: shape of tensor (b, time_emb_dim)
Args:
time_emb_dim (int): output's dimension
device (nn.device): data on what device
'''
def __init__(self, time_emb_dim, device):
super(PositionEncode, self).__init__()
self.time_emb_dim = time_emb_dim
self.base = torch.Tensor([ 1/math.pow(10000, (i//2)/self.time_emb_dim) for i in range(self.time_emb_dim) ]) # (d)
self.base = self.base.to(device)
def forward(self, time_seq):
seq_len = len(time_seq)
dim = self.base.reshape(1, self.time_emb_dim).repeat(seq_len, 1) # (b, time_emb_dim)
time_seq = time_seq[:, None].repeat(1, self.time_emb_dim) # (b, time_emb_dim)
ans = dim * time_seq # (b, time_emb_dim)
ans[:, 0::2] = torch.sin(ans[:, 0::2])
ans[:, 1::2] = torch.cos(ans[:, 1::2])
return ans

View File

@ -5,18 +5,11 @@ from unet import Unet
import sys import sys
import os import os
import configparser import configparser
from classifier import Classfier
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) < 2: if len(sys.argv) != 2:
print("Usage: python sample.py [pic_num]") print("Usage: python sample.py [pic_num]")
exit() exit()
elif len(sys.argv) == 3:
target = int( sys.argv[2] )
print("Target: {}".format(target))
else:
target = None
# read config file # read config file
config = configparser.ConfigParser() config = configparser.ConfigParser()
@ -29,16 +22,11 @@ if __name__ == '__main__':
# start sampling # start sampling
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE) model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
model.load_state_dict(torch.load('unet.pth'))
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE) ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
if target != None: model.load_state_dict(torch.load('unet.pth'))
classifier = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
classifier.load_state_dict(torch.load('classifier.pth'))
x_t = ddpm.sample(model, target=target, classifier=classifier, classifier_scale=0.5)
else:
x_t = ddpm.sample(model)
x_t = ddpm.sample(model)
if not os.path.isdir('./output'): if not os.path.isdir('./output'):
os.mkdir('./output') os.mkdir('./output')
@ -46,4 +34,4 @@ if __name__ == '__main__':
for index, pic in enumerate(x_t): for index, pic in enumerate(x_t):
p = pic.to('cpu').permute(1, 2, 0) p = pic.to('cpu').permute(1, 2, 0)
plt.imshow(p, cmap='gray', vmin=0, vmax=255) plt.imshow(p, cmap='gray', vmin=0, vmax=255)
plt.savefig("output/{}.png".format(index)) plt.savefig("output/{}.png".format(index))

View File

@ -1,11 +1,32 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from tqdm import tqdm from tqdm import tqdm
from ddpm import DDPM from ddpm import DDPM
from unet import Unet from unet import Unet
import configparser import configparser
from loader import getMnistLoader
def getMnistLoader(config):
'''
Get MNIST dataset's loader
Inputs:
config (configparser.ConfigParser)
Outputs:
loader (nn.utils.data.DataLoader)
'''
BATCH_SIZE = int(config['unet']['batch_size'])
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
data = MNIST("./data", train=True, download=True, transform=transform)
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
return loader
def train(config): def train(config):
''' '''

View File

@ -1,68 +0,0 @@
import torch
import torch.nn as nn
import configparser
from loader import getMnistLoader
from classifier import Classfier
from ddpm import DDPM
def train(config):
'''
Start Classier Training
Inputs:
config (configparser.ConfigParser)
Outputs:
None
'''
BATCH_SIZE = int(config['classifier']['batch_size'])
ITERATION = int(config['ddpm']['iteration'])
TIME_EMB_DIM = int(config['classifier']['time_emb_dim'])
DEVICE = torch.device(config['classifier']['device'])
EPOCH_NUM = int(config['classifier']['epoch_num'])
LEARNING_RATE = float(config['classifier']['learning_rate'])
# training
model = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
criterion = nn.CrossEntropyLoss()
optimzer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
min_loss = 99
for epoch in range(EPOCH_NUM):
loss_sum = 0
acc_sum = 0
data_count = 0
for x, y in loader:
optimzer.zero_grad()
x = x.to(DEVICE)
y = y.to(DEVICE)
time_seq = ddpm.get_time_seq(x.shape[0])
x_t, noise = ddpm.get_x_t(x, time_seq)
p = model(x_t, time_seq)
loss = criterion(p, y)
loss_sum += loss.item()
data_count += len(x)
acc_sum += (p.argmax(1)==y).sum()
loss.backward()
optimzer.step()
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}, acc: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader), acc_sum/data_count))
if loss_sum/len(loader) < min_loss:
min_loss = loss_sum/len(loader)
print("save model: the best loss is {}".format(min_loss))
torch.save(model.state_dict(), 'classifier.pth')
if __name__ == '__main__':
# read config file
config = configparser.ConfigParser()
config.read('training.ini')
# start training
loader = getMnistLoader(config)
train(config)

29
unet.py
View File

@ -2,7 +2,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from torchinfo import summary from torchinfo import summary
import math import math
from positionEncoding import PositionEncode
class DoubleConv(nn.Module): class DoubleConv(nn.Module):
''' '''
@ -91,6 +90,34 @@ class UpSampling(nn.Module):
time_emb = time_emb[:, :, None, None].repeat(1, 1, h, w) # (b, out_dim, h*2, w*2) time_emb = time_emb[:, :, None, None].repeat(1, 1, h, w) # (b, out_dim, h*2, w*2)
return x + time_emb return x + time_emb
class PositionEncode(nn.Module):
'''
Input a LongTensor time sequence, return position embedding
Input:
time_seq: shape of LongTensor (b, )
Output:
dim: shape of tensor (b, time_emb_dim)
Args:
time_emb_dim (int): output's dimension
device (nn.device): data on what device
'''
def __init__(self, time_emb_dim, device):
super(PositionEncode, self).__init__()
self.time_emb_dim = time_emb_dim
self.base = torch.Tensor([ 1/math.pow(10000, (i//2)/self.time_emb_dim) for i in range(self.time_emb_dim) ]) # (d)
self.base = self.base.to(device)
def forward(self, time_seq):
seq_len = len(time_seq)
dim = self.base.reshape(1, self.time_emb_dim).repeat(seq_len, 1) # (b, time_emb_dim)
time_seq = time_seq[:, None].repeat(1, self.time_emb_dim) # (b, time_emb_dim)
ans = dim * time_seq # (b, time_emb_dim)
ans[:, 0::2] = torch.sin(ans[:, 0::2])
ans[:, 1::2] = torch.cos(ans[:, 1::2])
return ans
class Unet(nn.Module): class Unet(nn.Module):
''' '''
Unet module, predict the noise Unet module, predict the noise