Udemy-PyTorch/CIFAR-10 - CNN.ipynb
2023-01-30 00:30:21 +08:00

366 lines
24 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "da7346f9",
"metadata": {},
"source": [
"> This notebook's training doesn't have validation data.\n",
"> So this accuracy is not true"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "e8d6eed6",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torchvision.datasets import CIFAR10\n",
"import torchvision.transforms as transforms\n",
"import matplotlib.pyplot as plt\n",
"from torchinfo import summary\n",
"from datetime import datetime"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "4ffcbf20",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Use GPU: NVIDIA GeForce RTX 3060\n"
]
}
],
"source": [
"if torch.cuda.is_available():\n",
" device = torch.device('cuda:0')\n",
" print(\"Use GPU:\", torch.cuda.get_device_name(device))\n",
"else:\n",
" device = torch.device('cpu')\n",
" print(\"Use CPU\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "845643e3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
}
],
"source": [
"# load train data\n",
"\n",
"transform = transforms.Compose([transforms.ToTensor(), ])\n",
"train = CIFAR10(\"./data\", train=True, transform=transform, download=True)\n",
"test = CIFAR10(\"./data\", train=False, transform=transform, download=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a7bdab59",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"50000\n"
]
}
],
"source": [
"# first, we should find mean & std in 3 channels\n",
"# becouse we need use transforms.Normalize() to normalize features\n",
"\n",
"# get pictures's tensor\n",
"pics = []\n",
"for i in train:\n",
" pics.append(i[0].unsqueeze(0))\n",
"print( len(pics) )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "6225f61b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([50000, 3, 32, 32])\n"
]
}
],
"source": [
"pics = torch.cat(pics, dim=0)\n",
"print(pics.shape)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ddf244b0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([50000, 3, 1024])\n",
"We get mean: [0.4913996756076813, 0.48215845227241516, 0.44653093814849854]\n",
"We get std: [0.20230092108249664, 0.19941280782222748, 0.2009616196155548]\n"
]
}
],
"source": [
"testPics = pics.view(50000, 3, -1)\n",
"print( testPics.shape )\n",
"\n",
"# count mean & std\n",
"feature_mean = ( testPics.mean(2).sum(0)/50000 ).tolist()\n",
"feature_std = ( testPics.std(2).sum(0)/50000 ).tolist()\n",
"\n",
"print(\"We get mean: {}\\nWe get std: {}\".format(feature_mean, feature_std))\n",
"\n",
"# On the internet, many people suggest to use ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) to normalization"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5c394c14",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# actual data\n",
"\n",
"transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(feature_mean, feature_std)])\n",
"train = CIFAR10(\"./data\", train=True, transform=transform, download=True)\n",
"\n",
"# show 1 picture, it's a frog.\n",
"feature, label = train[0]\n",
"\n",
"feature = feature.permute(1, 2, 0)\n",
"plt.imshow(feature)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "ed317276",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"CIFAR10_ClassificationModel [128, 10] --\n",
"├─Conv2d: 1-1 [128, 64, 32, 32] 1,792\n",
"├─ReLU: 1-2 [128, 64, 32, 32] --\n",
"├─Conv2d: 1-3 [128, 128, 30, 30] 73,856\n",
"├─ReLU: 1-4 [128, 128, 30, 30] --\n",
"├─Conv2d: 1-5 [128, 256, 28, 28] 295,168\n",
"├─ReLU: 1-6 [128, 256, 28, 28] --\n",
"├─AvgPool2d: 1-7 [128, 256, 14, 14] --\n",
"├─Conv2d: 1-8 [128, 256, 14, 14] 590,080\n",
"├─ReLU: 1-9 [128, 256, 14, 14] --\n",
"├─Conv2d: 1-10 [128, 256, 12, 12] 590,080\n",
"├─ReLU: 1-11 [128, 256, 12, 12] --\n",
"├─AvgPool2d: 1-12 [128, 256, 6, 6] --\n",
"├─Linear: 1-13 [128, 512] 4,719,104\n",
"├─ReLU: 1-14 [128, 512] --\n",
"├─Linear: 1-15 [128, 128] 65,664\n",
"├─ReLU: 1-16 [128, 128] --\n",
"├─Linear: 1-17 [128, 10] 1,290\n",
"==========================================================================================\n",
"Total params: 6,337,034\n",
"Trainable params: 6,337,034\n",
"Non-trainable params: 0\n",
"Total mult-adds (G): 64.66\n",
"==========================================================================================\n",
"Input size (MB): 1.57\n",
"Forward/backward pass size (MB): 480.39\n",
"Params size (MB): 25.35\n",
"Estimated Total Size (MB): 507.31\n",
"=========================================================================================="
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"BATCH_SIZE = 128\n",
"\n",
"class CIFAR10_ClassificationModel(nn.Module):\n",
" def __init__(self):\n",
" super(CIFAR10_ClassificationModel, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)\n",
" self.conv2 = nn.Conv2d(64, 128, 3, stride=1, padding=0)\n",
" self.conv3 = nn.Conv2d(128, 256, 3, stride=1, padding=0)\n",
"# self.conv4 = nn.Conv2d(256, 256, 3, stride=1, padding=0)\n",
"\n",
" self.conv5 = nn.Conv2d(256, 256, 3, stride=1, padding=1)\n",
" self.conv6 = nn.Conv2d(256, 256, 3, stride=1, padding=0)\n",
"\n",
" self.pool = nn.AvgPool2d(2,stride=2)\n",
" \n",
" self.fc1 = nn.Linear(9216, 512)\n",
" self.fc2 = nn.Linear(512, 128)\n",
" self.fc3 = nn.Linear(128, 10)\n",
" self.relu = nn.ReLU()\n",
" \n",
" def forward(self, x):\n",
" batch = x.shape[0]\n",
" x = self.relu( self.conv1(x) )\n",
" x = self.relu( self.conv2(x) )\n",
" x = self.relu( self.conv3(x) )\n",
"\n",
" x = self.pool(x)\n",
" x = self.relu( self.conv5(x) )\n",
" x = self.relu( self.conv6(x) )\n",
" x = self.pool(x)\n",
" x = x.view(batch, -1)\n",
" x = self.relu( self.fc1(x) )\n",
" x = self.relu( self.fc2(x) )\n",
" x = self.fc3(x)\n",
" return x\n",
"\n",
"model = CIFAR10_ClassificationModel()\n",
"summary(model, input_size=(BATCH_SIZE, 3, 32, 32))"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "cbb17733",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Current Time = 00:57:57\n",
"Epoch 0: 1.5633342016078626, 42.9739990234375%\n",
"Epoch 1: 1.030615347120768, 63.2859992980957%\n",
"Epoch 2: 0.7714244136420052, 72.7699966430664%\n",
"Epoch 3: 0.6027741941344708, 78.95999908447266%\n",
"Epoch 4: 0.45929622848320495, 84.03199768066406%\n",
"Current Time = 00:59:48\n"
]
}
],
"source": [
"now = datetime.now()\n",
"current_time = now.strftime(\"%H:%M:%S\")\n",
"print(\"Current Time =\", current_time)\n",
"\n",
"loader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)\n",
"model = CIFAR10_ClassificationModel().to(device)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
"for epoch in range(5):\n",
" loss_sum = 0\n",
" corr_sum = 0\n",
" for x, y in loader:\n",
" features, labels = x.to(device), y.to(device)\n",
"\n",
" # clear gradient in optimizer\n",
" optimizer.zero_grad()\n",
"\n",
" # predict\n",
" predicts = model(features)\n",
"\n",
" # get loss\n",
" loss = criterion(predicts, labels)\n",
" loss.backward()\n",
" \n",
" # back propagation\n",
" optimizer.step()\n",
" \n",
" # loss & acc\n",
" predicts = torch.argmax(predicts, 1)\n",
" loss_sum += loss.item()\n",
" corr_sum += (predicts==labels).sum()\n",
" print(\"Epoch {}: {}, {}%\".format(epoch, loss_sum/len(loader), corr_sum/len(train)*100))\n",
"\n",
"now = datetime.now()\n",
"current_time = now.strftime(\"%H:%M:%S\")\n",
"print(\"Current Time =\", current_time)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}