Skip to content

Commit a46d068

Browse files
authored
Support MIG parsing during CUDA context creation in UCX initialization (#6720)
1 parent ef13425 commit a46d068

File tree

6 files changed

+233
-39
lines changed

6 files changed

+233
-39
lines changed

codecov.yml

+5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
codecov:
22
require_ci_to_pass: yes
33

4+
ignore:
5+
# Files that exercise GPU-only functionality or are only tested by gpuCI
6+
# but don't interact with codecov are ignored.
7+
- "distributed/comm/ucx.py"
8+
49
coverage:
510
precision: 2
611
round: down

distributed/comm/tests/test_ucx.py

+38-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import os
5+
from unittest.mock import patch
46

57
import pytest
68

@@ -15,7 +17,12 @@
1517
from distributed.comm.core import CommClosedError
1618
from distributed.comm.registry import backends, get_backend
1719
from distributed.deploy.local import LocalCluster
18-
from distributed.diagnostics.nvml import has_cuda_context
20+
from distributed.diagnostics.nvml import (
21+
device_get_count,
22+
get_device_index_and_uuid,
23+
get_device_mig_mode,
24+
has_cuda_context,
25+
)
1926
from distributed.protocol import to_serialize
2027
from distributed.utils_test import gen_test, inc
2128

@@ -320,20 +327,40 @@ async def test_simple(
320327
assert await client.submit(lambda x: x + 1, 10) == 11
321328

322329

330+
@pytest.mark.xfail(reason="If running on Docker, requires --pid=host")
323331
@gen_test()
324332
async def test_cuda_context(
325333
ucx_loop,
326334
):
327-
with dask.config.set({"distributed.comm.ucx.create-cuda-context": True}):
328-
async with LocalCluster(
329-
protocol="ucx", n_workers=1, asynchronous=True
330-
) as cluster:
331-
async with Client(cluster, asynchronous=True) as client:
332-
assert cluster.scheduler_address.startswith("ucx://")
333-
assert has_cuda_context() == 0
334-
worker_cuda_context = await client.run(has_cuda_context)
335-
assert len(worker_cuda_context) == 1
336-
assert list(worker_cuda_context.values())[0] == 0
335+
try:
336+
device_info = get_device_index_and_uuid(
337+
next(
338+
filter(
339+
lambda i: get_device_mig_mode(i)[0] == 0, range(device_get_count())
340+
)
341+
)
342+
)
343+
except StopIteration:
344+
pytest.skip("No CUDA device in non-MIG mode available")
345+
346+
with patch.dict(
347+
os.environ, {"CUDA_VISIBLE_DEVICES": device_info.uuid.decode("utf-8")}
348+
):
349+
with dask.config.set({"distributed.comm.ucx.create-cuda-context": True}):
350+
async with LocalCluster(
351+
protocol="ucx", n_workers=1, asynchronous=True
352+
) as cluster:
353+
async with Client(cluster, asynchronous=True) as client:
354+
assert cluster.scheduler_address.startswith("ucx://")
355+
ctx = has_cuda_context()
356+
assert ctx.has_context and ctx.device_info == device_info
357+
worker_cuda_context = await client.run(has_cuda_context)
358+
assert len(worker_cuda_context) == 1
359+
worker_cuda_context = list(worker_cuda_context.values())
360+
assert (
361+
worker_cuda_context[0].has_context
362+
and worker_cuda_context[0].device_info == device_info
363+
)
337364

338365

339366
@gen_test()

distributed/comm/ucx.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
host_array,
2929
to_frames,
3030
)
31-
from distributed.diagnostics.nvml import has_cuda_context
31+
from distributed.diagnostics.nvml import (
32+
CudaDeviceInfo,
33+
get_device_index_and_uuid,
34+
has_cuda_context,
35+
)
3236
from distributed.utils import ensure_ip, get_ip, get_ipv6, log_errors, nbytes
3337

3438
logger = logging.getLogger(__name__)
@@ -57,17 +61,27 @@
5761
)
5862

5963

60-
def _warn_existing_cuda_context(ctx, pid):
64+
def _get_device_and_uuid_str(device_info: CudaDeviceInfo) -> str:
65+
return f"{device_info.device_index} ({str(device_info.uuid)})"
66+
67+
68+
def _warn_existing_cuda_context(device_info: CudaDeviceInfo, pid: int) -> None:
69+
device_uuid_str = _get_device_and_uuid_str(device_info)
6170
logger.warning(
62-
f"A CUDA context for device {ctx} already exists on process ID {pid}. {_warning_suffix}"
71+
f"A CUDA context for device {device_uuid_str} already exists "
72+
f"on process ID {pid}. {_warning_suffix}"
6373
)
6474

6575

66-
def _warn_cuda_context_wrong_device(ctx_expected, ctx_actual, pid):
76+
def _warn_cuda_context_wrong_device(
77+
device_info_expected: CudaDeviceInfo, device_info_actual: CudaDeviceInfo, pid: int
78+
) -> None:
79+
expected_device_uuid_str = _get_device_and_uuid_str(device_info_expected)
80+
actual_device_uuid_str = _get_device_and_uuid_str(device_info_actual)
6781
logger.warning(
6882
f"Worker with process ID {pid} should have a CUDA context assigned to device "
69-
f"{ctx_expected}, but instead the CUDA context is on device {ctx_actual}. "
70-
f"{_warning_suffix}"
83+
f"{expected_device_uuid_str}, but instead the CUDA context is on device "
84+
f"{actual_device_uuid_str}. {_warning_suffix}"
7185
)
7286

7387

@@ -116,22 +130,24 @@ def init_once():
116130
"CUDA support with UCX requires Numba for context management"
117131
)
118132

119-
cuda_visible_device = int(
133+
cuda_visible_device = get_device_index_and_uuid(
120134
os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
121135
)
122136
pre_existing_cuda_context = has_cuda_context()
123-
if pre_existing_cuda_context is not False:
124-
_warn_existing_cuda_context(pre_existing_cuda_context, os.getpid())
137+
if pre_existing_cuda_context.has_context:
138+
_warn_existing_cuda_context(
139+
pre_existing_cuda_context.device_info, os.getpid()
140+
)
125141

126142
numba.cuda.current_context()
127143

128144
cuda_context_created = has_cuda_context()
129145
if (
130-
cuda_context_created is not False
131-
and cuda_context_created != cuda_visible_device
146+
cuda_context_created.has_context
147+
and cuda_context_created.device_info.uuid != cuda_visible_device.uuid
132148
):
133149
_warn_cuda_context_wrong_device(
134-
cuda_visible_device, cuda_context_created, os.getpid()
150+
cuda_visible_device, cuda_context_created.device_info, os.getpid()
135151
)
136152

137153
import ucp as _ucp

distributed/diagnostics/nvml.py

+137-14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
from enum import IntEnum, auto
55
from platform import uname
6+
from typing import NamedTuple
67

78
from packaging.version import parse as parse_version
89

@@ -29,6 +30,17 @@ class NVMLState(IntEnum):
2930
"""PyNVML and NVML available, but on WSL and the driver version is insufficient"""
3031

3132

33+
class CudaDeviceInfo(NamedTuple):
34+
uuid: bytes | None = None
35+
device_index: int | None = None
36+
mig_index: int | None = None
37+
38+
39+
class CudaContext(NamedTuple):
40+
has_context: bool
41+
device_info: CudaDeviceInfo | None = None
42+
43+
3244
# Initialisation must occur per-process, so an initialised state is a
3345
# (state, pid) pair
3446
NVML_STATE = (
@@ -147,27 +159,138 @@ def _pynvml_handles():
147159
return pynvml.nvmlDeviceGetHandleByIndex(gpu_idx)
148160

149161

162+
def _running_process_matches(handle):
163+
"""Check whether the current process is same as that of handle
164+
165+
Parameters
166+
----------
167+
handle : pyvnml.nvml.LP_struct_c_nvmlDevice_t
168+
NVML handle to CUDA device
169+
170+
Returns
171+
-------
172+
out : bool
173+
Whether the device handle has a CUDA context on the running process.
174+
"""
175+
init_once()
176+
if hasattr(pynvml, "nvmlDeviceGetComputeRunningProcesses_v2"):
177+
running_processes = pynvml.nvmlDeviceGetComputeRunningProcesses_v2(handle)
178+
else:
179+
running_processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
180+
return any(os.getpid() == proc.pid for proc in running_processes)
181+
182+
150183
def has_cuda_context():
151184
"""Check whether the current process already has a CUDA context created.
152185
153186
Returns
154187
-------
155-
``False`` if current process has no CUDA context created, otherwise returns the
156-
index of the device for which there's a CUDA context.
188+
out : CudaContext
189+
Object containing information as to whether the current process has a CUDA
190+
context created, and in the positive case containing also information about
191+
the device the context belongs to.
157192
"""
158193
init_once()
159-
if not is_initialized():
160-
return False
161-
for index in range(device_get_count()):
162-
handle = pynvml.nvmlDeviceGetHandleByIndex(index)
163-
if hasattr(pynvml, "nvmlDeviceGetComputeRunningProcesses_v2"):
164-
running_processes = pynvml.nvmlDeviceGetComputeRunningProcesses_v2(handle)
165-
else:
166-
running_processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
167-
for proc in running_processes:
168-
if os.getpid() == proc.pid:
169-
return index
170-
return False
194+
if is_initialized():
195+
for index in range(device_get_count()):
196+
handle = pynvml.nvmlDeviceGetHandleByIndex(index)
197+
try:
198+
mig_current_mode, mig_pending_mode = pynvml.nvmlDeviceGetMigMode(handle)
199+
except pynvml.NVMLError_NotSupported:
200+
mig_current_mode = pynvml.NVML_DEVICE_MIG_DISABLE
201+
if mig_current_mode == pynvml.NVML_DEVICE_MIG_ENABLE:
202+
for mig_index in range(pynvml.nvmlDeviceGetMaxMigDeviceCount(handle)):
203+
try:
204+
mig_handle = pynvml.nvmlDeviceGetMigDeviceHandleByIndex(
205+
handle, mig_index
206+
)
207+
except pynvml.NVMLError_NotFound:
208+
# No MIG device with that index
209+
continue
210+
if _running_process_matches(mig_handle):
211+
uuid = pynvml.nvmlDeviceGetUUID(mig_handle)
212+
return CudaContext(
213+
has_context=True,
214+
device_info=CudaDeviceInfo(
215+
uuid=uuid, device_index=index, mig_index=mig_index
216+
),
217+
)
218+
else:
219+
if _running_process_matches(handle):
220+
uuid = pynvml.nvmlDeviceGetUUID(handle)
221+
return CudaContext(
222+
has_context=True,
223+
device_info=CudaDeviceInfo(uuid=uuid, device_index=index),
224+
)
225+
return CudaContext(has_context=False)
226+
227+
228+
def get_device_index_and_uuid(device):
229+
"""Get both device index and UUID from device index or UUID
230+
231+
Parameters
232+
----------
233+
device : int, bytes or str
234+
An ``int`` with the index of a GPU, or ``bytes`` or ``str`` with the UUID
235+
of a CUDA (either GPU or MIG) device.
236+
237+
Returns
238+
-------
239+
out : CudaDeviceInfo
240+
Object containing information about the device.
241+
242+
Examples
243+
--------
244+
>>> get_device_index_and_uuid(0) # doctest: +SKIP
245+
{'device-index': 0, 'uuid': b'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
246+
247+
>>> get_device_index_and_uuid('GPU-e1006a74-5836-264f-5c26-53d19d212dfe') # doctest: +SKIP
248+
{'device-index': 0, 'uuid': b'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
249+
250+
>>> get_device_index_and_uuid('MIG-7feb6df5-eccf-5faa-ab00-9a441867e237') # doctest: +SKIP
251+
{'device-index': 0, 'uuid': b'MIG-7feb6df5-eccf-5faa-ab00-9a441867e237'}
252+
"""
253+
init_once()
254+
try:
255+
device_index = int(device)
256+
device_handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
257+
uuid = pynvml.nvmlDeviceGetUUID(device_handle)
258+
except ValueError:
259+
uuid = device if isinstance(device, bytes) else bytes(device, "utf-8")
260+
261+
# Validate UUID, get index and UUID as seen with `nvidia-smi -L`
262+
uuid_handle = pynvml.nvmlDeviceGetHandleByUUID(uuid)
263+
device_index = pynvml.nvmlDeviceGetIndex(uuid_handle)
264+
uuid = pynvml.nvmlDeviceGetUUID(uuid_handle)
265+
266+
return CudaDeviceInfo(uuid=uuid, device_index=device_index)
267+
268+
269+
def get_device_mig_mode(device):
270+
"""Get MIG mode for a device index or UUID
271+
272+
Parameters
273+
----------
274+
device: int, bytes or str
275+
An ``int`` with the index of a GPU, or ``bytes`` or ``str`` with the UUID
276+
of a CUDA (either GPU or MIG) device.
277+
278+
Returns
279+
-------
280+
out : list
281+
A ``list`` with two integers ``[current_mode, pending_mode]``.
282+
"""
283+
init_once()
284+
try:
285+
device_index = int(device)
286+
handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
287+
except ValueError:
288+
uuid = device if isinstance(device, bytes) else bytes(device, "utf-8")
289+
handle = pynvml.nvmlDeviceGetHandleByUUID(uuid)
290+
try:
291+
return pynvml.nvmlDeviceGetMigMode(handle)
292+
except pynvml.NVMLError_NotSupported:
293+
return [0, 0]
171294

172295

173296
def _get_utilization(h):

distributed/diagnostics/tests/test_nvml.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,25 @@ def test_wsl_monitoring_enabled():
5656

5757
def run_has_cuda_context(queue):
5858
try:
59-
assert not nvml.has_cuda_context()
59+
assert not nvml.has_cuda_context().has_context
6060

6161
import numba.cuda
6262

6363
numba.cuda.current_context()
64-
assert nvml.has_cuda_context() == 0
64+
ctx = nvml.has_cuda_context()
65+
assert (
66+
ctx.has_context
67+
and ctx.device_info.device_index == 0
68+
and isinstance(ctx.device_info.uuid, bytes)
69+
)
6570

6671
queue.put(None)
6772

6873
except Exception as e:
6974
queue.put(e)
7075

7176

77+
@pytest.mark.xfail(reason="If running on Docker, requires --pid=host")
7278
def test_has_cuda_context():
7379
if nvml.device_get_count() < 1:
7480
pytest.skip("No GPUs available")

distributed/utils_test.py

+17
Original file line numberDiff line numberDiff line change
@@ -2130,6 +2130,23 @@ def ucx_loop():
21302130
ucp.reset()
21312131
loop.close()
21322132

2133+
# Reset also Distributed's UCX initialization, i.e., revert the effects of
2134+
# `distributed.comm.ucx.init_once()`.
2135+
import distributed.comm.ucx
2136+
2137+
distributed.comm.ucx.ucp = None
2138+
# If the test created a context, clean it up.
2139+
# TODO: should we check if there's already a context _before_ the test runs?
2140+
# I think that would be useful.
2141+
from distributed.diagnostics.nvml import has_cuda_context
2142+
2143+
ctx = has_cuda_context()
2144+
if ctx.has_context:
2145+
import numba.cuda
2146+
2147+
ctx = numba.cuda.current_context()
2148+
ctx.device.reset()
2149+
21332150

21342151
def wait_for_log_line(
21352152
match: bytes, stream: IO[bytes] | None, max_lines: int | None = 10

0 commit comments

Comments
 (0)