From 667346c561786792f159e9886acb17b98c580b88 Mon Sep 17 00:00:00 2001 From: snsd0805 Date: Tue, 14 Mar 2023 22:58:07 +0800 Subject: [PATCH] feat: complete sample.py --- sample.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/sample.py b/sample.py index afaebaf..fb03388 100644 --- a/sample.py +++ b/sample.py @@ -2,21 +2,30 @@ import torch import matplotlib.pyplot as plt from ddpm import DDPM from unet import Unet +import sys +import os BATCH_SIZE = 256 ITERATION = 500 TIME_EMB_DIM = 128 DEVICE = torch.device('cuda') -model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE) -ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE) +if __name__ == '__main__': + if len(sys.argv) != 2: + print("Usage: python sample.py [pic_num]") + exit() -model.load_state_dict(torch.load('unet.pth')) + model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE) + ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE) -x_t = ddpm.sample(model, 256) -for index, pic in enumerate(x_t): - p = pic.to('cpu').permute(1, 2, 0) - plt.imshow(p, cmap='gray', vmin=0, vmax=255) - plt.savefig("output/{}.png".format(index)) + model.load_state_dict(torch.load('unet.pth')) - \ No newline at end of file + x_t = ddpm.sample(model) + + if not os.path.isdir('./output'): + os.mkdir('./output') + + for index, pic in enumerate(x_t): + p = pic.to('cpu').permute(1, 2, 0) + plt.imshow(p, cmap='gray', vmin=0, vmax=255) + plt.savefig("output/{}.png".format(index)) \ No newline at end of file