Skip to content

Commit

Permalink
Added util to print pytree structure and sharding rule based on calla…
Browse files Browse the repository at this point in the history
…ble policy.
  • Loading branch information
young-geng committed Feb 11, 2024
1 parent 42eb597 commit b7e0cba
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
19 changes: 19 additions & 0 deletions scalax/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,25 @@ def get_partition_spec(name, leaf):
return named_tree_map(get_partition_spec, pytree, sep='/')


class PolicyShardingRule(ShardingRule):
""" Create PartitionSpec for a pytree with a callable policy. """

def __init__(self, policy):
""" Create a PolicyShardingRule with a callable policy.
Args:
policy: A callable that takes a tree path and a leaf tensor as input
and returns a PartitionSpec.
"""
self.policy = policy

def apply(self, pytree):
""" Returns a pytree of PartitionSpec according to the policy. """
def get_partition_spec(name, leaf):
return self.policy(name, leaf)
return named_tree_map(get_partition_spec, pytree, sep='/')


class MeshShardingHelper(object):
""" Helper class for creating jit sharding jax functions with sharding rules. """
global_mesh_helper = None
Expand Down
7 changes: 7 additions & 0 deletions scalax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,10 @@ def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
is_leaf=is_leaf
)


def print_pytree_structure(tree, sep='/', is_leaf=None):
def print_fn(path, val):
shape = f'shape: {val.shape if hasattr(val, "shape") else "none"}'
dtype = f'dtype: {val.dtype if hasattr(val, "dtype") else "none"}'
print(f'{path}: {shape}, {dtype}')
named_tree_map(print_fn, tree, is_leaf=is_leaf, sep=sep)

0 comments on commit b7e0cba

Please sign in to comment.