3
3
import os
4
4
from enum import IntEnum , auto
5
5
from platform import uname
6
+ from typing import NamedTuple
6
7
7
8
from packaging .version import parse as parse_version
8
9
@@ -29,6 +30,17 @@ class NVMLState(IntEnum):
29
30
"""PyNVML and NVML available, but on WSL and the driver version is insufficient"""
30
31
31
32
33
+ class CudaDeviceInfo (NamedTuple ):
34
+ uuid : bytes | None = None
35
+ device_index : int | None = None
36
+ mig_index : int | None = None
37
+
38
+
39
+ class CudaContext (NamedTuple ):
40
+ has_context : bool
41
+ device_info : CudaDeviceInfo | None = None
42
+
43
+
32
44
# Initialisation must occur per-process, so an initialised state is a
33
45
# (state, pid) pair
34
46
NVML_STATE = (
@@ -147,27 +159,138 @@ def _pynvml_handles():
147
159
return pynvml .nvmlDeviceGetHandleByIndex (gpu_idx )
148
160
149
161
162
+ def _running_process_matches (handle ):
163
+ """Check whether the current process is same as that of handle
164
+
165
+ Parameters
166
+ ----------
167
+ handle : pyvnml.nvml.LP_struct_c_nvmlDevice_t
168
+ NVML handle to CUDA device
169
+
170
+ Returns
171
+ -------
172
+ out : bool
173
+ Whether the device handle has a CUDA context on the running process.
174
+ """
175
+ init_once ()
176
+ if hasattr (pynvml , "nvmlDeviceGetComputeRunningProcesses_v2" ):
177
+ running_processes = pynvml .nvmlDeviceGetComputeRunningProcesses_v2 (handle )
178
+ else :
179
+ running_processes = pynvml .nvmlDeviceGetComputeRunningProcesses (handle )
180
+ return any (os .getpid () == proc .pid for proc in running_processes )
181
+
182
+
150
183
def has_cuda_context ():
151
184
"""Check whether the current process already has a CUDA context created.
152
185
153
186
Returns
154
187
-------
155
- ``False`` if current process has no CUDA context created, otherwise returns the
156
- index of the device for which there's a CUDA context.
188
+ out : CudaContext
189
+ Object containing information as to whether the current process has a CUDA
190
+ context created, and in the positive case containing also information about
191
+ the device the context belongs to.
157
192
"""
158
193
init_once ()
159
- if not is_initialized ():
160
- return False
161
- for index in range (device_get_count ()):
162
- handle = pynvml .nvmlDeviceGetHandleByIndex (index )
163
- if hasattr (pynvml , "nvmlDeviceGetComputeRunningProcesses_v2" ):
164
- running_processes = pynvml .nvmlDeviceGetComputeRunningProcesses_v2 (handle )
165
- else :
166
- running_processes = pynvml .nvmlDeviceGetComputeRunningProcesses (handle )
167
- for proc in running_processes :
168
- if os .getpid () == proc .pid :
169
- return index
170
- return False
194
+ if is_initialized ():
195
+ for index in range (device_get_count ()):
196
+ handle = pynvml .nvmlDeviceGetHandleByIndex (index )
197
+ try :
198
+ mig_current_mode , mig_pending_mode = pynvml .nvmlDeviceGetMigMode (handle )
199
+ except pynvml .NVMLError_NotSupported :
200
+ mig_current_mode = pynvml .NVML_DEVICE_MIG_DISABLE
201
+ if mig_current_mode == pynvml .NVML_DEVICE_MIG_ENABLE :
202
+ for mig_index in range (pynvml .nvmlDeviceGetMaxMigDeviceCount (handle )):
203
+ try :
204
+ mig_handle = pynvml .nvmlDeviceGetMigDeviceHandleByIndex (
205
+ handle , mig_index
206
+ )
207
+ except pynvml .NVMLError_NotFound :
208
+ # No MIG device with that index
209
+ continue
210
+ if _running_process_matches (mig_handle ):
211
+ uuid = pynvml .nvmlDeviceGetUUID (mig_handle )
212
+ return CudaContext (
213
+ has_context = True ,
214
+ device_info = CudaDeviceInfo (
215
+ uuid = uuid , device_index = index , mig_index = mig_index
216
+ ),
217
+ )
218
+ else :
219
+ if _running_process_matches (handle ):
220
+ uuid = pynvml .nvmlDeviceGetUUID (handle )
221
+ return CudaContext (
222
+ has_context = True ,
223
+ device_info = CudaDeviceInfo (uuid = uuid , device_index = index ),
224
+ )
225
+ return CudaContext (has_context = False )
226
+
227
+
228
+ def get_device_index_and_uuid (device ):
229
+ """Get both device index and UUID from device index or UUID
230
+
231
+ Parameters
232
+ ----------
233
+ device : int, bytes or str
234
+ An ``int`` with the index of a GPU, or ``bytes`` or ``str`` with the UUID
235
+ of a CUDA (either GPU or MIG) device.
236
+
237
+ Returns
238
+ -------
239
+ out : CudaDeviceInfo
240
+ Object containing information about the device.
241
+
242
+ Examples
243
+ --------
244
+ >>> get_device_index_and_uuid(0) # doctest: +SKIP
245
+ {'device-index': 0, 'uuid': b'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
246
+
247
+ >>> get_device_index_and_uuid('GPU-e1006a74-5836-264f-5c26-53d19d212dfe') # doctest: +SKIP
248
+ {'device-index': 0, 'uuid': b'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
249
+
250
+ >>> get_device_index_and_uuid('MIG-7feb6df5-eccf-5faa-ab00-9a441867e237') # doctest: +SKIP
251
+ {'device-index': 0, 'uuid': b'MIG-7feb6df5-eccf-5faa-ab00-9a441867e237'}
252
+ """
253
+ init_once ()
254
+ try :
255
+ device_index = int (device )
256
+ device_handle = pynvml .nvmlDeviceGetHandleByIndex (device_index )
257
+ uuid = pynvml .nvmlDeviceGetUUID (device_handle )
258
+ except ValueError :
259
+ uuid = device if isinstance (device , bytes ) else bytes (device , "utf-8" )
260
+
261
+ # Validate UUID, get index and UUID as seen with `nvidia-smi -L`
262
+ uuid_handle = pynvml .nvmlDeviceGetHandleByUUID (uuid )
263
+ device_index = pynvml .nvmlDeviceGetIndex (uuid_handle )
264
+ uuid = pynvml .nvmlDeviceGetUUID (uuid_handle )
265
+
266
+ return CudaDeviceInfo (uuid = uuid , device_index = device_index )
267
+
268
+
269
+ def get_device_mig_mode (device ):
270
+ """Get MIG mode for a device index or UUID
271
+
272
+ Parameters
273
+ ----------
274
+ device: int, bytes or str
275
+ An ``int`` with the index of a GPU, or ``bytes`` or ``str`` with the UUID
276
+ of a CUDA (either GPU or MIG) device.
277
+
278
+ Returns
279
+ -------
280
+ out : list
281
+ A ``list`` with two integers ``[current_mode, pending_mode]``.
282
+ """
283
+ init_once ()
284
+ try :
285
+ device_index = int (device )
286
+ handle = pynvml .nvmlDeviceGetHandleByIndex (device_index )
287
+ except ValueError :
288
+ uuid = device if isinstance (device , bytes ) else bytes (device , "utf-8" )
289
+ handle = pynvml .nvmlDeviceGetHandleByUUID (uuid )
290
+ try :
291
+ return pynvml .nvmlDeviceGetMigMode (handle )
292
+ except pynvml .NVMLError_NotSupported :
293
+ return [0 , 0 ]
171
294
172
295
173
296
def _get_utilization (h ):
0 commit comments