Skip to content

Commit

Permalink
Hotfix for diverging behaviours in between vissl and classy (facebook…
Browse files Browse the repository at this point in the history
…research#669)

Summary:
Pull Request resolved: facebookresearch#669

Pull Request resolved: facebookresearch/vissl#107

Testing, fix an issue which comes from the different definitions for AMP in between Classy and VISSL

Reviewed By: prigoyal

Differential Revision: D25598769

fbshipit-source-id: 088c9b370b0d2cabeb0cab5901b8725c19c22181
  • Loading branch information
blefaudeux authored and facebook-github-bot committed Dec 17, 2020
1 parent 17c4a18 commit 1cea86f
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,11 +870,12 @@ def get_classy_state(self, deep_copy: bool = False):
if isinstance(self.base_loss, ClassyLoss):
classy_state_dict["loss"] = self.base_loss.get_classy_state()
if self.amp_args is not None:
classy_state_dict["amp"] = (
apex.amp.state_dict()
if self.amp_type == AmpType.APEX
else self.amp_grad_scaler.state_dict()
)
if self.amp_type == AmpType.APEX:
classy_state_dict["amp"] = apex.amp.state_dict()

elif self.amp_grad_scaler is not None:
classy_state_dict["amp"] = self.amp_grad_scaler.state_dict()

if deep_copy:
classy_state_dict = copy.deepcopy(classy_state_dict)
return classy_state_dict
Expand Down

0 comments on commit 1cea86f

Please sign in to comment.