fix: decrease batch size, iteration, epoch and lr

This commit is contained in:
snsd0805 2023-03-14 18:45:04 +08:00
parent 3530d91aaf
commit 38fa577706
Signed by: snsd0805
GPG Key ID: 569349933C77A854
2 changed files with 18 additions and 11 deletions

View File

@ -3,8 +3,8 @@ import matplotlib.pyplot as plt
from ddpm import DDPM
from unet import Unet
BATCH_SIZE = 512
ITERATION = 1500
BATCH_SIZE = 256
ITERATION = 500
TIME_EMB_DIM = 128
DEVICE = torch.device('cuda')
@ -13,8 +13,10 @@ ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
model.load_state_dict(torch.load('unet.pth'))
x_t = ddpm.sample(model, 32)
x_t = ddpm.sample(model, 256)
for index, pic in enumerate(x_t):
p = pic.to('cpu').permute(1, 2, 0)
plt.imshow(p)
plt.imshow(p, cmap='gray', vmin=0, vmax=255)
plt.savefig("output/{}.png".format(index))

View File

@ -8,12 +8,12 @@ from tqdm import tqdm
from ddpm import DDPM
from unet import Unet
BATCH_SIZE = 512
ITERATION = 1500
BATCH_SIZE = 256
ITERATION = 500
TIME_EMB_DIM = 128
DEVICE = torch.device('cuda')
EPOCH_NUM = 3000
LEARNING_RATE = 1e-3
EPOCH_NUM = 500
LEARNING_RATE = 1e-4
def getMnistLoader():
@ -32,6 +32,8 @@ def train(loader, device, epoch_num, lr):
criterion = nn.MSELoss()
optimzer = torch.optim.Adam(model.parameters(), lr=lr)
min_loss = 99
for epoch in range(epoch_num):
loss_sum = 0
# progress = tqdm(total=len(loader))
@ -50,8 +52,11 @@ def train(loader, device, epoch_num, lr):
loss.backward()
optimzer.step()
# progress.update(1)
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader)))
if loss_sum/len(loader) < min_loss:
min_loss = loss_sum/len(loader)
print("save model: the best loss is {}".format(min_loss))
torch.save(model.state_dict(), 'unet.pth')
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, loss_sum/len(loader)))
loader = getMnistLoader()
train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE)