Skip to content

[XPU] Support XCCL on deepspeed side #7113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,14 @@ def get_accelerator():
if accelerator_name is None:
try:
import intel_extension_for_pytorch as ipex

if ipex._C._has_xpu():
accelerator_name = "xpu"
except ImportError as e:
pass
import torch
if torch.xpu.is_available():
accelerator_name = "xpu"
else:
pass
if accelerator_name is None:
try:
import torch_npu # noqa: F401,F811 # type: ignore
Expand Down
42 changes: 31 additions & 11 deletions accelerator/xpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,33 @@

import torch
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
import functools

import importlib
import inspect

try:
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
oneccl_imported_p = True
except ImportError as e:
oneccl_imported_p = False

try:
import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
ipex_imported_p = True
except ImportError as e:
ipex_imported_p = False


class XPU_Accelerator(DeepSpeedAccelerator):

def __init__(self):
self._name = 'xpu'
self._communication_backend_name = 'ccl'
if oneccl_imported_p:
self._communication_backend_name = 'ccl'
else:
# changed to xccl if not using torch-CCL on XPU device
self._communication_backend_name = 'xccl'
self._compile_backend = "inductor"
self.aligned_tensors = []
self.class_dict = None
Expand All @@ -26,11 +40,14 @@ def is_synchronized_device(self):
return False

def use_host_timers(self):
# WA XPU event will be consolidated in 2.6
if ipex.__version__ < '2.6':
return True
else:
if not ipex_imported_p:
return self.is_synchronized_device()
else:
# WA XPU event will be consolidated in 2.6
if ipex.__version__ < '2.6':
return True
else:
return self.is_synchronized_device()

def resolves_data_dependency(self):
return self.is_synchronized_device()
Expand Down Expand Up @@ -290,10 +307,13 @@ def get_op_builder(self, class_name):
return self.class_dict['NotImplementedBuilder']

def build_extension(self):
try:
from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension
except ImportError:
from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension
if ipex_imported_p:
try:
from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension
except ImportError:
from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension
else:
from torch.utils.cpp_extension import DpcppBuildExtension
return DpcppBuildExtension

def export_envs(self):
Expand Down
Loading