Skip to content

Commit dc652f3

Browse files
crusaderkyrgommerslucascolley
authored andcommitted
ENH: lazy_apply (data-apis#86)
* ENH: New function `lazy_apply` * Update docs/api-lazy.md Co-authored-by: Ralf Gommers <[email protected]> * Update src/array_api_extra/_lib/_lazy.py Co-authored-by: Ralf Gommers <[email protected]> * Code review * Remove kwargs introspection; support None | complex args * Don't always import numpy * update lockfile * appease mypy --------- Co-authored-by: Ralf Gommers <[email protected]> Co-authored-by: Lucas Colley <[email protected]>
1 parent a09b40b commit dc652f3

14 files changed

+3339
-2988
lines changed

docs/api-lazy.md

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Tools for lazy backends
2+
3+
These additional functions are meant to be used to support compatibility with
4+
lazy backends, e.g. Dask or JAX:
5+
6+
```{eval-rst}
7+
.. currentmodule:: array_api_extra
8+
.. autosummary::
9+
:nosignatures:
10+
:toctree: generated
11+
12+
lazy_apply
13+
testing.lazy_xp_function
14+
testing.patch_lazy_xp_functions
15+
```

docs/index.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
:hidden:
66
self
77
api-reference.md
8-
testing-utils.md
8+
api-lazy.md
99
contributing.md
1010
contributors.md
1111
```

docs/testing-utils.md

-14
This file was deleted.

pixi.lock

+2,508-2,951
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ sphinx-autodoc-typehints = "*"
108108
dask-core = "*"
109109
pytest = "*"
110110
typing-extensions = "*"
111+
numpy = "*"
111112

112113
[tool.pixi.feature.docs.tasks]
113114
docs = { cmd = "sphinx-build . build/", cwd = "docs" }

src/array_api_extra/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
setdiff1d,
1515
sinc,
1616
)
17+
from ._lib._lazy import lazy_apply
1718

1819
__version__ = "0.7.0.dev0"
1920

@@ -29,6 +30,7 @@
2930
"expand_dims",
3031
"isclose",
3132
"kron",
33+
"lazy_apply",
3234
"nunique",
3335
"pad",
3436
"setdiff1d",

src/array_api_extra/_lib/_funcs.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737

3838
@overload
39-
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
39+
def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=GL08
4040
cond: Array,
4141
args: Array | tuple[Array, ...],
4242
f1: Callable[..., Array],
@@ -48,7 +48,7 @@ def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ig
4848

4949

5050
@overload
51-
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
51+
def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=GL08
5252
cond: Array,
5353
args: Array | tuple[Array, ...],
5454
f1: Callable[..., Array],
@@ -59,7 +59,7 @@ def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ig
5959
) -> Array: ...
6060

6161

62-
def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
62+
def apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,PR02
6363
cond: Array,
6464
args: Array | tuple[Array, ...],
6565
f1: Callable[..., Array],
@@ -145,7 +145,7 @@ def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
145145
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp)
146146

147147

148-
def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
148+
def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
149149
cond: Array,
150150
f1: Callable[..., Array],
151151
f2: Callable[..., Array] | None,
@@ -743,7 +743,7 @@ def pad(
743743
pad_width_seq = cast(list[tuple[int, int]], list(pad_width))
744744

745745
# https://github.com/python/typeshed/issues/13376
746-
slices: list[slice] = [] # type: ignore[no-any-explicit]
746+
slices: list[slice] = [] # type: ignore[explicit-any]
747747
newshape: list[int] = []
748748
for ax, w_tpl in enumerate(pad_width_seq):
749749
if len(w_tpl) != 2:

0 commit comments

Comments
 (0)