From d48f6a1286d5daba0bf66a84e00c113289abdc68 Mon Sep 17 00:00:00 2001 From: eromomon Date: Mon, 19 May 2025 16:27:21 -0700 Subject: [PATCH] Add Accelerator Api to Imagenet Example Signed-off-by: eromomon --- imagenet/README.md | 7 ++- imagenet/main.py | 111 ++++++++++++++++++++++---------------- imagenet/requirements.txt | 4 +- 3 files changed, 71 insertions(+), 51 deletions(-) diff --git a/imagenet/README.md b/imagenet/README.md index 9b280f087e..e3f66429b9 100644 --- a/imagenet/README.md +++ b/imagenet/README.md @@ -33,7 +33,9 @@ python main.py -a resnet18 --dummy ## Multi-processing Distributed Data Parallel Training -You should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance. +If running on CUDA, you should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance. + +For XPU multiprocessing is not supported as of PyTorch 2.6. ### Single node, multiple GPUs: @@ -59,7 +61,7 @@ python main.py -a resnet50 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' --dist-backen ```bash usage: main.py [-h] [-a ARCH] [-j N] [--epochs N] [--start-epoch N] [-b N] [--lr LR] [--momentum M] [--wd W] [-p N] [--resume PATH] [-e] [--pretrained] [--world-size WORLD_SIZE] [--rank RANK] - [--dist-url DIST_URL] [--dist-backend DIST_BACKEND] [--seed SEED] [--gpu GPU] [--multiprocessing-distributed] [--dummy] + [--dist-url DIST_URL] [--dist-backend DIST_BACKEND] [--seed SEED] [--gpu GPU] [--no-accel][--multiprocessing-distributed] [--dummy] [DIR] PyTorch ImageNet Training @@ -96,6 +98,7 @@ optional arguments: distributed backend --seed SEED seed for initializing training. --gpu GPU GPU id to use. + --no-accel disables accelerator --multiprocessing-distributed Use multi-processing distributed training to launch N processes per node, which has N GPUs. This is the fastest way to use PyTorch for either single node or multi node data parallel training diff --git a/imagenet/main.py b/imagenet/main.py index cc32d50733..dd33470908 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -71,6 +71,8 @@ help='seed for initializing training. ') parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') +parser.add_argument('--no-accel', action='store_true', + help='disables accelerator') parser.add_argument('--multiprocessing-distributed', action='store_true', help='Use multi-processing distributed training to launch ' 'N processes per node, which has N GPUs. This is the ' @@ -104,8 +106,17 @@ def main(): args.distributed = args.world_size > 1 or args.multiprocessing_distributed - if torch.cuda.is_available(): - ngpus_per_node = torch.cuda.device_count() + use_accel = not args.no_accel and torch.accelerator.is_available() + + if use_accel: + device = torch.accelerator.current_accelerator() + else: + device = torch.device("cpu") + + print(f"Using device: {device}") + + if device.type =='cuda': + ngpus_per_node = torch.accelerator.device_count() if ngpus_per_node == 1 and args.dist_backend == "nccl": warnings.warn("nccl backend >=2.5 requires GPU count>1, see https://github.com/NVIDIA/nccl/issues/103 perhaps use 'gloo'") else: @@ -127,8 +138,15 @@ def main_worker(gpu, ngpus_per_node, args): global best_acc1 args.gpu = gpu - if args.gpu is not None: - print("Use GPU: {} for training".format(args.gpu)) + use_accel = not args.no_accel and torch.accelerator.is_available() + + if use_accel: + if args.gpu is not None: + torch.accelerator.set_device_index(args.gpu) + print("Use GPU: {} for training".format(args.gpu)) + device = torch.accelerator.current_accelerator() + else: + device = torch.device("cpu") if args.distributed: if args.dist_url == "env://" and args.rank == -1: @@ -147,16 +165,16 @@ def main_worker(gpu, ngpus_per_node, args): print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch]() - if not torch.cuda.is_available() and not torch.backends.mps.is_available(): + if not use_accel: print('using CPU, this will be slow') elif args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. - if torch.cuda.is_available(): + if device.type == 'cuda': if args.gpu is not None: torch.cuda.set_device(args.gpu) - model.cuda(args.gpu) + model.cuda(device) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs of the current node. @@ -168,29 +186,17 @@ def main_worker(gpu, ngpus_per_node, args): # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set model = torch.nn.parallel.DistributedDataParallel(model) - elif args.gpu is not None and torch.cuda.is_available(): - torch.cuda.set_device(args.gpu) - model = model.cuda(args.gpu) - elif torch.backends.mps.is_available(): - device = torch.device("mps") - model = model.to(device) - else: + elif device.type == 'cuda': # DataParallel will divide and allocate batch_size to all available GPUs if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() - - if torch.cuda.is_available(): - if args.gpu: - device = torch.device('cuda:{}'.format(args.gpu)) - else: - device = torch.device("cuda") - elif torch.backends.mps.is_available(): - device = torch.device("mps") else: - device = torch.device("cpu") + model.to(device) + + # define loss function (criterion), optimizer, and learning rate scheduler criterion = nn.CrossEntropyLoss().to(device) @@ -207,9 +213,9 @@ def main_worker(gpu, ngpus_per_node, args): print("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) - elif torch.cuda.is_available(): + else: # Map model to be loaded to specified single gpu. - loc = 'cuda:{}'.format(args.gpu) + loc = f'{device.type}:{args.gpu}' checkpoint = torch.load(args.resume, map_location=loc) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] @@ -302,11 +308,14 @@ def main_worker(gpu, ngpus_per_node, args): def train(train_loader, model, criterion, optimizer, epoch, device, args): - batch_time = AverageMeter('Time', ':6.3f') - data_time = AverageMeter('Data', ':6.3f') - losses = AverageMeter('Loss', ':.4e') - top1 = AverageMeter('Acc@1', ':6.2f') - top5 = AverageMeter('Acc@5', ':6.2f') + + use_accel = not args.no_accel and torch.accelerator.is_available() + + batch_time = AverageMeter('Time', use_accel, ':6.3f', Summary.NONE) + data_time = AverageMeter('Data', use_accel, ':6.3f', Summary.NONE) + losses = AverageMeter('Loss', use_accel, ':.4e', Summary.NONE) + top1 = AverageMeter('Acc@1', use_accel, ':6.2f', Summary.NONE) + top5 = AverageMeter('Acc@5', use_accel, ':6.2f', Summary.NONE) progress = ProgressMeter( len(train_loader), [batch_time, data_time, losses, top1, top5], @@ -349,18 +358,27 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args): def validate(val_loader, model, criterion, args): + use_accel = not args.no_accel and torch.accelerator.is_available() + def run_validate(loader, base_progress=0): + + if use_accel: + device = torch.accelerator.current_accelerator() + else: + device = torch.device("cpu") + with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(loader): i = base_progress + i - if args.gpu is not None and torch.cuda.is_available(): - images = images.cuda(args.gpu, non_blocking=True) - if torch.backends.mps.is_available(): - images = images.to('mps') - target = target.to('mps') - if torch.cuda.is_available(): - target = target.cuda(args.gpu, non_blocking=True) + if use_accel: + if args.gpu is not None and device.type=='cuda': + torch.accelerator.set_device_index(argps.gpu) + images = images.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + else: + images = images.to(device) + target = target.to(device) # compute output output = model(images) @@ -379,10 +397,10 @@ def run_validate(loader, base_progress=0): if i % args.print_freq == 0: progress.display(i + 1) - batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) - losses = AverageMeter('Loss', ':.4e', Summary.NONE) - top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) - top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) + batch_time = AverageMeter('Time', use_accel, ':6.3f', Summary.NONE) + losses = AverageMeter('Loss', use_accel, ':.4e', Summary.NONE) + top1 = AverageMeter('Acc@1', use_accel, ':6.2f', Summary.AVERAGE) + top5 = AverageMeter('Acc@5', use_accel, ':6.2f', Summary.AVERAGE) progress = ProgressMeter( len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))), [batch_time, losses, top1, top5], @@ -422,8 +440,9 @@ class Summary(Enum): class AverageMeter(object): """Computes and stores the average and current value""" - def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): + def __init__(self, name, use_accel, fmt=':f', summary_type=Summary.AVERAGE): self.name = name + self.use_accel = use_accel self.fmt = fmt self.summary_type = summary_type self.reset() @@ -440,11 +459,9 @@ def update(self, val, n=1): self.count += n self.avg = self.sum / self.count - def all_reduce(self): - if torch.cuda.is_available(): - device = torch.device("cuda") - elif torch.backends.mps.is_available(): - device = torch.device("mps") + def all_reduce(self): + if use_accel: + device = torch.accelerator.current_accelerator() else: device = torch.device("cpu") total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) diff --git a/imagenet/requirements.txt b/imagenet/requirements.txt index 6cec7414dc..9a083ba390 100644 --- a/imagenet/requirements.txt +++ b/imagenet/requirements.txt @@ -1,2 +1,2 @@ -torch -torchvision==0.20.0 +torch>=2.6 +torchvision