-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: lazy_apply
#86
ENH: lazy_apply
#86
Conversation
src/array_api_extra/_apply.py
Outdated
``core_indices`` is a safety measure to prevent incorrect results on | ||
Dask along chunked axes. Consider this:: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This design was informed from https://docs.xarray.dev/en/latest/generated/xarray.apply_ufunc.html
src/array_api_extra/_apply.py
Outdated
The dask graph won't be computed. As a special limitation, `func` must return | ||
exactly one output. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This limitation is straightforward to fix in Dask (at the cost of API duplication).
Until then, however, I suspect it will be a major roadblock for Dask support in scipy.
It can be also hacked outside of dask but I'm a hesitant to do that for the sake of robustness, as it would rely on deliberately triggering key collisions between diverging graph branches.
src/array_api_extra/_apply.py
Outdated
`input_indices`, `output_indices`, and `core_indices`, but you may also need | ||
`adjust_chunks` and `new_axes` depending on the function. | ||
|
||
Read `dask.array.blockwise`: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/array_api_extra/_apply.py
Outdated
- ``output_indices[0]`` maps to the ``out_ind`` parameter | ||
- ``adjust_chunks[0]`` maps to the ``adjust_chunks`` parameter | ||
- ``new_axes[0]`` maps to the ``new_axes`` parameter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are all lists for forward-compatibility to a yet-to-be-written da.blockwise variant that supports multiple outputs.
src/array_api_extra/_apply.py
Outdated
If `func` returns a single (non-sequence) output, this must be a sequence | ||
with a single element. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried overloading but it was a big headache when validating inputs. I found this approach much simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I am missing some context here. Why are we wrapping arbitrary NumPy functions? Instead of, e.g., considering individual functions which we need one-by-one.
There are many points in scipy that look like this: x = np.asarray(x)
y = np.asarray(y)
z = some_cython_kernel(x, y)
z = xp.asarray(z) None of them will ever work with arrays on GPU devices, of course, and they'll either need a pure-array-api alternative or a dispatch to cupy.scipy etc. None of them work with jitted, cpu-based JAX either, because With Dask, they technically work because there is no materialization guard but most times you would prefer it if there was one. However, it is possible to make these pieces of Cython code work thanks to this PR. There are two competing functions in Dask to achieve this, with different API:
map_blocks is a variant of blockwise with simplified API, which can only work on broadcastable inputs. This problem has already been dealt with by xarray, with https://docs.xarray.dev/en/latest/generated/xarray.apply_ufunc.html. Note that xarray API is more user-friendly thanks to each dimension being labelled at all times, so apply_ufunc can do fancy tricks like auto-transposing the inputs and pushing the dimensions that func doesn't know about to the left. What I tried to implement here is equivalent to |
Okay, thanks. When you say
in the docstring, that isn't strictly true, right? It relies on I understand the utility of this PR now.
I hadn't envisioned tackling this yet. In my mind getting Cython kernels working with dask/jax jit has been in the same "for later" basket as handling device transfers or writing new implementations to delegate to. But if the implementation works, makes sense to tackle it. |
Correct. Nominally it will fail when densifying sparse arrays and moving data from GPU to CPU. A final user can however force their way through, if they want to, by deliberately suppressing transfer/densification guards for the time necessary to run the scipy function. Either that, or do an explicit device to cpu transfer / There is nothing however a jax or dask user can do today, short of completely getting out of the graph generation phase. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @crusaderky!
The enormous amount of extra complication needed to make it work with Dask makes me uncomfortable.
Yes indeed, it does. That looks like it's way too much. I don't think we would like to use all those extra keywords and Dask-specific functions in SciPy. If you'd drop those, does it make Dask completely non-working or is there a subset of functionality that would still work. I'd say that JAX shows that it can be straightforward, and a similar callback mechanism could be used for PyTorch/MLX/ndonnx as well - if that were to exist in those libraries.
src/array_api_extra/_apply.py
Outdated
Sparse | ||
By default, sparse prevents implicit densification through ``np.asarray`. | ||
`This safety mechanism can be disabled | ||
<https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine to leave as is for now I'd say. Once sparse
adds a better API for this (an env var doesn't work), it seems reasonable to add a force=False
option to this function. There are various reasons why one may want to force an expensive conversion; that kind of thing should always be opt-in on a case-by-case basis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be tackled by a more general design pattern?
with disable_guards():
y = scipy.somefunc(x)
where disable_guards
is backend-specific.
Applies to torch/cupy/jax arrays on GPU, sparse arrays, etc.
Do you have a SciPy branch with this function being used @crusaderky? I'd be interested in playing with it. |
I could make it work by rechunking all the inputs to a single chunk. In other words the whole calculation would need to fit in memory at once on a single worker. |
Not yet |
@rgommers I rewrote it to do exactly this and now it's a lot cleaner. I'll keep my eyes open if I can see patterns in scipy we can leverage to improve Dask support (e.g. if there is there are frequent elementwise functions that could be trivially served by map_blocks) |
@allcontributors, please add @crusaderky for bug let me just try this once from this PR... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't tried to test this, but did go through it in a bit more detail now - overall looks good, a few comments. Looking forward to trying it out!
I've reworked the design a bit.
FYI, when I move |
src/array_api_extra/_lib/_apply.py
Outdated
if any(s is None for shape in shapes for s in shape): | ||
# Unknown output shape. Won't work with jax.jit, but it | ||
# can work with eager jax. | ||
# Raises jax.errors.TracerArrayConversionError if we're inside jax.jit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Offline conversation:
how do you see scipy functions with unknown output size work with jax.jit? e.g. scipy.cluster.leaders? Should we do like in jnp.unique_all and add a size=None optional parameter, which becomes mandatory when the jit is on?
I'm not sure that we should add support for those functions. My assumption is that it's only a few functions, and that those are inherently pretty clunky with JAX. I don't really want to think about extending the signatures (yet at least), because the current jax.jit support is experimental and behind a flag, and adding keywords is public and not reversible.
Perhaps making a note on the tracking issue about this being an option, but not done because of the reason above (could be done in the future, if JAX usage takes off)?
If in the future we want to support these functions, we'll have to modify this point to catch jax.errors.TracerArrayConversionError
and reraise a backend-agnostic exception, so that scipy.cluster.leaders
and similar can then catch it and reraise an informative error message about size=
being mandatory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scipy.cluster.leaders
is a function which in a nearby future will work in eager JAX but can't work in jax.jit
short of a public API change, because its output arrays' shape is xp.unique_values(input).shape
.
@pearu @vfdev-5 a while ago you asked offline how can we run inside jax.jit a function such as this.
It will be possible for an end user to call do so, at the condition that they consume its output and revert to a known shape, for example:
import array_api_extra as xpx
from scipy.cluster import leaders
def _eager(x):
a, b = leaders(x) # shapes = (None, ), (None, )
xp = array_namespace(a, b)
# silly example; probably won't make sense functionally
return xp.max(a), xp.max(b)
# This is just an example that makes little sense;
# in practice @jax.jit will be much higher in the call stack
@jax.jit
def f(x):
return xpx.lazy_apply(
_eager, x, shape=((), ()), dtype=(x.dtype, x.dtype))
)
I think I want to have some evidence that the whole thing works in practice before I finalize this PR. |
39b367c
to
2e28881
Compare
a8827ce
to
6f26ebb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @crusaderky, and apologies for the long delay. I think this is good to go - a few questions left from me only. I think we should merge this, then update scipy/scipy#22342 and see how we like that PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I skimmed the diff quickly—happy to merge given Ralf's approval!
Co-authored-by: Ralf Gommers <[email protected]>
Co-authored-by: Ralf Gommers <[email protected]>
I've manged to avoid having arrays in the kwargs in scipy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both the JAX and the Dask code paths look clean and understandable now - very nice!. No more comments, let's ship it I'd say!
Ah there is a merge conflict |
release inbound |
Thanks Guido & Lucas! |
* ENH: New function `lazy_apply` * Update docs/api-lazy.md Co-authored-by: Ralf Gommers <[email protected]> * Update src/array_api_extra/_lib/_lazy.py Co-authored-by: Ralf Gommers <[email protected]> * Code review * Remove kwargs introspection; support None | complex args * Don't always import numpy * update lockfile * appease mypy --------- Co-authored-by: Ralf Gommers <[email protected]> Co-authored-by: Lucas Colley <[email protected]>
Wrapper around jax.apply_pure_callback with added support for Dask.
CC @rgommers