diff --git a/benchmarks/dynamo/README.md b/benchmarks/dynamo/README.md index 91556084cd0db..0f441083cdeef 100644 --- a/benchmarks/dynamo/README.md +++ b/benchmarks/dynamo/README.md @@ -7,7 +7,7 @@ The runner integrates with models from TorchBenchmark, HuggingFace and TIMM suit The infrastructure allows us to specify a loss function. For torchbench models, we use .sum().backward() call in place of the native loss function. For TIMM models, we use a CrossEntropy loss. And HF models contain a loss function inside the model itself, so we don't need any special loss computation handling. -Training benchmarks approximate training by running the model forward, computing loss and then running backward. We entirely skip the optimizer step today. +Training benchmarks approximate training by running the model forward, computing loss, running backward, and then the optimizer (SGD). Note: the optimizer is currently not compiled by Torchdynamo. Inference benchmarks and Training benchmarks measure correctness by comparing dynamo and eager model outputs given fixed inputs and seeds. diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 826b5946f4190..0746d9152add7 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -128,6 +128,12 @@ ] +CI_SKIP_OPTIMIZER = { + # TIMM + "convmixer_768_32", # accuracy +} + + def model_specified_by_path(path_and_class_str): return ":" in path_and_class_str @@ -872,6 +878,7 @@ def __init__(self): self.use_amp = False self.grad_scaler = DummyGradScaler() self.autocast = NullContext + self.optimizer = None self._args = None def setup_amp(self): @@ -900,16 +907,11 @@ def setup_amp(self): # self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0) self.autocast = torch.cuda.amp.autocast - def init_optimizer(self, device, params): - self.optimizer = None - # TODO - Currently, optimizers are used incorrectly. Fix optimizers with - # https://github.com/pytorch/pytorch/pull/87492 - # param_list = list(params) - # if device == "cuda" and len(param_list) != 0: - # # capturable is only supported on cuda at the moment - # self.optimizer = torch.optim.Adam(param_list, capturable=True) - # else: - # self.optimizer = None + def init_optimizer(self, name, device, params): + if device == "cuda" and self.args.training and name not in CI_SKIP_OPTIMIZER: + self.optimizer = torch.optim.SGD(params, lr=0.01) + else: + self.optimizer = None @property def args(self): @@ -1092,12 +1094,12 @@ def deepcopy_and_maybe_ddp(model): # Collect the fp64 reference outputs to be used later for accuracy checking. fp64_outputs = None try: - fp64_outputs = self.run_n_iterations( - *cast_to_fp64( - deepcopy_and_maybe_ddp(model), - clone_inputs(example_inputs), - ) + model_fp64, inputs_fp64 = cast_to_fp64( + deepcopy_and_maybe_ddp(model), + clone_inputs(example_inputs), ) + self.init_optimizer(name, current_device, model_fp64.parameters()) + fp64_outputs = self.run_n_iterations(model_fp64, inputs_fp64) except Exception: log.warning( f"fp64 golden ref were not generated for {name}. Setting accuracy check to cosine" @@ -1118,14 +1120,18 @@ def deepcopy_and_maybe_ddp(model): with self.pick_grad(name, self.args.training): # Get results of native pytorch reset_rng_state() + model_copy = deepcopy_and_maybe_ddp(model) + self.init_optimizer(name, current_device, model_copy.parameters()) correct_result = self.run_n_iterations( - deepcopy_and_maybe_ddp(model), clone_inputs(example_inputs) + model_copy, clone_inputs(example_inputs) ) # Rerun native pytorch reset_rng_state() + model_copy = deepcopy_and_maybe_ddp(model) + self.init_optimizer(name, current_device, model_copy.parameters()) correct_rerun_result = self.run_n_iterations( - deepcopy_and_maybe_ddp(model), clone_inputs(example_inputs) + model_copy, clone_inputs(example_inputs) ) if not same( correct_result, @@ -1141,11 +1147,11 @@ def deepcopy_and_maybe_ddp(model): reset_rng_state() torch._dynamo.reset() try: + model_copy = deepcopy_and_maybe_ddp(model) + self.init_optimizer(name, current_device, model_copy.parameters()) optimized_model_iter_fn = optimize_ctx(self.run_n_iterations) - new_result = optimized_model_iter_fn( - deepcopy_and_maybe_ddp(model), example_inputs - ) + new_result = optimized_model_iter_fn(model_copy, example_inputs) except Exception as e: accuracy_status = "fail_to_run" print( @@ -1193,6 +1199,7 @@ def warmup(fn, model, example_inputs, mode, niters=5): # Cast the model to float16/float32 as necessary model, example_inputs = self.maybe_cast(model, example_inputs) + self.init_optimizer(name, current_device, model.parameters()) with self.pick_grad(name, self.args.training): ok, total = Stats.reset_counters() experiment_kwargs = {} diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index 1add72e8dcb2c..bc4ace91b0eec 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -433,8 +433,6 @@ def load_model( else: model.eval() - self.init_optimizer(device, model.parameters()) - self.validate_model(model, example_inputs) return device, model_name, model, example_inputs, batch_size diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index b0540238b7949..4ecab848b4140 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -261,8 +261,6 @@ def load_model( else: model.eval() - self.init_optimizer(device, model.parameters()) - self.validate_model(model, example_inputs) return device, model_name, model, example_inputs, batch_size diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index d138e3e692462..50f9d1e5af5fc 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -295,8 +295,6 @@ def load_model( gc.collect() batch_size = benchmark.batch_size - self.init_optimizer(device, model.parameters()) - # Torchbench has quite different setup for yolov3, so directly passing # the right example_inputs if model_name == "yolov3":