{ "cells": [ { "cell_type": "code", "execution_count": 66, "id": "bdb7316a", "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 67, "id": "54adec76", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([100, 1])\n" ] } ], "source": [ "# data\n", "\n", "x = torch.tensor(range(100), dtype=torch.float).view((-1, 1))\n", "y = x * 2 + 1\n", "print(y.shape)" ] }, { "cell_type": "code", "execution_count": 68, "id": "19d893e1", "metadata": {}, "outputs": [], "source": [ "class LinearRegressionModel(torch.nn.Module):\n", " def __init__(self, input_d, output_d):\n", " super(LinearRegressionModel, self).__init__()\n", " self.linear = torch.nn.Linear(input_d, output_d)\n", " \n", " def forward(self, x):\n", "# print(x)\n", " x = self.linear(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 69, "id": "6d5d6e68", "metadata": {}, "outputs": [], "source": [ "learning_rate = 0.00001\n", "epochs = 10000" ] }, { "cell_type": "code", "execution_count": 70, "id": "735bf7f4", "metadata": {}, "outputs": [], "source": [ "model = LinearRegressionModel(1, 1)\n", "criterion = torch.nn.MSELoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)" ] }, { "cell_type": "code", "execution_count": 71, "id": "338e44eb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Use GPU: NVIDIA GeForce RTX 3060\n" ] } ], "source": [ "# move model\n", "\n", "if torch.cuda.is_available():\n", " device = torch.device('cuda:0')\n", " print(\"Use GPU:\", torch.cuda.get_device_name())\n", "else:\n", " device = torch.device('cpu')\n", " print(\"Use CPU\")\n", " \n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": 72, "id": "54e4ce0e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0, loss=19837.623046875\n", "epoch 100, loss=0.4198426306247711\n", "epoch 200, loss=0.3945448100566864\n", "epoch 300, loss=0.39414462447166443\n", "epoch 400, loss=0.3937450647354126\n", "epoch 500, loss=0.39334583282470703\n", "epoch 600, loss=0.3929465413093567\n", "epoch 700, loss=0.39254775643348694\n", "epoch 800, loss=0.3921498656272888\n", "epoch 900, loss=0.39175212383270264\n", "epoch 1000, loss=0.39135488867759705\n", "epoch 1100, loss=0.39095762372016907\n", "epoch 1200, loss=0.3905612826347351\n", "epoch 1300, loss=0.3901650905609131\n", "epoch 1400, loss=0.38976970314979553\n", "epoch 1500, loss=0.3893739879131317\n", "epoch 1600, loss=0.3889789581298828\n", "epoch 1700, loss=0.38858461380004883\n", "epoch 1800, loss=0.38819044828414917\n", "epoch 1900, loss=0.3877968490123749\n", "epoch 2000, loss=0.3874036371707916\n", "epoch 2100, loss=0.38701024651527405\n", "epoch 2200, loss=0.3866178095340729\n", "epoch 2300, loss=0.38622617721557617\n", "epoch 2400, loss=0.3858342170715332\n", "epoch 2500, loss=0.3854426443576813\n", "epoch 2600, loss=0.3850513994693756\n", "epoch 2700, loss=0.38466185331344604\n", "epoch 2800, loss=0.38427120447158813\n", "epoch 2900, loss=0.38388144969940186\n", "epoch 3000, loss=0.3834918439388275\n", "epoch 3100, loss=0.3831026256084442\n", "epoch 3200, loss=0.38271453976631165\n", "epoch 3300, loss=0.38232627511024475\n", "epoch 3400, loss=0.3819386065006256\n", "epoch 3500, loss=0.38155093789100647\n", "epoch 3600, loss=0.38116419315338135\n", "epoch 3700, loss=0.3807776868343353\n", "epoch 3800, loss=0.3803914189338684\n", "epoch 3900, loss=0.38000550866127014\n", "epoch 4000, loss=0.3796202838420868\n", "epoch 4100, loss=0.3792349100112915\n", "epoch 4200, loss=0.378850519657135\n", "epoch 4300, loss=0.37846603989601135\n", "epoch 4400, loss=0.37808194756507874\n", "epoch 4500, loss=0.3776988089084625\n", "epoch 4600, loss=0.3773157298564911\n", "epoch 4700, loss=0.376933217048645\n", "epoch 4800, loss=0.37655070424079895\n", "epoch 4900, loss=0.3761685788631439\n", "epoch 5000, loss=0.37578698992729187\n", "epoch 5100, loss=0.3754061758518219\n", "epoch 5200, loss=0.3750254809856415\n", "epoch 5300, loss=0.3746449947357178\n", "epoch 5400, loss=0.37426450848579407\n", "epoch 5500, loss=0.37388503551483154\n", "epoch 5600, loss=0.37350621819496155\n", "epoch 5700, loss=0.37312743067741394\n", "epoch 5800, loss=0.3727487027645111\n", "epoch 5900, loss=0.37237048149108887\n", "epoch 6000, loss=0.3719930946826935\n", "epoch 6100, loss=0.3716159760951996\n", "epoch 6200, loss=0.37123891711235046\n", "epoch 6300, loss=0.3708621859550476\n", "epoch 6400, loss=0.3704858422279358\n", "epoch 6500, loss=0.3701103925704956\n", "epoch 6600, loss=0.36973509192466736\n", "epoch 6700, loss=0.36936017870903015\n", "epoch 6800, loss=0.3689850866794586\n", "epoch 6900, loss=0.3686109185218811\n", "epoch 7000, loss=0.36823728680610657\n", "epoch 7100, loss=0.36786380410194397\n", "epoch 7200, loss=0.3674907684326172\n", "epoch 7300, loss=0.36711767315864563\n", "epoch 7400, loss=0.3667454123497009\n", "epoch 7500, loss=0.3663736879825592\n", "epoch 7600, loss=0.3660021424293518\n", "epoch 7700, loss=0.3656298816204071\n", "epoch 7800, loss=0.3652595579624176\n", "epoch 7900, loss=0.36488935351371765\n", "epoch 8000, loss=0.36451929807662964\n", "epoch 8100, loss=0.36414942145347595\n", "epoch 8200, loss=0.36378055810928345\n", "epoch 8300, loss=0.3634108603000641\n", "epoch 8400, loss=0.3630427420139313\n", "epoch 8500, loss=0.3626745641231537\n", "epoch 8600, loss=0.3623066544532776\n", "epoch 8700, loss=0.36193886399269104\n", "epoch 8800, loss=0.36157187819480896\n", "epoch 8900, loss=0.36120542883872986\n", "epoch 9000, loss=0.3608390688896179\n", "epoch 9100, loss=0.360473096370697\n", "epoch 9200, loss=0.3601067364215851\n", "epoch 9300, loss=0.35974177718162537\n", "epoch 9400, loss=0.3593772053718567\n", "epoch 9500, loss=0.35901322960853577\n", "epoch 9600, loss=0.35864850878715515\n", "epoch 9700, loss=0.35828468203544617\n", "epoch 9800, loss=0.35792145133018494\n", "epoch 9900, loss=0.35755839943885803\n" ] } ], "source": [ "for epoch in range(epochs):\n", " \n", " # move data\n", " inputs = x.to(device)\n", " labels = y.to(device)\n", " \n", " # clear gradient\n", " optimizer.zero_grad()\n", " \n", " # predict & loss\n", " predicts = model(inputs)\n", "# print(predicts)\n", " loss = criterion(predicts, labels)\n", " \n", " # cal gradient\n", " loss.backward()\n", " \n", " # back propatation\n", " optimizer.step()\n", " \n", " if (epoch%100) == 0:\n", " print(\"epoch {}, loss={}\".format(epoch, loss.item()))" ] }, { "cell_type": "code", "execution_count": 74, "id": "052506ab", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-1.8641e-01],\n", " [ 1.8315e+00],\n", " [ 3.8494e+00],\n", " [ 5.8673e+00],\n", " [ 7.8851e+00],\n", " [ 9.9030e+00],\n", " [ 1.1921e+01],\n", " [ 1.3939e+01],\n", " [ 1.5957e+01],\n", " [ 1.7975e+01],\n", " [ 1.9992e+01],\n", " [ 2.2010e+01],\n", " [ 2.4028e+01],\n", " [ 2.6046e+01],\n", " [ 2.8064e+01],\n", " [ 3.0082e+01],\n", " [ 3.2100e+01],\n", " [ 3.4118e+01],\n", " [ 3.6136e+01],\n", " [ 3.8153e+01],\n", " [ 4.0171e+01],\n", " [ 4.2189e+01],\n", " [ 4.4207e+01],\n", " [ 4.6225e+01],\n", " [ 4.8243e+01],\n", " [ 5.0261e+01],\n", " [ 5.2279e+01],\n", " [ 5.4297e+01],\n", " [ 5.6314e+01],\n", " [ 5.8332e+01],\n", " [ 6.0350e+01],\n", " [ 6.2368e+01],\n", " [ 6.4386e+01],\n", " [ 6.6404e+01],\n", " [ 6.8422e+01],\n", " [ 7.0440e+01],\n", " [ 7.2458e+01],\n", " [ 7.4475e+01],\n", " [ 7.6493e+01],\n", " [ 7.8511e+01],\n", " [ 8.0529e+01],\n", " [ 8.2547e+01],\n", " [ 8.4565e+01],\n", " [ 8.6583e+01],\n", " [ 8.8601e+01],\n", " [ 9.0619e+01],\n", " [ 9.2636e+01],\n", " [ 9.4654e+01],\n", " [ 9.6672e+01],\n", " [ 9.8690e+01],\n", " [ 1.0071e+02],\n", " [ 1.0273e+02],\n", " [ 1.0474e+02],\n", " [ 1.0676e+02],\n", " [ 1.0878e+02],\n", " [ 1.1080e+02],\n", " [ 1.1282e+02],\n", " [ 1.1483e+02],\n", " [ 1.1685e+02],\n", " [ 1.1887e+02],\n", " [ 1.2089e+02],\n", " [ 1.2290e+02],\n", " [ 1.2492e+02],\n", " [ 1.2694e+02],\n", " [ 1.2896e+02],\n", " [ 1.3098e+02],\n", " [ 1.3299e+02],\n", " [ 1.3501e+02],\n", " [ 1.3703e+02],\n", " [ 1.3905e+02],\n", " [ 1.4107e+02],\n", " [ 1.4308e+02],\n", " [ 1.4510e+02],\n", " [ 1.4712e+02],\n", " [ 1.4914e+02],\n", " [ 1.5116e+02],\n", " [ 1.5317e+02],\n", " [ 1.5519e+02],\n", " [ 1.5721e+02],\n", " [ 1.5923e+02],\n", " [ 1.6124e+02],\n", " [ 1.6326e+02],\n", " [ 1.6528e+02],\n", " [ 1.6730e+02],\n", " [ 1.6932e+02],\n", " [ 1.7133e+02],\n", " [ 1.7335e+02],\n", " [ 1.7537e+02],\n", " [ 1.7739e+02],\n", " [ 1.7941e+02],\n", " [ 1.8142e+02],\n", " [ 1.8344e+02],\n", " [ 1.8546e+02],\n", " [ 1.8748e+02],\n", " [ 1.8950e+02],\n", " [ 1.9151e+02],\n", " [ 1.9353e+02],\n", " [ 1.9555e+02],\n", " [ 1.9757e+02],\n", " [ 1.9958e+02]], device='cuda:0', grad_fn=)\n", "tensor([[ 0.],\n", " [ 1.],\n", " [ 2.],\n", " [ 3.],\n", " [ 4.],\n", " [ 5.],\n", " [ 6.],\n", " [ 7.],\n", " [ 8.],\n", " [ 9.],\n", " [10.],\n", " [11.],\n", " [12.],\n", " [13.],\n", " [14.],\n", " [15.],\n", " [16.],\n", " [17.],\n", " [18.],\n", " [19.],\n", " [20.],\n", " [21.],\n", " [22.],\n", " [23.],\n", " [24.],\n", " [25.],\n", " [26.],\n", " [27.],\n", " [28.],\n", " [29.],\n", " [30.],\n", " [31.],\n", " [32.],\n", " [33.],\n", " [34.],\n", " [35.],\n", " [36.],\n", " [37.],\n", " [38.],\n", " [39.],\n", " [40.],\n", " [41.],\n", " [42.],\n", " [43.],\n", " [44.],\n", " [45.],\n", " [46.],\n", " [47.],\n", " [48.],\n", " [49.],\n", " [50.],\n", " [51.],\n", " [52.],\n", " [53.],\n", " [54.],\n", " [55.],\n", " [56.],\n", " [57.],\n", " [58.],\n", " [59.],\n", " [60.],\n", " [61.],\n", " [62.],\n", " [63.],\n", " [64.],\n", " [65.],\n", " [66.],\n", " [67.],\n", " [68.],\n", " [69.],\n", " [70.],\n", " [71.],\n", " [72.],\n", " [73.],\n", " [74.],\n", " [75.],\n", " [76.],\n", " [77.],\n", " [78.],\n", " [79.],\n", " [80.],\n", " [81.],\n", " [82.],\n", " [83.],\n", " [84.],\n", " [85.],\n", " [86.],\n", " [87.],\n", " [88.],\n", " [89.],\n", " [90.],\n", " [91.],\n", " [92.],\n", " [93.],\n", " [94.],\n", " [95.],\n", " [96.],\n", " [97.],\n", " [98.],\n", " [99.]], device='cuda:0')\n" ] } ], "source": [ "inputs = x.to(device)\n", "predicts = model(inputs)\n", "print(predicts)\n", "print(inputs)" ] }, { "cell_type": "code", "execution_count": null, "id": "1d0fb537", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }