fix: update sample.py about conditional ddpm
This commit is contained in:
parent
95600b80c9
commit
7f62ee3f31
15
sample.py
15
sample.py
@ -14,6 +14,8 @@ if __name__ == '__main__':
|
||||
elif len(sys.argv) == 3:
|
||||
target = int( sys.argv[2] )
|
||||
print("Target: {}".format(target))
|
||||
else:
|
||||
target = None
|
||||
|
||||
|
||||
# read config file
|
||||
@ -27,13 +29,16 @@ if __name__ == '__main__':
|
||||
|
||||
# start sampling
|
||||
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
|
||||
classifier = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||
|
||||
model.load_state_dict(torch.load('unet.pth'))
|
||||
classifier.load_state_dict(torch.load('classifier.pth'))
|
||||
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
|
||||
|
||||
if target != None:
|
||||
classifier = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||
classifier.load_state_dict(torch.load('classifier.pth'))
|
||||
x_t = ddpm.sample(model, target=target, classifier=classifier, classifier_scale=0.5)
|
||||
else:
|
||||
x_t = ddpm.sample(model)
|
||||
|
||||
x_t = ddpm.sample(model, target=target, classifier=classifier, classifier_scale=0.5)
|
||||
|
||||
if not os.path.isdir('./output'):
|
||||
os.mkdir('./output')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user