fix: update sample.py about conditional ddpm

This commit is contained in:
snsd0805 2023-03-22 16:22:53 +08:00
parent 95600b80c9
commit 7f62ee3f31
Signed by: snsd0805
GPG Key ID: 569349933C77A854

View File

@ -14,6 +14,8 @@ if __name__ == '__main__':
elif len(sys.argv) == 3:
target = int( sys.argv[2] )
print("Target: {}".format(target))
else:
target = None
# read config file
@ -27,13 +29,16 @@ if __name__ == '__main__':
# start sampling
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
classifier = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
model.load_state_dict(torch.load('unet.pth'))
classifier.load_state_dict(torch.load('classifier.pth'))
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
if target != None:
classifier = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
classifier.load_state_dict(torch.load('classifier.pth'))
x_t = ddpm.sample(model, target=target, classifier=classifier, classifier_scale=0.5)
else:
x_t = ddpm.sample(model)
x_t = ddpm.sample(model, target=target, classifier=classifier, classifier_scale=0.5)
if not os.path.isdir('./output'):
os.mkdir('./output')