Skip to content

Add Accelerator API to Imagenet Example #1349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions imagenet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ python main.py -a resnet18 --dummy

## Multi-processing Distributed Data Parallel Training

You should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance.
If running on CUDA, you should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance.

For XPU multiprocessing is not supported as of PyTorch 2.6.

### Single node, multiple GPUs:

Expand All @@ -59,7 +61,7 @@ python main.py -a resnet50 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' --dist-backen

```bash
usage: main.py [-h] [-a ARCH] [-j N] [--epochs N] [--start-epoch N] [-b N] [--lr LR] [--momentum M] [--wd W] [-p N] [--resume PATH] [-e] [--pretrained] [--world-size WORLD_SIZE] [--rank RANK]
[--dist-url DIST_URL] [--dist-backend DIST_BACKEND] [--seed SEED] [--gpu GPU] [--multiprocessing-distributed] [--dummy]
[--dist-url DIST_URL] [--dist-backend DIST_BACKEND] [--seed SEED] [--gpu GPU] [--no-accel][--multiprocessing-distributed] [--dummy]
[DIR]

PyTorch ImageNet Training
Expand Down Expand Up @@ -96,6 +98,7 @@ optional arguments:
distributed backend
--seed SEED seed for initializing training.
--gpu GPU GPU id to use.
--no-accel disables accelerator
--multiprocessing-distributed
Use multi-processing distributed training to launch N processes per node, which has N GPUs. This is the fastest way to use PyTorch for either single node or multi node data parallel
training
Expand Down
111 changes: 64 additions & 47 deletions imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use.')
parser.add_argument('--no-accel', action='store_true',
help='disables accelerator')
parser.add_argument('--multiprocessing-distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
Expand Down Expand Up @@ -104,8 +106,17 @@ def main():

args.distributed = args.world_size > 1 or args.multiprocessing_distributed

if torch.cuda.is_available():
ngpus_per_node = torch.cuda.device_count()
use_accel = not args.no_accel and torch.accelerator.is_available()

if use_accel:
device = torch.accelerator.current_accelerator()
else:
device = torch.device("cpu")

print(f"Using device: {device}")

if device.type =='cuda':
ngpus_per_node = torch.accelerator.device_count()
if ngpus_per_node == 1 and args.dist_backend == "nccl":
warnings.warn("nccl backend >=2.5 requires GPU count>1, see https://github.com/NVIDIA/nccl/issues/103 perhaps use 'gloo'")
else:
Expand All @@ -127,8 +138,15 @@ def main_worker(gpu, ngpus_per_node, args):
global best_acc1
args.gpu = gpu

if args.gpu is not None:
print("Use GPU: {} for training".format(args.gpu))
use_accel = not args.no_accel and torch.accelerator.is_available()

if use_accel:
if args.gpu is not None:
torch.accelerator.set_device_index(args.gpu)
print("Use GPU: {} for training".format(args.gpu))
device = torch.accelerator.current_accelerator()
else:
device = torch.device("cpu")

if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
Expand All @@ -147,16 +165,16 @@ def main_worker(gpu, ngpus_per_node, args):
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()

if not torch.cuda.is_available() and not torch.backends.mps.is_available():
if not use_accel:
print('using CPU, this will be slow')
elif args.distributed:
# For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise,
# DistributedDataParallel will use all available devices.
if torch.cuda.is_available():
if device.type == 'cuda':
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
model.cuda(device)
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs of the current node.
Expand All @@ -168,29 +186,17 @@ def main_worker(gpu, ngpus_per_node, args):
# DistributedDataParallel will divide and allocate batch_size to all
# available GPUs if device_ids are not set
model = torch.nn.parallel.DistributedDataParallel(model)
elif args.gpu is not None and torch.cuda.is_available():
torch.cuda.set_device(args.gpu)
model = model.cuda(args.gpu)
elif torch.backends.mps.is_available():
device = torch.device("mps")
model = model.to(device)
else:
elif device.type == 'cuda':
# DataParallel will divide and allocate batch_size to all available GPUs
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
model.features = torch.nn.DataParallel(model.features)
model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()

if torch.cuda.is_available():
if args.gpu:
device = torch.device('cuda:{}'.format(args.gpu))
else:
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
model.to(device)


# define loss function (criterion), optimizer, and learning rate scheduler
criterion = nn.CrossEntropyLoss().to(device)

Expand All @@ -207,9 +213,9 @@ def main_worker(gpu, ngpus_per_node, args):
print("=> loading checkpoint '{}'".format(args.resume))
if args.gpu is None:
checkpoint = torch.load(args.resume)
elif torch.cuda.is_available():
else:
# Map model to be loaded to specified single gpu.
loc = 'cuda:{}'.format(args.gpu)
loc = f'{device.type}:{args.gpu}'
checkpoint = torch.load(args.resume, map_location=loc)
args.start_epoch = checkpoint['epoch']
best_acc1 = checkpoint['best_acc1']
Expand Down Expand Up @@ -302,11 +308,14 @@ def main_worker(gpu, ngpus_per_node, args):


def train(train_loader, model, criterion, optimizer, epoch, device, args):
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')

use_accel = not args.no_accel and torch.accelerator.is_available()

batch_time = AverageMeter('Time', use_accel, ':6.3f', Summary.NONE)
data_time = AverageMeter('Data', use_accel, ':6.3f', Summary.NONE)
losses = AverageMeter('Loss', use_accel, ':.4e', Summary.NONE)
top1 = AverageMeter('Acc@1', use_accel, ':6.2f', Summary.NONE)
top5 = AverageMeter('Acc@5', use_accel, ':6.2f', Summary.NONE)
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, losses, top1, top5],
Expand Down Expand Up @@ -349,18 +358,27 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args):

def validate(val_loader, model, criterion, args):

use_accel = not args.no_accel and torch.accelerator.is_available()

def run_validate(loader, base_progress=0):

if use_accel:
device = torch.accelerator.current_accelerator()
else:
device = torch.device("cpu")

with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(loader):
i = base_progress + i
if args.gpu is not None and torch.cuda.is_available():
images = images.cuda(args.gpu, non_blocking=True)
if torch.backends.mps.is_available():
images = images.to('mps')
target = target.to('mps')
if torch.cuda.is_available():
target = target.cuda(args.gpu, non_blocking=True)
if use_accel:
if args.gpu is not None and device.type=='cuda':
torch.accelerator.set_device_index(argps.gpu)
images = images.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)
else:
images = images.to(device)
target = target.to(device)

# compute output
output = model(images)
Expand All @@ -379,10 +397,10 @@ def run_validate(loader, base_progress=0):
if i % args.print_freq == 0:
progress.display(i + 1)

batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
losses = AverageMeter('Loss', ':.4e', Summary.NONE)
top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
batch_time = AverageMeter('Time', use_accel, ':6.3f', Summary.NONE)
losses = AverageMeter('Loss', use_accel, ':.4e', Summary.NONE)
top1 = AverageMeter('Acc@1', use_accel, ':6.2f', Summary.AVERAGE)
top5 = AverageMeter('Acc@5', use_accel, ':6.2f', Summary.AVERAGE)
progress = ProgressMeter(
len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))),
[batch_time, losses, top1, top5],
Expand Down Expand Up @@ -422,8 +440,9 @@ class Summary(Enum):

class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
def __init__(self, name, use_accel, fmt=':f', summary_type=Summary.AVERAGE):
self.name = name
self.use_accel = use_accel
self.fmt = fmt
self.summary_type = summary_type
self.reset()
Expand All @@ -440,11 +459,9 @@ def update(self, val, n=1):
self.count += n
self.avg = self.sum / self.count

def all_reduce(self):
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
def all_reduce(self):
if use_accel:
device = torch.accelerator.current_accelerator()
else:
device = torch.device("cpu")
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
Expand Down
4 changes: 2 additions & 2 deletions imagenet/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch
torchvision==0.20.0
torch>=2.6
torchvision