diff --git a/imagenet/README.md b/imagenet/README.md index 9b280f087e..0d029ee948 100644 --- a/imagenet/README.md +++ b/imagenet/README.md @@ -33,7 +33,9 @@ python main.py -a resnet18 --dummy ## Multi-processing Distributed Data Parallel Training -You should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance. +If running on CUDA, you should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance. + +For XPU multiprocessing is not supported as of PyTorch 2.6. ### Single node, multiple GPUs: diff --git a/imagenet/main.py b/imagenet/main.py index cc32d50733..bf50586c49 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -147,7 +147,7 @@ def main_worker(gpu, ngpus_per_node, args): print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch]() - if not torch.cuda.is_available() and not torch.backends.mps.is_available(): + if not torch.cuda.is_available() and not torch.backends.mps.is_available() and not torch.xpu.is_available(): print('using CPU, this will be slow') elif args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor @@ -171,6 +171,9 @@ def main_worker(gpu, ngpus_per_node, args): elif args.gpu is not None and torch.cuda.is_available(): torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) + elif torch.xpu.is_available(): + device = torch.device("xpu") + model = model.to(device) elif torch.backends.mps.is_available(): device = torch.device("mps") model = model.to(device) @@ -187,10 +190,15 @@ def main_worker(gpu, ngpus_per_node, args): device = torch.device('cuda:{}'.format(args.gpu)) else: device = torch.device("cuda") + elif torch.xpu.is_available(): + device = torch.device("xpu") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") + + print (f"Device to use: ", {device.type}) + # define loss function (criterion), optimizer, and learning rate scheduler criterion = nn.CrossEntropyLoss().to(device) @@ -354,14 +362,19 @@ def run_validate(loader, base_progress=0): end = time.time() for i, (images, target) in enumerate(loader): i = base_progress + i - if args.gpu is not None and torch.cuda.is_available(): - images = images.cuda(args.gpu, non_blocking=True) - if torch.backends.mps.is_available(): - images = images.to('mps') - target = target.to('mps') + if torch.cuda.is_available(): + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) + elif torch.xpu.is_available(): + images = images.to("xpu") + target = target.to("xpu") + elif torch.backends.mps.is_available(): + images = images.to('mps') + target = target.to('mps') + # compute output output = model(images) loss = criterion(output, target) @@ -443,6 +456,8 @@ def update(self, val, n=1): def all_reduce(self): if torch.cuda.is_available(): device = torch.device("cuda") + elif torch.xpu.is_available(): + device = torch.device("xpu") elif torch.backends.mps.is_available(): device = torch.device("mps") else: