fix: move trainer
This commit is contained in:
parent
86e0c50a65
commit
20fd2fbe08
28
train/main.py
Normal file
28
train/main.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# python3 -m torch.distributed.run --nproc_per_node=1 --nnodes=1 --node_rank=0 --rdzv_id=21046 --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:21046 main.py
|
||||||
|
from torch import optim
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from model import Network
|
||||||
|
from dataset import Cifar10Dataset
|
||||||
|
from trainer import ddp_init, Trainer
|
||||||
|
from model import Network
|
||||||
|
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
EPOCH_NUM = 30
|
||||||
|
|
||||||
|
def main(batch_size, epoch_num):
|
||||||
|
print(f'training config: batch_size={batch_size}, epoch={epoch_num}')
|
||||||
|
ddp_init()
|
||||||
|
|
||||||
|
model = Network()
|
||||||
|
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
trainer = Trainer(model, dataset, batch_size, optimizer, criterion)
|
||||||
|
trainer.train(epoch_num)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main(BATCH_SIZE, EPOCH_NUM)
|
||||||
|
|
||||||
@ -55,35 +55,3 @@ class Trainer():
|
|||||||
print(f"[DEVICE {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
|
print(f"[DEVICE {self.global_rank}] EPOCH {epoch} loss={train_loss_sum/len(self.loader)} acc={(train_correct_sum/train_item_counter).item()}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main(batch_size, epoch_num):
|
|
||||||
print(f'training with {world_size} GPUs')
|
|
||||||
print(f'training config: batch_size={batch_size}, epoch={epoch_num}')
|
|
||||||
print(f'LOCAL_RANK={os.environ["LOCAL_RANK"]}')
|
|
||||||
print(f'RANK={os.environ["RANK"]}')
|
|
||||||
print(f'LOCAL_WORLD_SIZE={os.environ["LOCAL_WORLD_SIZE"]}')
|
|
||||||
print(f'WORLD_SIZE={os.environ["WORLD_SIZE"]}')
|
|
||||||
print(f'MASTER_ADDR={os.environ["MASTER_ADDR"]}')
|
|
||||||
print(f'MASTER_PORT={os.environ["MASTER_PORT"]}')
|
|
||||||
ddp_init()
|
|
||||||
|
|
||||||
model = Network()
|
|
||||||
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
trainer = Trainer(model, dataset, batch_size, optimizer, criterion)
|
|
||||||
trainer.train(epoch_num)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--batch_size', type=int, default=32, help="batch size for training")
|
|
||||||
parser.add_argument('--epoch_num', type=int, default=50, help="training epoch")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
world_size = torch.cuda.device_count()
|
|
||||||
main(args.batch_size, args.epoch_num)
|
|
||||||
# mp.spawn(main, args=(world_size, args.batch_size, args.epoch_num), nprocs=world_size)
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user