diff --git a/sample.py b/sample.py index e06c028..afaebaf 100644 --- a/sample.py +++ b/sample.py @@ -3,8 +3,8 @@ import matplotlib.pyplot as plt from ddpm import DDPM from unet import Unet -BATCH_SIZE = 512 -ITERATION = 1500 +BATCH_SIZE = 256 +ITERATION = 500 TIME_EMB_DIM = 128 DEVICE = torch.device('cuda') @@ -13,8 +13,10 @@ ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE) model.load_state_dict(torch.load('unet.pth')) -x_t = ddpm.sample(model, 32) +x_t = ddpm.sample(model, 256) for index, pic in enumerate(x_t): p = pic.to('cpu').permute(1, 2, 0) - plt.imshow(p) - plt.savefig("output/{}.png".format(index)) \ No newline at end of file + plt.imshow(p, cmap='gray', vmin=0, vmax=255) + plt.savefig("output/{}.png".format(index)) + + \ No newline at end of file diff --git a/train.py b/train.py index ef0a6c8..ba10994 100644 --- a/train.py +++ b/train.py @@ -8,12 +8,12 @@ from tqdm import tqdm from ddpm import DDPM from unet import Unet -BATCH_SIZE = 512 -ITERATION = 1500 +BATCH_SIZE = 256 +ITERATION = 500 TIME_EMB_DIM = 128 DEVICE = torch.device('cuda') -EPOCH_NUM = 3000 -LEARNING_RATE = 1e-3 +EPOCH_NUM = 500 +LEARNING_RATE = 1e-4 def getMnistLoader(): @@ -32,6 +32,8 @@ def train(loader, device, epoch_num, lr): criterion = nn.MSELoss() optimzer = torch.optim.Adam(model.parameters(), lr=lr) + min_loss = 99 + for epoch in range(epoch_num): loss_sum = 0 # progress = tqdm(total=len(loader)) @@ -50,8 +52,11 @@ def train(loader, device, epoch_num, lr): loss.backward() optimzer.step() # progress.update(1) - torch.save(model.state_dict(), 'unet.pth') - print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, loss_sum/len(loader))) + print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader))) + if loss_sum/len(loader) < min_loss: + min_loss = loss_sum/len(loader) + print("save model: the best loss is {}".format(min_loss)) + torch.save(model.state_dict(), 'unet.pth') loader = getMnistLoader() train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE) \ No newline at end of file