@@ -178,7 +178,7 @@ def _check_api_version(api_version):
178
178
elif api_version is not None and api_version != '2022.12' :
179
179
raise ValueError ("Only the 2022.12 version of the array API specification is currently supported" )
180
180
181
- def array_namespace (* xs , api_version = None , _use_compat = True ):
181
+ def array_namespace (* xs , api_version = None , use_compat = None ):
182
182
"""
183
183
Get the array API compatible namespace for the arrays `xs`.
184
184
@@ -191,6 +191,12 @@ def array_namespace(*xs, api_version=None, _use_compat=True):
191
191
The newest version of the spec that you need support for (currently
192
192
the compat library wrapped APIs support v2022.12).
193
193
194
+ use_compat: bool or None
195
+ If None (the default), the native namespace will be returned if it is
196
+ already array API compatible, otherwise a compat wrapper is used. If
197
+ True, the compat library wrapped library will be returned. If False,
198
+ the native library namespace is returned.
199
+
194
200
Returns
195
201
-------
196
202
@@ -234,46 +240,66 @@ def your_function(x, y):
234
240
is_jax_array
235
241
236
242
"""
243
+ if use_compat not in [None , True , False ]:
244
+ raise ValueError ("use_compat must be None, True, or False" )
245
+
246
+ _use_compat = use_compat in [None , True ]
247
+
237
248
namespaces = set ()
238
249
for x in xs :
239
250
if is_numpy_array (x ):
240
- _check_api_version (api_version )
241
- if _use_compat :
242
- from .. import numpy as numpy_namespace
251
+ from .. import numpy as numpy_namespace
252
+ import numpy as np
253
+ if use_compat is True :
254
+ _check_api_version (api_version )
243
255
namespaces .add (numpy_namespace )
244
- else :
245
- import numpy as np
256
+ elif use_compat is False :
246
257
namespaces .add (np )
258
+ else :
259
+ # numpy 2.0 has __array_namespace__ and is fully array API
260
+ # compatible.
261
+ if hasattr (x , '__array_namespace__' ):
262
+ namespaces .add (x .__array_namespace__ (api_version = api_version ))
263
+ else :
264
+ namespaces .add (numpy_namespace )
247
265
elif is_cupy_array (x ):
248
- _check_api_version (api_version )
249
266
if _use_compat :
267
+ _check_api_version (api_version )
250
268
from .. import cupy as cupy_namespace
251
269
namespaces .add (cupy_namespace )
252
270
else :
253
271
import cupy as cp
254
272
namespaces .add (cp )
255
273
elif is_torch_array (x ):
256
- _check_api_version (api_version )
257
274
if _use_compat :
275
+ _check_api_version (api_version )
258
276
from .. import torch as torch_namespace
259
277
namespaces .add (torch_namespace )
260
278
else :
261
279
import torch
262
280
namespaces .add (torch )
263
281
elif is_dask_array (x ):
264
- _check_api_version (api_version )
265
282
if _use_compat :
283
+ _check_api_version (api_version )
266
284
from ..dask import array as dask_namespace
267
285
namespaces .add (dask_namespace )
268
286
else :
269
- raise TypeError ("_use_compat cannot be False if input array is a dask array!" )
287
+ import dask .array as da
288
+ namespaces .add (da )
270
289
elif is_jax_array (x ):
271
- _check_api_version (api_version )
272
- # jax.experimental.array_api is already an array namespace. We do
273
- # not have a wrapper submodule for it.
274
- import jax .experimental .array_api as jnp
290
+ if use_compat is True :
291
+ _check_api_version (api_version )
292
+ raise ValueError ("JAX does not have an array-api-compat wrapper" )
293
+ elif use_compat is False :
294
+ import jax .numpy as jnp
295
+ else :
296
+ # jax.experimental.array_api is already an array namespace. We do
297
+ # not have a wrapper submodule for it.
298
+ import jax .experimental .array_api as jnp
275
299
namespaces .add (jnp )
276
300
elif hasattr (x , '__array_namespace__' ):
301
+ if use_compat is True :
302
+ raise ValueError ("The given array does not have an array-api-compat wrapper" )
277
303
namespaces .add (x .__array_namespace__ (api_version = api_version ))
278
304
else :
279
305
# TODO: Support Python scalars?
0 commit comments