From 25e6a5ff62b2a68e5e1bdfc6acdc59edb8d4304b Mon Sep 17 00:00:00 2001 From: snsd0805 Date: Tue, 14 Mar 2023 02:16:06 +0800 Subject: [PATCH] feat: add sample.py --- sample.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 sample.py diff --git a/sample.py b/sample.py new file mode 100644 index 0000000..e06c028 --- /dev/null +++ b/sample.py @@ -0,0 +1,20 @@ +import torch +import matplotlib.pyplot as plt +from ddpm import DDPM +from unet import Unet + +BATCH_SIZE = 512 +ITERATION = 1500 +TIME_EMB_DIM = 128 +DEVICE = torch.device('cuda') + +model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE) +ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE) + +model.load_state_dict(torch.load('unet.pth')) + +x_t = ddpm.sample(model, 32) +for index, pic in enumerate(x_t): + p = pic.to('cpu').permute(1, 2, 0) + plt.imshow(p) + plt.savefig("output/{}.png".format(index)) \ No newline at end of file