diff --git a/classifier.py b/classifier.py new file mode 100644 index 0000000..8ba7430 --- /dev/null +++ b/classifier.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +from unet import DownSampling, DoubleConv +from positionEncoding import PositionEncode + +class Classfier(nn.Module): + ''' + Args: + time_emb_dim (int): dimention when position encoding + device (nn.device): for PositionEncode layer, means the data puts on what device + Inputs: + x: feature maps, (b, c, h, w) + time_seq: A longtensor means this x is x_t, In this module, it will transform time_seq to position embedding. (b, ) + Outputs: + p: probability (b, 10) + ''' + def __init__(self, time_emb_dim, device): + super(Classfier, self).__init__() + self.conv1 = DoubleConv(1, 32, nn.ReLU()) + self.conv2 = DownSampling(32, 64, time_emb_dim) + self.conv3 = DownSampling(64, 128, time_emb_dim) + self.conv4 = DownSampling(128, 256, time_emb_dim) + self.dense1 = nn.Linear(256*3*3, 512) + self.dense2 = nn.Linear(512, 128) + self.dense3 = nn.Linear(128, 10) + self.pooling = nn.AvgPool2d(2, stride=2) + self.relu = nn .ReLU() + self.dropout = nn.Dropout(0.3) + + self.time_embedding = PositionEncode(time_emb_dim, device) + + def forward(self, x, time_seq): + time_emb = self.time_embedding(time_seq) # (b, time_emb_dim) + + x = self.conv1(x,) # b, 32, 28, 28 + x = self.conv2(x, time_emb) # b, 64, 14, 14 + x = self.dropout(x) + x = self.conv3(x, time_emb) # b, 128, 7, 7 + x = self.dropout(x) + x = self.conv4(x, time_emb) # b, 256, 3, 3 + x = self.dropout(x) + + x = x.reshape((x.shape[0], -1)) # b, 2304 + x = self.relu(self.dense1(x)) # b, 512 + x = self.dropout(x) + x = self.relu(self.dense2(x)) # b, 128 + x = self.dropout(x) + x = self.dense3(x) # b, 10 + return x diff --git a/ddpm.py b/ddpm.py index 9dafdc5..7fb5a44 100644 --- a/ddpm.py +++ b/ddpm.py @@ -51,12 +51,15 @@ class DDPM(nn.Module): return mu + sigma * epsilon, epsilon # (b, c, w, h) - def sample(self, model, generate_iteration_pic=False, n=None): + def sample(self, model, generate_iteration_pic=False, n=None, target=None, classifier=None, classifier_scale=0.5): ''' Inputs: model (nn.Module): Unet instance generate_iteration_pic (bool): whether generate 10 pic on different denoising time n (int, default=self.batch_size): want to sample n pictures + target (int, default=None): conditional target + classifier (int, default=None): for conditional diffusion model + classifier_scale (float): scaling classifier's control Outputs: x_0 (nn.Tensor): (n, c, h, w) ''' @@ -65,25 +68,43 @@ class DDPM(nn.Module): c, h, w = 1, 28, 28 model.eval() with torch.no_grad(): - x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w) + x_t = torch.randn((n, c, h, w)).to(self.device) # (n, c, h, w) + x_t += 0.2 for i in reversed(range(self.iteration)): - time_seq = (torch.ones(n) * i).long().to(self.device) # (n, ) - predict_noise = model(x_t, time_seq) # (n, c, h, w) + time_seq = (torch.ones(n) * i).long().to(self.device) # (n, ) + predict_noise = model(x_t, time_seq) # (n, c, h, w) - first_term = 1/(torch.sqrt(self.alpha[time_seq])) # (n, ) + first_term = 1/(torch.sqrt(self.alpha[time_seq])) second_term = (1-self.alpha[time_seq]) / (torch.sqrt(1-self.overline_alpha[time_seq])) - first_term = first_term[:, None, None, None].repeat(1, c, h, w) - second_term = second_term[:, None, None, None].repeat(1, c, h, w) + first_term = first_term[:, None, None, None].repeat(1, c, h, w) # (n, c, h, w) + second_term = second_term[:, None, None, None].repeat(1, c, h, w) # (n, c, h, w) - beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w) + beta = self.beta[time_seq][:, None, None, None].repeat(1, c, h, w) # (n, c, h, w) + + mu = first_term * (x_t-(second_term * predict_noise)) # origin mu (n, c, h, w) + + # if conditional => get classifier gradient + if target != None and classifier != None: + with torch.enable_grad(): + # ref: https://github.com/clalanliu/IntroductionDiffusionModels/blob/main/control_diffusion.ipynb + x = x_t.detach() + x.requires_grad_(True) + + logits = classifier(x, time_seq) # (b, 10) + log_probs = nn.LogSoftmax(dim=1)(logits) # (b, 10) + selected = log_probs[:, target] # (b) + grad = torch.autograd.grad(selected.sum(), x)[0] # + mu = mu + beta * classifier_scale * grad + + # mask if i!= 0: z = torch.randn((n, c, h, w)).to(self.device) else: z = torch.zeros((n, c, h, w)).to(self.device) - x_t = first_term * (x_t-(second_term * predict_noise)) - z * beta + x_t = mu - z * beta # origin mu # generate 10 pic on the different denoising times if generate_iteration_pic: @@ -97,4 +118,4 @@ class DDPM(nn.Module): x_t = ( x_t.clamp(-1, 1) + 1 ) / 2 x_t = x_t * 255 - return x_t \ No newline at end of file + return x_t diff --git a/loader.py b/loader.py new file mode 100644 index 0000000..b7ed36a --- /dev/null +++ b/loader.py @@ -0,0 +1,22 @@ +import torchvision +from torchvision.datasets import MNIST +from torch.utils.data import DataLoader, Dataset + +def getMnistLoader(config): + ''' + Get MNIST dataset's loader + + Inputs: + config (configparser.ConfigParser) + Outputs: + loader (nn.utils.data.DataLoader) + ''' + BATCH_SIZE = int(config['unet']['batch_size']) + + transform = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor() + ]) + + data = MNIST("./data", train=True, download=True, transform=transform) + loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True) + return loader diff --git a/positionEncoding.py b/positionEncoding.py new file mode 100644 index 0000000..2b26842 --- /dev/null +++ b/positionEncoding.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +import math + +class PositionEncode(nn.Module): + ''' + Input a LongTensor time sequence, return position embedding + + Input: + time_seq: shape of LongTensor (b, ) + Output: + dim: shape of tensor (b, time_emb_dim) + Args: + time_emb_dim (int): output's dimension + device (nn.device): data on what device + ''' + def __init__(self, time_emb_dim, device): + super(PositionEncode, self).__init__() + self.time_emb_dim = time_emb_dim + + self.base = torch.Tensor([ 1/math.pow(10000, (i//2)/self.time_emb_dim) for i in range(self.time_emb_dim) ]) # (d) + self.base = self.base.to(device) + + def forward(self, time_seq): + seq_len = len(time_seq) + dim = self.base.reshape(1, self.time_emb_dim).repeat(seq_len, 1) # (b, time_emb_dim) + time_seq = time_seq[:, None].repeat(1, self.time_emb_dim) # (b, time_emb_dim) + ans = dim * time_seq # (b, time_emb_dim) + ans[:, 0::2] = torch.sin(ans[:, 0::2]) + ans[:, 1::2] = torch.cos(ans[:, 1::2]) + return ans diff --git a/sample.py b/sample.py index 7e3bb01..80a54e1 100644 --- a/sample.py +++ b/sample.py @@ -5,11 +5,16 @@ from unet import Unet import sys import os import configparser +from classifier import Classfier if __name__ == '__main__': - if len(sys.argv) != 2: + if len(sys.argv) < 2: print("Usage: python sample.py [pic_num]") exit() + elif len(sys.argv) == 3: + target = int( sys.argv[2] ) + print("Target: {}".format(target)) + # read config file config = configparser.ConfigParser() @@ -23,10 +28,12 @@ if __name__ == '__main__': # start sampling model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE) ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE) + classifier = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE) model.load_state_dict(torch.load('unet.pth')) + classifier.load_state_dict(torch.load('classifier.pth')) - x_t = ddpm.sample(model) + x_t = ddpm.sample(model, target=target, classifier=classifier, classifier_scale=0.5) if not os.path.isdir('./output'): os.mkdir('./output') @@ -34,4 +41,4 @@ if __name__ == '__main__': for index, pic in enumerate(x_t): p = pic.to('cpu').permute(1, 2, 0) plt.imshow(p, cmap='gray', vmin=0, vmax=255) - plt.savefig("output/{}.png".format(index)) \ No newline at end of file + plt.savefig("output/{}.png".format(index)) diff --git a/train.py b/train.py index 94ac4ae..f01790a 100644 --- a/train.py +++ b/train.py @@ -1,32 +1,11 @@ import torch import torch.nn as nn -from torchvision.datasets import MNIST -import torchvision -from torch.utils.data import DataLoader, Dataset import matplotlib.pyplot as plt from tqdm import tqdm from ddpm import DDPM from unet import Unet import configparser - -def getMnistLoader(config): - ''' - Get MNIST dataset's loader - - Inputs: - config (configparser.ConfigParser) - Outputs: - loader (nn.utils.data.DataLoader) - ''' - BATCH_SIZE = int(config['unet']['batch_size']) - - transform = torchvision.transforms.Compose([ - torchvision.transforms.ToTensor() - ]) - - data = MNIST("./data", train=True, download=True, transform=transform) - loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True) - return loader +from loader import getMnistLoader def train(config): ''' diff --git a/train_classifier.py b/train_classifier.py new file mode 100644 index 0000000..f7ef506 --- /dev/null +++ b/train_classifier.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +import configparser +from loader import getMnistLoader +from classifier import Classfier +from ddpm import DDPM + +def train(config): + ''' + Start Classier Training + + Inputs: + config (configparser.ConfigParser) + Outputs: + None + ''' + BATCH_SIZE = int(config['classifier']['batch_size']) + ITERATION = int(config['ddpm']['iteration']) + TIME_EMB_DIM = int(config['classifier']['time_emb_dim']) + DEVICE = torch.device(config['classifier']['device']) + EPOCH_NUM = int(config['classifier']['epoch_num']) + LEARNING_RATE = float(config['classifier']['learning_rate']) + + # training + model = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE) + ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE) + + criterion = nn.CrossEntropyLoss() + optimzer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) + + min_loss = 99 + + for epoch in range(EPOCH_NUM): + loss_sum = 0 + acc_sum = 0 + data_count = 0 + for x, y in loader: + optimzer.zero_grad() + + x = x.to(DEVICE) + y = y.to(DEVICE) + time_seq = ddpm.get_time_seq(x.shape[0]) + x_t, noise = ddpm.get_x_t(x, time_seq) + + p = model(x_t, time_seq) + loss = criterion(p, y) + + loss_sum += loss.item() + data_count += len(x) + acc_sum += (p.argmax(1)==y).sum() + + loss.backward() + optimzer.step() + + print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}, acc: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader), acc_sum/data_count)) + if loss_sum/len(loader) < min_loss: + min_loss = loss_sum/len(loader) + print("save model: the best loss is {}".format(min_loss)) + torch.save(model.state_dict(), 'classifier.pth') + +if __name__ == '__main__': + # read config file + config = configparser.ConfigParser() + config.read('training.ini') + + # start training + loader = getMnistLoader(config) + train(config) \ No newline at end of file diff --git a/unet.py b/unet.py index 5034544..1033884 100644 --- a/unet.py +++ b/unet.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn from torchinfo import summary import math +from positionEncoding import PositionEncode class DoubleConv(nn.Module): ''' @@ -90,34 +91,6 @@ class UpSampling(nn.Module): time_emb = time_emb[:, :, None, None].repeat(1, 1, h, w) # (b, out_dim, h*2, w*2) return x + time_emb -class PositionEncode(nn.Module): - ''' - Input a LongTensor time sequence, return position embedding - - Input: - time_seq: shape of LongTensor (b, ) - Output: - dim: shape of tensor (b, time_emb_dim) - Args: - time_emb_dim (int): output's dimension - device (nn.device): data on what device - ''' - def __init__(self, time_emb_dim, device): - super(PositionEncode, self).__init__() - self.time_emb_dim = time_emb_dim - - self.base = torch.Tensor([ 1/math.pow(10000, (i//2)/self.time_emb_dim) for i in range(self.time_emb_dim) ]) # (d) - self.base = self.base.to(device) - - def forward(self, time_seq): - seq_len = len(time_seq) - dim = self.base.reshape(1, self.time_emb_dim).repeat(seq_len, 1) # (b, time_emb_dim) - time_seq = time_seq[:, None].repeat(1, self.time_emb_dim) # (b, time_emb_dim) - ans = dim * time_seq # (b, time_emb_dim) - ans[:, 0::2] = torch.sin(ans[:, 0::2]) - ans[:, 1::2] = torch.cos(ans[:, 1::2]) - return ans - class Unet(nn.Module): ''' Unet module, predict the noise