From 7f62ee3f311f27c3ede803d3e982f9f56137517b Mon Sep 17 00:00:00 2001 From: snsd0805 Date: Wed, 22 Mar 2023 16:22:53 +0800 Subject: [PATCH] fix: update sample.py about conditional ddpm --- sample.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/sample.py b/sample.py index 80a54e1..a20fb11 100644 --- a/sample.py +++ b/sample.py @@ -14,6 +14,8 @@ if __name__ == '__main__': elif len(sys.argv) == 3: target = int( sys.argv[2] ) print("Target: {}".format(target)) + else: + target = None # read config file @@ -27,13 +29,16 @@ if __name__ == '__main__': # start sampling model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE) - ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE) - classifier = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE) - model.load_state_dict(torch.load('unet.pth')) - classifier.load_state_dict(torch.load('classifier.pth')) + ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE) + + if target != None: + classifier = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE) + classifier.load_state_dict(torch.load('classifier.pth')) + x_t = ddpm.sample(model, target=target, classifier=classifier, classifier_scale=0.5) + else: + x_t = ddpm.sample(model) - x_t = ddpm.sample(model, target=target, classifier=classifier, classifier_scale=0.5) if not os.path.isdir('./output'): os.mkdir('./output')