-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f2cbac3
commit ea8d1ad
Showing
5 changed files
with
263 additions
and
1,071 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Loading pretrained model weights\n", | ||
"If the loading of parameters was successful, a message should be printed out saying `Loaded parameters from /PATH/TO/WEIGHTS`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2\n", | ||
"\n", | ||
"import os\n", | ||
"\n", | ||
"from mouse_vision.core.model_loader_utils import load_model\n", | ||
"from mouse_vision.models.model_paths import MODEL_PATHS" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def load_pretrained_model(model_name):\n", | ||
" model_path = MODEL_PATHS[model_name]\n", | ||
" assert os.path.isfile(model_path)\n", | ||
"\n", | ||
" model, layers = load_model(\n", | ||
" model_name, \n", | ||
" trained=True, \n", | ||
" model_path=model_path, \n", | ||
" model_family=\"imagenet\",\n", | ||
" state_dict_key=\"model_state_dict\", # make sure `model_state_dict` is in the *.pt file\n", | ||
" )\n", | ||
" \n", | ||
" return model, layers" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### AlexNet (instance recognition)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loading alexnet_bn_ir_64x64_input_pool_6. Pretrained: True. Model Family: imagenet.\n", | ||
"Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/alexnet_bn_ir.pt\n", | ||
"======= Model architecture =======\n", | ||
" AlexNetBN(\n", | ||
" (features): Sequential(\n", | ||
" (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n", | ||
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | ||
" (2): ReLU(inplace=True)\n", | ||
" (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | ||
" (4): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", | ||
" (5): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | ||
" (6): ReLU(inplace=True)\n", | ||
" (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | ||
" (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
" (9): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | ||
" (10): ReLU(inplace=True)\n", | ||
" (11): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
" (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | ||
" (13): ReLU(inplace=True)\n", | ||
" (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
" (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | ||
" (16): ReLU(inplace=True)\n", | ||
" (17): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | ||
" )\n", | ||
" (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n", | ||
" (classifier): Sequential(\n", | ||
" (0): Dropout(p=0.5, inplace=False)\n", | ||
" (1): Linear(in_features=9216, out_features=4096, bias=True)\n", | ||
" (2): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | ||
" (3): ReLU(inplace=True)\n", | ||
" (4): Dropout(p=0.5, inplace=False)\n", | ||
" (5): Linear(in_features=4096, out_features=4096, bias=True)\n", | ||
" (6): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | ||
" (7): ReLU(inplace=True)\n", | ||
" )\n", | ||
")\n", | ||
"======= Model layers =======\n", | ||
"['features.3', 'features.7', 'features.10', 'features.13', 'features.17', 'classifier.3', 'classifier.7']\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"name = \"alexnet_bn_ir_64x64_input_pool_6\"\n", | ||
"model, model_layers = load_pretrained_model(name)\n", | ||
"print(\"======= Model architecture =======\\n\", model)\n", | ||
"print(f\"======= Model layers =======\\n{model_layers}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### MouseNet of Shi et al. (instance recognition)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### a) Original architecture" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loading shi_mousenet_ir. Pretrained: True. Model Family: imagenet.\n", | ||
"Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/shi_mousenet_ir.pt\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"name = \"shi_mousenet_ir\"\n", | ||
"model, model_layers = load_pretrained_model(name)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### b) Our variant" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loading shi_mousenet_vispor5_ir. Pretrained: True. Model Family: imagenet.\n", | ||
"Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/shi_mousenet_vispor5_ir.pt\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"name = \"shi_mousenet_vispor5_ir\"\n", | ||
"model, model_layers = load_pretrained_model(name)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Dual stream (instance recognition)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loading simplified_mousenet_dual_stream_visp_3x3_ir. Pretrained: True. Model Family: imagenet.\n", | ||
"Single stream set to False\n", | ||
"Using {'type': 'BN'} normalization\n", | ||
"Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/dual_stream_ir.pt\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"name = \"simplified_mousenet_dual_stream_visp_3x3_ir\"\n", | ||
"model, model_layers = load_pretrained_model(name)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Six stream (SimCLR)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loading simplified_mousenet_six_stream_visp_3x3_simclr. Pretrained: True. Model Family: imagenet.\n", | ||
"Single stream set to False\n", | ||
"Using {'type': 'SyncBN'} normalization\n", | ||
"Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/six_stream_simclr.pt\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"name = \"simplified_mousenet_six_stream_visp_3x3_simclr\"\n", | ||
"model, model_layers = load_pretrained_model(name)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
Oops, something went wrong.