From 4c4297f2bd9b3cbf2f5da63b2827db5368f1e3ef Mon Sep 17 00:00:00 2001 From: snsd0805 Date: Tue, 14 Mar 2023 23:20:32 +0800 Subject: [PATCH] style: use configparser to manage config file --- sample.py | 16 ++++++++++----- train.py | 56 +++++++++++++++++++++++++++++++++++++--------------- training.ini | 9 +++++++++ 3 files changed, 60 insertions(+), 21 deletions(-) create mode 100644 training.ini diff --git a/sample.py b/sample.py index fb03388..7e3bb01 100644 --- a/sample.py +++ b/sample.py @@ -4,17 +4,23 @@ from ddpm import DDPM from unet import Unet import sys import os - -BATCH_SIZE = 256 -ITERATION = 500 -TIME_EMB_DIM = 128 -DEVICE = torch.device('cuda') +import configparser if __name__ == '__main__': if len(sys.argv) != 2: print("Usage: python sample.py [pic_num]") exit() + + # read config file + config = configparser.ConfigParser() + config.read('training.ini') + BATCH_SIZE = int(config['unet']['batch_size']) + ITERATION = int(config['ddpm']['iteration']) + TIME_EMB_DIM = int(config['unet']['time_emb_dim']) + DEVICE = torch.device(config['unet']['device']) + + # start sampling model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE) ddpm = DDPM(int(sys.argv[1]), ITERATION, 1e-4, 2e-2, DEVICE) diff --git a/train.py b/train.py index ba10994..94ac4ae 100644 --- a/train.py +++ b/train.py @@ -7,15 +7,18 @@ import matplotlib.pyplot as plt from tqdm import tqdm from ddpm import DDPM from unet import Unet +import configparser -BATCH_SIZE = 256 -ITERATION = 500 -TIME_EMB_DIM = 128 -DEVICE = torch.device('cuda') -EPOCH_NUM = 500 -LEARNING_RATE = 1e-4 +def getMnistLoader(config): + ''' + Get MNIST dataset's loader -def getMnistLoader(): + Inputs: + config (configparser.ConfigParser) + Outputs: + loader (nn.utils.data.DataLoader) + ''' + BATCH_SIZE = int(config['unet']['batch_size']) transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor() @@ -25,18 +28,33 @@ def getMnistLoader(): loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True) return loader -def train(loader, device, epoch_num, lr): - model = Unet(TIME_EMB_DIM, DEVICE).to(device) - ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, device) +def train(config): + ''' + Start Unet Training + + Inputs: + config (configparser.ConfigParser) + Outputs: + None + ''' + BATCH_SIZE = int(config['unet']['batch_size']) + ITERATION = int(config['ddpm']['iteration']) + TIME_EMB_DIM = int(config['unet']['time_emb_dim']) + DEVICE = torch.device(config['unet']['device']) + EPOCH_NUM = int(config['unet']['epoch_num']) + LEARNING_RATE = float(config['unet']['learning_rate']) + + # training + model = Unet(TIME_EMB_DIM, DEVICE).to(DEVICE) + ddpm = DDPM(BATCH_SIZE, ITERATION, 1e-4, 2e-2, DEVICE) criterion = nn.MSELoss() - optimzer = torch.optim.Adam(model.parameters(), lr=lr) + optimzer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) min_loss = 99 - for epoch in range(epoch_num): + for epoch in range(EPOCH_NUM): loss_sum = 0 - # progress = tqdm(total=len(loader)) for x, y in loader: optimzer.zero_grad() @@ -51,12 +69,18 @@ def train(loader, device, epoch_num, lr): loss.backward() optimzer.step() - # progress.update(1) + print("Epoch {}/{}: With lr={}, batch_size={}, iteration={}. The best loss: {} - loss: {}".format(epoch, EPOCH_NUM, LEARNING_RATE, BATCH_SIZE, ITERATION, min_loss, loss_sum/len(loader))) if loss_sum/len(loader) < min_loss: min_loss = loss_sum/len(loader) print("save model: the best loss is {}".format(min_loss)) torch.save(model.state_dict(), 'unet.pth') -loader = getMnistLoader() -train(loader, DEVICE, EPOCH_NUM, LEARNING_RATE) \ No newline at end of file +if __name__ == '__main__': + # read config file + config = configparser.ConfigParser() + config.read('training.ini') + + # start training + loader = getMnistLoader(config) + train(config) \ No newline at end of file diff --git a/training.ini b/training.ini new file mode 100644 index 0000000..f9e7532 --- /dev/null +++ b/training.ini @@ -0,0 +1,9 @@ +[unet] +batch_size = 256 +time_emb_dim = 128 +device = cuda +epoch_num = 500 +learning_rate = 1e-4 + +[ddpm] +iteration = 500