Skip to content

Commit

Permalink
Add training loop that runs.
Browse files Browse the repository at this point in the history
- modified creation of adjacency matrix - was 32x32 when it should have been 1024 x 1024
- wrote a very basic training loop that iterates over one data sample
  • Loading branch information
emmabenjaminson committed Aug 1, 2021
1 parent fc3004c commit 55bf488
Show file tree
Hide file tree
Showing 5 changed files with 1,810 additions and 133 deletions.
213 changes: 180 additions & 33 deletions notebooks/.ipynb_checkpoints/1.0_graph_inputs-checkpoint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,23 @@
"The adjacency matrix contains 1's indicating where 2 publications cite each other. Kipf and Welling define the adj matrix as also having all diagonal elements set to 1 to indicate that nodes/publications are connected to themselves. This matrix is converted to $\\tilde{A}$, which is a symmetrically normalized adjacency matrix computed as $\\tilde{A} = D^{-\\frac{1}{2}} A D^{-\\frac{1}{2}}$. "
]
},
{
"cell_type": "markdown",
"id": "2d305169-cbb0-4d1c-b893-d9f69c9a1ba5",
"metadata": {},
"source": [
"## Overall Process\n",
"\n",
"1. Import the raw data as images with 3 channels using a Dataset class\n",
"2. The Dataset class applies a Transform that converts the raw data to A and X matrices\n",
"3. The DataLoader imports the Dataset\n",
"4. Define GCN layers \n",
"5. Define GCN encoder\n",
"6. Define GCN decoder\n",
"7. Define a training loop\n",
"8. Run training"
]
},
{
"cell_type": "markdown",
"id": "076171a0-ef7d-4aa8-a8f3-963204c87cd7",
Expand All @@ -58,13 +75,14 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "0e539e94-deee-4777-8dd6-e7c0ca3d08d4",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import scipy.sparse as sp\n",
"from scipy.linalg import sqrtm\n",
"import torch\n",
"import matplotlib.pyplot as plt"
]
Expand Down Expand Up @@ -335,7 +353,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "1e05d943-d061-43b6-91fa-a6bbd8c83875",
"metadata": {},
"outputs": [],
Expand All @@ -348,7 +366,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "57209bae-2188-4b68-9fa7-34bfcdcefef6",
"metadata": {},
"outputs": [
Expand All @@ -358,7 +376,7 @@
"(3, 32, 32)"
]
},
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -369,17 +387,17 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "590ccef7-ad4c-4ab1-a226-7c2410440c5e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x7f794d84a3d0>"
"<matplotlib.colorbar.Colorbar at 0x7f1532faf7f0>"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
Expand All @@ -403,17 +421,17 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "6526e177-02f8-461b-91ea-a11db8866315",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x7f794d185d90>"
"<matplotlib.colorbar.Colorbar at 0x7f160c5e7940>"
]
},
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
Expand All @@ -437,17 +455,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "a933307e-eb28-4ce1-9b26-0aa4528f039d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x7f794cc6db80>"
"<matplotlib.colorbar.Colorbar at 0x7f1532da3970>"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
Expand All @@ -471,7 +489,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "fbd24ccc-e836-4388-82e8-8b5e04719337",
"metadata": {},
"outputs": [
Expand All @@ -491,10 +509,10 @@
{
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x7f794d7752e0>"
"<matplotlib.colorbar.Colorbar at 0x7f1532db1670>"
]
},
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
},
Expand Down Expand Up @@ -532,7 +550,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "6683295d-e71f-464a-a6c2-6c0d059f510e",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -561,7 +579,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 11,
"id": "1c2c8a51-6195-47b0-978e-0a8d1aa674cf",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -598,15 +616,27 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "f2195a18-c43d-45f6-95fe-7c4640b7207f",
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "IndexError",
"evalue": "index 70 is out of bounds for axis 0 with size 70",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-12-e1cbcf9f7b42>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# add these IDs of nonzero pixels as keys to dictionary\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mkey\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmyfile\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow_coords\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcol_coords\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;31m# get x and y indices of i-th nonzero pixel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mIndexError\u001b[0m: index 70 is out of bounds for axis 0 with size 70"
]
}
],
"source": [
"# iterate through the keys\n",
"\n",
"A_dict = {}\n",
"for i in range(row_coords.shape[0]):\n",
"for i in range(row_coords.shape[0]**2):\n",
" \n",
" # add these IDs of nonzero pixels as keys to dictionary\n",
" key = myfile[3, row_coords[i], col_coords[i]]\n",
Expand Down Expand Up @@ -671,7 +701,7 @@
"# print(A_dict)\n",
"# break\n",
"\n",
"# print(A_dict)"
"print(A_dict)"
]
},
{
Expand Down Expand Up @@ -771,7 +801,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 6,
"id": "9d27ce28-e9ef-48b6-bcd6-07fcfe7de673",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -821,36 +851,153 @@
"# THIS IS THE ORIGINAL SCRIPT - WE MATCH THESE RESULTS\n",
"adj2 = sp.coo_matrix(adj)\n",
"print(\"adj2 before transformation\\n\", adj2.todense())\n",
"adj3 = adj2 + adj2.T.multiply(adj2.T > adj2) - adj2.multiply(adj2.T > adj2)\n",
"print(\"adj2 after transformation\\n\", adj3.todense())"
"adj2 = adj2 + adj2.T.multiply(adj2.T > adj2) - adj2.multiply(adj2.T > adj2)\n",
"print(\"adj2 after transformation\\n\", adj2.todense())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a7a5dff1-2fc7-4c85-a5c8-402f633d404f",
"metadata": {},
"outputs": [],
"source": [
"# normalize A\n",
"def normalize(mx):\n",
" \"\"\"Row-normalize sparse matrix\"\"\"\n",
" rowsum = np.array(mx.sum(1))\n",
" r_inv = np.power(rowsum, -1).flatten()\n",
" r_inv[np.isinf(r_inv)] = 0.\n",
" r_mat_inv = sp.diags(r_inv)\n",
" mx = r_mat_inv.dot(mx)\n",
" return mx"
]
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 10,
"id": "99d718a5-4125-407a-827e-b2b95de80a9c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4\n",
"A not normalized [[1. 1. 0. 0.]\n",
"A not normalized\n",
" [[1. 1. 0. 0.]\n",
" [1. 1. 1. 0.]\n",
" [0. 1. 1. 1.]\n",
" [0. 0. 1. 1.]]\n"
" [0. 0. 1. 1.]]\n",
"rowsum [2. 3. 3. 2.]\n",
"diagonal rowsum\n",
" [[2. 0. 0. 0.]\n",
" [0. 3. 0. 0.]\n",
" [0. 0. 3. 0.]\n",
" [0. 0. 0. 2.]]\n",
"sqrt inverse of D\n",
" [[0.70710678 0. 0. 0. ]\n",
" [0. 0.57735027 0. 0. ]\n",
" [0. 0. 0.57735027 0. ]\n",
" [0. 0. 0. 0.70710678]]\n",
"normalized A\n",
" [[0.5 0.40824829 0. 0. ]\n",
" [0.40824829 0.33333333 0.33333333 0. ]\n",
" [0. 0.33333333 0.33333333 0.40824829]\n",
" [0. 0. 0.40824829 0.5 ]]\n",
"adj2 before normalization\n",
" [[1. 1. 0. 0.]\n",
" [1. 1. 1. 0.]\n",
" [0. 1. 1. 1.]\n",
" [0. 0. 1. 1.]]\n",
"adj2 after normalization\n",
" [[0.5 0.5 0. 0. ]\n",
" [0.33333333 0.33333333 0.33333333 0. ]\n",
" [0. 0.33333333 0.33333333 0.33333333]\n",
" [0. 0. 0.5 0.5 ]]\n"
]
}
],
"source": [
"A_not_normalized = adj1\n",
"A = adj1\n",
"\n",
"# add 1's along the diagonal of A\n",
"A_not_normalized = A_not_normalized + np.eye(A_not_normalized.shape[0])\n",
"print(\"A not normalized\", A_not_normalized)\n",
"A = A + np.eye(A.shape[0])\n",
"print(\"A not normalized\\n\", A)\n",
"\n",
"rowsum = A.sum(axis=0)\n",
"print(\"rowsum\", rowsum)\n",
"D = np.diagflat(rowsum)\n",
"print(\"diagonal rowsum\\n\", D)\n",
"D_inv = np.linalg.inv(sqrtm(D))\n",
"# D_inv = np.linalg.inv(D)\n",
"\n",
"print(\"sqrt inverse of D\\n\", D_inv)\n",
"A_tilde = D_inv.T @ A @ D_inv\n",
"# A_tilde = np.dot(D_inv, A)\n",
"print(\"normalized A\\n\", A_tilde)\n",
"\n",
"# normalize A\n"
"\n",
"# COMPARE AGAINST THE ORIGINAL PAPER'S RESULTS\n",
"adj3 = adj2 + sp.eye(adj2.shape[0])\n",
"print(\"adj2 before normalization\\n\", adj3.todense())\n",
"adj3_tilde = normalize(adj3)\n",
"print(\"adj2 after normalization\\n\", adj3_tilde.todense())\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "4759fd84-4559-41fb-ba71-ae48859dfeb6",
"metadata": {},
"outputs": [],
"source": [
"def preprocess_graph(adj):\n",
" adj = sp.coo_matrix(adj)\n",
" adj_ = adj + sp.eye(adj.shape[0])\n",
" rowsum = np.array(adj_.sum(1))\n",
" degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())\n",
" adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()\n",
" return adj_normalized.todense()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b94e6207-8e9e-4f5c-b0e2-78cbbfcbfa8a",
"metadata": {},
"outputs": [],
"source": [
"def sparse_to_tuple(sparse_mx):\n",
" if not sp.isspmatrix_coo(sparse_mx):\n",
" sparse_mx = sparse_mx.tocoo()\n",
" coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()\n",
" values = sparse_mx.data\n",
" shape = sparse_mx.shape\n",
" return coords, values, shape"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "fda1fcef-1a75-416f-b49e-c6b13d145589",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.5 0.40824829 0. 0. ]\n",
" [0.40824829 0.33333333 0.33333333 0. ]\n",
" [0. 0.33333333 0.33333333 0.40824829]\n",
" [0. 0. 0.40824829 0.5 ]]\n"
]
}
],
"source": [
"# MY RESULTS MATCH VGAE_PYTORCH REPO, I'M GOING TO PROCEED WITH THIS BECAUSE ITS THE MATHEMATICALLY CORRECT DEFINITION\n",
"adj2_tilde = preprocess_graph(adj2)\n",
"print(adj2_tilde)"
]
},
{
Expand Down
Loading

0 comments on commit 55bf488

Please sign in to comment.