Skip to content

Commit

Permalink
pytorch/torchtune/tests/torchtune/modules/_export
Browse files Browse the repository at this point in the history
Differential Revision: D67388194

Pull Request resolved: pytorch#2179
  • Loading branch information
gmagogsfm authored Dec 20, 2024
1 parent 0cd8bc4 commit 6a53242
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
1 change: 1 addition & 0 deletions tests/torchtune/modules/_export/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def test_attention_export(self):
(self.x, self.x),
kwargs={"input_pos": self.input_pos},
dynamic_shapes=self.dynamic_shapes,
strict=True,
)
et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def test_tile_positional_embedding_smoke(self):
torch_version_ge("2.6.0.dev20241117"), reason="Need recent fixes for export"
)
def test_tile_positional_embedding_export(self):

tpe_ep = torch.export.export(
self.tpe,
(self.x, self.aspect_ratio),
dynamic_shapes=(
self.dynamic_shape,
None,
), # assuming aspect ratio is static
strict=True,
)

y = tpe_ep.module()(self.x, self.aspect_ratio)
Expand Down Expand Up @@ -129,14 +129,14 @@ def test_tiled_token_positional_embedding_smoke(self):
torch_version_ge("2.6.0.dev20241117"), reason="Need recent fixes for export"
)
def test_tiled_token_positional_embedding_export(self):

tpe_ep = torch.export.export(
self.tpe,
(self.x, self.aspect_ratio),
dynamic_shapes=(
self.dynamic_shape,
None,
), # assuming aspect ratio is static
strict=True,
)

y = tpe_ep.module()(self.x, self.aspect_ratio)
Expand All @@ -155,6 +155,7 @@ def test_tiled_token_positional_embedding_aoti(self):
self.dynamic_shape,
None,
), # assuming aspect ratio is static
strict=True,
)

with tempfile.TemporaryDirectory() as tmpdir:
Expand Down

0 comments on commit 6a53242

Please sign in to comment.