Skip to content

Commit

Permalink
[reland2][dynamo] Revert "Revert "[reland][dynamo] use optimizers cor…
Browse files Browse the repository at this point in the history
…rectly in benchmar… (pytorch#90956)

…king (pytorch#87492)" (pytorch#90746)"

This reverts commit ff1bbc2.

This should be okay to merge now. The flakiness of HF models will be fixed by seeding the rng (pytorch#90936), and the numeric mismatch was root-caused to three decomps (still investigating why those decomps cause this) see pytorch/torchdynamo#1985 for more detail.

Pull Request resolved: pytorch#90956
Approved by: https://github.com/desertfire
  • Loading branch information
mlazos authored and pytorchmergebot committed Dec 16, 2022
1 parent c2c14f9 commit 8bc38ae
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 27 deletions.
2 changes: 1 addition & 1 deletion benchmarks/dynamo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The runner integrates with models from TorchBenchmark, HuggingFace and TIMM suit

The infrastructure allows us to specify a loss function. For torchbench models, we use .sum().backward() call in place of the native loss function. For TIMM models, we use a CrossEntropy loss. And HF models contain a loss function inside the model itself, so we don't need any special loss computation handling.

Training benchmarks approximate training by running the model forward, computing loss and then running backward. We entirely skip the optimizer step today.
Training benchmarks approximate training by running the model forward, computing loss, running backward, and then the optimizer (SGD). Note: the optimizer is currently not compiled by Torchdynamo.

Inference benchmarks and Training benchmarks measure correctness by comparing dynamo and eager model outputs given fixed inputs and seeds.

Expand Down
47 changes: 27 additions & 20 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@
]


CI_SKIP_OPTIMIZER = {
# TIMM
"convmixer_768_32", # accuracy
}


def model_specified_by_path(path_and_class_str):
return ":" in path_and_class_str

Expand Down Expand Up @@ -872,6 +878,7 @@ def __init__(self):
self.use_amp = False
self.grad_scaler = DummyGradScaler()
self.autocast = NullContext
self.optimizer = None
self._args = None

def setup_amp(self):
Expand Down Expand Up @@ -900,16 +907,11 @@ def setup_amp(self):
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
self.autocast = torch.cuda.amp.autocast

def init_optimizer(self, device, params):
self.optimizer = None
# TODO - Currently, optimizers are used incorrectly. Fix optimizers with
# https://github.com/pytorch/pytorch/pull/87492
# param_list = list(params)
# if device == "cuda" and len(param_list) != 0:
# # capturable is only supported on cuda at the moment
# self.optimizer = torch.optim.Adam(param_list, capturable=True)
# else:
# self.optimizer = None
def init_optimizer(self, name, device, params):
if device == "cuda" and self.args.training and name not in CI_SKIP_OPTIMIZER:
self.optimizer = torch.optim.SGD(params, lr=0.01)
else:
self.optimizer = None

@property
def args(self):
Expand Down Expand Up @@ -1092,12 +1094,12 @@ def deepcopy_and_maybe_ddp(model):
# Collect the fp64 reference outputs to be used later for accuracy checking.
fp64_outputs = None
try:
fp64_outputs = self.run_n_iterations(
*cast_to_fp64(
deepcopy_and_maybe_ddp(model),
clone_inputs(example_inputs),
)
model_fp64, inputs_fp64 = cast_to_fp64(
deepcopy_and_maybe_ddp(model),
clone_inputs(example_inputs),
)
self.init_optimizer(name, current_device, model_fp64.parameters())
fp64_outputs = self.run_n_iterations(model_fp64, inputs_fp64)
except Exception:
log.warning(
f"fp64 golden ref were not generated for {name}. Setting accuracy check to cosine"
Expand All @@ -1118,14 +1120,18 @@ def deepcopy_and_maybe_ddp(model):
with self.pick_grad(name, self.args.training):
# Get results of native pytorch
reset_rng_state()
model_copy = deepcopy_and_maybe_ddp(model)
self.init_optimizer(name, current_device, model_copy.parameters())
correct_result = self.run_n_iterations(
deepcopy_and_maybe_ddp(model), clone_inputs(example_inputs)
model_copy, clone_inputs(example_inputs)
)

# Rerun native pytorch
reset_rng_state()
model_copy = deepcopy_and_maybe_ddp(model)
self.init_optimizer(name, current_device, model_copy.parameters())
correct_rerun_result = self.run_n_iterations(
deepcopy_and_maybe_ddp(model), clone_inputs(example_inputs)
model_copy, clone_inputs(example_inputs)
)
if not same(
correct_result,
Expand All @@ -1141,11 +1147,11 @@ def deepcopy_and_maybe_ddp(model):
reset_rng_state()
torch._dynamo.reset()
try:
model_copy = deepcopy_and_maybe_ddp(model)
self.init_optimizer(name, current_device, model_copy.parameters())
optimized_model_iter_fn = optimize_ctx(self.run_n_iterations)

new_result = optimized_model_iter_fn(
deepcopy_and_maybe_ddp(model), example_inputs
)
new_result = optimized_model_iter_fn(model_copy, example_inputs)
except Exception as e:
accuracy_status = "fail_to_run"
print(
Expand Down Expand Up @@ -1193,6 +1199,7 @@ def warmup(fn, model, example_inputs, mode, niters=5):

# Cast the model to float16/float32 as necessary
model, example_inputs = self.maybe_cast(model, example_inputs)
self.init_optimizer(name, current_device, model.parameters())
with self.pick_grad(name, self.args.training):
ok, total = Stats.reset_counters()
experiment_kwargs = {}
Expand Down
2 changes: 0 additions & 2 deletions benchmarks/dynamo/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,6 @@ def load_model(
else:
model.eval()

self.init_optimizer(device, model.parameters())

self.validate_model(model, example_inputs)
return device, model_name, model, example_inputs, batch_size

Expand Down
2 changes: 0 additions & 2 deletions benchmarks/dynamo/timm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,6 @@ def load_model(
else:
model.eval()

self.init_optimizer(device, model.parameters())

self.validate_model(model, example_inputs)

return device, model_name, model, example_inputs, batch_size
Expand Down
2 changes: 0 additions & 2 deletions benchmarks/dynamo/torchbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,6 @@ def load_model(
gc.collect()
batch_size = benchmark.batch_size

self.init_optimizer(device, model.parameters())

# Torchbench has quite different setup for yolov3, so directly passing
# the right example_inputs
if model_name == "yolov3":
Expand Down

0 comments on commit 8bc38ae

Please sign in to comment.