style: use configparser to manage config file

This commit is contained in:
snsd0805 2023-03-14 23:20:32 +08:00
parent 81779cf5e2
commit 4c4297f2bd
Signed by: snsd0805
GPG Key ID: 569349933C77A854
3 changed files with 60 additions and 21 deletions

View File

@ -4,17 +4,23 @@ from ddpm import DDPM
from unet import Unet
import sys
import os
BATCH_SIZE = 256
ITERATION = 500
TIME_EMB_DIM = 128
DEVICE = torch.device('cuda')
import configparser
if __name__ == '__main__':
if len(sys.argv) != 2:
print("Usage: python sample.py [pic_num]")
exit()
# read config file
config = configparser.ConfigParser()
config.read('training.ini')
BATCH_SIZE = int(config['unet']['batch_size'])
ITERATION = int(config['ddpm']['iteration'])
TIME_EMB_DIM = int(config['unet']['time_emb_dim'])
DEVICE = torch.device(config['unet']['device'])
# start sampling
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)

View File

@ -7,15 +7,18 @@ import matplotlib.pyplot as plt
from tqdm import tqdm
from ddpm import DDPM
from unet import Unet
import configparser
BATCH_SIZE = 256
ITERATION = 500
TIME_EMB_DIM = 128
DEVICE = torch.device('cuda')
EPOCH_NUM = 500
LEARNING_RATE = 1e-4
def getMnistLoader(config):
'''
Get MNIST dataset's loader
def getMnistLoader():
Inputs:
config (configparser.ConfigParser)
Outputs:
loader (nn.utils.data.DataLoader)
'''
BATCH_SIZE = int(config['unet']['batch_size'])
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
@ -25,18 +28,33 @@ def getMnistLoader():
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
return loader
def train(loader, device, epoch_num, lr):
model = Unet(TIME_EMB_DIM, DEVICE).to(device)
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, device)
def train(config):
'''
Start Unet Training
Inputs:
config (configparser.ConfigParser)
Outputs:
None
'''
BATCH_SIZE = int(config['unet']['batch_size'])
ITERATION = int(config['ddpm']['iteration'])
TIME_EMB_DIM = int(config['unet']['time_emb_dim'])
DEVICE = torch.device(config['unet']['device'])
EPOCH_NUM = int(config['unet']['epoch_num'])
LEARNING_RATE = float(config['unet']['learning_rate'])
# training
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
criterion = nn.MSELoss()
optimzer = torch.optim.Adam(model.parameters(), lr=lr)
optimzer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
min_loss = 99
for epoch in range(epoch_num):
for epoch in range(EPOCH_NUM):
loss_sum = 0
# progress = tqdm(total=len(loader))
for x, y in loader:
optimzer.zero_grad()
@ -51,12 +69,18 @@ def train(loader, device, epoch_num, lr):
loss.backward()
optimzer.step()
# progress.update(1)
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader)))
if loss_sum/len(loader) < min_loss:
min_loss = loss_sum/len(loader)
print("save model: the best loss is {}".format(min_loss))
torch.save(model.state_dict(), 'unet.pth')
loader = getMnistLoader()
train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE)
if __name__ == '__main__':
# read config file
config = configparser.ConfigParser()
config.read('training.ini')
# start training
loader = getMnistLoader(config)
train(config)

9
training.ini Normal file
View File

@ -0,0 +1,9 @@
[unet]
batch_size = 256
time_emb_dim = 128
device = cuda
epoch_num = 500
learning_rate = 1e-4
[ddpm]
iteration = 500