Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use shuffle in groupby binary ops. #9896

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
use map instead
dcherian committed Dec 16, 2024
commit 9294b9408ffe40aaec9bca24ccd9450876072613
66 changes: 25 additions & 41 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,6 @@
from xarray.core.alignment import align, broadcast
from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.computation import apply_ufunc
from xarray.core.concat import concat
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
from xarray.core.duck_array_ops import where
@@ -900,32 +899,48 @@ def _binary_op(self, other, f, reflexive=False):
group = group.where(~mask, drop=True)
codes = codes.where(~mask, drop=True).astype(int)

def _vindex_wrapper(array, idxr, like):
def _vindex_like(da: DataArray, dim, indexer: DataArray):
# we want to use the fact that we know the chunksizes for the output (matches obj)
# so we can't just use Variable's indexing

array = da._variable._data
like_da = obj_as_dataset.get(da.name)
if not is_duck_dask_array(array):
if like_da is None or not is_duck_dask_array(like_da._variable._data):
return da.isel({dim: indexer})
else:
da = da.chunk("auto")
like = like_da._variable._data
array = da._variable._data

import dask
from dask.array.core import slices_from_chunks
from dask.graph_manipulation import clone

array = clone(array) # FIXME: add to dask

assert array.ndim == 1
to_shape = like.shape[-1:]
to_chunks = like.chunks[-1:]
dims = indexer.dims
axes = tuple(like_da.get_axis_num(dim) for dim in dims)
to_shape = tuple(size for ax, size in enumerate(like.shape) if ax in axes)
to_chunks = tuple(
chunksize for ax, chunksize in enumerate(like.chunks) if ax in axes
)
idxr = indexer._variable._data

# shuffle indices that can be reshaped blockwise to desired shape
flat_indices = [
idxr[slicer].ravel().tolist()
for slicer in slices_from_chunks(to_chunks)
]
# FIXME: figure out axis
shuffled = dask.array.shuffle(
array, flat_indices, axis=array.ndim - 1, chunks="auto"
array, flat_indices, axis=da.get_axis_num(dim), chunks="auto"
)
if shuffled.shape != to_shape:
return dask.array.reshape_blockwise(
shuffled = dask.array.reshape_blockwise(
shuffled, shape=to_shape, chunks=to_chunks
)
else:
return shuffled
return DataArray(dims=like_da.dims[-1:], data=shuffled, attrs=da.attrs)

# codes are defined for coord, so we align `other` with `coord`
# before indexing
@@ -935,38 +950,7 @@ def _vindex_wrapper(array, idxr, like):
other._to_temp_dataset() if isinstance(other, DataArray) else other
)
obj_as_dataset = obj._to_temp_dataset() if isinstance(obj, DataArray) else obj
dask_vars = []
non_dask_vars = []
for varname, var in other_as_dataset._variables.items():
if is_duck_dask_array(var._data):
dask_vars.append(varname)
else:
non_dask_vars.append(varname)
expanded = other_as_dataset[non_dask_vars].isel({name: codes})
if dask_vars:
other_dims = other_as_dataset[dask_vars].dims
obj_dims = obj_as_dataset[dask_vars].dims
expanded = expanded.merge(
apply_ufunc(
_vindex_wrapper,
other_as_dataset[dask_vars],
codes,
obj_as_dataset[dask_vars],
input_core_dims=[
tuple(other_dims), # FIXME: ..., name
tuple(codes.dims),
tuple(obj_dims),
],
# When other is the result of a reduction over Ellipsis
# obj.dims is a superset of other.dims, and contains
# dims not present in the output
exclude_dims=set(obj_dims) - set(other_dims),
output_core_dims=[tuple(codes.dims)],
dask="allowed",
join=OPTIONS["arithmetic_join"],
)
)

expanded = other_as_dataset.map(_vindex_like, dim=name, indexer=codes)
if isinstance(other, DataArray):
expanded = other._from_temp_dataset(expanded)