Skip to content

Commit

Permalink
notebook for loading model weights
Browse files Browse the repository at this point in the history
  • Loading branch information
nathankong committed Sep 28, 2023
1 parent f2cbac3 commit ea8d1ad
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 1,071 deletions.
248 changes: 248 additions & 0 deletions Loading model weights.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading pretrained model weights\n",
"If the loading of parameters was successful, a message should be printed out saying `Loaded parameters from /PATH/TO/WEIGHTS`"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import os\n",
"\n",
"from mouse_vision.core.model_loader_utils import load_model\n",
"from mouse_vision.models.model_paths import MODEL_PATHS"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def load_pretrained_model(model_name):\n",
" model_path = MODEL_PATHS[model_name]\n",
" assert os.path.isfile(model_path)\n",
"\n",
" model, layers = load_model(\n",
" model_name, \n",
" trained=True, \n",
" model_path=model_path, \n",
" model_family=\"imagenet\",\n",
" state_dict_key=\"model_state_dict\", # make sure `model_state_dict` is in the *.pt file\n",
" )\n",
" \n",
" return model, layers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### AlexNet (instance recognition)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading alexnet_bn_ir_64x64_input_pool_6. Pretrained: True. Model Family: imagenet.\n",
"Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/alexnet_bn_ir.pt\n",
"======= Model architecture =======\n",
" AlexNetBN(\n",
" (features): Sequential(\n",
" (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU(inplace=True)\n",
" (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (4): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
" (5): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): ReLU(inplace=True)\n",
" (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (9): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (10): ReLU(inplace=True)\n",
" (11): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (13): ReLU(inplace=True)\n",
" (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (16): ReLU(inplace=True)\n",
" (17): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n",
" (classifier): Sequential(\n",
" (0): Dropout(p=0.5, inplace=False)\n",
" (1): Linear(in_features=9216, out_features=4096, bias=True)\n",
" (2): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): ReLU(inplace=True)\n",
" (4): Dropout(p=0.5, inplace=False)\n",
" (5): Linear(in_features=4096, out_features=4096, bias=True)\n",
" (6): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (7): ReLU(inplace=True)\n",
" )\n",
")\n",
"======= Model layers =======\n",
"['features.3', 'features.7', 'features.10', 'features.13', 'features.17', 'classifier.3', 'classifier.7']\n"
]
}
],
"source": [
"name = \"alexnet_bn_ir_64x64_input_pool_6\"\n",
"model, model_layers = load_pretrained_model(name)\n",
"print(\"======= Model architecture =======\\n\", model)\n",
"print(f\"======= Model layers =======\\n{model_layers}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### MouseNet of Shi et al. (instance recognition)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### a) Original architecture"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading shi_mousenet_ir. Pretrained: True. Model Family: imagenet.\n",
"Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/shi_mousenet_ir.pt\n"
]
}
],
"source": [
"name = \"shi_mousenet_ir\"\n",
"model, model_layers = load_pretrained_model(name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### b) Our variant"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading shi_mousenet_vispor5_ir. Pretrained: True. Model Family: imagenet.\n",
"Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/shi_mousenet_vispor5_ir.pt\n"
]
}
],
"source": [
"name = \"shi_mousenet_vispor5_ir\"\n",
"model, model_layers = load_pretrained_model(name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dual stream (instance recognition)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading simplified_mousenet_dual_stream_visp_3x3_ir. Pretrained: True. Model Family: imagenet.\n",
"Single stream set to False\n",
"Using {'type': 'BN'} normalization\n",
"Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/dual_stream_ir.pt\n"
]
}
],
"source": [
"name = \"simplified_mousenet_dual_stream_visp_3x3_ir\"\n",
"model, model_layers = load_pretrained_model(name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Six stream (SimCLR)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading simplified_mousenet_six_stream_visp_3x3_simclr. Pretrained: True. Model Family: imagenet.\n",
"Single stream set to False\n",
"Using {'type': 'SyncBN'} normalization\n",
"Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/six_stream_simclr.pt\n"
]
}
],
"source": [
"name = \"simplified_mousenet_six_stream_visp_3x3_simclr\"\n",
"model, model_layers = load_pretrained_model(name)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading

0 comments on commit ea8d1ad

Please sign in to comment.