Skip to content

Commit

Permalink
[examples] update autoparallel tutorial demo (hpcaitech#2449)
Browse files Browse the repository at this point in the history
* [examples] update autoparallel tutorial demo

* add test_ci.sh

* polish

* add conda yaml
  • Loading branch information
YuliangLiu0306 authored Jan 12, 2023
1 parent 9358262 commit c20529f
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 116 deletions.
132 changes: 16 additions & 116 deletions examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,119 +4,37 @@

import torch
from titans.utils import barrier_context
from torch.fx import GraphModule
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet50
from tqdm import tqdm

import colossalai
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize
from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingLR
from colossalai.utils import get_dataloader

DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute()


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10")
return parser.parse_args()


def synthesize_data():
img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32)
label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,))
return img, label


def main():
args = parse_args()
colossalai.launch_from_torch(config='./config.py')

logger = get_dist_logger()

if not args.synthetic:
with barrier_context():
# build dataloaders
train_dataset = CIFAR10(root=DATA_ROOT,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(size=32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]),
]))

test_dataset = CIFAR10(root=DATA_ROOT,
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
]))

train_dataloader = get_dataloader(
dataset=train_dataset,
add_sampler=True,
shuffle=True,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
)

test_dataloader = get_dataloader(
dataset=test_dataset,
add_sampler=True,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
)
else:
train_dataloader, test_dataloader = None, None

# initialize device mesh
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)

# trace the model with meta data
tracer = ColoTracer()
model = resnet50(num_classes=10).cuda()
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()

# prepare info for solver
solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)

# solve the solution
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
if gpc.get_global_rank() == 0:
for index, node in enumerate(graph.nodes):
print(node.name, node.strategies_vector[solution[index]].name)

# process the graph for distributed training ability
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
gm = runtime_apply_pass(gm)
gm.recompile()

model = autoparallelize(model, input_sample)
# build criterion
criterion = torch.nn.CrossEntropyLoss()

Expand All @@ -127,65 +45,47 @@ def main():
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)

for epoch in range(gpc.config.NUM_EPOCHS):
gm.train()
model.train()

if args.synthetic:
# if we use synthetic data
# we assume it only has 30 steps per epoch
num_steps = range(30)
# if we use synthetic data
# we assume it only has 30 steps per epoch
num_steps = range(30)

else:
# we use the actual number of steps for training
num_steps = range(len(train_dataloader))
data_iter = iter(train_dataloader)
progress = tqdm(num_steps)

for _ in progress:
if args.synthetic:
# generate fake data
img, label = synthesize_data()
else:
# get the real data
img, label = next(data_iter)
# generate fake data
img, label = synthesize_data()

img = img.cuda()
label = label.cuda()
optimizer.zero_grad()
output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
output = model(img)
train_loss = criterion(output, label)
train_loss.backward(train_loss)
optimizer.step()
lr_scheduler.step()

# run evaluation
gm.eval()
model.eval()
correct = 0
total = 0

if args.synthetic:
# if we use synthetic data
# we assume it only has 10 steps for evaluation
num_steps = range(30)
# if we use synthetic data
# we assume it only has 10 steps for evaluation
num_steps = range(30)

else:
# we use the actual number of steps for training
num_steps = range(len(test_dataloader))
data_iter = iter(test_dataloader)
progress = tqdm(num_steps)

for _ in progress:
if args.synthetic:
# generate fake data
img, label = synthesize_data()
else:
# get the real data
img, label = next(data_iter)
# generate fake data
img, label = synthesize_data()

img = img.cuda()
label = label.cuda()

with torch.no_grad():
output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
output = model(img)
test_loss = criterion(output, label)
pred = torch.argmax(output, dim=-1)
correct += torch.sum(pred == label)
Expand Down
32 changes: 32 additions & 0 deletions examples/tutorial/auto_parallel/environment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: auto
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_kmp_llvm
- blas=1.0=mkl
- brotlipy=0.7.0=py38h27cfd23_1003
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2022.12.7=ha878542_0
- certifi=2022.12.7=pyhd8ed1ab_0
- cffi=1.15.1=py38h74dc2b5_0
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- coin-or-cbc=2.10.8=h3786ebc_0
- coin-or-cgl=0.60.6=h6f57e76_2
- coin-or-clp=1.17.7=hc56784d_2
- coin-or-osi=0.108.7=h2720bb7_2
- coin-or-utils=2.11.6=h202d8b1_2
- python=3.8.13
- pip=22.2.2
- cudatoolkit=11.3
- pytorch=1.12.1
- torchvision=0.13.1
- numpy=1.23.1
- pip:
- titans
- torch==1.12.1
- pulp==2.7.0
- datasets
- colossalai
13 changes: 13 additions & 0 deletions examples/tutorial/auto_parallel/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from setuptools import find_packages, setup

setup(
name='auto_parallel',
version='0.0.1',
description='',
packages=find_packages(),
install_requires=[
'torch',
'numpy',
'tqdm',
],
)
11 changes: 11 additions & 0 deletions examples/tutorial/auto_parallel/test_ci.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash
set -euxo pipefail

conda init bash
conda env create -f environment.yaml
conda activate auto
cd ../../..
pip uninstall colossalai
pip install -v .
cd ./examples/tutorial/auto_parallel
colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py -s

0 comments on commit c20529f

Please sign in to comment.