Skip to content

Commit

Permalink
Created using Colaboratory
Browse files Browse the repository at this point in the history
  • Loading branch information
TonmoyTalukder committed Aug 11, 2023
1 parent 01187da commit d7d7c28
Showing 1 changed file with 333 additions and 0 deletions.
333 changes: 333 additions & 0 deletions PyTorch_Explore/8_Save_and_Load_Model.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,333 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNVdpHfJPPUhEbP4iWXR/vP",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/TonmoyTalukder/deep-learning-explore/blob/main/PyTorch_Explore/8_Save_and_Load_Model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"\n",
"\n",
"```\n",
"# Methods\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"torch.save(arg, PATH)\n",
"torch.load(PATH)\n",
"model.load_sate_dict(arg)\n",
"\n",
"\n",
"\n",
"# Save\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"#### COMPLETE MODEL ####\n",
"torch.save(arg, PATH)\n",
"model = torch.load(PATH)\n",
"model.eval()\n",
"\n",
"#### STATE DICT ####\n",
"torch.save(model.state_dict(), PATH) # it only saves parameters\n",
"model = Model(*args, **kwargs)\n",
"model.load_state_dict(torch.load(PATH))\n",
"model.eval()\n",
"```\n",
"\n"
],
"metadata": {
"id": "0iz1V7eqW-Hf"
}
},
{
"cell_type": "code",
"source": [
"# Save Load Example\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self, n_input_features):\n",
" super(Model, self).__init__()\n",
" self.linear = nn.Linear(n_input_features, 1)\n",
"\n",
" def forward(self, x):\n",
" y_pred = torch.sigmoid(self.linear())\n",
" return y_pred\n",
"\n",
"model = Model(n_input_features=6)\n",
"\n",
"print(model.state_dict())\n",
"# train model ..."
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HVnVCOtSLQ0R",
"outputId": "88f20b04-3eb5-4fc6-eded-148640e5344a"
},
"execution_count": 20,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"OrderedDict([('linear.weight', tensor([[ 0.0033, 0.0729, 0.1773, 0.3808, -0.3613, 0.0527]])), ('linear.bias', tensor([0.2571]))])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"for param in model.parameters():\n",
" print(param)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6Q0FkDWIS5dQ",
"outputId": "bb35a9e9-f927-4dda-dbe3-3cbb9562dff1"
},
"execution_count": 21,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Parameter containing:\n",
"tensor([[ 0.0033, 0.0729, 0.1773, 0.3808, -0.3613, 0.0527]],\n",
" requires_grad=True)\n",
"Parameter containing:\n",
"tensor([0.2571], requires_grad=True)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"#### LAZY OPTION ####\n",
"FILE = \"lazy_model.pth\"\n",
"torch.save(model, FILE) # Save\n",
"\n",
"modeltest = torch.load(FILE) # Load\n",
"modeltest.eval()\n",
"\n",
"for param in modeltest.parameters():\n",
" print(param)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XaTgitedMRUv",
"outputId": "e78048a0-f0f6-488d-a907-2fefb4a7be6b"
},
"execution_count": 22,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Parameter containing:\n",
"tensor([[ 0.0033, 0.0729, 0.1773, 0.3808, -0.3613, 0.0527]],\n",
" requires_grad=True)\n",
"Parameter containing:\n",
"tensor([0.2571], requires_grad=True)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"#### PREFARABLE WAY ####\n",
"FILE = \"pref_model.pth\"\n",
"torch.save(model.state_dict(), FILE) # Save\n",
"\n",
"loaded_model = Model(n_input_features=6)\n",
"loaded_model.load_state_dict(torch.load(FILE)) # Load\n",
"modeltest.eval()\n",
"\n",
"for param in loaded_model.parameters():\n",
" print(param)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pDt8KUK3M4cm",
"outputId": "4d43fe3a-e64a-4af5-efd8-c44ea8811c16"
},
"execution_count": 23,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Parameter containing:\n",
"tensor([[ 0.0033, 0.0729, 0.1773, 0.3808, -0.3613, 0.0527]],\n",
" requires_grad=True)\n",
"Parameter containing:\n",
"tensor([0.2571], requires_grad=True)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Saving Checkpoints\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self, n_input_features):\n",
" super(Model, self).__init__()\n",
" self.linear = nn.Linear(n_input_features, 1)\n",
"\n",
" def forward(self, x):\n",
" y_pred = torch.sigmoid(self.linear())\n",
" return y_pred\n",
"\n",
"model = Model(n_input_features=6)\n",
"learning_rate = 0.01\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
"\n",
"print(optimizer.state_dict())\n",
"# train model ...\n",
"\n",
"checkpoint = {\n",
" \"epoch\": 90,\n",
" \"model_state\": model.state_dict(),\n",
" \"optim_state\": optimizer.state_dict()\n",
"}\n",
"\n",
"torch.save(checkpoint, \"checkpoint.pth\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "iRMKaYZbS9jh",
"outputId": "c86f575c-9207-4e69-bb87-675de9bcac66"
},
"execution_count": 25,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1]}]}\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"loaded_checkpoint = torch.load(\"checkpoint.pth\")\n",
"epoch = loaded_checkpoint[\"epoch\"]\n",
"\n",
"model = Model(n_input_features=6)\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=0)\n",
"\n",
"model.load_state_dict(checkpoint[\"model_state\"])\n",
"optimizer.load_state_dict(checkpoint[\"optim_state\"])\n",
"\n",
"print(optimizer.state_dict())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "FaLi-2P7Tgvj",
"outputId": "5e183e99-0af7-4ba0-fb3e-3a5c430128b3"
},
"execution_count": 26,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1]}]}\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"\n",
"\n",
"```\n",
"# GPU | CPU Saving and Loading\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"# Save on GPU, Load on CPU\n",
"device = torch.device(\"cuda\")\n",
"model.to(device)\n",
"torch.save(model.state_dict(), PATH)\n",
"\n",
"device = torch.device(\"cpu\")\n",
"model = Model(*args, **kwargs)\n",
"model.load_state_dict(torch.load(PATH, map_location=device))\n",
"\n",
"# Save on GPU, Load on GPU\n",
"device = torch.device(\"cuda\")\n",
"model.to(device)\n",
"torch.save(model.state_dict(), PATH)\n",
"\n",
"model = Model(*args, **kwargs)\n",
"model.load_state_dict(torch.load(PATH))\n",
"model.to(device)\n",
"\n",
"# Save on CPU, Load on GPU\n",
"torch.save(model.state_dict(), PATH)\n",
"\n",
"device = torch.device(\"cuda\")\n",
"model = Model(*args, **kwargs)\n",
"model.load_state_dict(torch.load(PATH, map_location=\"cuda:0\"))\n",
"model.to(device) \n",
"```\n",
"\n"
],
"metadata": {
"id": "HwdaDhRQXWYI"
}
}
]
}

0 comments on commit d7d7c28

Please sign in to comment.