diff --git a/train/model.py b/train/model.py new file mode 100644 index 0000000..a658c6d --- /dev/null +++ b/train/model.py @@ -0,0 +1,34 @@ +from torch import nn +import torch.nn.functional as F + +class Network(nn.Module): + def __init__(self): + super().__init__() + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding='same') + self.conv2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding='same') + self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding='same') + self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding='same') + self.fc1 = nn.Linear(2048, 1024) + self.fc2 = nn.Linear(1024, 128) + self.fc3 = nn.Linear(128, 10) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.3) + + def forward(self, x): + x = self.relu(self.conv1(x)) + x = self.relu(self.conv2(x)) + x = self.pool(x) + x = self.dropout(x) + x = self.relu(self.conv3(x)) + x = self.relu(self.conv4(x)) + x = self.pool(x) + x = self.dropout(x) + x = x.reshape((x.shape[0], -1)) + x = self.relu(self.fc1(x)) + x = self.dropout(x) + x = self.relu(self.fc2(x)) + x = self.dropout(x) + x = self.fc3(x) + return x + diff --git a/train/train.py b/train/train.py new file mode 100644 index 0000000..a8d8fdf --- /dev/null +++ b/train/train.py @@ -0,0 +1,39 @@ +import torch +from torch import optim +from torch import nn +from dataset import Cifar10Dataset +from model import Network +from torch.utils.data import DataLoader +import matplotlib.pyplot as plt + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +model = Network().to(device) +dataset = Cifar10Dataset('./dataset_dir/cifar-10-batches-py') +loader = DataLoader(dataset, batch_size=32, shuffle=True) + +optimizer = optim.Adam(model.parameters(), lr=0.001) +criterion = nn.CrossEntropyLoss() + +for epoch in range(50): + model.train() + train_loss_sum = 0 + train_correct_sum = 0 + for x, y in loader: + x = x.float() + x, y = x.to(device), y.to(device) + + predict = model(x) + loss = criterion(predict, y) + loss.backward() + + # evaluate + train_loss_sum += loss.item() + predicted_classes = torch.argmax(predict, dim=1) + train_correct_sum += (predicted_classes == y).sum() + + optimizer.step() + optimizer.zero_grad() + print(train_loss_sum / len(loader)) + print((train_correct_sum / len(dataset)).item(),'%') + print()