Skip to content

Commit

Permalink
Rename sharded_jit to sjit in readme
Browse files Browse the repository at this point in the history
  • Loading branch information
young-geng committed Feb 11, 2024
1 parent b7e0cba commit 1da5cfb
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def train_step(train_state, batch):
This works fine for a single GPU/TPU, but if we want to scale up to multiple
GPU/TPUs, we need to partition the data or the model in order to parallelize
the training across devices. This is where scalax comes in. We can first create
a device mesh and then replace the `jax.jit` decorator with `mesh.sharded_jit`.
a device mesh and then replace the `jax.jit` decorator with `mesh.sjit`.
To use different parallelization strategies, we can provide different sharding
rules to the `sharded_jit` function. For example, to change the previous example
rules to the `sjit` function. For example, to change the previous example
into a data parallel training, we can do the following:

```python
Expand All @@ -56,7 +56,7 @@ from scalax.sharding import MeshShardingHelper, PartitionSpec

mesh = MeshShardingHelper([-1], ['dp']) # Create a 1D mesh with data parallelism axis
@partial(
mesh.sharded_jit,
mesh.sjit,
in_shardings=None,
out_shardings=None,
# constraint the batch argument to be sharded along the dp axis to enable data parallelism
Expand All @@ -80,7 +80,7 @@ from scalax.sharding import MeshShardingHelper, PartitionSpec, FSDPShardingRule

mesh = MeshShardingHelper([-1], ['fsdp']) # Create a 1D mesh with data parallelism axis
@partial(
mesh.sharded_jit,
mesh.sjit,
in_shardings=(FSDPShardingRule(), None), # Shard the train_state using FSDP
out_shardings=(FSDPShardingRule(), None),
args_sharding_constraint=(FSDPShardingRule(), PartitionSpec('fsdp')),
Expand Down

0 comments on commit 1da5cfb

Please sign in to comment.