Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/WongKinYiu/yolov7
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeyAB committed Jul 28, 2022
2 parents 954cde6 + 2a731fc commit 264fc09
Showing 1 changed file with 43 additions and 8 deletions.
51 changes: 43 additions & 8 deletions tools/reparameterization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,18 @@
"from models.yolo import Model\n",
"import torch\n",
"from utils.torch_utils import select_device, is_parallel\n",
"import yaml\n",
"\n",
"device = select_device('0', batch_size=1)\n",
"# model trained by cfg/training/*.yaml\n",
"ckpt = torch.load('cfg/training/yolov7.pt', map_location=device)\n",
"# reparameterized model in cfg/deploy/*.yaml\n",
"model = Model('cfg/deploy/yolov7.yaml', ch=3, nc=80).to(device)\n",
"\n",
"with open('cfg/deploy/yolov7.yaml') as f:\n",
" yml = yaml.load(f, Loader=yaml.SafeLoader)\n",
"anchors = len(yml['anchors'])\n",
"\n",
"# copy intersect weights\n",
"state_dict = ckpt['model'].float().state_dict()\n",
"exclude = []\n",
Expand All @@ -44,7 +49,7 @@
"model.nc = ckpt['model'].nc\n",
"\n",
"# reparametrized YOLOR\n",
"for i in range(255):\n",
"for i in range((model.nc+5)*anchors):\n",
" model.state_dict()['model.105.m.0.weight'].data[i, :, :, :] *= state_dict['model.105.im.0.implicit'].data[:, i, : :].squeeze()\n",
" model.state_dict()['model.105.m.1.weight'].data[i, :, :, :] *= state_dict['model.105.im.1.implicit'].data[:, i, : :].squeeze()\n",
" model.state_dict()['model.105.m.2.weight'].data[i, :, :, :] *= state_dict['model.105.im.2.implicit'].data[:, i, : :].squeeze()\n",
Expand Down Expand Up @@ -85,13 +90,18 @@
"from models.yolo import Model\n",
"import torch\n",
"from utils.torch_utils import select_device, is_parallel\n",
"import yaml\n",
"\n",
"device = select_device('0', batch_size=1)\n",
"# model trained by cfg/training/*.yaml\n",
"ckpt = torch.load('cfg/training/yolov7x.pt', map_location=device)\n",
"# reparameterized model in cfg/deploy/*.yaml\n",
"model = Model('cfg/deploy/yolov7x.yaml', ch=3, nc=80).to(device)\n",
"\n",
"with open('cfg/deploy/yolov7x.yaml') as f:\n",
" yml = yaml.load(f, Loader=yaml.SafeLoader)\n",
"anchors = len(yml['anchors'])\n",
"\n",
"# copy intersect weights\n",
"state_dict = ckpt['model'].float().state_dict()\n",
"exclude = []\n",
Expand All @@ -101,7 +111,7 @@
"model.nc = ckpt['model'].nc\n",
"\n",
"# reparametrized YOLOR\n",
"for i in range(255):\n",
"for i in range((model.nc+5)*anchors):\n",
" model.state_dict()['model.121.m.0.weight'].data[i, :, :, :] *= state_dict['model.121.im.0.implicit'].data[:, i, : :].squeeze()\n",
" model.state_dict()['model.121.m.1.weight'].data[i, :, :, :] *= state_dict['model.121.im.1.implicit'].data[:, i, : :].squeeze()\n",
" model.state_dict()['model.121.m.2.weight'].data[i, :, :, :] *= state_dict['model.121.im.2.implicit'].data[:, i, : :].squeeze()\n",
Expand Down Expand Up @@ -142,13 +152,18 @@
"from models.yolo import Model\n",
"import torch\n",
"from utils.torch_utils import select_device, is_parallel\n",
"import yaml\n",
"\n",
"device = select_device('0', batch_size=1)\n",
"# model trained by cfg/training/*.yaml\n",
"ckpt = torch.load('cfg/training/yolov7-w6.pt', map_location=device)\n",
"# reparameterized model in cfg/deploy/*.yaml\n",
"model = Model('cfg/deploy/yolov7-w6.yaml', ch=3, nc=80).to(device)\n",
"\n",
"with open('cfg/deploy/yolov7-w6.yaml') as f:\n",
" yml = yaml.load(f, Loader=yaml.SafeLoader)\n",
"anchors = len(yml['anchors'])\n",
"\n",
"# copy intersect weights\n",
"state_dict = ckpt['model'].float().state_dict()\n",
"exclude = []\n",
Expand Down Expand Up @@ -179,7 +194,7 @@
"model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n",
"\n",
"# reparametrized YOLOR\n",
"for i in range(255):\n",
"for i in range((model.nc+5)*anchors):\n",
" model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
" model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
" model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
Expand Down Expand Up @@ -223,13 +238,18 @@
"from models.yolo import Model\n",
"import torch\n",
"from utils.torch_utils import select_device, is_parallel\n",
"import yaml\n",
"\n",
"device = select_device('0', batch_size=1)\n",
"# model trained by cfg/training/*.yaml\n",
"ckpt = torch.load('cfg/training/yolov7-e6.pt', map_location=device)\n",
"# reparameterized model in cfg/deploy/*.yaml\n",
"model = Model('cfg/deploy/yolov7-e6.yaml', ch=3, nc=80).to(device)\n",
"\n",
"with open('cfg/deploy/yolov7-e6.yaml') as f:\n",
" yml = yaml.load(f, Loader=yaml.SafeLoader)\n",
"anchors = len(yml['anchors'])\n",
"\n",
"# copy intersect weights\n",
"state_dict = ckpt['model'].float().state_dict()\n",
"exclude = []\n",
Expand Down Expand Up @@ -260,7 +280,7 @@
"model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n",
"\n",
"# reparametrized YOLOR\n",
"for i in range(255):\n",
"for i in range((model.nc+5)*anchors):\n",
" model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
" model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
" model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
Expand Down Expand Up @@ -304,13 +324,18 @@
"from models.yolo import Model\n",
"import torch\n",
"from utils.torch_utils import select_device, is_parallel\n",
"import yaml\n",
"\n",
"device = select_device('0', batch_size=1)\n",
"# model trained by cfg/training/*.yaml\n",
"ckpt = torch.load('cfg/training/yolov7-d6.pt', map_location=device)\n",
"# reparameterized model in cfg/deploy/*.yaml\n",
"model = Model('cfg/deploy/yolov7-d6.yaml', ch=3, nc=80).to(device)\n",
"\n",
"with open('cfg/deploy/yolov7-d6.yaml') as f:\n",
" yml = yaml.load(f, Loader=yaml.SafeLoader)\n",
"anchors = len(yml['anchors'])\n",
"\n",
"# copy intersect weights\n",
"state_dict = ckpt['model'].float().state_dict()\n",
"exclude = []\n",
Expand Down Expand Up @@ -341,7 +366,7 @@
"model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n",
"\n",
"# reparametrized YOLOR\n",
"for i in range(255):\n",
"for i in range((model.nc+5)*anchors):\n",
" model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
" model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
" model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
Expand Down Expand Up @@ -385,13 +410,18 @@
"from models.yolo import Model\n",
"import torch\n",
"from utils.torch_utils import select_device, is_parallel\n",
"import yaml\n",
"\n",
"device = select_device('0', batch_size=1)\n",
"# model trained by cfg/training/*.yaml\n",
"ckpt = torch.load('cfg/training/yolov7-e6e.pt', map_location=device)\n",
"# reparameterized model in cfg/deploy/*.yaml\n",
"model = Model('cfg/deploy/yolov7-e6e.yaml', ch=3, nc=80).to(device)\n",
"\n",
"with open('cfg/deploy/yolov7-e6e.yaml') as f:\n",
" yml = yaml.load(f, Loader=yaml.SafeLoader)\n",
"anchors = len(yml['anchors'])\n",
"\n",
"# copy intersect weights\n",
"state_dict = ckpt['model'].float().state_dict()\n",
"exclude = []\n",
Expand Down Expand Up @@ -422,7 +452,7 @@
"model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n",
"\n",
"# reparametrized YOLOR\n",
"for i in range(255):\n",
"for i in range((model.nc+5)*anchors):\n",
" model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
" model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
" model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
Expand Down Expand Up @@ -457,7 +487,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.7.0 ('py37')",
"language": "python",
"name": "python3"
},
Expand All @@ -471,7 +501,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.7.0"
},
"vscode": {
"interpreter": {
"hash": "73080970ff6fd25f9fcdf9c6f9e85b950a97864bb936ee53fb633f473cbfae4b"
}
}
},
"nbformat": 4,
Expand Down

0 comments on commit 264fc09

Please sign in to comment.