Skip to content

Commit

Permalink
add explicit gradient for lls example
Browse files Browse the repository at this point in the history
  • Loading branch information
jtamir committed Jul 28, 2020
1 parent 1a7aed3 commit 537779e
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions Physics_based_Learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
"metadata": {},
"outputs": [],
"source": [
"def pbnet(A, alpha, lamb, x0, xt, K, testFlag=True):\n",
"def pbnet(A, alpha, lamb, x0, xt, K, testFlag=True, explicit_grad=True):\n",
" # pbnet - Physics-based Network\n",
" # args in - \n",
" # A - measurement matrix\n",
Expand All @@ -132,6 +132,10 @@
" x.requires_grad = True\n",
"\n",
" y_meas = Aop(A, xt)\n",
" \n",
" if explicit_grad:\n",
" AHA = torch.matmul(torch.transpose(A, 0, 1), A)\n",
" AHy = torch.matmul(torch.transpose(A, 0, 1), y_meas)\n",
"\n",
" if testFlag: y_meas = y_meas.detach()\n",
" \n",
Expand All @@ -141,12 +145,14 @@
" y_est = Aop(A,x)\n",
" res = y_est - y_meas\n",
" loss_dc = torch.sum(res**2)\n",
" g = torch.autograd.grad(loss_dc, \n",
" x, \n",
" create_graph = not testFlag)\n",
"\n",
" if explicit_grad:\n",
" g = torch.matmul(AHA, x) - AHy\n",
" else:\n",
" g = torch.autograd.grad(loss_dc, \n",
" x, \n",
" create_graph = not testFlag)[0]\n",
" \n",
" x = x - alpha*g[0] # gradient update\n",
" x = x - alpha*g # gradient update\n",
" x = softthr(x, lamb*alpha) # proximal update\n",
" \n",
" with torch.no_grad():\n",
Expand Down

0 comments on commit 537779e

Please sign in to comment.