style: use configparser to manage config file
This commit is contained in:
parent
81779cf5e2
commit
4c4297f2bd
16
sample.py
16
sample.py
@ -4,17 +4,23 @@ from ddpm import DDPM
|
|||||||
from unet import Unet
|
from unet import Unet
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import configparser
|
||||||
BATCH_SIZE = 256
|
|
||||||
ITERATION = 500
|
|
||||||
TIME_EMB_DIM = 128
|
|
||||||
DEVICE = torch.device('cuda')
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if len(sys.argv) != 2:
|
if len(sys.argv) != 2:
|
||||||
print("Usage: python sample.py [pic_num]")
|
print("Usage: python sample.py [pic_num]")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
|
# read config file
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read('training.ini')
|
||||||
|
|
||||||
|
BATCH_SIZE = int(config['unet']['batch_size'])
|
||||||
|
ITERATION = int(config['ddpm']['iteration'])
|
||||||
|
TIME_EMB_DIM = int(config['unet']['time_emb_dim'])
|
||||||
|
DEVICE = torch.device(config['unet']['device'])
|
||||||
|
|
||||||
|
# start sampling
|
||||||
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||||
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
|
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)
|
||||||
|
|
||||||
|
|||||||
56
train.py
56
train.py
@ -7,15 +7,18 @@ import matplotlib.pyplot as plt
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from ddpm import DDPM
|
from ddpm import DDPM
|
||||||
from unet import Unet
|
from unet import Unet
|
||||||
|
import configparser
|
||||||
|
|
||||||
BATCH_SIZE = 256
|
def getMnistLoader(config):
|
||||||
ITERATION = 500
|
'''
|
||||||
TIME_EMB_DIM = 128
|
Get MNIST dataset's loader
|
||||||
DEVICE = torch.device('cuda')
|
|
||||||
EPOCH_NUM = 500
|
|
||||||
LEARNING_RATE = 1e-4
|
|
||||||
|
|
||||||
def getMnistLoader():
|
Inputs:
|
||||||
|
config (configparser.ConfigParser)
|
||||||
|
Outputs:
|
||||||
|
loader (nn.utils.data.DataLoader)
|
||||||
|
'''
|
||||||
|
BATCH_SIZE = int(config['unet']['batch_size'])
|
||||||
|
|
||||||
transform = torchvision.transforms.Compose([
|
transform = torchvision.transforms.Compose([
|
||||||
torchvision.transforms.ToTensor()
|
torchvision.transforms.ToTensor()
|
||||||
@ -25,18 +28,33 @@ def getMnistLoader():
|
|||||||
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
|
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
def train(loader, device, epoch_num, lr):
|
def train(config):
|
||||||
model = Unet(TIME_EMB_DIM, DEVICE).to(device)
|
'''
|
||||||
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, device)
|
Start Unet Training
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
config (configparser.ConfigParser)
|
||||||
|
Outputs:
|
||||||
|
None
|
||||||
|
'''
|
||||||
|
BATCH_SIZE = int(config['unet']['batch_size'])
|
||||||
|
ITERATION = int(config['ddpm']['iteration'])
|
||||||
|
TIME_EMB_DIM = int(config['unet']['time_emb_dim'])
|
||||||
|
DEVICE = torch.device(config['unet']['device'])
|
||||||
|
EPOCH_NUM = int(config['unet']['epoch_num'])
|
||||||
|
LEARNING_RATE = float(config['unet']['learning_rate'])
|
||||||
|
|
||||||
|
# training
|
||||||
|
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
|
||||||
|
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
|
||||||
|
|
||||||
criterion = nn.MSELoss()
|
criterion = nn.MSELoss()
|
||||||
optimzer = torch.optim.Adam(model.parameters(), lr=lr)
|
optimzer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
||||||
|
|
||||||
min_loss = 99
|
min_loss = 99
|
||||||
|
|
||||||
for epoch in range(epoch_num):
|
for epoch in range(EPOCH_NUM):
|
||||||
loss_sum = 0
|
loss_sum = 0
|
||||||
# progress = tqdm(total=len(loader))
|
|
||||||
for x, y in loader:
|
for x, y in loader:
|
||||||
optimzer.zero_grad()
|
optimzer.zero_grad()
|
||||||
|
|
||||||
@ -51,12 +69,18 @@ def train(loader, device, epoch_num, lr):
|
|||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimzer.step()
|
optimzer.step()
|
||||||
# progress.update(1)
|
|
||||||
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader)))
|
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader)))
|
||||||
if loss_sum/len(loader) < min_loss:
|
if loss_sum/len(loader) < min_loss:
|
||||||
min_loss = loss_sum/len(loader)
|
min_loss = loss_sum/len(loader)
|
||||||
print("save model: the best loss is {}".format(min_loss))
|
print("save model: the best loss is {}".format(min_loss))
|
||||||
torch.save(model.state_dict(), 'unet.pth')
|
torch.save(model.state_dict(), 'unet.pth')
|
||||||
|
|
||||||
loader = getMnistLoader()
|
if __name__ == '__main__':
|
||||||
train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE)
|
# read config file
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read('training.ini')
|
||||||
|
|
||||||
|
# start training
|
||||||
|
loader = getMnistLoader(config)
|
||||||
|
train(config)
|
||||||
9
training.ini
Normal file
9
training.ini
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
[unet]
|
||||||
|
batch_size = 256
|
||||||
|
time_emb_dim = 128
|
||||||
|
device = cuda
|
||||||
|
epoch_num = 500
|
||||||
|
learning_rate = 1e-4
|
||||||
|
|
||||||
|
[ddpm]
|
||||||
|
iteration = 500
|
||||||
Loading…
Reference in New Issue
Block a user