Skip to content

Commit

Permalink
remove .data in NadamWCosineDecay (google-research#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mxbonn authored Mar 26, 2020
1 parent c14eda7 commit 1486dfe
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions opt_list/opt_list/torch_opt_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(

super(NadamWCosineDecay, self).__init__(params, defaults)

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Expand All @@ -176,13 +177,14 @@ def step(self, closure=None):
"""
loss = None
if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
grad = p.grad

if grad.is_sparse:
raise RuntimeError("No SparseGrads supported at this time.")
Expand All @@ -194,10 +196,10 @@ def step(self, closure=None):
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p.data, memory_format=torch.preserve_format)
p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p.data, memory_format=torch.preserve_format)
p, memory_format=torch.preserve_format)

lr = get_cosine_learning_rate_fn(group["training_steps"],
group["learning_rate"],
Expand Down Expand Up @@ -233,7 +235,7 @@ def step(self, closure=None):

step = step + (lr_t * group["adamw_weight_decay"] * p)

p.data.add_(-step)
p.add_(-step)

return loss

Expand Down

0 comments on commit 1486dfe

Please sign in to comment.