Skip to content

Commit

Permalink
Modify: sum of weights
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasBoTang committed Aug 28, 2022
1 parent b97ee1d commit 5583b9e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions gradnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def gradNorm(net, layer, alpha, dataloader, num_epochs, lr1, lr2, log=False):
if iters == 0:
# init weights
weights = torch.ones_like(loss)
weights = torch.nn.Parameter(weights / weights.sum())
weights = torch.nn.Parameter(weights)
T = weights.sum().detach()
# set optimizer for weights
optimizer2 = torch.optim.Adam([weights], lr=lr2)
# set L(0)
Expand Down Expand Up @@ -79,7 +80,8 @@ def gradNorm(net, layer, alpha, dataloader, num_epochs, lr1, lr2, log=False):
# update loss weights
optimizer2.step()
# renormalize weights
weights = torch.nn.Parameter(weights / weights.sum())
weights = (weights / weights.sum() * T).detach()
weights = torch.nn.Parameter(weights)
optimizer2 = torch.optim.Adam([weights], lr=lr2)
# update iters
iters += 1
Expand Down

0 comments on commit 5583b9e

Please sign in to comment.