Skip to content

Commit

Permalink
[optim] More cleanup and reorg of test_optim.py (pytorch#100917)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#100917
Approved by: https://github.com/albanD
  • Loading branch information
janeyx99 authored and pytorchmergebot committed May 9, 2023
1 parent d0dab77 commit d63e0b1
Showing 1 changed file with 38 additions and 44 deletions.
82 changes: 38 additions & 44 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@
load_tests = load_tests


# Assumes an input of a tensor with exactly 2 numerical scalars
def rosenbrock(tensor):
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
x, y = tensor
return (1 - x) ** 2 + 100 * (y - x**2) ** 2


# Assumes an input of a tensor with exactly 2 numerical scalars
def drosenbrock(tensor):
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
x, y = tensor
return torch.tensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))

Expand All @@ -79,28 +79,29 @@ def _test_rosenbrock_sparse(
):
if scheduler_constructors is None:
scheduler_constructors = []
params_t = torch.tensor([1.5, 1.5])
# For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers
param_t = torch.tensor([1.5, 1.5])

params = Parameter(params_t)
optimizer = constructor([params])
param = Parameter(param_t)
optimizer = constructor([param])
schedulers = []
for scheduler_constructor in scheduler_constructors:
schedulers.append(scheduler_constructor(optimizer))

if not sparse_only:
params_c = Parameter(params_t.clone())
optimizer_c = constructor([params_c])
param_c = Parameter(param_t.clone())
optimizer_c = constructor([param_c])

solution = torch.tensor([1, 1])
with torch.no_grad():
initial_dist = params.dist(solution)
initial_dist = param.dist(solution)

def eval(params, sparse_grad, w):
def eval(param, sparse_grad, w):
# Depending on w, provide only the x or y gradient
optimizer.zero_grad()
loss = rosenbrock(params)
loss = rosenbrock(param)
loss.backward()
grad = drosenbrock(params.data)
grad = drosenbrock(param)
# NB: We torture test the optimizer by returning an
# uncoalesced sparse tensor
if w:
Expand All @@ -114,28 +115,28 @@ def eval(params, sparse_grad, w):
x = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
with torch.no_grad():
if sparse_grad:
params.grad = x
param.grad = x
else:
params.grad = x.to_dense()
param.grad = x.to_dense()
return loss

for i in range(2000):
# Do cyclic coordinate descent
w = i % 2
optimizer.step(functools.partial(eval, params, True, w))
optimizer.step(functools.partial(eval, param, True, w))
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(rosenbrock(params))
scheduler.step(rosenbrock(param))
else:
scheduler.step()
if not sparse_only:
optimizer_c.step(functools.partial(eval, params_c, False, w))
self.assertEqual(params, params_c)
optimizer_c.step(functools.partial(eval, param_c, False, w))
self.assertEqual(param, param_c)

if not maximize:
self.assertLessEqual(params.data.dist(solution), initial_dist)
self.assertLessEqual(param.dist(solution), initial_dist)
else:
self.assertGreaterEqual(rosenbrock(params), rosenbrock(params_t))
self.assertGreaterEqual(rosenbrock(param), rosenbrock(param_t))

def _test_basic_cases_template(
self,
Expand Down Expand Up @@ -342,7 +343,7 @@ def _test_basic_cases(
scheduler_constructors = []

def make_two_arg_constructor(
constructor, maximize: bool = False, foreach: bool = False
constructor, maximize: bool, foreach: bool
):
if constructor_accepts_maximize and constructor_accepts_foreach:
return lambda weight, bias: constructor(weight, bias, maximize, foreach)
Expand Down Expand Up @@ -462,13 +463,6 @@ def test_sgd(self):
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
self._build_params_dict(weight, bias, lr=1e-2),
Expand Down Expand Up @@ -502,15 +496,15 @@ def test_sgd(self):
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
[lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
scheduler_constructors=[lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
[
scheduler_constructors=[
lambda opt: LinearLR(
opt, start_factor=0.4, end_factor=0.8, total_iters=4
)
Expand All @@ -522,15 +516,23 @@ def test_sgd(self):
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
scheduler_constructors=[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
[
scheduler_constructors=[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: LinearLR(
opt, start_factor=0.4, end_factor=0.6, total_iters=4
Expand Down Expand Up @@ -598,14 +600,6 @@ def test_sgd(self):
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -0.5"):
optim.SGD(None, lr=1e-2, momentum=-0.5)

Expand All @@ -616,7 +610,7 @@ def test_sgd_sparse(self):
)
self._test_rosenbrock_sparse(
lambda params: optim.SGD(params, lr=0.0048, foreach=foreach),
[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
scheduler_constructors=[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
)

def test_sgd_complex(self):
Expand Down Expand Up @@ -1096,9 +1090,9 @@ def test_sparse_adam(self):
)
self._test_rosenbrock_sparse(
lambda params: optim.SparseAdam(params, lr=4e-2, maximize=True),
[],
True,
True,
scheduler_constructors=[],
sparse_only=True,
maximize=True,
)
with self.assertRaisesRegex(
ValueError, "Invalid beta parameter at index 0: 1.0"
Expand Down Expand Up @@ -1271,7 +1265,7 @@ def test_adagrad_sparse(self):
)
self._test_rosenbrock_sparse(
lambda params: optim.Adagrad(params, lr=0.1, foreach=foreach),
[
scheduler_constructors=[
lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
],
Expand Down

0 comments on commit d63e0b1

Please sign in to comment.