Skip to content

Commit

Permalink
Merge pull request #4 from young-geng/sharding_annotation
Browse files Browse the repository at this point in the history
Implement sharding annotation options
  • Loading branch information
Sea-Snell authored Feb 26, 2024
2 parents 1da5cfb + eb0e8d9 commit 0d9d5f3
Showing 1 changed file with 56 additions and 13 deletions.
69 changes: 56 additions & 13 deletions scalax/sharding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from functools import partial
import re
import abc
from dataclasses import dataclass
from typing import Optional, Mapping, Union, ClassVar, List
import numpy as np

import jax
Expand Down Expand Up @@ -124,7 +126,6 @@ def get_partition_spec(name, leaf):

class MeshShardingHelper(object):
""" Helper class for creating jit sharding jax functions with sharding rules. """
global_mesh_helper = None

def __init__(self, axis_dims, axis_names, mesh_axis_splitting=False):
""" Create a MeshShardingHelper.
Expand All @@ -145,15 +146,25 @@ def __init__(self, axis_dims, axis_names, mesh_axis_splitting=False):
self.mesh = Mesh(physical_mesh, axis_names)
self.previous_global_meshes = []

def __enter__(self):
# Use current mesh as global mesh
self.previous_global_meshes.append(MeshShardingHelper.global_mesh_helper)
MeshShardingHelper.global_mesh_helper = self
return self
def get_context(self, **kwargs):
return MeshShardingContext(
mesh_helper=self,
**kwargs
)

def __exit__(self, exc_type, exc_value, traceback):
# Restore last global mesh
MeshShardingHelper.global_mesh_helper = self.previous_global_meshes.pop()
@classmethod
def get_global_mesh(cls):
context = MeshShardingContext.get_global_context()
if context is None:
return None
return context.mesh_helper

@classmethod
def get_global_sharding_annotation_rules(cls):
context = MeshShardingContext.get_global_context()
if context is None:
return None
return context.sharding_annotation_rules

def _split_static_dynamic_args(self, static_argnums, args):
if static_argnums is None:
Expand Down Expand Up @@ -210,6 +221,7 @@ def sjit(self,
out_shardings=None,
static_argnums=None,
args_sharding_constraint=None,
sharding_annotation_rules=None,
**kwargs):
""" JIT compile a function with sharding rules.
Expand All @@ -220,6 +232,8 @@ def sjit(self,
static_argnums: The indices of the static arguments.
args_sharding_constraint: The sharding rule or partition specs to constrain
the args after the beginning of the function.
sharding_annotation_rules: A dictionary of sharding annotation rules, which
maps the name of the sharding annotation to a sharding rule or partition specs.
kwargs: Additional arguments for jax.jit.
Returns:
Expand All @@ -242,7 +256,7 @@ def sharding_constrained_fun(*args):
def wrapped(*args):
static_args = tuple(args[i] for i in static_argnums) if static_argnums is not None else ()
if static_args in static_args_jitted_fn_cache:
with self:
with self.get_context(sharding_annotation_rules=sharding_annotation_rules):
results = static_args_jitted_fn_cache[static_args](*args)
return results

Expand Down Expand Up @@ -272,7 +286,7 @@ def wrapped(*args):

static_args_jitted_fn_cache[static_args] = jitted_fn

with self:
with self.get_context(sharding_annotation_rules=sharding_annotation_rules):
results = jitted_fn(*args)
return results

Expand All @@ -285,11 +299,20 @@ def sharded_jit(self, *args, **kwargs):
@classmethod
def with_sharding_constraint(cls, pytree, sharding_rule):
# Enforce shard constraint with global mesh
if cls.global_mesh_helper is None:
if cls.get_global_mesh() is None:
return pytree
named_shardings = cls.global_mesh_helper.match_sharding_rule(sharding_rule, pytree)
named_shardings = cls.get_global_mesh().match_sharding_rule(
sharding_rule, pytree
)
return jax.lax.with_sharding_constraint(pytree, named_shardings)

@classmethod
def with_sharding_annotation(cls, pytree, sharding_name):
rules = cls.get_global_sharding_annotation_rules()
if rules is None or sharding_name not in rules:
return pytree
return cls.with_sharding_constraint(pytree, rules[sharding_name])

def make_shard_and_gather_fns(self, pytree, sharding_rule):
"""
Create pytree of sharding and gathering functions from sharding rule
Expand Down Expand Up @@ -388,3 +411,23 @@ def to_global_array(array):

return jax.tree_util.tree_map(to_global_array, pytree)


@dataclass
class MeshShardingContext(object):
""" Context and context manager for MeshShardingHelper. """
mesh_helper: MeshShardingHelper
sharding_annotation_rules: Optional[Mapping[str, Union[ShardingRule, PartitionSpec]]] = None
global_contexts: ClassVar[List] = []

def __enter__(self):
MeshShardingContext.global_contexts.append(self)
return self

def __exit__(self, exc_type, exc_value, traceback):
MeshShardingContext.global_contexts.pop()

@classmethod
def get_global_context(cls):
if len(cls.global_contexts) == 0:
return None
return cls.global_contexts[-1]

0 comments on commit 0d9d5f3

Please sign in to comment.