Skip to content

Commit

Permalink
[autoparallel] add shard option (hpcaitech#2696)
Browse files Browse the repository at this point in the history
* [autoparallel] add shard option

* polish
  • Loading branch information
YuliangLiu0306 authored Feb 15, 2023
1 parent 5b24987 commit 21d6a48
Show file tree
Hide file tree
Showing 15 changed files with 176 additions and 104 deletions.
70 changes: 60 additions & 10 deletions colossalai/auto_parallel/tensor_shard/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,9 @@

from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
Expand Down Expand Up @@ -69,13 +64,43 @@ def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[f
pass


def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str,
shard_option: str):
'''
This method is used to build the strategy_constructor for the given graph.
After this method, each node in the graph will have a strategies_vector which
is constructed by the related node handler.
'''
solver_options = SolverOptions()
if solver_preference == 'standard':
solver_preference = SolverPerference.STANDARD
elif solver_preference == 'tp':
solver_preference = SolverPerference.TP
elif solver_preference == 'dp':
solver_preference = SolverPerference.DP
else:
raise ValueError(f'Invalid solver_preference: {solver_preference}')

if dataloader_option == 'replicated':
dataloader_option = DataloaderOption.REPLICATED
elif dataloader_option == 'distributed':
dataloader_option = DataloaderOption.DISTRIBUTED
else:
raise ValueError(f'Invalid dataloader_option: {dataloader_option}')

if shard_option == 'standard':
shard_option = ShardOption.STANDARD
elif shard_option == 'shard':
shard_option = ShardOption.SHARD
elif shard_option == 'shard_last_axis':
shard_option = ShardOption.SHARD_LAST_AXIS
elif shard_option == 'full_shard':
shard_option = ShardOption.FULL_SHARD
else:
raise ValueError(f'Invalid shard_option: {shard_option}')

solver_options = SolverOptions(solver_perference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()

Expand Down Expand Up @@ -183,6 +208,9 @@ def initialize_model(model: nn.Module,
device_mesh: DeviceMesh,
memory_budget: float = -1.0,
overlap: bool = False,
solver_preference: str = 'standard',
dataloader_option: str = 'replicated',
shard_option: str = 'standard',
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solution_path: str = None,
Expand All @@ -198,6 +226,12 @@ def initialize_model(model: nn.Module,
the memory budget will be infinity.
overlap(optional): the overlap is used to specify whether to overlap gradient communication and
backward computing.
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
Expand All @@ -212,7 +246,12 @@ def initialize_model(model: nn.Module,
graph = tracer.trace(root=model, meta_args=meta_args)
gm = ColoGraphModule(model, graph, model.__class__.__name__)
gm.recompile()
strategies_constructor = build_strategy_constructor(graph, device_mesh)

strategies_constructor = build_strategy_constructor(graph,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option)
if load_solver_solution:
solution = torch.load(solution_path)
else:
Expand Down Expand Up @@ -240,6 +279,9 @@ def autoparallelize(model: nn.Module,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None,
solver_preference: str = 'standard',
dataloader_option: str = 'replicated',
shard_option: str = 'standard',
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solver_solution_path: str = None,
Expand All @@ -262,6 +304,12 @@ def autoparallelize(model: nn.Module,
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
Expand All @@ -280,6 +328,8 @@ def autoparallelize(model: nn.Module,
rst_to_unpack = initialize_model(model,
meta_args,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solution_path=solver_solution_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler
from .option import ShardOption
from .output_handler import OutputHandler
from .permute_handler import PermuteHandler
from .placeholder_handler import PlaceholderHandler
Expand All @@ -31,6 +30,6 @@
'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption',
'TransposeHandler', 'SplitHandler'
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler',
'SplitHandler'
]
31 changes: 20 additions & 11 deletions colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.fx.node import Node

from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
Expand All @@ -32,19 +32,19 @@ class NodeHandler(ABC):
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
'''

def __init__(
self,
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
shard_option: ShardOption = ShardOption.STANDARD,
) -> None:
def __init__(self,
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
shard_option: ShardOption = ShardOption.STANDARD,
solver_perference: SolverPerference = SolverPerference.STANDARD) -> None:
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector
self.shard_option = shard_option
self.solver_perference = solver_perference

def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
"""
Expand Down Expand Up @@ -187,15 +187,24 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV

remove_strategy_list = []
for strategy in self.strategies_vector:
shard_level = 0
shard_axis_list = []
last_axis = len(self.device_mesh.mesh_shape) - 1
for op_data, sharding_spec in strategy.sharding_specs.items():
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
for dim, shard_axis in sharding_spec.dim_partition_dict.items():
shard_level += len(shard_axis)
for dim, shard_axes in sharding_spec.dim_partition_dict.items():
for shard_axis in shard_axes:
if shard_axis not in shard_axis_list:
shard_axis_list.append(shard_axis)

shard_level = len(shard_axis_list)
using_last_axis = last_axis in shard_axis_list or -1 in shard_axis_list
if self.shard_option == ShardOption.SHARD and shard_level == 0:
remove_strategy_list.append(strategy)
if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1:
remove_strategy_list.append(strategy)
if self.shard_option == ShardOption.SHARD_LAST_AXIS:
if shard_level != 1 or using_last_axis == False:
remove_strategy_list.append(strategy)

for strategy in remove_strategy_list:
self.strategies_vector.remove(strategy)
Expand Down
17 changes: 0 additions & 17 deletions colossalai/auto_parallel/tensor_shard/node_handler/option.py

This file was deleted.

49 changes: 49 additions & 0 deletions colossalai/auto_parallel/tensor_shard/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from dataclasses import dataclass
from enum import Enum

__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption']


class SolverPerference(Enum):
"""
This enum class is to define the solver preference.
"""
STANDARD = 0
DP = 1
TP = 2


class ShardOption(Enum):
"""
This enum class is to define the shard level required in node strategies.
Notes:
STANDARD: We do not add any extra shard requirements.
SHARD: We require the node to be shard using at least one device mesh axis.
SHARD_ONE_AXIS: We require the node to be shard using the last device mesh axis.
FULL_SHARD: We require the node to be shard using all device mesh axes.
TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis.
TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes.
"""
STANDARD = 0
SHARD = 1
SHARD_LAST_AXIS = 2
FULL_SHARD = 3


class DataloaderOption(Enum):
"""
This enum class is to define the dataloader option.
"""
REPLICATED = 0
DISTRIBUTED = 1


@dataclass
class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
solver_perference: SolverPerference = SolverPerference.STANDARD
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
shard_option: ShardOption = ShardOption.STANDARD
3 changes: 1 addition & 2 deletions colossalai/auto_parallel/tensor_shard/solver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .options import SolverOptions
from .solver import Solver
from .strategies_constructor import StrategiesConstructor

__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph', 'SolverOptions']
__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
30 changes: 0 additions & 30 deletions colossalai/auto_parallel/tensor_shard/solver/options.py

This file was deleted.

2 changes: 1 addition & 1 deletion colossalai/auto_parallel/tensor_shard/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self,
solution_numbers: int = 1,
forward_only: bool = False,
memory_increasing_coefficient: float = 1.3,
verbose=True):
verbose=False):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.device.device_mesh import DeviceMesh

from .options import DataloaderOption, SolverOptions
from ..options import DataloaderOption, SolverOptions

__all__ = ['StrategiesConstructor']

Expand Down Expand Up @@ -101,15 +101,23 @@ def _check_no_strategy_for_data(data):

# get_attr node
elif node.op == 'get_attr':
getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector)
getattr_handler = GetattrHandler(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
getattr_handler.register_strategy()

# call_module node
elif node.op == 'call_module':
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector)
handler = operator_registry.get(submod_type)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler.register_strategy()
# attach metainfo_vector to node
if hasattr(handler, 'metainfo_vector'):
Expand All @@ -118,7 +126,11 @@ def _check_no_strategy_for_data(data):
# call_function node
elif node.op == 'call_function':
target = node.target
handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector)
handler = operator_registry.get(target)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler.register_strategy()
# attach metainfo_vector to node
if hasattr(handler, 'metainfo_vector'):
Expand All @@ -127,7 +139,11 @@ def _check_no_strategy_for_data(data):
# call_method node
elif node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector)
handler = operator_registry.get(method)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler.register_strategy()
# attach metainfo_vector to node
if hasattr(handler, 'metainfo_vector'):
Expand Down
Loading

0 comments on commit 21d6a48

Please sign in to comment.