Distributed-Training-Example/train/train.py

40 lines
1.1 KiB
Python

import torch
from torch import optim
from torch import nn
from dataset import Cifar10Dataset
from model import Network
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Network().to(device)
dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py')
loader = DataLoader(dataset, batch_size=32, shuffle=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(50):
model.train()
train_loss_sum = 0
train_correct_sum = 0
for x, y in loader:
x = x.float()
x, y = x.to(device), y.to(device)
predict = model(x)
loss = criterion(predict, y)
loss.backward()
# evaluate
train_loss_sum += loss.item()
predicted_classes = torch.argmax(predict, dim=1)
train_correct_sum += (predicted_classes == y).sum()
optimizer.step()
optimizer.zero_grad()
print(train_loss_sum / len(loader))
print((train_correct_sum / len(dataset)).item(),'%')
print()