Skip to content

Commit

Permalink
decoupled sophia based on original
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed May 25, 2023
1 parent 17cc685 commit f0af7fb
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 6 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ input_data = ... #input data
optimizer = SophiaG(model.parameters(), lr=2e-4, betas=(0.965, 0.99), rho = 0.01, weight_decay=1e-1)

#decoupled
#optimizer = DecoupledSophia(model.parameters(), hessian_estimator, lr=1e-3)

#optimizer = DecoupledSophia(model.parameters(), lr=1e-3, betas=(0.9, 0.999), rho=0.04, weight_decay=1e-1, } estimator="Hutchinson")

#training loop
for epoch in range(epochs):
Expand Down
67 changes: 66 additions & 1 deletion Sophia/Sophia.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,4 +267,69 @@ def _single_tensor_sophiag(params: List[Tensor],
step_size_neg = - lr

ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1)
param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)
param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)

class DecoupledSophia(torch.optim.Optimizer):
def __init__(self, model, input_data, params, lr=1e-3, betas=(0.9, 0.999), rho=0.04, weight_decay=1e-1, k=10, estimator="Hutchinson"):
self.model = model
self.input_data = input_data
defaults = dict(lr=lr, betas=betas, rho=rho, weight_decay=weight_decay, k=k, estimator=estimator)
super(DecoupledSophia, self).__init__(params, defaults)

@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError("DecoupledSophia does not support sparse gradients")

state = self.state[p]

if len(state) == 0:
state['step'] = 0
state['m'] = torch.zeros_like(p.data)
state['h'] = torch.zeros_like(p.data)

m, h = state['m'], state['h']
beta1, beta2 = group['betas']
state['step'] += 1

if group['weight_decay'] != 0:
grad = grad.add(group["weight_decay"], p.data)

m.mul_(beta1).add_(1 - beta1, grad)

if state['step'] % group['k'] == 1:
if group['estimator'] == "Hutchinson":
hessian_estimate = self.hutchinson(p, grad)
elif group['estimator'] == "Gauss-Newton-Bartlett":
hessian_estimate = self.gauss_newton_bartlett(p, grad)
else:
raise ValueError("Invalid estimator choice")
h.mul_(beta2).add_(1 - beta2, hessian_estimate)

p.data.add_(-group['lr'] * group['weight_decay'], p.data)
p.data.addcdiv_(-group['lr'], m, h.add(group['rho']))

return loss

def hutchinson(self, p, grad):
u = torch.randn_like(grad)
grad_dot_u = torch.einsum("...,...->", grad, u)
hessian_vector_product = torch.autograd.grad(grad_dot_u, p, retain_graph=True)[0]
return u * hessian_vector_product

def gauss_newton_bartlett(self, p, grad):
B = len(self.input_data)
logits = [self.model(xb) for xb in self.input_data]
y_hats = [torch.softmax(logit, dim=0) for logit in logits]
g_hat = torch.autograd.grad(sum([self.loss_function(logit, y_hat) for logit, y_hat in zip(logits, y_hats)]) / B, p, retain_graph=True)[0]
return B * g_hat * g_hat
4 changes: 2 additions & 2 deletions Sophia/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from Sophia.Sophia import Sophia
from Sophia.Sophia import Sophia, SophiaG, DecoupledSophia
from Sophia.Sophiav2 import Sophia2
from Sophia.Sophia import SophiaG
# from decoupled_sophia.decoupled_sophia.decoupled_sophia import DecoupledSophia
from experiments.training import trainer

from Sophia.decoupled_sophia.decoupled_sophia import DecoupledSophia
# from Sophia.decoupled_sophia.decoupled_sophia import DecoupledSophia
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'Sophia-Optimizer',
packages = find_packages(exclude=[]),
version = '0.1.9',
version = '0.2.1',
license='APACHE',
description = 'Sophia Optimizer ULTRA FAST',
author = 'Kye Gomez',
Expand Down

0 comments on commit f0af7fb

Please sign in to comment.