Skip to content

Commit

Permalink
[Tests] fix more sharding tests (huggingface#8797)
Browse files Browse the repository at this point in the history
* fix

* fix

* ugly

* okay

* fix more

* fix oops
  • Loading branch information
sayakpaul authored Jul 9, 2024
1 parent 35cc66d commit a785992
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,11 +885,11 @@ def test_model_parallelism(self):

@require_torch_gpu
def test_sharded_checkpoints(self):
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model = model.to(torch_device)

torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
Expand All @@ -909,7 +909,8 @@ def test_sharded_checkpoints(self):
new_model = new_model.to(torch_device)

torch.manual_seed(0)
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
Expand Down Expand Up @@ -942,7 +943,8 @@ def test_sharded_checkpoints_device_map(self):
new_model = new_model.to(torch_device)

torch.manual_seed(0)
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

Expand Down

0 comments on commit a785992

Please sign in to comment.