feat: complete sample.py

This commit is contained in:
snsd0805 2023-03-14 22:58:07 +08:00
parent a330666a17
commit 667346c561
Signed by: snsd0805
GPG Key ID: 569349933C77A854

View File

@ -2,21 +2,30 @@ import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from ddpm import DDPM from ddpm import DDPM
from unet import Unet from unet import Unet
import sys
import os
BATCH_SIZE = 256 BATCH_SIZE = 256
ITERATION = 500 ITERATION = 500
TIME_EMB_DIM = 128 TIME_EMB_DIM = 128
DEVICE = torch.device('cuda') DEVICE = torch.device('cuda')
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE) if __name__ == '__main__':
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE) if len(sys.argv) != 2:
print("Usage: python sample.py [pic_num]")
exit()
model.load_state_dict(torch.load('unet.pth')) model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
x_t = ddpm.sample(model, 256) model.load_state_dict(torch.load('unet.pth'))
for index, pic in enumerate(x_t):
x_t = ddpm.sample(model)
if not os.path.isdir('./output'):
os.mkdir('./output')
for index, pic in enumerate(x_t):
p = pic.to('cpu').permute(1, 2, 0) p = pic.to('cpu').permute(1, 2, 0)
plt.imshow(p, cmap='gray', vmin=0, vmax=255) plt.imshow(p, cmap='gray', vmin=0, vmax=255)
plt.savefig("output/{}.png".format(index)) plt.savefig("output/{}.png".format(index))