Udemy-PyTorch/CIFAR-10 - CNN.ipynb
2023-01-30 00:30:21 +08:00

366 lines
24 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "da7346f9",
"metadata": {},
"source": [
"> This notebook's training doesn't have validation data.\n",
"> So this accuracy is not true"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "e8d6eed6",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torchvision.datasets import CIFAR10\n",
"import torchvision.transforms as transforms\n",
"import matplotlib.pyplot as plt\n",
"from torchinfo import summary\n",
"from datetime import datetime"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "4ffcbf20",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Use GPU: NVIDIA GeForce RTX 3060\n"
]
}
],
"source": [
"if torch.cuda.is_available():\n",
" device = torch.device('cuda:0')\n",
" print(\"Use GPU:\", torch.cuda.get_device_name(device))\n",
"else:\n",
" device = torch.device('cpu')\n",
" print(\"Use CPU\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "845643e3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
}
],
"source": [
"# load train data\n",
"\n",
"transform = transforms.Compose([transforms.ToTensor(), ])\n",
"train = CIFAR10(\"./data\", train=True, transform=transform, download=True)\n",
"test = CIFAR10(\"./data\", train=False, transform=transform, download=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a7bdab59",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"50000\n"
]
}
],
"source": [
"# first, we should find mean & std in 3 channels\n",
"# becouse we need use transforms.Normalize() to normalize features\n",
"\n",
"# get pictures's tensor\n",
"pics = []\n",
"for i in train:\n",
" pics.append(i[0].unsqueeze(0))\n",
"print( len(pics) )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "6225f61b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([50000, 3, 32, 32])\n"
]
}
],
"source": [
"pics = torch.cat(pics, dim=0)\n",
"print(pics.shape)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ddf244b0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([50000, 3, 1024])\n",
"We get mean: [0.4913996756076813, 0.48215845227241516, 0.44653093814849854]\n",
"We get std: [0.20230092108249664, 0.19941280782222748, 0.2009616196155548]\n"
]
}
],
"source": [
"testPics = pics.view(50000, 3, -1)\n",
"print( testPics.shape )\n",
"\n",
"# count mean & std\n",
"feature_mean = ( testPics.mean(2).sum(0)/50000 ).tolist()\n",
"feature_std = ( testPics.std(2).sum(0)/50000 ).tolist()\n",
"\n",
"print(\"We get mean: {}\\nWe get std: {}\".format(feature_mean, feature_std))\n",
"\n",
"# On the internet, many people suggest to use ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) to normalization"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5c394c14",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAnN0lEQVR4nO3df3TU9Z3v8VegZABJhoZAfiwJG0ChFkgrlZirpSiRH/ZwQGiPv3YLloWFBreaWjV76y/ably6q6gXw7m3FOo9IlZX4Oq9xR9owmoDLlEKaM2VbNqEQxIKPcyEYEJMvvcPrtOOJPJ9hxk+mfB8nDPnSPLinc/kG3g5ZPKeJM/zPAEAcIENcH0AAMDFiQICADhBAQEAnKCAAABOUEAAACcoIACAExQQAMAJCggA4MQXXB/gs7q6unTkyBGlpKQoKSnJ9XEAAEae56mlpUXZ2dkaMKDnxzl9roCOHDminJwc18cAAJynhoYGjR49usf3x62A1q1bp5/97GdqampSfn6+nnzySU2bNu2cvy8lJSVeR0po3zDme77kZxthnG31giF7xDh7niE7xDi7zpA9ZJw93pgPG7I1xtnWry2Lw4ZsqnG2JX+NcfbXLrXlMzP9Z9+3fGFJ+mSQIWz8Im8wnOVPHf6zpz1pQ+e5/z6PSwE999xzKikp0fr161VQUKC1a9dq9uzZqqmp0ahRoz739/LPbt2zXqhkQzZgnG0Vz280Wv5sWj4nku1zbv2qtV7Pgca8RTz/GcRybut9tJx7sHH2JcbDDDMcZqjxD0SHJW8892DDF26gF381n+vv87j83fDoo49q2bJluv3223X55Zdr/fr1Gjp0qH7xi1/E48MBABJQzAvo9OnTqq6uVlFR0Z8/yIABKioqUlVV1Vn59vZ2hcPhqBsAoP+LeQEdO3ZMnZ2dysjIiHp7RkaGmpqazsqXlZUpGAxGbjwBAQAuDs5/Dqi0tFShUChya2hocH0kAMAFEPPvP6anp2vgwIFqbm6Oentzc7Myu3mqSCAQUCAQ72+DAwD6mpg/AkpOTtbUqVO1c+fOyNu6urq0c+dOFRYWxvrDAQASVFyegVlSUqLFixfra1/7mqZNm6a1a9eqtbVVt99+ezw+HAAgAcWlgG666Sb98Y9/1AMPPKCmpiZ95Stf0Y4dO856YgIA4OKV5Hme5/oQfykcDisYDLo+Rp9j3Q9h+KFlU1aSOo159B/WH+hsM2Qt2zsk25YFK+tZLD8Qbcla/d84zu6NUCik1NSed1Y4fxYcAODiRAEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyI58vB4xzSDFnr+o7mc0cuOtZ1Ri2GrHXLYciYt6y0iSfrkqx0Q3aCcfYxQ9a6bsoyW7J9Xq42zn7bmE8kPAICADhBAQEAnKCAAABOUEAAACcoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnKCAAABOsAvOoT+5PsBF5hZjvsKQrTfOjudut8HGvOUsQ42zMw1Z6348y565w3GcLdk+L7XG2f0Zj4AAAE5QQAAAJyggAIATFBAAwAkKCADgBAUEAHCCAgIAOEEBAQCcoIAAAE5QQAAAJ1jFcw6WtSbxXK+SyNIM2XiuJ/rvcZzdl8Tz63CIMR80ZA8aZ3/VkB1hnF1jzDcb8ziDR0AAACcoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnKCAAABOUEAAACcoIACAExQQAMAJdsGdw1BDdpJxdpMhe8o4O5471az60llwfj6Ic97isCE7xzh7kDG/15jHGTwCAgA4EfMCeuihh5SUlBR1mzhxYqw/DAAgwcXln+C+/OUv6/XXX//zB/kC/9IHAIgWl2b4whe+oMzMzHiMBgD0E3H5HtBHH32k7OxsjR07Vrfddpvq6+t7zLa3tyscDkfdAAD9X8wLqKCgQJs2bdKOHTtUXl6uuro6ff3rX1dLS0u3+bKyMgWDwcgtJycn1kcCAPRBSZ7nefH8ACdOnNCYMWP06KOPaunSpWe9v729Xe3t7ZFfh8PhPlVClpeTHmucfbE8DRtwzfo07GPGPE/D7l4oFFJqamqP74/7swOGDx+uyy67TIcOHer2/YFAQIFAIN7HAAD0MXH/OaCTJ0+qtrZWWVlZ8f5QAIAEEvMCuvvuu1VZWanf//73+s1vfqMbb7xRAwcO1C233BLrDwUASGAx/ye4w4cP65ZbbtHx48c1cuRIXXPNNdq9e7dGjhwZ6w91QVhXcsRrdjy/pzPYmJ9mzO8y5nFx+poxb/m+S4dx9jeN+f80ZPn+7J/FvIC2bNkS65EAgH6IXXAAACcoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnKCAAABOUEAAACcoIACAE3F/OYa+JsOY/6oha9kHJUmNhux042zL7qsG42zL6xhJ0uWG7AfG2YnK/CJc6f6jSdYXs+kj4vmaOvH+lEwwZKvidgq7gYZsZxw+Po+AAABOUEAAACcoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnKCAAABOUEAAACf67CqeK+X/cAcNc4caz/G2IdtinG3xnjFvuZ/NxtmFxrx11Q+6Ydgl87+No79pzCei37o+wF9IM+ZPGbJtxtmW9TqWlVqdkmp85HgEBABwggICADhBAQEAnKCAAABOUEAAACcoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnOizu+AWSBrsM/stw1zrLrjthuwO42wLyz4oSeowZC8zzm4y5g8b8xeDnxjzxw3ZWuNsnG1QHGdnGvOfGLKWP/eSacWg6Rx+d8zxCAgA4AQFBABwggICADhBAQEAnKCAAABOUEAAACcoIACAExQQAMAJCggA4AQFBABwggICADjRZ3fBWcRzb9M4Q9a6U22EIWvdYddoyFruoyS9ZMwnqkJD1np9LLvdJOlVQzbdOPtyQzbXOPvfDL/hb+tts1+0xU2sO9Us19P6tWLZHWfdGWnJW8oiyWeOR0AAACfMBbRr1y7NmzdP2dnZSkpK0rZt26Le73meHnjgAWVlZWnIkCEqKirSRx99FKvzAgD6CXMBtba2Kj8/X+vWrev2/WvWrNETTzyh9evXa8+ePbrkkks0e/ZstbW1nfdhAQD9h/l7QHPnztXcuXO7fZ/neVq7dq1+9KMfaf78+ZKkp59+WhkZGdq2bZtuvvnm8zstAKDfiOn3gOrq6tTU1KSioqLI24LBoAoKClRVVdXt72lvb1c4HI66AQD6v5gWUFPTmdfKzMjIiHp7RkZG5H2fVVZWpmAwGLnl5OTE8kgAgD7K+bPgSktLFQqFIreGhgbXRwIAXAAxLaDMzDPPWG9ubo56e3Nzc+R9nxUIBJSamhp1AwD0fzEtoLy8PGVmZmrnzp2Rt4XDYe3Zs0eFhZYf6wMA9HfmZ8GdPHlShw4divy6rq5O+/btU1pamnJzc3XnnXfqJz/5iS699FLl5eXp/vvvV3Z2thYsWBDLcwMAEpy5gPbu3atrr7028uuSkhJJ0uLFi7Vp0ybdc889am1t1fLly3XixAldc8012rFjhwYPHmz6OB9ISvaZjecqnj2GrHGTiI4Zsn8yzrawroWJp+XG/LuG7D8ZZ1ueDvOCcfZkY97iMWN+uyE7v8w4/Nv+7+m/lR8wjU76V+NZDLKMect6ne6fjtUzy7oc63OILef+xJDt9JkzF9CMGTPkeV6P709KStLq1au1evVq62gAwEXE+bPgAAAXJwoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOBEkvd5e3UcCIfDCgaDmif/O946DPMtWUlKN2SDxtnrjPlENd2QrbTuGis1ZK2zF0/0ny3/0DZ7gS2uU4aNh0ONX4nlhq2E37SNVq3/6Pa7baP/3pBtPnckymhj3vL3inV3peUFaqyzLXnLnrlOnbn0oVDoc19ih0dAAAAnKCAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBN9dhXP+kulIQP9/Z5jhi0obxvPY1lV8Y5xdp0xn6jeMmSvtn41zjZkVxpnL7AcZqtx+Hpj3rLqp9442yBkzJ8yZA1reyQp9GP/2YpXbbOftMXVaMjWGGdnGbLTjLMtXylDDdlPJP1GrOIBAPRRFBAAwAkKCADgBAUEAHCCAgIAOEEBAQCcoIAAAE5QQAAAJyggAIATFBAAwAkKCADgRJ/dBXebpGSfvyfdMN+ymspqjzG/Ny6niL80Y/7444bwMdvs7xr2gf3i38fZhl9zyJaPK8sfU+tX+S8N2Sbj7FsMWeP18f03hCTtN01en5Rvyt9jyFr2S0q2HWyWrGTbYWeZ3SXpj2IXHACgj6KAAABOUEAAACcoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnKCAAABOfMH1AXqSKingMxsyzO0wnsOyfiJsnJ2ojv/A+Bv+wf8amW8kJZlG32sJ19eaZuvD7/nPTnzKNtvM8nm5xDjbcD/NjhiyltU6VlNM6aGTbdNbDvjPDrSN1p+MeQvLWVri8PF5BAQAcIICAgA4YS6gXbt2ad68ecrOzlZSUpK2bdsW9f4lS5YoKSkp6jZnzpxYnRcA0E+YC6i1tVX5+flat25dj5k5c+aosbExcnv22WfP65AAgP7H/CSEuXPnau7cuZ+bCQQCyszM7PWhAAD9X1y+B1RRUaFRo0ZpwoQJWrlypY4fP95jtr29XeFwOOoGAOj/Yl5Ac+bM0dNPP62dO3fqn//5n1VZWam5c+eqs7Oz23xZWZmCwWDklpOTE+sjAQD6oJj/HNDNN98c+e/JkydrypQpGjdunCoqKjRz5syz8qWlpSopKYn8OhwOU0IAcBGI+9Owx44dq/T0dB06dKjb9wcCAaWmpkbdAAD9X9wL6PDhwzp+/LiysrLi/aEAAAnE/E9wJ0+ejHo0U1dXp3379iktLU1paWl6+OGHtWjRImVmZqq2tlb33HOPxo8fr9mzZ8f04ACAxGYuoL179+raa6+N/PrT798sXrxY5eXl2r9/v375y1/qxIkTys7O1qxZs/TjH/9YgYDfzW5nDPr/Nz8su+D8zvxUuiFr3TPXV7xq/Q3/4n+32xm/850cZ5x8wxWGsPXiT1xh/A0XgzdM6f86++zv+/bkp69UGc9ylTHvX5bl60rS7Y3+s6csCyYlher9Z4/ZRmuvMR9r5gKaMWOGPK/nv4BeeeWV8zoQAODiwC44AIATFBAAwAkKCADgBAUEAHCCAgIAOEEBAQCcoIAAAE5QQAAAJyggAIATFBAAwImYvx5QrHwsqfuXsDvbAcPcXOM5LLvg+tILSfyNIXu9tzpu55CkU6X5vrP/zTq8+puG8IvG4cnGfB+xp9SWb3zXf/baiabR3/07Q7i20DRb4yw7Cf9oGt1h+UtF0iDDEjbL7krJtmNysnG25e/D1wxZT9JJHzkeAQEAnKCAAABOUEAAACcoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnKCAAABO9NlVPKMkDfaZ/cAw15KVpB2G7EDj7Hj6n//jPkP6/ridQ5Ia3/S/TGScdZeIXrb+Bv8a/8V/9sA7ptHvvLrNlN/wr/4/h0HTZGnCIP/Zpc+8apo97ttlhvQ002ybkaZ0+oe26YZlRjplG63jhuzbxtlfNWRLsvxn27ukR5rPneMREADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcKLP7oL7vaRkn1m/O+MkybD2SpJkWH8k/9u6zkg1ZP/XPxiH/51lB1d8BYf6z364xzZ7okr9hxtDptlJ2eW+s/mmyVKNMX+vIXt9um323x7zn21YZpv90Lctf8VcZxseR+8aF7btjc8x4q7KkA02+s9+4jPHIyAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADAiT67iudj+V/nYGFdxWPYUmKePcGQzX38qHF6HHW8Yor/Z63/7JvGo0y8+xH/YeOupBRD9hrbaM0y5v9+ov/sTz+0zbZ8jb9r22akY4/80Hc2/b5bbcNDhq+WoO3iH7P+Ybbu4UpAlq+TTp85HgEBAJwwFVBZWZmuvPJKpaSkaNSoUVqwYIFqaqLXKra1tam4uFgjRozQsGHDtGjRIjU3N8f00ACAxGcqoMrKShUXF2v37t167bXX1NHRoVmzZqm1tTWSueuuu/TSSy/p+eefV2VlpY4cOaKFCxfG/OAAgMRm+h7Qjh07on69adMmjRo1StXV1Zo+fbpCoZA2bNigzZs367rrzqxW37hxo770pS9p9+7duuqqq2J3cgBAQjuv7wGFQme+I5mWliZJqq6uVkdHh4qKiiKZiRMnKjc3V1VV3b/yRHt7u8LhcNQNAND/9bqAurq6dOedd+rqq6/WpEmTJElNTU1KTk7W8OHDo7IZGRlqamrqdk5ZWZmCwWDklpOT09sjAQASSK8LqLi4WAcPHtSWLVvO6wClpaUKhUKRW0NDw3nNAwAkhl79HNCqVav08ssva9euXRo9enTk7ZmZmTp9+rROnDgR9SioublZmZmZ3c4KBAIKBAK9OQYAIIGZHgF5nqdVq1Zp69ateuONN5SXlxf1/qlTp2rQoEHauXNn5G01NTWqr69XYWFhbE4MAOgXTI+AiouLtXnzZm3fvl0pKSmR7+sEg0ENGTJEwWBQS5cuVUlJidLS0pSamqo77rhDhYWFPAMOABDFVEDl5eWSpBkzZkS9fePGjVqyZIkk6bHHHtOAAQO0aNEitbe3a/bs2XrqqadiclgAQP+R5Hme5/oQfykcDisYDGqh/O9We84wf/S5I1Esq69ajLMNW8zU/XMIe3bLrem+s9PuX2sb/u7/McXrn9jsO9uwx3aUq+83hLNss3/yPf/Zg7bRmmTMW1aTvWOcbVljVm+cbTm3dZ+eYcWgmXHlnXbF5RR2aca85euw8t/9Z8OtUnDOmR/VSU1N7THHLjgAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADAiV69HMOFcEr+13hY1utM6MVZ/LKsNJGkjw3ZDcbZb20+5jtb2fg3ptmnjGfJ/bb/tUC5q4fahof8L4c5+IBt9CeGrPHUsr7qVdCQnWacfTxO57Cyrr+xfM5fM87+kzFv+sRMNs42/IH707u20VdcawhbdiX5fGFrHgEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAn+uwuuKHyvwvOst/Nuq+twJC17uB6z5D9jnG2/+1rUv2bttmWHXaSlJ7lfy+dfnCfbXi9/11wk35nW8L1dFKp/2OYJktZxvwBQ3aWcXaOIZtqnP22IbvDOLsvyVvgP1tn/UtoszFvMM2yCy4OeAQEAHCCAgIAOEEBAQCcoIAAAE5QQAAAJyggAIATFBAAwAkKCADgBAUEAHCCAgIAONFnV/GMlhTwmbWstLGuEjllyFrOIdnOMtY427LqJXeibfZQ466Xd8r9Z6eVNdqG5z5jCP/KNHrN/sX+w+W/NM3+e8PnRJLeMWQta5gk6QZD1u96rE8l6nqdgcb8oA8N4T3G4QZzjPlbFvjPPjrFf7at01+OR0AAACcoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnKCAAABOUEAAACcoIACAExQQAMCJJM/zPNeH+EvhcFjBYFA1GVKKz3q8x7A+7IDxPJbdVx3G2ZadXdcbZ19hyAaNs6dda8u/86b/7HbbaP10tf/s+gdss4casleMs81+utaWf9WQbbKNNt3PY8bZlj8TE4yzLSsMrecOGfMHDdk24+x8Q3bfy8bh3/QfvTLJf7ZTZ3ZjhkIhpab2vPWSR0AAACdMBVRWVqYrr7xSKSkpGjVqlBYsWKCampqozIwZM5SUlBR1W7FiRUwPDQBIfKYCqqysVHFxsXbv3q3XXntNHR0dmjVrllpbW6Nyy5YtU2NjY+S2Zs2amB4aAJD4TK8HtGNH9Kt7bNq0SaNGjVJ1dbWmT58eefvQoUOVmZkZmxMCAPql8/oeUCh05lt1aWlpUW9/5plnlJ6erkmTJqm0tFSnTvX8sm7t7e0Kh8NRNwBA/9frV0Tt6urSnXfeqauvvlqTJk2KvP3WW2/VmDFjlJ2drf379+vee+9VTU2NXnzxxW7nlJWV6eGHH+7tMQAACarXBVRcXKyDBw/qrbfeinr78uXLI/89efJkZWVlaebMmaqtrdW4cWc/V7W0tFQlJSWRX4fDYeXk5PT2WACABNGrAlq1apVefvll7dq1S6NHj/7cbEFBgSTp0KFD3RZQIBBQIBDozTEAAAnMVECe5+mOO+7Q1q1bVVFRoby8vHP+nn379kmSsrKyenVAAED/ZCqg4uJibd68Wdu3b1dKSoqams78zHUwGNSQIUNUW1urzZs364YbbtCIESO0f/9+3XXXXZo+fbqmTJkSlzsAAEhMpgIqLy+XdOaHTf/Sxo0btWTJEiUnJ+v111/X2rVr1draqpycHC1atEg/+tGPYnZgAED/0Gd3wYU2Sal+l1Td7X/+Q/W28zxqyLbYRutyQ9ay98qa73lTU/esu8YMq+BUc+5IlIcMWeMKO9O1rzDO7vkHE7pXcu5IhPUbuy8Yspb9hZI035DNNc42rtMzeceYrzBkLbv3JKnC8Ic593fG4b/0H01aYpwtdsEBAPooCggA4AQFBABwggICADhBAQEAnKCAAABOUEAAACcoIACAExQQAMAJCggA4ESvXw8o7m6U/x0x5f7Hfsu4isdigzFvWTtjXYFiuZvWtTCNxvxvDdnPf3GPs1k+hweNsy2fw5Bxdocx/9a5IxHXG2fPMmQPGGdb/kxcY5w9KE5Zyb4WaGwcZ+e+bAgbv7B6s14nlngEBABwggICADhBAQEAnKCAAABOUEAAACcoIACAExQQAMAJCggA4AQFBABwggICADhBAQEAnOi7u+AsDMusJhkXTo0wLL/K3Gab/ZotbjLJkA0bZ1v3ng00ZDONs4OGbK1xdoMha93tZt1NZplvvZ7Wz7nFDkPWuqvvBkPWuktxqDH/niFbZZy9xbAI8Nge43DHeAQEAHCCAgIAOEEBAQCcoIAAAE5QQAAAJyggAIATFBAAwAkKCADgBAUEAHCCAgIAONF3V/Ecl3TaZ9ayj6XAdoysyf6zK8bZZt9gWJvxtmEdhyS9bcgaR6vJmJ9hyE40zj5myFrXseQYstZ1NtZVPFcYskOMs48bspbPiSRNN2R3GWdXGLLXG2e/a8y3GfMW/2WJ/2yJ5QulD+AREADACQoIAOAEBQQAcIICAgA4QQEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcKLv7oK7Q74XZh3b5n9s+krjOQy74HSNbXTurYasYW+cJF2/zX/2hVdts1+wxU2r+k4ZZ4eMeYuhhqx1t5vlcyJJBwzZDuPsLEPWsntPiu/1aTFkX4zbKeKvwZD91jeNw61L72KMR0AAACdMBVReXq4pU6YoNTVVqampKiws1K9//evI+9va2lRcXKwRI0Zo2LBhWrRokZqbm2N+aABA4jMV0OjRo/XII4+ourpae/fu1XXXXaf58+fr/ffflyTdddddeumll/T888+rsrJSR44c0cKFC+NycABAYjN9D2jevHlRv/7pT3+q8vJy7d69W6NHj9aGDRu0efNmXXfddZKkjRs36ktf+pJ2796tq666KnanBgAkvF5/D6izs1NbtmxRa2urCgsLVV1drY6ODhUVFUUyEydOVG5urqqqqnqc097ernA4HHUDAPR/5gI6cOCAhg0bpkAgoBUrVmjr1q26/PLL1dTUpOTkZA0fPjwqn5GRoaamnl9Ds6ysTMFgMHLLybG+5iIAIBGZC2jChAnat2+f9uzZo5UrV2rx4sX64IMPen2A0tJShUKhyK2hwfKkQwBAojL/HFBycrLGjx8vSZo6dar+4z/+Q48//rhuuukmnT59WidOnIh6FNTc3KzMzMwe5wUCAQUCAfvJAQAJ7bx/Dqirq0vt7e2aOnWqBg0apJ07d0beV1NTo/r6ehUWFp7vhwEA9DOmR0ClpaWaO3eucnNz1dLSos2bN6uiokKvvPKKgsGgli5dqpKSEqWlpSk1NVV33HGHCgsLeQYcAOAspgI6evSovvOd76ixsVHBYFBTpkzRK6+8ouuvv16S9Nhjj2nAgAFatGiR2tvbNXv2bD311FO9O1nmECk5yVe0Kdf/Apd0664Xy+4R6xqMXEP2Cttoy8qhFcY1P9/5uS1/0DA/ZN31Ytg7c8o4+0ND1nIpJSlo3N1TYbiflrU9km0VzwzjbMsft98aZ1tkGPN96cfnK75tCK9ebZr94I8f8J192DTZH1MBbdiw4XPfP3jwYK1bt07r1q07r0MBAPo/dsEBAJyggAAATlBAAAAnKCAAgBMUEADACQoIAOAEBQQAcIICAgA4QQEBAJwwb8OON8/zJEnh057v33Oyy//88GnjgdoM2Rbj7L7y2nuttvgp4+fwZKfhKIZrKUky5D82jrZceuuGp0H+v7wlSe2G7Ce20abZ1vtp/eMWL9Yvq76kxbCGKRy2fNXarn1vfPr3eU+SvHMlLrDDhw/zonQA0A80NDRo9OjRPb6/zxVQV1eXjhw5opSUFCUl/XkZaTgcVk5OjhoaGpSamurwhPHF/ew/Lob7KHE/+5tY3E/P89TS0qLs7GwNGNDzd3r63D/BDRgw4HMbMzU1tV9f/E9xP/uPi+E+StzP/uZ872cwGDxnhichAACcoIAAAE4kTAEFAgE9+OCDCgQCro8SV9zP/uNiuI8S97O/uZD3s889CQEAcHFImEdAAID+hQICADhBAQEAnKCAAABOJEwBrVu3Tn/913+twYMHq6CgQO+8847rI8XUQw89pKSkpKjbxIkTXR/rvOzatUvz5s1Tdna2kpKStG3btqj3e56nBx54QFlZWRoyZIiKior00UcfuTnseTjX/VyyZMlZ13bOnDluDttLZWVluvLKK5WSkqJRo0ZpwYIFqqmpicq0tbWpuLhYI0aM0LBhw7Ro0SI1Nzc7OnHv+LmfM2bMOOt6rlixwtGJe6e8vFxTpkyJ/LBpYWGhfv3rX0fef6GuZUIU0HPPPaeSkhI9+OCDevfdd5Wfn6/Zs2fr6NGjro8WU1/+8pfV2NgYub311luuj3ReWltblZ+fr3Xr1nX7/jVr1uiJJ57Q+vXrtWfPHl1yySWaPXu22tpsCxVdO9f9lKQ5c+ZEXdtnn332Ap7w/FVWVqq4uFi7d+/Wa6+9po6ODs2aNUutrX/eZHvXXXfppZde0vPPP6/KykodOXJECxcudHhqOz/3U5KWLVsWdT3XrFnj6MS9M3r0aD3yyCOqrq7W3r17dd1112n+/Pl6//33JV3Aa+klgGnTpnnFxcWRX3d2dnrZ2dleWVmZw1PF1oMPPujl5+e7PkbcSPK2bt0a+XVXV5eXmZnp/exnP4u87cSJE14gEPCeffZZByeMjc/eT8/zvMWLF3vz5893cp54OXr0qCfJq6ys9DzvzLUbNGiQ9/zzz0cyv/vd7zxJXlVVlatjnrfP3k/P87xvfOMb3ve//313h4qTL37xi97Pf/7zC3ot+/wjoNOnT6u6ulpFRUWRtw0YMEBFRUWqqqpyeLLY++ijj5Sdna2xY8fqtttuU319vesjxU1dXZ2ampqirmswGFRBQUG/u66SVFFRoVGjRmnChAlauXKljh8/7vpI5yUUCkmS0tLSJEnV1dXq6OiIup4TJ05Ubm5uQl/Pz97PTz3zzDNKT0/XpEmTVFpaqlOnrC9U0Xd0dnZqy5Ytam1tVWFh4QW9ln1uGelnHTt2TJ2dncrIyIh6e0ZGhj788ENHp4q9goICbdq0SRMmTFBjY6Mefvhhff3rX9fBgweVkpLi+ngx19TUJEndXtdP39dfzJkzRwsXLlReXp5qa2v1j//4j5o7d66qqqo0cOBA18cz6+rq0p133qmrr75akyZNknTmeiYnJ2v48OFR2US+nt3dT0m69dZbNWbMGGVnZ2v//v269957VVNToxdffNHhae0OHDigwsJCtbW1adiwYdq6dasuv/xy7du374Jdyz5fQBeLuXPnRv57ypQpKigo0JgxY/SrX/1KS5cudXgynK+bb7458t+TJ0/WlClTNG7cOFVUVGjmzJkOT9Y7xcXFOnjwYMJ/j/Jcerqfy5cvj/z35MmTlZWVpZkzZ6q2tlbjxo270MfstQkTJmjfvn0KhUJ64YUXtHjxYlVWVl7QM/T5f4JLT0/XwIEDz3oGRnNzszIzMx2dKv6GDx+uyy67TIcOHXJ9lLj49NpdbNdVksaOHav09PSEvLarVq3Syy+/rDfffDPqZVMyMzN1+vRpnThxIiqfqNezp/vZnYKCAklKuOuZnJys8ePHa+rUqSorK1N+fr4ef/zxC3ot+3wBJScna+rUqdq5c2fkbV1dXdq5c6cKCwsdniy+Tp48qdraWmVlZbk+Slzk5eUpMzMz6rqGw2Ht2bOnX19X6cyr/h4/fjyhrq3neVq1apW2bt2qN954Q3l5eVHvnzp1qgYNGhR1PWtqalRfX59Q1/Nc97M7+/btk6SEup7d6erqUnt7+4W9ljF9SkOcbNmyxQsEAt6mTZu8Dz74wFu+fLk3fPhwr6mpyfXRYuYHP/iBV1FR4dXV1Xlvv/22V1RU5KWnp3tHjx51fbRea2lp8d577z3vvffe8yR5jz76qPfee+95f/jDHzzP87xHHnnEGz58uLd9+3Zv//793vz58728vDzv448/dnxym8+7ny0tLd7dd9/tVVVVeXV1dd7rr7/uXXHFFd6ll17qtbW1uT66bytXrvSCwaBXUVHhNTY2Rm6nTp2KZFasWOHl5uZ6b7zxhrd3716vsLDQKywsdHhqu3Pdz0OHDnmrV6/29u7d69XV1Xnbt2/3xo4d602fPt3xyW3uu+8+r7Ky0qurq/P279/v3XfffV5SUpL36quvep534a5lQhSQ53nek08+6eXm5nrJycnetGnTvN27d7s+UkzddNNNXlZWlpecnOz91V/9lXfTTTd5hw4dcn2s8/Lmm296ks66LV682PO8M0/Fvv/++72MjAwvEAh4M2fO9Gpqatweuhc+736eOnXKmzVrljdy5Ehv0KBB3pgxY7xly5Yl3P88dXf/JHkbN26MZD7++GPve9/7nvfFL37RGzp0qHfjjTd6jY2N7g7dC+e6n/X19d706dO9tLQ0LxAIeOPHj/d++MMfeqFQyO3Bjb773e96Y8aM8ZKTk72RI0d6M2fOjJSP5124a8nLMQAAnOjz3wMCAPRPFBAAwAkKCADgBAUEAHCCAgIAOEEBAQCcoIAAAE5QQAAAJyggAIATFBAAwAkKCADgBAUEAHDi/wE7oxUZw9frpQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# actual data\n",
"\n",
"transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(feature_mean, feature_std)])\n",
"train = CIFAR10(\"./data\", train=True, transform=transform, download=True)\n",
"\n",
"# show 1 picture, it's a frog.\n",
"feature, label = train[0]\n",
"\n",
"feature = feature.permute(1, 2, 0)\n",
"plt.imshow(feature)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "ed317276",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"CIFAR10_ClassificationModel [128, 10] --\n",
"├─Conv2d: 1-1 [128, 64, 32, 32] 1,792\n",
"├─ReLU: 1-2 [128, 64, 32, 32] --\n",
"├─Conv2d: 1-3 [128, 128, 30, 30] 73,856\n",
"├─ReLU: 1-4 [128, 128, 30, 30] --\n",
"├─Conv2d: 1-5 [128, 256, 28, 28] 295,168\n",
"├─ReLU: 1-6 [128, 256, 28, 28] --\n",
"├─AvgPool2d: 1-7 [128, 256, 14, 14] --\n",
"├─Conv2d: 1-8 [128, 256, 14, 14] 590,080\n",
"├─ReLU: 1-9 [128, 256, 14, 14] --\n",
"├─Conv2d: 1-10 [128, 256, 12, 12] 590,080\n",
"├─ReLU: 1-11 [128, 256, 12, 12] --\n",
"├─AvgPool2d: 1-12 [128, 256, 6, 6] --\n",
"├─Linear: 1-13 [128, 512] 4,719,104\n",
"├─ReLU: 1-14 [128, 512] --\n",
"├─Linear: 1-15 [128, 128] 65,664\n",
"├─ReLU: 1-16 [128, 128] --\n",
"├─Linear: 1-17 [128, 10] 1,290\n",
"==========================================================================================\n",
"Total params: 6,337,034\n",
"Trainable params: 6,337,034\n",
"Non-trainable params: 0\n",
"Total mult-adds (G): 64.66\n",
"==========================================================================================\n",
"Input size (MB): 1.57\n",
"Forward/backward pass size (MB): 480.39\n",
"Params size (MB): 25.35\n",
"Estimated Total Size (MB): 507.31\n",
"=========================================================================================="
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"BATCH_SIZE = 128\n",
"\n",
"class CIFAR10_ClassificationModel(nn.Module):\n",
" def __init__(self):\n",
" super(CIFAR10_ClassificationModel, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)\n",
" self.conv2 = nn.Conv2d(64, 128, 3, stride=1, padding=0)\n",
" self.conv3 = nn.Conv2d(128, 256, 3, stride=1, padding=0)\n",
"# self.conv4 = nn.Conv2d(256, 256, 3, stride=1, padding=0)\n",
"\n",
" self.conv5 = nn.Conv2d(256, 256, 3, stride=1, padding=1)\n",
" self.conv6 = nn.Conv2d(256, 256, 3, stride=1, padding=0)\n",
"\n",
" self.pool = nn.AvgPool2d(2,stride=2)\n",
" \n",
" self.fc1 = nn.Linear(9216, 512)\n",
" self.fc2 = nn.Linear(512, 128)\n",
" self.fc3 = nn.Linear(128, 10)\n",
" self.relu = nn.ReLU()\n",
" \n",
" def forward(self, x):\n",
" batch = x.shape[0]\n",
" x = self.relu( self.conv1(x) )\n",
" x = self.relu( self.conv2(x) )\n",
" x = self.relu( self.conv3(x) )\n",
"\n",
" x = self.pool(x)\n",
" x = self.relu( self.conv5(x) )\n",
" x = self.relu( self.conv6(x) )\n",
" x = self.pool(x)\n",
" x = x.view(batch, -1)\n",
" x = self.relu( self.fc1(x) )\n",
" x = self.relu( self.fc2(x) )\n",
" x = self.fc3(x)\n",
" return x\n",
"\n",
"model = CIFAR10_ClassificationModel()\n",
"summary(model, input_size=(BATCH_SIZE, 3, 32, 32))"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "cbb17733",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Current Time = 00:57:57\n",
"Epoch 0: 1.5633342016078626, 42.9739990234375%\n",
"Epoch 1: 1.030615347120768, 63.2859992980957%\n",
"Epoch 2: 0.7714244136420052, 72.7699966430664%\n",
"Epoch 3: 0.6027741941344708, 78.95999908447266%\n",
"Epoch 4: 0.45929622848320495, 84.03199768066406%\n",
"Current Time = 00:59:48\n"
]
}
],
"source": [
"now = datetime.now()\n",
"current_time = now.strftime(\"%H:%M:%S\")\n",
"print(\"Current Time =\", current_time)\n",
"\n",
"loader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)\n",
"model = CIFAR10_ClassificationModel().to(device)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
"for epoch in range(5):\n",
" loss_sum = 0\n",
" corr_sum = 0\n",
" for x, y in loader:\n",
" features, labels = x.to(device), y.to(device)\n",
"\n",
" # clear gradient in optimizer\n",
" optimizer.zero_grad()\n",
"\n",
" # predict\n",
" predicts = model(features)\n",
"\n",
" # get loss\n",
" loss = criterion(predicts, labels)\n",
" loss.backward()\n",
" \n",
" # back propagation\n",
" optimizer.step()\n",
" \n",
" # loss & acc\n",
" predicts = torch.argmax(predicts, 1)\n",
" loss_sum += loss.item()\n",
" corr_sum += (predicts==labels).sum()\n",
" print(\"Epoch {}: {}, {}%\".format(epoch, loss_sum/len(loader), corr_sum/len(train)*100))\n",
"\n",
"now = datetime.now()\n",
"current_time = now.strftime(\"%H:%M:%S\")\n",
"print(\"Current Time =\", current_time)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}