|
| 1 | +"""Public API Functions.""" |
| 2 | + |
| 3 | +# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +from collections.abc import Callable, Hashable, Mapping, Sequence |
| 7 | +from functools import wraps |
| 8 | +from types import ModuleType |
| 9 | +from typing import TYPE_CHECKING, Any, cast |
| 10 | + |
| 11 | +from ._lib._compat import ( |
| 12 | + array_namespace, |
| 13 | + is_dask_namespace, |
| 14 | + is_jax_namespace, |
| 15 | +) |
| 16 | +from ._lib._typing import Array, DType |
| 17 | + |
| 18 | +if TYPE_CHECKING: |
| 19 | + # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 |
| 20 | + from typing import TypeAlias |
| 21 | + |
| 22 | + import numpy as np |
| 23 | + import numpy.typing as npt |
| 24 | + |
| 25 | + NumPyObject: TypeAlias = npt.NDArray[DType] | np.generic # type: ignore[no-any-explicit] |
| 26 | + |
| 27 | + |
| 28 | +def apply_numpy_func( # type: ignore[no-any-explicit] |
| 29 | + func: Callable[..., NumPyObject | Sequence[NumPyObject]], |
| 30 | + *args: Array, |
| 31 | + shapes: Sequence[tuple[int, ...]] | None = None, |
| 32 | + dtypes: Sequence[DType] | None = None, |
| 33 | + xp: ModuleType | None = None, |
| 34 | + input_indices: Sequence[Sequence[Hashable]] | None = None, |
| 35 | + core_indices: Sequence[Hashable] | None = None, |
| 36 | + output_indices: Sequence[Sequence[Hashable]] | None = None, |
| 37 | + adjust_chunks: Sequence[dict[Hashable, Callable[[int], int]]] | None = None, |
| 38 | + new_axes: Sequence[dict[Hashable, int]] | None = None, |
| 39 | + **kwargs: Any, |
| 40 | +) -> tuple[Array, ...]: |
| 41 | + """ |
| 42 | + Apply a function that operates on NumPy arrays to Array API compliant arrays. |
| 43 | +
|
| 44 | + Parameters |
| 45 | + ---------- |
| 46 | + func : callable |
| 47 | + The function to apply. It must accept one or more NumPy arrays or generics as |
| 48 | + positional arguments and return either a single NumPy array or generic, or a |
| 49 | + tuple or list thereof. |
| 50 | +
|
| 51 | + It must be a pure function, i.e. without side effects such as disk output, |
| 52 | + as depending on the backend it may be executed more than once. |
| 53 | + *args : Array |
| 54 | + One or more Array API compliant arrays. You need to be able to apply |
| 55 | + ``np.asarray()`` to them to convert them to numpy; read notes below about |
| 56 | + specific backends. |
| 57 | + shapes : Sequence[tuple[int, ...]], optional |
| 58 | + Sequence of output shapes, one for each output of `func`. |
| 59 | + If `func` returns a single (non-sequence) output, this must be a sequence |
| 60 | + with a single element. |
| 61 | + Default: assume a single output and broadcast shapes of the input arrays. |
| 62 | + dtypes : Sequence[DType], optional |
| 63 | + Sequence of output dtypes, one for each output of `func`. |
| 64 | + If `func` returns a single (non-sequence) output, this must be a sequence |
| 65 | + with a single element. |
| 66 | + Default: infer the result type(s) from the input arrays. |
| 67 | + xp : array_namespace, optional |
| 68 | + The standard-compatible namespace for `args`. Default: infer. |
| 69 | + input_indices : Sequence[Sequence[Hashable]], optional |
| 70 | + Dask specific. |
| 71 | + Axes labels for each input array, e.g. if there are two args with respectively |
| 72 | + ndim=3 and 1, `input_indices` could be ``['ijk', 'j']`` or ``[(0, 1, 2), |
| 73 | + (1,)]``. |
| 74 | + Default: disallow Dask. |
| 75 | + core_indices : Sequence[Hashable], optional |
| 76 | + **Dask specific.** |
| 77 | + Axes of the input arrays that cannot be broken into chunks. |
| 78 | + Default: disallow Dask. |
| 79 | + output_indices : Sequence[Sequence[Hashable]], optional |
| 80 | + **Dask specific.** |
| 81 | + Axes labels for each output array. If `func` returns a single (non-sequence) |
| 82 | + output, this must be a sequence containing a single sequence of labels, e.g. |
| 83 | + ``['ijk']``. |
| 84 | + Default: disallow Dask. |
| 85 | + adjust_chunks : Sequence[Mapping[Hashable, Callable[[int], int]]], optional |
| 86 | + **Dask specific.** |
| 87 | + Sequence of dicts, one per output, mapping index to function to be applied to |
| 88 | + each chunk to determine the output size. The total must add up to the output |
| 89 | + shape. |
| 90 | + Default: on Dask, the size along each index cannot change. |
| 91 | + new_axes : Sequence[Mapping[Hashable, int]], optional |
| 92 | + **Dask specific.** |
| 93 | + New indexes and their dimension lengths, one per output. |
| 94 | + Default: on Dask, there can't be `output_indices` that don't appear in |
| 95 | + `input_indices`. |
| 96 | + **kwargs : Any, optional |
| 97 | + Additional keyword arguments to pass verbatim to `func`. |
| 98 | + Any array objects in them won't be converted to NumPy. |
| 99 | +
|
| 100 | + Returns |
| 101 | + ------- |
| 102 | + tuple[Array, ...] |
| 103 | + The result(s) of `func` applied to the input arrays. |
| 104 | + This is always a tuple, even if `func` returns a single output. |
| 105 | +
|
| 106 | + Notes |
| 107 | + ----- |
| 108 | + JAX |
| 109 | + This allows applying eager functions to jitted JAX arrays, which are lazy. |
| 110 | + The function won't be applied until the JAX array is materialized. |
| 111 | +
|
| 112 | + The `JAX transfer guard |
| 113 | + <https://jax.readthedocs.io/en/latest/transfer_guard.html>`_ |
| 114 | + may prevent arrays on a GPU device from being transferred back to CPU. |
| 115 | + This is treated as an implicit transfer. |
| 116 | +
|
| 117 | + PyTorch, CuPy |
| 118 | + These backends raise by default if you attempt to convert arrays on a GPU device |
| 119 | + to NumPy. |
| 120 | +
|
| 121 | + Sparse |
| 122 | + By default, sparse prevents implicit densification through ``np.asarray`. |
| 123 | + `This safety mechanism can be disabled |
| 124 | + <https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_. |
| 125 | +
|
| 126 | + Dask |
| 127 | + This allows applying eager functions to the individual chunks of dask arrays. |
| 128 | + The dask graph won't be computed. As a special limitation, `func` must return |
| 129 | + exactly one output. |
| 130 | +
|
| 131 | + In order to enable running on Dask you need to specify at least |
| 132 | + `input_indices`, `output_indices`, and `core_indices`, but you may also need |
| 133 | + `adjust_chunks` and `new_axes` depending on the function. |
| 134 | +
|
| 135 | + Read `dask.array.blockwise`: |
| 136 | + - ``input_indices`` map to the even ``*args`` of `dask.array.blockwise` |
| 137 | + - ``output_indices[0]`` maps to the ``out_ind`` parameter |
| 138 | + - ``adjust_chunks[0]`` maps to the ``adjust_chunks`` parameter |
| 139 | + - ``new_axes[0]`` maps to the ``new_axes`` parameter |
| 140 | +
|
| 141 | + ``core_indices`` is a safety measure to prevent incorrect results on |
| 142 | + Dask along chunked axes. Consider this:: |
| 143 | +
|
| 144 | + >>> apply_numpy_func(lambda x: x + x.sum(axis=0), x, |
| 145 | + ... input_indices=['ij'], output_indices=['ij']) |
| 146 | +
|
| 147 | + The above example would produce incorrect results if x is a dask array with more |
| 148 | + than one chunk along axis 0, as each chunk will calculate its own local |
| 149 | + subtotal. To prevent this, we need to declare the first axis of ``args[0]`` as a |
| 150 | + *core axis*:: |
| 151 | +
|
| 152 | + >>> apply_numpy_func(lambda x: x + x.sum(axis=0), x, |
| 153 | + ... input_indices=['ij'], output_indices=['ij'], |
| 154 | + ... core_indices='i') |
| 155 | +
|
| 156 | + This will cause `apply_numpy_func` to raise if the first axis of `x` is broken |
| 157 | + along multiple chunks, thus forcing the final user to rechunk ahead of time: |
| 158 | +
|
| 159 | + >>> x = x.chunk({0: -1}) |
| 160 | +
|
| 161 | + This needs to always be a conscious decision on behalf of the final user, as the |
| 162 | + new chunks will be larger than the old and may cause memory issues, unless chunk |
| 163 | + size is reduced along a different, non-core axis. |
| 164 | + """ |
| 165 | + if xp is None: |
| 166 | + xp = array_namespace(*args) |
| 167 | + if shapes is None: |
| 168 | + shapes = [xp.broadcast_shapes(*(arg.shape for arg in args))] |
| 169 | + if dtypes is None: |
| 170 | + dtypes = [xp.result_type(*args)] * len(shapes) |
| 171 | + |
| 172 | + if len(shapes) != len(dtypes): |
| 173 | + msg = f"got {len(shapes)} shapes and {len(dtypes)} dtypes" |
| 174 | + raise ValueError(msg) |
| 175 | + if len(shapes) == 0: |
| 176 | + msg = "Must have at least one output array" |
| 177 | + raise ValueError(msg) |
| 178 | + |
| 179 | + if is_dask_namespace(xp): |
| 180 | + # General validation |
| 181 | + if len(shapes) > 1: |
| 182 | + msg = "dask.array.map_blocks() does not support multiple outputs" |
| 183 | + raise NotImplementedError(msg) |
| 184 | + if input_indices is None or output_indices is None or core_indices is None: |
| 185 | + msg = ( |
| 186 | + "Dask is disallowed unless one declares input_indices, " |
| 187 | + "output_indices, and core_indices" |
| 188 | + ) |
| 189 | + raise ValueError(msg) |
| 190 | + if len(input_indices) != len(args): |
| 191 | + msg = f"got {len(input_indices)} input_indices and {len(args)} args" |
| 192 | + raise ValueError(msg) |
| 193 | + if len(output_indices) != len(shapes): |
| 194 | + msg = f"got {len(output_indices)} input_indices and {len(shapes)} shapes" |
| 195 | + raise NotImplementedError(msg) |
| 196 | + if isinstance(adjust_chunks, Mapping): |
| 197 | + msg = "adjust_chunks must be a sequence of mappings" |
| 198 | + raise ValueError(msg) |
| 199 | + if adjust_chunks is not None and len(adjust_chunks) != len(shapes): |
| 200 | + msg = f"got {len(adjust_chunks)} adjust_chunks and {len(shapes)} shapes" |
| 201 | + raise ValueError(msg) |
| 202 | + if isinstance(new_axes, Mapping): |
| 203 | + msg = "new_axes must be a sequence of mappings" |
| 204 | + raise ValueError(msg) |
| 205 | + if new_axes is not None and len(new_axes) != len(shapes): |
| 206 | + msg = f"got {len(new_axes)} new_axes and {len(shapes)} shapes" |
| 207 | + raise ValueError(msg) |
| 208 | + |
| 209 | + # core_indices validation |
| 210 | + for inp_idx, arg in zip(input_indices, args, strict=True): |
| 211 | + for i, chunks in zip(inp_idx, arg.chunks, strict=True): |
| 212 | + if i in core_indices and len(chunks) > 1: |
| 213 | + msg = f"Core index {i} is broken into multiple chunks" |
| 214 | + raise ValueError(msg) |
| 215 | + |
| 216 | + meta_xp = array_namespace(*(getattr(arg, "meta", None) for arg in args)) |
| 217 | + wrapped = _npfunc_single_output_wrapper(func, meta_xp) |
| 218 | + dask_args = [] |
| 219 | + for arg, inp_idx in zip(args, input_indices, strict=True): |
| 220 | + dask_args += [arg, inp_idx] |
| 221 | + |
| 222 | + out = xp.blockwise( |
| 223 | + wrapped, |
| 224 | + output_indices[0], |
| 225 | + *dask_args, |
| 226 | + dtype=dtypes[0], |
| 227 | + adjust_chunks=adjust_chunks[0] if adjust_chunks is not None else None, |
| 228 | + new_axes=new_axes[0] if new_axes is not None else None, |
| 229 | + **kwargs, |
| 230 | + ) |
| 231 | + if out.shape != shapes[0]: |
| 232 | + msg = f"expected shape {shapes[0]}, but got {out.shape} from indices" |
| 233 | + raise ValueError(msg) |
| 234 | + return (out,) |
| 235 | + |
| 236 | + wrapped = _npfunc_tuple_output_wrapper(func, xp) |
| 237 | + if is_jax_namespace(xp): |
| 238 | + # If we're inside jax.jit, we can't eagerly convert |
| 239 | + # the JAX tracer objects to numpy. |
| 240 | + # Instead, we delay calling wrapped, which will receive |
| 241 | + # as arguments and will return JAX eager arrays. |
| 242 | + import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports] |
| 243 | + |
| 244 | + return cast( |
| 245 | + tuple[Array, ...], |
| 246 | + jax.pure_callback( |
| 247 | + wrapped, |
| 248 | + tuple( |
| 249 | + jax.ShapeDtypeStruct(s, dt) # pyright: ignore[reportUnknownArgumentType] |
| 250 | + for s, dt in zip(shapes, dtypes, strict=True) |
| 251 | + ), |
| 252 | + *args, |
| 253 | + **kwargs, |
| 254 | + ), |
| 255 | + ) |
| 256 | + |
| 257 | + # Eager backends |
| 258 | + out = wrapped(*args, **kwargs) |
| 259 | + |
| 260 | + # Output validation |
| 261 | + if len(out) != len(shapes): |
| 262 | + msg = f"func was declared to return {len(shapes)} outputs, got {len(out)}" |
| 263 | + raise ValueError(msg) |
| 264 | + for out_i, shape_i, dtype_i in zip(out, shapes, dtypes, strict=True): |
| 265 | + if out_i.shape != shape_i: |
| 266 | + msg = f"expected shape {shape_i}, got {out_i.shape}" |
| 267 | + raise ValueError(msg) |
| 268 | + if not xp.isdtype(out_i.dtype, dtype_i): |
| 269 | + msg = f"expected dtype {dtype_i}, got {out_i.dtype}" |
| 270 | + raise ValueError(msg) |
| 271 | + return out # type: ignore[no-any-return] |
| 272 | + |
| 273 | + |
| 274 | +def _npfunc_tuple_output_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01 |
| 275 | + func: Callable[..., NumPyObject | Sequence[NumPyObject]], |
| 276 | + xp: ModuleType, |
| 277 | +) -> Callable[..., tuple[Array, ...]]: |
| 278 | + """ |
| 279 | + Helper of `apply_numpy_func`. |
| 280 | +
|
| 281 | + Given a function that accepts one or more numpy arrays as positional arguments and |
| 282 | + returns a single numpy array or a sequence of numpy arrays, |
| 283 | + return a function that accepts the same number of Array API arrays and always |
| 284 | + returns a tuple of Array API array. |
| 285 | +
|
| 286 | + Any keyword arguments are passed through verbatim to the wrapped function. |
| 287 | +
|
| 288 | + Raise if np.asarray() raises on any input. This typically happens if the input is |
| 289 | + lazy and has a guard against being implicitly turned into a NumPy array (e.g. |
| 290 | + densification for sparse arrays, device->host transfer for cupy and torch arrays). |
| 291 | + """ |
| 292 | + |
| 293 | + @wraps(func) |
| 294 | + def wrapper( # type: ignore[no-any-decorated,no-any-explicit] |
| 295 | + *args: Array, **kwargs: Any |
| 296 | + ) -> tuple[Array, ...]: # numpydoc ignore=GL08 |
| 297 | + import numpy as np # pylint: disable=import-outside-toplevel |
| 298 | + |
| 299 | + args = tuple(np.asarray(arg) for arg in args) |
| 300 | + out = func(*args, **kwargs) |
| 301 | + |
| 302 | + if isinstance(out, np.ndarray | np.generic): |
| 303 | + out = (out,) |
| 304 | + elif not isinstance(out, Sequence): # pyright: ignore[reportUnnecessaryIsInstance] |
| 305 | + msg = ( |
| 306 | + "apply_numpy_func: func must return a numpy object or a " |
| 307 | + f"sequence of numpy objects; got {out}" |
| 308 | + ) |
| 309 | + raise TypeError(msg) |
| 310 | + |
| 311 | + return tuple(xp.asarray(o) for o in out) |
| 312 | + |
| 313 | + return wrapper |
| 314 | + |
| 315 | + |
| 316 | +def _npfunc_single_output_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01 |
| 317 | + func: Callable[..., NumPyObject | Sequence[NumPyObject]], |
| 318 | + xp: ModuleType, |
| 319 | +) -> Callable[..., Array]: |
| 320 | + """ |
| 321 | + Dask-specific helper of `apply_numpy_func`. |
| 322 | +
|
| 323 | + Variant of `_npfunc_tuple_output_wrapper`, to be used with Dask which, at the time |
| 324 | + of writing, does not support multiple outputs in `dask.array.blockwise`. |
| 325 | +
|
| 326 | + func may return a single numpy object or a sequence with exactly one numpy object. |
| 327 | + The wrapper returns a single Array object, with no tuple wrapping. |
| 328 | + """ |
| 329 | + |
| 330 | + # @wraps causes the generated dask key to contain the name of the wrapped function |
| 331 | + @wraps(func) |
| 332 | + def wrapper( # type: ignore[no-any-decorated,no-any-explicit] # numpydoc ignore=GL08 |
| 333 | + *args: Array, **kwargs: Any |
| 334 | + ) -> Array: |
| 335 | + import numpy as np # pylint: disable=import-outside-toplevel |
| 336 | + |
| 337 | + args = tuple(np.asarray(arg) for arg in args) |
| 338 | + out = func(*args, **kwargs) |
| 339 | + |
| 340 | + if not isinstance(out, np.ndarray | np.generic): |
| 341 | + if not isinstance(out, Sequence) or len(out) != 1: # pyright: ignore[reportUnnecessaryIsInstance] |
| 342 | + msg = ( |
| 343 | + "apply_numpy_func: func must return a single numpy object or a " |
| 344 | + f"sequence with exactly one numpy object; got {out}" |
| 345 | + ) |
| 346 | + raise ValueError(msg) |
| 347 | + out = out[0] |
| 348 | + |
| 349 | + return xp.asarray(out) |
| 350 | + |
| 351 | + return wrapper |
0 commit comments