50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
import torch
|
|
import matplotlib.pyplot as plt
|
|
from ddpm import DDPM
|
|
from unet import Unet
|
|
import sys
|
|
import os
|
|
import configparser
|
|
from classifier import Classfier
|
|
|
|
if __name__ == '__main__':
|
|
if len(sys.argv) < 2:
|
|
print("Usage: python sample.py [pic_num]")
|
|
exit()
|
|
elif len(sys.argv) == 3:
|
|
target = int( sys.argv[2] )
|
|
print("Target: {}".format(target))
|
|
else:
|
|
target = None
|
|
|
|
|
|
# read config file
|
|
config = configparser.ConfigParser()
|
|
config.read('training.ini')
|
|
|
|
BATCH_SIZE = int(config['unet']['batch_size'])
|
|
ITERATION = int(config['ddpm']['iteration'])
|
|
TIME_EMB_DIM = int(config['unet']['time_emb_dim'])
|
|
DEVICE = torch.device(config['unet']['device'])
|
|
|
|
# start sampling
|
|
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
|
model.load_state_dict(torch.load('unet.pth'))
|
|
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
|
|
|
|
if target != None:
|
|
classifier = Classfier(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
|
classifier.load_state_dict(torch.load('classifier.pth'))
|
|
x_t = ddpm.sample(model, target=target, classifier=classifier, classifier_scale=0.5)
|
|
else:
|
|
x_t = ddpm.sample(model)
|
|
|
|
|
|
if not os.path.isdir('./output'):
|
|
os.mkdir('./output')
|
|
|
|
for index, pic in enumerate(x_t):
|
|
p = pic.to('cpu').permute(1, 2, 0)
|
|
plt.imshow(p, cmap='gray', vmin=0, vmax=255)
|
|
plt.savefig("output/{}.png".format(index))
|