DDPM_Mnist/train.py

62 lines
1.9 KiB
Python

import torch
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from tqdm import tqdm
from ddpm import DDPM
from unet import Unet
BATCH_SIZE = 256
ITERATION = 500
TIME_EMB_DIM = 128
DEVICE = torch.device('cuda')
EPOCH_NUM = 500
LEARNING_RATE = 1e-4
def getMnistLoader():
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
data = MNIST("./data", train=True, download=True, transform=transform)
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
return loader
def train(loader, device, epoch_num, lr):
model = Unet(TIME_EMB_DIM, DEVICE).to(device)
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, device)
criterion = nn.MSELoss()
optimzer = torch.optim.Adam(model.parameters(), lr=lr)
min_loss = 99
for epoch in range(epoch_num):
loss_sum = 0
# progress = tqdm(total=len(loader))
for x, y in loader:
optimzer.zero_grad()
x = x.to(DEVICE)
time_seq = ddpm.get_time_seq(x.shape[0])
x_t, noise = ddpm.get_x_t(x, time_seq)
predict_noise = model(x_t, time_seq)
loss = criterion(predict_noise, noise)
loss_sum += loss.item()
loss.backward()
optimzer.step()
# progress.update(1)
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader)))
if loss_sum/len(loader) < min_loss:
min_loss = loss_sum/len(loader)
print("save model: the best loss is {}".format(min_loss))
torch.save(model.state_dict(), 'unet.pth')
loader = getMnistLoader()
train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE)