Skip to content

Commit

Permalink
dont persist feature order buffers as they dont need to be checkpoint…
Browse files Browse the repository at this point in the history
…ed (pytorch#956)

Summary:
Pull Request resolved: pytorch#956

TorchRec don't need to persist feature order buffers -> these are correct on instantiation and don't need to be check pointed

Previously, this issue was hidden bc shardedModules buffers() (via embedding kernels buffer call) yielded nothing, which is no longer true as of D41964643 (pytorch@e8ab2de)

Reviewed By: YLGH

Differential Revision: D42591693

fbshipit-source-id: 6694239c54fbbaeb563b1c0221f4a5324f2c99b6
  • Loading branch information
colin2328 authored and facebook-github-bot committed Jan 19, 2023
1 parent 5823e3f commit 1815b91
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ def _create_input_dist(
self.register_buffer(
"_features_order_tensor",
torch.tensor(self._features_order, device=self._device, dtype=torch.int32),
persistent=False,
)

def _create_lookups(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/embedding_tower_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def _create_input_dist(
torch.tensor(
self._kjt_features_order, device=self._device, dtype=torch.int32
),
persistent=False,
)

if self._wkjt_feature_names != wkjt_feature_names:
Expand All @@ -208,6 +209,7 @@ def _create_input_dist(
torch.tensor(
self._wkjt_features_order, device=self._device, dtype=torch.int32
),
persistent=False,
)

node_count = dist.get_world_size(self._cross_pg)
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ def _create_input_dist(
torch.tensor(
self._features_order, device=self._device, dtype=torch.int32
),
persistent=False,
)

def _create_lookups(
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def _create_input_dist(
self.register_buffer(
"_features_order_tensor",
torch.tensor(self._features_order, device=device, dtype=torch.int32),
persistent=False,
)

def _create_lookups(self, fused_params: Optional[Dict[str, Any]]) -> None:
Expand Down

0 comments on commit 1815b91

Please sign in to comment.