Skip to content

Commit

Permalink
[gemini] update the gpt example (hpcaitech#2527)
Browse files Browse the repository at this point in the history
  • Loading branch information
1SAA authored Jan 30, 2023
1 parent ecbad93 commit 66dfcf5
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 98 deletions.
9 changes: 6 additions & 3 deletions colossalai/nn/parallel/zero_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,19 @@ def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Opt
>>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto')
>>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict)
"""
setattr(model, "_colo_zero_stage", zero_stage)
assert zero_stage in [1, 2, 3], "The stage of ZeRO should be 1, 2 or 3"

if gemini_config is None:
gemini_config = dict()

if zero_stage in [1, 2]:
return model
wrapped_model = model
else:
return GeminiDDP(model, **gemini_config)
wrapped_model = GeminiDDP(model, **gemini_config)

setattr(wrapped_model, "_colo_zero_stage", zero_stage)

return wrapped_model


def zero_optim_wrapper(model: nn.Module,
Expand Down
2 changes: 1 addition & 1 deletion examples/language/gpt/gemini/benchmark_gemini.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
for MODEL_TYPE in "gpt2_medium"; do
for DISTPLAN in "colossalai"; do
for DISTPLAN in "CAI_Gemini"; do
for BATCH_SIZE in 16; do
for GPUNUM in 1 2 4 8; do
for TPDEGREE in 1 2 4 8; do
Expand Down
12 changes: 9 additions & 3 deletions examples/language/gpt/gemini/run_gemini.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
set -x
# distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"]
export DISTPLAN=${DISTPLAN:-"colossalai"}
# distplan in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]
export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}

# The following options only valid when DISTPLAN="colossalai"
export GPUNUM=${GPUNUM:-1}
Expand All @@ -12,14 +12,20 @@ export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
export TRAIN_STEP=${TRAIN_STEP:-10}
# export PYTHONPATH=$PWD:$PYTHONPATH

if [ ${USE_SHARD_INIT} = "True" ]; then
USE_SHARD_INIT="--shardinit"
else
USE_SHARD_INIT=""
fi

mkdir -p gemini_logs

torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
--tp_degree=${TPDEGREE} \
--model_type=${MODEL_TYPE} \
--batch_size=${BATCH_SIZE} \
--placement=${PLACEMENT} \
--shardinit=${USE_SHARD_INIT} \
${USE_SHARD_INIT} \
--distplan=${DISTPLAN} \
--train_step=${TRAIN_STEP} \
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
150 changes: 59 additions & 91 deletions examples/language/gpt/gemini/train_gpt_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,21 @@

import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext

CAI_VERSION = colossalai.__version__

if version.parse(CAI_VERSION) > version.parse("0.1.10"):
# These are added after 0.1.10
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer


def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--distplan",
type=str,
default='colossalai',
default='CAI_Gemini',
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
)
parser.add_argument(
Expand All @@ -48,8 +43,7 @@ def parse_args():
)
parser.add_argument(
"--shardinit",
type=bool,
default=False,
action='store_true',
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
)
Expand Down Expand Up @@ -186,57 +180,16 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
param.visited = True


# Gemini + ZeRO DDP
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto", ddp_flag: bool = True):
fp16_init_scale = 2**5
gpu_margin_mem_ratio_for_auto = 0

if version.parse(CAI_VERSION) > version.parse("0.1.10"):
model = GeminiDDP(model,
strict_ddp_mode=ddp_flag,
device=get_current_device(),
placement_policy=placement_policy,
pin_memory=True,
hidden_dim=model.config.n_embd,
search_range_mb=128)
# configure the const policy
if placement_policy == 'const':
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
# build a highly optimized cpu optimizer
optimizer = GeminiAdamOptimizer(model,
lr=1e-3,
initial_scale=fp16_init_scale,
gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
elif version.parse("0.1.9") <= version.parse(CAI_VERSION) <= version.parse("0.1.10"):
from colossalai.gemini import ChunkManager, GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 1024, filter_exlarge_params=True)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer,
model,
initial_scale=fp16_init_scale,
gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
else:
raise NotImplemented(f"CAI version {CAI_VERSION} is not supported")
return model, optimizer


def main():
# version check
# this example is supposed to work for versions greater than 0.1.9
assert version.parse(CAI_VERSION) >= version.parse("0.1.9")
# this example is supposed to work for versions greater than 0.2.0
assert version.parse(CAI_VERSION) >= version.parse("0.2.0")

set_cpu_maximum_parallelism()
args = parse_args()

if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
# if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]:
raise TypeError(f"{args.distplan} is error")

# batch size per DP degree
Expand All @@ -260,57 +213,71 @@ def main():
criterion = GPTLMLoss()

torch.manual_seed(123)
if args.distplan == "colossalai":
if args.distplan.startswith("CAI"):
# all param must use the same process group.
world_size = torch.distributed.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None

if args.shardinit and args.distplan != "CAI_Gemini":
raise RuntimeError("You can only use shardinit with CAI_Gemini")

# build GPT model
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
with ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = model_builder(args.model_type)(checkpoint=True)
else:
with ColoInitContext(device=get_current_device()):
model = model_builder(args.model_type)(checkpoint=True)
with ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = model_builder(args.model_type)(checkpoint=True)

tp_pg = ProcessGroup(tp_degree=args.tp_degree)
# Tensor Parallelism (TP)
# You should notice that v0.1.10 is not compatible with TP degree > 1
if args.tp_degree > 1:
tensor_parallelize(model, tp_pg)

# build a Gemini model and a highly optimized cpu optimizer
# Gemini + ZeRO DP, Note it must be used after TP
model, optimizer = build_gemini(model, tp_pg, args.placement, args.tp_degree == 1)
# asign running configurations
gemini_config = None
if args.distplan.startswith("CAI_ZeRO"):
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
elif args.distplan == "CAI_Gemini":
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
device=get_current_device(),
placement_policy=args.placement,
pin_memory=True,
hidden_dim=model.config.n_embd,
search_range_mb=128)
optim_config = dict(gpu_margin_mem_ratio=0.)
else:
raise RuntimeError

# build a highly optimized gpu/cpu optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)

if args.distplan == "CAI_ZeRO1":
zero_stage = 1
elif args.distplan == "CAI_ZeRO2":
zero_stage = 2
elif args.distplan == "CAI_Gemini":
zero_stage = 3
else:
raise RuntimeError

# wrap your model and optimizer
model = zero_model_wrapper(model, zero_stage, gemini_config)
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)

logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
else:
elif args.distplan.startswith("Pytorch"):
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
model = model_builder(args.model_type)(checkpoint=True).cuda()

if args.distplan.startswith("torch"):
model = DDP(model)
if args.distplan.endswith("ddp"):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
elif args.distplan.endswith("zero"):
if args.distplan.endswith("DDP"):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
elif args.distplan.endswith("ZeRO"):
from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
elif args.distplan.startswith("zero"):
model = model.half()
partition_flag = (args.distplan == "zero2")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

optimizer = LowLevelZeroOptimizer(
optimizer,
reduce_bucket_size=12 * 1024 * 1024,
overlap_communication=True,
partition_grad=partition_flag,
verbose=True,
)
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
else:
raise RuntimeError

# model is shared after TP
numel = get_model_size(model)
Expand Down Expand Up @@ -338,17 +305,18 @@ def main():
fwd_time = fwd_end - start
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])

if args.distplan in ["colossalai", "zero1", "zero2"]:
if args.distplan.startswith("CAI"):
optimizer.backward(loss)
elif args.distplan in ["torch_ddp", "torch_zero"]:
elif args.distplan.startswith("Pytorch"):
loss.backward()
else:
raise RuntimeError

torch.cuda.synchronize()
bwd_end = time()
bwd_time = bwd_end - fwd_end
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])

if args.distplan in ["zero1", "zero2"]:
optimizer.sync_grad()
optimizer.step()
torch.cuda.synchronize()
optim_time = time() - bwd_end
Expand Down

0 comments on commit 66dfcf5

Please sign in to comment.