Skip to content

Commit

Permalink
[autoparallel] distinguish different parallel strategies (hpcaitech#2699
Browse files Browse the repository at this point in the history
)
  • Loading branch information
YuliangLiu0306 authored Feb 15, 2023
1 parent ae86a29 commit 1dc003c
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 219 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
LinearProjectionStrategyGenerator(op_data_mapping,
self.device_mesh,
linear_projection_type='linear',
solver_perference=self.solver_perference))
return generators

def get_operation_data_mapping(self) -> Dict[str, OperationData]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import reduce
from typing import List

from colossalai.auto_parallel.tensor_shard.options import SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
Expand Down Expand Up @@ -209,9 +210,14 @@ def collate_strategies(self) -> List[ShardingStrategy]:

class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):

def __init__(self, operation_data_mapping, device_mesh, linear_projection_type='linear'):
def __init__(self,
operation_data_mapping,
device_mesh,
linear_projection_type='linear',
solver_perference=SolverPerference.STANDARD):
super().__init__(operation_data_mapping, device_mesh)
self.linear_projection_type = linear_projection_type
self.solver_perference = solver_perference

def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# C = AB
Expand All @@ -231,16 +237,22 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost

def collate_strategies(self) -> List[ShardingStrategy]:
def dp_strategies(self) -> List[ShardingStrategy]:
strategies = []

# SS = SR x RS
strategies.append(self.split_lhs_space_rhs_space(0, 1))
strategies.append(self.split_lhs_space_rhs_space(1, 0))
# S01R = S01R x RR
strategies.append(self.split_lhs_1st_dim_1d(0, 1))

# SR = SS x SR
strategies.append(self.split_lhs_space_both_contract(0, 1))
strategies.append(self.split_lhs_space_both_contract(1, 0))
return strategies

def tp_strategies(self) -> List[ShardingStrategy]:
strategies = []

# RR = RS01 x S01R
strategies.append(self.split_lhs_2nd_dim_1d(0, 1))

# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))

# RS = RS x SS
strategies.append(self.split_rhs_space_both_contract(0, 1))
Expand All @@ -254,20 +266,38 @@ def collate_strategies(self) -> List[ShardingStrategy]:
strategies.append(self.split_rhs_space_only(0))
strategies.append(self.split_rhs_space_only(1))

# S01R = S01R x RR
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
return strategies

# RR = RS01 x S01R
strategies.append(self.split_lhs_2nd_dim_1d(0, 1))
def mix_strategies(self) -> List[ShardingStrategy]:
strategies = []

# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# SS = SR x RS
strategies.append(self.split_lhs_space_rhs_space(0, 1))
strategies.append(self.split_lhs_space_rhs_space(1, 0))

# SR = SS x SR
strategies.append(self.split_lhs_space_both_contract(0, 1))
strategies.append(self.split_lhs_space_both_contract(1, 0))

# RR = RR x RR
strategies.append(self.non_split())

return strategies

def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []

if self.solver_perference == SolverPerference.STANDARD:
strategies.extend(self.dp_strategies())
strategies.extend(self.tp_strategies())
strategies.extend(self.mix_strategies())
elif self.solver_perference == SolverPerference.DP:
strategies.extend(self.dp_strategies())
elif self.solver_perference == SolverPerference.TP:
strategies.extend(self.tp_strategies())

return strategies

@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()

strategies_constructor = build_strategy_constructor(graph, device_mesh)
strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard')
solution = solve_solution(gm, strategies_constructor, memory_budget=-1)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
gm = ModuleWrapper(gm, *sharding_spec_dicts)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,79 +243,79 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
if model_cls.__name__ == 'LinearReshapeModel':

if reshape_dims == ((0, 2, 1, 3), (1, 2)):
assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, S0, R, S1]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, S1, R, S0]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, S0, R, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, S1, R, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, S01, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list
assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list
assert '[R, R, S0, S1] -> [R, S0, R, S1]_13' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list
assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list
assert '[R, R, S1, S0] -> [R, S1, R, S0]_16' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list
assert '[R, R, S0, R] -> [R, S0, R, R]_19' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list
assert '[R, R, S1, R] -> [R, S1, R, R]_22' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list
assert '[R, R, S01, R] -> [R, S01, R, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list

if reshape_dims == (2, 0, 1, 3):
assert '[S0, R, R, S1] -> [R, S0, R, S1]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [S0, R, R, S1]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [R, S1, R, S0]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [S1, R, R, S0]_5' in strategy_name_list
assert '[S0, R, R, R] -> [R, S0, R, R]_6' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list
assert '[R, R, S0, R] -> [S0, R, R, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [R, S1, R, R]_9' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list
assert '[R, R, S1, R] -> [S1, R, R, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
assert '[S01, R, R, R] -> [R, S01, R, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list
assert '[R, R, S01, R] -> [S01, R, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
assert '[S0, R, R, S1] -> [R, S0, R, S1]_11' in strategy_name_list
assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list
assert '[R, R, S0, S1] -> [S0, R, R, S1]_13' in strategy_name_list
assert '[S1, R, R, S0] -> [R, S1, R, S0]_14' in strategy_name_list
assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list
assert '[R, R, S1, S0] -> [S1, R, R, S0]_16' in strategy_name_list
assert '[S0, R, R, R] -> [R, S0, R, R]_17' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list
assert '[R, R, S0, R] -> [S0, R, R, R]_19' in strategy_name_list
assert '[S1, R, R, R] -> [R, S1, R, R]_20' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list
assert '[R, R, S1, R] -> [S1, R, R, R]_22' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
assert '[S01, R, R, R] -> [R, S01, R, R]_0' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list
assert '[R, R, S01, R] -> [S01, R, R, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list

if reshape_dims == (1, 3):
assert '[S0, R, R, S1] -> [S0, S1, R, R]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [R, S1, R, S0]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, S1, S0, R]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, S0, R, R]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [R, S0, R, S1]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, S0, S1, R]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, R, S1]_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1, R, R]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0, R, R]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0, R, R]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1, R, R]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, R, S01]_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, S01, R, R]_22' in strategy_name_list
assert '[S0, R, R, S1] -> [S0, S1, R, R]_11' in strategy_name_list
assert '[R, S0, R, S1] -> [R, S1, R, S0]_12' in strategy_name_list
assert '[R, R, S0, S1] -> [R, S1, S0, R]_13' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, S0, R, R]_14' in strategy_name_list
assert '[R, S1, R, S0] -> [R, S0, R, S1]_15' in strategy_name_list
assert '[R, R, S1, S0] -> [R, S0, S1, R]_16' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, R, S0]_18' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, R, S1]_21' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1, R, R]_10' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0, R, R]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0, R, R]_6' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1, R, R]_5' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, R, S01]_1' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, S01, R, R]_4' in strategy_name_list


@run_on_environment_flag(name='AUTO_PARALLEL')
Expand Down
Loading

0 comments on commit 1dc003c

Please sign in to comment.