fix: move trainer
This commit is contained in:
parent
86e0c50a65
commit
20fd2fbe08
28
train/main.py
Normal file
28
train/main.py
Normal file
@ -0,0 +1,28 @@
|
||||
# python3 -m torch.distributed.run --nproc_per_node=1 --nnodes=1 --node_rank=0 --rdzv_id=21046 --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:21046 main.py
|
||||
from torch import optim
|
||||
from torch import nn
|
||||
|
||||
from model import Network
|
||||
from dataset import Cifar10Dataset
|
||||
from trainer import ddp_init, Trainer
|
||||
from model import Network
|
||||
|
||||
BATCH_SIZE = 64
|
||||
EPOCH_NUM = 30
|
||||
|
||||
def main(batch_size, epoch_num):
|
||||
print(f'training config: batch_size={batch_size}, epoch={epoch_num}')
|
||||
ddp_init()
|
||||
|
||||
model = Network()
|
||||
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
trainer = Trainer(model, dataset, batch_size, optimizer, criterion)
|
||||
trainer.train(epoch_num)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(BATCH_SIZE, EPOCH_NUM)
|
||||
|
||||
@ -55,35 +55,3 @@ class Trainer():
|
||||
print(f"[DEVICE {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
|
||||
|
||||
|
||||
|
||||
def main(batch_size, epoch_num):
|
||||
print(f'training with {world_size} GPUs')
|
||||
print(f'training config: batch_size={batch_size}, epoch={epoch_num}')
|
||||
print(f'LOCAL_RANK={os.environ["LOCAL_RANK"]}')
|
||||
print(f'RANK={os.environ["RANK"]}')
|
||||
print(f'LOCAL_WORLD_SIZE={os.environ["LOCAL_WORLD_SIZE"]}')
|
||||
print(f'WORLD_SIZE={os.environ["WORLD_SIZE"]}')
|
||||
print(f'MASTER_ADDR={os.environ["MASTER_ADDR"]}')
|
||||
print(f'MASTER_PORT={os.environ["MASTER_PORT"]}')
|
||||
ddp_init()
|
||||
|
||||
model = Network()
|
||||
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
trainer = Trainer(model, dataset, batch_size, optimizer, criterion)
|
||||
trainer.train(epoch_num)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--batch_size', type=int, default=32, help="batch size for training")
|
||||
parser.add_argument('--epoch_num', type=int, default=50, help="training epoch")
|
||||
args = parser.parse_args()
|
||||
|
||||
world_size = torch.cuda.device_count()
|
||||
main(args.batch_size, args.epoch_num)
|
||||
# mp.spawn(main, args=(world_size, args.batch_size, args.epoch_num), nprocs=world_size)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user