Skip to content

Commit

Permalink
Update the arg description.
Browse files Browse the repository at this point in the history
`slice_index` attribute has been added for GPU.

PiperOrigin-RevId: 618354455
  • Loading branch information
changhuilin authored and jax authors committed Mar 23, 2024
1 parent 6f0737b commit b709925
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/experimental/mesh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def create_hybrid_device_mesh(
process_is_granule: if True, this function will treat processes as the units
of the slower/outer network. Otherwise it will look for slice_index
attributes on devices and use slices as the units. Enabling this is meant
as a fallback for platforms (e.g., GPU) that don't set slice_index.
as a fallback for platforms that don't set slice_index.
should_sort_granules_by_key: Whether device granules should be sorted by the
granule key, either slice or process index, depending on
process_is_granule.
Expand Down

0 comments on commit b709925

Please sign in to comment.