Skip to content

Commit

Permalink
Ensure tensors are on proper device
Browse files Browse the repository at this point in the history
Remove execution output in regularization notebook
  • Loading branch information
AaronCCWong committed Feb 14, 2019
1 parent 3e75af6 commit 9f3a9a9
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 82 deletions.
6 changes: 3 additions & 3 deletions 01-tensor_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -910,9 +910,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "pytorch_latest",
"display_name": "Python 3",
"language": "python",
"name": "pytorch_latest"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -924,7 +924,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.6"
"version": "3.6.7"
}
},
"nbformat": 4,
Expand Down
45 changes: 29 additions & 16 deletions 02-space_stretching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,25 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# generate some points in 2-D space\n",
"n_points = 1000\n",
"X = torch.randn(n_points, 2) \n",
"X = torch.randn(n_points, 2).to(device)\n",
"colors = X[:, 0]\n",
"\n",
"show_scatterplot(X, colors, title='X')\n",
"OI = torch.cat((torch.zeros(2, 2), torch.eye(2)))\n",
"OI = torch.cat((torch.zeros(2, 2), torch.eye(2))).to(device)\n",
"plot_bases(OI)"
]
},
Expand Down Expand Up @@ -76,7 +87,7 @@
"\n",
"for i in range(10):\n",
" # create a random matrix\n",
" W = torch.randn(2, 2)\n",
" W = torch.randn(2, 2).to(device)\n",
" # transform points\n",
" Y = X @ W\n",
" # compute singular values\n",
Expand All @@ -100,16 +111,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"model = nn.Sequential(\n",
Expand Down Expand Up @@ -211,6 +215,7 @@
" NL, \n",
" nn.Linear(n_hidden, 2)\n",
" )\n",
" model.to(device)\n",
" with torch.no_grad():\n",
" Y = model(X)\n",
" show_scatterplot(Y, colors, title='f(x)')\n",
Expand Down Expand Up @@ -244,17 +249,25 @@
" NL, \n",
" nn.Linear(n_hidden, 2)\n",
" )\n",
" model.to(device)\n",
" with torch.no_grad():\n",
" Y = model(X).detach()\n",
" show_scatterplot(Y, colors, title='f(x)')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch_latest",
"display_name": "Python 3",
"language": "python",
"name": "pytorch_latest"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -266,7 +279,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.6"
"version": "3.6.7"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions 03-autograd_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "pytorch_latest",
"display_name": "Python 3",
"language": "python",
"name": "pytorch_latest"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -292,7 +292,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.6"
"version": "3.6.7"
}
},
"nbformat": 4,
Expand Down
35 changes: 21 additions & 14 deletions 04-spiral_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@
"set_default()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -59,8 +68,8 @@
"metadata": {},
"outputs": [],
"source": [
"X = torch.zeros(N * C, D)\n",
"y = torch.zeros(N * C, dtype=torch.long)\n",
"X = torch.zeros(N * C, D).to(device)\n",
"y = torch.zeros(N * C, dtype=torch.long).to(device)\n",
"for c in range(C):\n",
" index = 0\n",
" t = torch.linspace(0, 1, N)\n",
Expand Down Expand Up @@ -114,15 +123,6 @@
"lambda_l2 = 1e-5"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -261,13 +261,20 @@
"print(model)\n",
"plot_model(X, y, model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch_latest",
"display_name": "Python 3",
"language": "python",
"name": "pytorch_latest"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -279,7 +286,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.6"
"version": "3.6.7"
}
},
"nbformat": 4,
Expand Down
51 changes: 29 additions & 22 deletions 05-regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@
"set_default()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -60,8 +69,8 @@
"metadata": {},
"outputs": [],
"source": [
"X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)\n",
"y = X.pow(3) + 0.3 * torch.rand(X.size())"
"X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1).to(device)\n",
"y = X.pow(3) + 0.3 * torch.rand(X.size()).to(device)"
]
},
{
Expand All @@ -81,7 +90,7 @@
"metadata": {},
"outputs": [],
"source": [
"plt.scatter(X.numpy(), y.numpy())\n",
"plt.scatter(X.cpu().numpy(), y.cpu().numpy())\n",
"plt.axis('equal')"
]
},
Expand All @@ -102,15 +111,6 @@
"lambda_l2 = 1e-5"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -172,8 +172,8 @@
"metadata": {},
"outputs": [],
"source": [
"plt.scatter(X.data.numpy(), y.data.numpy())\n",
"plt.plot(X.data.numpy(), y_pred.data.numpy(), 'r-', lw=5)\n",
"plt.scatter(X.data.cpu().numpy(), y.data.cpu().numpy())\n",
"plt.plot(X.data.cpu().numpy(), y_pred.data.cpu().numpy(), 'r-', lw=5)\n",
"plt.axis('equal');"
]
},
Expand Down Expand Up @@ -283,13 +283,13 @@
" # New X that ranges from -5 to 5 instead of -1 to 1\n",
" X_new = torch.unsqueeze(torch.linspace(-2, 2, 100), dim=1)\n",
" \n",
" plt.plot(X_new.numpy(), y_pretrain_idx.numpy(), 'r-', lw=1)\n",
" plt.plot(X_new.numpy(), y_pretrain_idx.cpu().numpy(), 'r-', lw=1)\n",
"\n",
"plt.scatter(X.numpy(), y.numpy(), label='data')\n",
"plt.scatter(X.cpu().numpy(), y.cpu().numpy(), label='data')\n",
"plt.axis('square')\n",
"plt.axis((-1.1, 1.1, -1.1, 1.1));\n",
"y_combo = torch.stack(y_pretrain)\n",
"plt.plot(X_new.numpy(), y_combo.var(dim=0).numpy(), 'g', label='variance');\n",
"plt.plot(X_new.numpy(), y_combo.var(dim=0).cpu().numpy(), 'g', label='variance');\n",
"plt.legend()"
]
},
Expand All @@ -309,21 +309,28 @@
"y_pred = list()\n",
"for model in models:\n",
" # New X that ranges from -5 to 5 instead of -1 to 1\n",
" X_new = torch.unsqueeze(torch.linspace(-4, 4, 1001), dim=1)\n",
" X_new = torch.unsqueeze(torch.linspace(-4, 4, 1001), dim=1).to(device)\n",
"\n",
" # Getting predictions from input\n",
" with torch.no_grad():\n",
" y_pred.append(model(X_new))\n",
" \n",
" plt.plot(X_new.numpy(), y_pred[-1].numpy(), 'r-', lw=1)\n",
" plt.plot(X_new.cpu().numpy(), y_pred[-1].cpu().numpy(), 'r-', lw=1)\n",
"\n",
"plt.scatter(X.numpy(), y.numpy(), label='data')\n",
"plt.scatter(X.cpu().numpy(), y.cpu().numpy(), label='data')\n",
"plt.axis('square')\n",
"plt.axis((-1.1, 1.1, -1.1, 1.1));\n",
"# plt.axis((-4.1, 4.1, -4.1, 4.1)); # toggle me :)\n",
"y_combo = torch.stack(y_pred)\n",
"plt.plot(X_new.numpy(), y_combo.var(dim=0).numpy(), 'g', label='variance');"
"plt.plot(X_new.cpu().numpy(), y_combo.var(dim=0).cpu().numpy(), 'g', label='variance');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -342,7 +349,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
"version": "3.6.7"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 9f3a9a9

Please sign in to comment.