Skip to content

Commit

Permalink
Allow optional optimizers part 1
Browse files Browse the repository at this point in the history
  • Loading branch information
mcarilli authored Mar 20, 2019
1 parent bd0db55 commit b80b4d3
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions apex/amp/_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def _initialize(models, optimizers, properties):
if isinstance(optimizers, torch.optim.Optimizer):
optimizers_was_list = False
optimizers = [optimizers]
elif optimizers is None:
optimizers = []
elif isinstance(optimizers, list):
optimizers_was_list = True
else:
Expand All @@ -145,7 +147,7 @@ def _initialize(models, optimizers, properties):
check_models(models)

check_params_fp32(models)

check_optimizers(optimizers)

# In the future, when FP16_Optimizer can be deprecated and master weights can
Expand Down Expand Up @@ -217,6 +219,12 @@ def new_step(*args, **kwargs):
return models[0], optimizers
else:
if models_was_list:
return models, optimizers[0]
if len(optimizers) == 0:
return models
else:
return models, optimizers[0]
else:
return models[0], optimizers[0]
if len(optimizers) == 0:
return models[0]
else:
return models[0], optimizers[0]

0 comments on commit b80b4d3

Please sign in to comment.