fix: decrease batch size, iteration, epoch and lr
This commit is contained in:
parent
3530d91aaf
commit
38fa577706
12
sample.py
12
sample.py
@ -3,8 +3,8 @@ import matplotlib.pyplot as plt
|
|||||||
from ddpm import DDPM
|
from ddpm import DDPM
|
||||||
from unet import Unet
|
from unet import Unet
|
||||||
|
|
||||||
BATCH_SIZE = 512
|
BATCH_SIZE = 256
|
||||||
ITERATION = 1500
|
ITERATION = 500
|
||||||
TIME_EMB_DIM = 128
|
TIME_EMB_DIM = 128
|
||||||
DEVICE = torch.device('cuda')
|
DEVICE = torch.device('cuda')
|
||||||
|
|
||||||
@ -13,8 +13,10 @@ ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
|
|||||||
|
|
||||||
model.load_state_dict(torch.load('unet.pth'))
|
model.load_state_dict(torch.load('unet.pth'))
|
||||||
|
|
||||||
x_t = ddpm.sample(model, 32)
|
x_t = ddpm.sample(model, 256)
|
||||||
for index, pic in enumerate(x_t):
|
for index, pic in enumerate(x_t):
|
||||||
p = pic.to('cpu').permute(1, 2, 0)
|
p = pic.to('cpu').permute(1, 2, 0)
|
||||||
plt.imshow(p)
|
plt.imshow(p, cmap='gray', vmin=0, vmax=255)
|
||||||
plt.savefig("output/{}.png".format(index))
|
plt.savefig("output/{}.png".format(index))
|
||||||
|
|
||||||
|
|
||||||
17
train.py
17
train.py
@ -8,12 +8,12 @@ from tqdm import tqdm
|
|||||||
from ddpm import DDPM
|
from ddpm import DDPM
|
||||||
from unet import Unet
|
from unet import Unet
|
||||||
|
|
||||||
BATCH_SIZE = 512
|
BATCH_SIZE = 256
|
||||||
ITERATION = 1500
|
ITERATION = 500
|
||||||
TIME_EMB_DIM = 128
|
TIME_EMB_DIM = 128
|
||||||
DEVICE = torch.device('cuda')
|
DEVICE = torch.device('cuda')
|
||||||
EPOCH_NUM = 3000
|
EPOCH_NUM = 500
|
||||||
LEARNING_RATE = 1e-3
|
LEARNING_RATE = 1e-4
|
||||||
|
|
||||||
def getMnistLoader():
|
def getMnistLoader():
|
||||||
|
|
||||||
@ -32,6 +32,8 @@ def train(loader, device, epoch_num, lr):
|
|||||||
criterion = nn.MSELoss()
|
criterion = nn.MSELoss()
|
||||||
optimzer = torch.optim.Adam(model.parameters(), lr=lr)
|
optimzer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||||
|
|
||||||
|
min_loss = 99
|
||||||
|
|
||||||
for epoch in range(epoch_num):
|
for epoch in range(epoch_num):
|
||||||
loss_sum = 0
|
loss_sum = 0
|
||||||
# progress = tqdm(total=len(loader))
|
# progress = tqdm(total=len(loader))
|
||||||
@ -50,8 +52,11 @@ def train(loader, device, epoch_num, lr):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimzer.step()
|
optimzer.step()
|
||||||
# progress.update(1)
|
# progress.update(1)
|
||||||
torch.save(model.state_dict(), 'unet.pth')
|
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader)))
|
||||||
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, loss_sum/len(loader)))
|
if loss_sum/len(loader) < min_loss:
|
||||||
|
min_loss = loss_sum/len(loader)
|
||||||
|
print("save model: the best loss is {}".format(min_loss))
|
||||||
|
torch.save(model.state_dict(), 'unet.pth')
|
||||||
|
|
||||||
loader = getMnistLoader()
|
loader = getMnistLoader()
|
||||||
train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE)
|
train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE)
|
||||||
Loading…
Reference in New Issue
Block a user