style: use configparser to manage config file

This commit is contained in:
snsd0805 2023-03-14 23:20:32 +08:00
parent 81779cf5e2
commit 4c4297f2bd
Signed by: snsd0805
GPG Key ID: 569349933C77A854
3 changed files with 60 additions and 21 deletions

View File

@ -4,17 +4,23 @@ from ddpm import DDPM
from unet import Unet from unet import Unet
import sys import sys
import os import os
import configparser
BATCH_SIZE = 256
ITERATION = 500
TIME_EMB_DIM = 128
DEVICE = torch.device('cuda')
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("Usage: python sample.py [pic_num]") print("Usage: python sample.py [pic_num]")
exit() exit()
# read config file
config = configparser.ConfigParser()
config.read('training.ini')
BATCH_SIZE = int(config['unet']['batch_size'])
ITERATION = int(config['ddpm']['iteration'])
TIME_EMB_DIM = int(config['unet']['time_emb_dim'])
DEVICE = torch.device(config['unet']['device'])
# start sampling
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE) model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE) ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE)

View File

@ -7,15 +7,18 @@ import matplotlib.pyplot as plt
from tqdm import tqdm from tqdm import tqdm
from ddpm import DDPM from ddpm import DDPM
from unet import Unet from unet import Unet
import configparser
BATCH_SIZE = 256 def getMnistLoader(config):
ITERATION = 500 '''
TIME_EMB_DIM = 128 Get MNIST dataset's loader
DEVICE = torch.device('cuda')
EPOCH_NUM = 500
LEARNING_RATE = 1e-4
def getMnistLoader(): Inputs:
config (configparser.ConfigParser)
Outputs:
loader (nn.utils.data.DataLoader)
'''
BATCH_SIZE = int(config['unet']['batch_size'])
transform = torchvision.transforms.Compose([ transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor() torchvision.transforms.ToTensor()
@ -25,18 +28,33 @@ def getMnistLoader():
loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True) loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
return loader return loader
def train(loader, device, epoch_num, lr): def train(config):
model = Unet(TIME_EMB_DIM, DEVICE).to(device) '''
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, device) Start Unet Training
Inputs:
config (configparser.ConfigParser)
Outputs:
None
'''
BATCH_SIZE = int(config['unet']['batch_size'])
ITERATION = int(config['ddpm']['iteration'])
TIME_EMB_DIM = int(config['unet']['time_emb_dim'])
DEVICE = torch.device(config['unet']['device'])
EPOCH_NUM = int(config['unet']['epoch_num'])
LEARNING_RATE = float(config['unet']['learning_rate'])
# training
model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE)
ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE)
criterion = nn.MSELoss() criterion = nn.MSELoss()
optimzer = torch.optim.Adam(model.parameters(), lr=lr) optimzer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
min_loss = 99 min_loss = 99
for epoch in range(epoch_num): for epoch in range(EPOCH_NUM):
loss_sum = 0 loss_sum = 0
# progress = tqdm(total=len(loader))
for x, y in loader: for x, y in loader:
optimzer.zero_grad() optimzer.zero_grad()
@ -51,12 +69,18 @@ def train(loader, device, epoch_num, lr):
loss.backward() loss.backward()
optimzer.step() optimzer.step()
# progress.update(1)
print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader))) print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader)))
if loss_sum/len(loader) < min_loss: if loss_sum/len(loader) < min_loss:
min_loss = loss_sum/len(loader) min_loss = loss_sum/len(loader)
print("save model: the best loss is {}".format(min_loss)) print("save model: the best loss is {}".format(min_loss))
torch.save(model.state_dict(), 'unet.pth') torch.save(model.state_dict(), 'unet.pth')
loader = getMnistLoader() if __name__ == '__main__':
train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE) # read config file
config = configparser.ConfigParser()
config.read('training.ini')
# start training
loader = getMnistLoader(config)
train(config)

9
training.ini Normal file
View File

@ -0,0 +1,9 @@
[unet]
batch_size = 256
time_emb_dim = 128
device = cuda
epoch_num = 500
learning_rate = 1e-4
[ddpm]
iteration = 500