Skip to content

Commit 2e97b6f

Browse files
committed
WIP apply_numpy_func
1 parent 1708482 commit 2e97b6f

File tree

8 files changed

+369
-1
lines changed

8 files changed

+369
-1
lines changed

Diff for: docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
apply_numpy_func
910
at
1011
atleast_nd
1112
cov

Diff for: docs/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
intersphinx_mapping = {
5555
"python": ("https://docs.python.org/3", None),
5656
"jax": ("https://jax.readthedocs.io/en/latest", None),
57+
"dask": ("https://docs.dask.org/en/stable", None),
5758
}
5859

5960
nitpick_ignore = [

Diff for: src/array_api_extra/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3+
from ._apply import apply_numpy_func
34
from ._funcs import (
45
at,
56
atleast_nd,
@@ -17,6 +18,7 @@
1718
# pylint: disable=duplicate-code
1819
__all__ = [
1920
"__version__",
21+
"apply_numpy_func",
2022
"at",
2123
"atleast_nd",
2224
"cov",

Diff for: src/array_api_extra/_apply.py

+351
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
"""Public API Functions."""
2+
3+
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
4+
from __future__ import annotations
5+
6+
from collections.abc import Callable, Hashable, Mapping, Sequence
7+
from functools import wraps
8+
from types import ModuleType
9+
from typing import TYPE_CHECKING, Any, cast
10+
11+
from ._lib._compat import (
12+
array_namespace,
13+
is_dask_namespace,
14+
is_jax_namespace,
15+
)
16+
from ._lib._typing import Array, DType
17+
18+
if TYPE_CHECKING:
19+
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
20+
from typing import TypeAlias
21+
22+
import numpy as np
23+
import numpy.typing as npt
24+
25+
NumPyObject: TypeAlias = npt.NDArray[DType] | np.generic # type: ignore[no-any-explicit]
26+
27+
28+
def apply_numpy_func( # type: ignore[no-any-explicit]
29+
func: Callable[..., NumPyObject | Sequence[NumPyObject]],
30+
*args: Array,
31+
shapes: Sequence[tuple[int, ...]] | None = None,
32+
dtypes: Sequence[DType] | None = None,
33+
xp: ModuleType | None = None,
34+
input_indices: Sequence[Sequence[Hashable]] | None = None,
35+
core_indices: Sequence[Hashable] | None = None,
36+
output_indices: Sequence[Sequence[Hashable]] | None = None,
37+
adjust_chunks: Sequence[dict[Hashable, Callable[[int], int]]] | None = None,
38+
new_axes: Sequence[dict[Hashable, int]] | None = None,
39+
**kwargs: Any,
40+
) -> tuple[Array, ...]:
41+
"""
42+
Apply a function that operates on NumPy arrays to Array API compliant arrays.
43+
44+
Parameters
45+
----------
46+
func : callable
47+
The function to apply. It must accept one or more NumPy arrays or generics as
48+
positional arguments and return either a single NumPy array or generic, or a
49+
tuple or list thereof.
50+
51+
It must be a pure function, i.e. without side effects such as disk output,
52+
as depending on the backend it may be executed more than once.
53+
*args : Array
54+
One or more Array API compliant arrays. You need to be able to apply
55+
``np.asarray()`` to them to convert them to numpy; read notes below about
56+
specific backends.
57+
shapes : Sequence[tuple[int, ...]], optional
58+
Sequence of output shapes, one for each output of `func`.
59+
If `func` returns a single (non-sequence) output, this must be a sequence
60+
with a single element.
61+
Default: assume a single output and broadcast shapes of the input arrays.
62+
dtypes : Sequence[DType], optional
63+
Sequence of output dtypes, one for each output of `func`.
64+
If `func` returns a single (non-sequence) output, this must be a sequence
65+
with a single element.
66+
Default: infer the result type(s) from the input arrays.
67+
xp : array_namespace, optional
68+
The standard-compatible namespace for `args`. Default: infer.
69+
input_indices : Sequence[Sequence[Hashable]], optional
70+
Dask specific.
71+
Axes labels for each input array, e.g. if there are two args with respectively
72+
ndim=3 and 1, `input_indices` could be ``['ijk', 'j']`` or ``[(0, 1, 2),
73+
(1,)]``.
74+
Default: disallow Dask.
75+
core_indices : Sequence[Hashable], optional
76+
**Dask specific.**
77+
Axes of the input arrays that cannot be broken into chunks.
78+
Default: disallow Dask.
79+
output_indices : Sequence[Sequence[Hashable]], optional
80+
**Dask specific.**
81+
Axes labels for each output array. If `func` returns a single (non-sequence)
82+
output, this must be a sequence containing a single sequence of labels, e.g.
83+
``['ijk']``.
84+
Default: disallow Dask.
85+
adjust_chunks : Sequence[Mapping[Hashable, Callable[[int], int]]], optional
86+
**Dask specific.**
87+
Sequence of dicts, one per output, mapping index to function to be applied to
88+
each chunk to determine the output size. The total must add up to the output
89+
shape.
90+
Default: on Dask, the size along each index cannot change.
91+
new_axes : Sequence[Mapping[Hashable, int]], optional
92+
**Dask specific.**
93+
New indexes and their dimension lengths, one per output.
94+
Default: on Dask, there can't be `output_indices` that don't appear in
95+
`input_indices`.
96+
**kwargs : Any, optional
97+
Additional keyword arguments to pass verbatim to `func`.
98+
Any array objects in them won't be converted to NumPy.
99+
100+
Returns
101+
-------
102+
tuple[Array, ...]
103+
The result(s) of `func` applied to the input arrays.
104+
This is always a tuple, even if `func` returns a single output.
105+
106+
Notes
107+
-----
108+
JAX
109+
This allows applying eager functions to jitted JAX arrays, which are lazy.
110+
The function won't be applied until the JAX array is materialized.
111+
112+
The `JAX transfer guard
113+
<https://jax.readthedocs.io/en/latest/transfer_guard.html>`_
114+
may prevent arrays on a GPU device from being transferred back to CPU.
115+
This is treated as an implicit transfer.
116+
117+
PyTorch, CuPy
118+
These backends raise by default if you attempt to convert arrays on a GPU device
119+
to NumPy.
120+
121+
Sparse
122+
By default, sparse prevents implicit densification through ``np.asarray`.
123+
`This safety mechanism can be disabled
124+
<https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
125+
126+
Dask
127+
This allows applying eager functions to the individual chunks of dask arrays.
128+
The dask graph won't be computed. As a special limitation, `func` must return
129+
exactly one output.
130+
131+
In order to enable running on Dask you need to specify at least
132+
`input_indices`, `output_indices`, and `core_indices`, but you may also need
133+
`adjust_chunks` and `new_axes` depending on the function.
134+
135+
Read `dask.array.blockwise`:
136+
- ``input_indices`` map to the even ``*args`` of `dask.array.blockwise`
137+
- ``output_indices[0]`` maps to the ``out_ind`` parameter
138+
- ``adjust_chunks[0]`` maps to the ``adjust_chunks`` parameter
139+
- ``new_axes[0]`` maps to the ``new_axes`` parameter
140+
141+
``core_indices`` is a safety measure to prevent incorrect results on
142+
Dask along chunked axes. Consider this::
143+
144+
>>> apply_numpy_func(lambda x: x + x.sum(axis=0), x,
145+
... input_indices=['ij'], output_indices=['ij'])
146+
147+
The above example would produce incorrect results if x is a dask array with more
148+
than one chunk along axis 0, as each chunk will calculate its own local
149+
subtotal. To prevent this, we need to declare the first axis of ``args[0]`` as a
150+
*core axis*::
151+
152+
>>> apply_numpy_func(lambda x: x + x.sum(axis=0), x,
153+
... input_indices=['ij'], output_indices=['ij'],
154+
... core_indices='i')
155+
156+
This will cause `apply_numpy_func` to raise if the first axis of `x` is broken
157+
along multiple chunks, thus forcing the final user to rechunk ahead of time:
158+
159+
>>> x = x.chunk({0: -1})
160+
161+
This needs to always be a conscious decision on behalf of the final user, as the
162+
new chunks will be larger than the old and may cause memory issues, unless chunk
163+
size is reduced along a different, non-core axis.
164+
"""
165+
if xp is None:
166+
xp = array_namespace(*args)
167+
if shapes is None:
168+
shapes = [xp.broadcast_shapes(*(arg.shape for arg in args))]
169+
if dtypes is None:
170+
dtypes = [xp.result_type(*args)] * len(shapes)
171+
172+
if len(shapes) != len(dtypes):
173+
msg = f"got {len(shapes)} shapes and {len(dtypes)} dtypes"
174+
raise ValueError(msg)
175+
if len(shapes) == 0:
176+
msg = "Must have at least one output array"
177+
raise ValueError(msg)
178+
179+
if is_dask_namespace(xp):
180+
# General validation
181+
if len(shapes) > 1:
182+
msg = "dask.array.map_blocks() does not support multiple outputs"
183+
raise NotImplementedError(msg)
184+
if input_indices is None or output_indices is None or core_indices is None:
185+
msg = (
186+
"Dask is disallowed unless one declares input_indices, "
187+
"output_indices, and core_indices"
188+
)
189+
raise ValueError(msg)
190+
if len(input_indices) != len(args):
191+
msg = f"got {len(input_indices)} input_indices and {len(args)} args"
192+
raise ValueError(msg)
193+
if len(output_indices) != len(shapes):
194+
msg = f"got {len(output_indices)} input_indices and {len(shapes)} shapes"
195+
raise NotImplementedError(msg)
196+
if isinstance(adjust_chunks, Mapping):
197+
msg = "adjust_chunks must be a sequence of mappings"
198+
raise ValueError(msg)
199+
if adjust_chunks is not None and len(adjust_chunks) != len(shapes):
200+
msg = f"got {len(adjust_chunks)} adjust_chunks and {len(shapes)} shapes"
201+
raise ValueError(msg)
202+
if isinstance(new_axes, Mapping):
203+
msg = "new_axes must be a sequence of mappings"
204+
raise ValueError(msg)
205+
if new_axes is not None and len(new_axes) != len(shapes):
206+
msg = f"got {len(new_axes)} new_axes and {len(shapes)} shapes"
207+
raise ValueError(msg)
208+
209+
# core_indices validation
210+
for inp_idx, arg in zip(input_indices, args, strict=True):
211+
for i, chunks in zip(inp_idx, arg.chunks, strict=True):
212+
if i in core_indices and len(chunks) > 1:
213+
msg = f"Core index {i} is broken into multiple chunks"
214+
raise ValueError(msg)
215+
216+
meta_xp = array_namespace(*(getattr(arg, "meta", None) for arg in args))
217+
wrapped = _npfunc_single_output_wrapper(func, meta_xp)
218+
dask_args = []
219+
for arg, inp_idx in zip(args, input_indices, strict=True):
220+
dask_args += [arg, inp_idx]
221+
222+
out = xp.blockwise(
223+
wrapped,
224+
output_indices[0],
225+
*dask_args,
226+
dtype=dtypes[0],
227+
adjust_chunks=adjust_chunks[0] if adjust_chunks is not None else None,
228+
new_axes=new_axes[0] if new_axes is not None else None,
229+
**kwargs,
230+
)
231+
if out.shape != shapes[0]:
232+
msg = f"expected shape {shapes[0]}, but got {out.shape} from indices"
233+
raise ValueError(msg)
234+
return (out,)
235+
236+
wrapped = _npfunc_tuple_output_wrapper(func, xp)
237+
if is_jax_namespace(xp):
238+
# If we're inside jax.jit, we can't eagerly convert
239+
# the JAX tracer objects to numpy.
240+
# Instead, we delay calling wrapped, which will receive
241+
# as arguments and will return JAX eager arrays.
242+
import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
243+
244+
return cast(
245+
tuple[Array, ...],
246+
jax.pure_callback(
247+
wrapped,
248+
tuple(
249+
jax.ShapeDtypeStruct(s, dt) # pyright: ignore[reportUnknownArgumentType]
250+
for s, dt in zip(shapes, dtypes, strict=True)
251+
),
252+
*args,
253+
**kwargs,
254+
),
255+
)
256+
257+
# Eager backends
258+
out = wrapped(*args, **kwargs)
259+
260+
# Output validation
261+
if len(out) != len(shapes):
262+
msg = f"func was declared to return {len(shapes)} outputs, got {len(out)}"
263+
raise ValueError(msg)
264+
for out_i, shape_i, dtype_i in zip(out, shapes, dtypes, strict=True):
265+
if out_i.shape != shape_i:
266+
msg = f"expected shape {shape_i}, got {out_i.shape}"
267+
raise ValueError(msg)
268+
if not xp.isdtype(out_i.dtype, dtype_i):
269+
msg = f"expected dtype {dtype_i}, got {out_i.dtype}"
270+
raise ValueError(msg)
271+
return out # type: ignore[no-any-return]
272+
273+
274+
def _npfunc_tuple_output_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
275+
func: Callable[..., NumPyObject | Sequence[NumPyObject]],
276+
xp: ModuleType,
277+
) -> Callable[..., tuple[Array, ...]]:
278+
"""
279+
Helper of `apply_numpy_func`.
280+
281+
Given a function that accepts one or more numpy arrays as positional arguments and
282+
returns a single numpy array or a sequence of numpy arrays,
283+
return a function that accepts the same number of Array API arrays and always
284+
returns a tuple of Array API array.
285+
286+
Any keyword arguments are passed through verbatim to the wrapped function.
287+
288+
Raise if np.asarray() raises on any input. This typically happens if the input is
289+
lazy and has a guard against being implicitly turned into a NumPy array (e.g.
290+
densification for sparse arrays, device->host transfer for cupy and torch arrays).
291+
"""
292+
293+
@wraps(func)
294+
def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
295+
*args: Array, **kwargs: Any
296+
) -> tuple[Array, ...]: # numpydoc ignore=GL08
297+
import numpy as np # pylint: disable=import-outside-toplevel
298+
299+
args = tuple(np.asarray(arg) for arg in args)
300+
out = func(*args, **kwargs)
301+
302+
if isinstance(out, np.ndarray | np.generic):
303+
out = (out,)
304+
elif not isinstance(out, Sequence): # pyright: ignore[reportUnnecessaryIsInstance]
305+
msg = (
306+
"apply_numpy_func: func must return a numpy object or a "
307+
f"sequence of numpy objects; got {out}"
308+
)
309+
raise TypeError(msg)
310+
311+
return tuple(xp.asarray(o) for o in out)
312+
313+
return wrapper
314+
315+
316+
def _npfunc_single_output_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
317+
func: Callable[..., NumPyObject | Sequence[NumPyObject]],
318+
xp: ModuleType,
319+
) -> Callable[..., Array]:
320+
"""
321+
Dask-specific helper of `apply_numpy_func`.
322+
323+
Variant of `_npfunc_tuple_output_wrapper`, to be used with Dask which, at the time
324+
of writing, does not support multiple outputs in `dask.array.blockwise`.
325+
326+
func may return a single numpy object or a sequence with exactly one numpy object.
327+
The wrapper returns a single Array object, with no tuple wrapping.
328+
"""
329+
330+
# @wraps causes the generated dask key to contain the name of the wrapped function
331+
@wraps(func)
332+
def wrapper( # type: ignore[no-any-decorated,no-any-explicit] # numpydoc ignore=GL08
333+
*args: Array, **kwargs: Any
334+
) -> Array:
335+
import numpy as np # pylint: disable=import-outside-toplevel
336+
337+
args = tuple(np.asarray(arg) for arg in args)
338+
out = func(*args, **kwargs)
339+
340+
if not isinstance(out, np.ndarray | np.generic):
341+
if not isinstance(out, Sequence) or len(out) != 1: # pyright: ignore[reportUnnecessaryIsInstance]
342+
msg = (
343+
"apply_numpy_func: func must return a single numpy object or a "
344+
f"sequence with exactly one numpy object; got {out}"
345+
)
346+
raise ValueError(msg)
347+
out = out[0]
348+
349+
return xp.asarray(out)
350+
351+
return wrapper

0 commit comments

Comments
 (0)