diff --git a/lib/iris/analysis/_interpolation.py b/lib/iris/analysis/_interpolation.py index 6904c5ae4f..24df2b1b56 100644 --- a/lib/iris/analysis/_interpolation.py +++ b/lib/iris/analysis/_interpolation.py @@ -8,10 +8,12 @@ from itertools import product import operator +import dask.array import numpy as np from numpy.lib.stride_tricks import as_strided import numpy.ma as ma +from iris._lazy_data import is_lazy_data, is_masked_data from iris.coords import AuxCoord, DimCoord import iris.util @@ -200,13 +202,8 @@ def __init__(self, src_cube, coords, method, extrapolation_mode): set to NaN. """ - # Trigger any deferred loading of the source cube's data and snapshot - # its state to ensure that the interpolator is impervious to external - # changes to the original source cube. The data is loaded to prevent - # the snapshot having lazy data, avoiding the potential for the - # same data to be loaded again and again. - if src_cube.has_lazy_data(): - src_cube.data + # Snapshot the cube state to ensure that the interpolator is impervious + # to external changes to the original source cube. self._src_cube = src_cube.copy() # Coordinates defining the dimensions to be interpolated. self._src_coords = [self._src_cube.coord(coord) for coord in coords] @@ -315,6 +312,7 @@ def _interpolate(self, data, interp_points): data = data.astype(dtype) mode = EXTRAPOLATION_MODES[self._mode] + _data = _get_data(data) if self._interpolator is None: # Cache the interpolator instance. # NB. The constructor of the _RegularGridInterpolator class does @@ -322,13 +320,13 @@ def _interpolate(self, data, interp_points): # so we set it afterwards instead. Sneaky. ;-) self._interpolator = _RegularGridInterpolator( self._src_points, - data, + _data, method=self.method, bounds_error=mode.bounds_error, fill_value=None, ) else: - self._interpolator.values = data + self._interpolator.values = _data # We may be re-using a cached interpolator, so ensure the fill # value is set appropriately for extrapolating data values. @@ -341,17 +339,16 @@ def _interpolate(self, data, interp_points): # interpolation points. result = result.astype(data.dtype) - if np.ma.isMaskedArray(data) or mode.force_mask: - # NB. np.ma.getmaskarray returns an array of `False` if + if _is_masked_array(data) or mode.force_mask: + # NB. getmaskarray returns an array of `False` if # `data` is not a masked array. - src_mask = np.ma.getmaskarray(data) + src_mask = _get_mask_array(data) # Switch the extrapolation to work with mask values. self._interpolator.fill_value = mode.mask_fill_value self._interpolator.values = src_mask mask_fraction = self._interpolator(interp_points) new_mask = mask_fraction > 0 - if ma.isMaskedArray(data) or np.any(new_mask): - result = np.ma.MaskedArray(result, new_mask) + result = iris.util._mask_array(result, new_mask) return result @@ -592,7 +589,7 @@ def __call__(self, sample_points, collapse_scalar=True): sample_points = _canonical_sample_points(self._src_coords, sample_points) - data = self._src_cube.data + data = self._src_cube.core_data() # Interpolate the cube payload. interpolated_data = self._points(sample_points, data) @@ -668,3 +665,30 @@ def gen_new_cube(): new_cube = new_cube[tuple(dim_slices)] return new_cube + + +def _is_masked_array(array): + """Equivalent to func:`numpy.ma.isMaskedArray`, but works for both lazy AND realised arrays.""" + if is_lazy_data(array): + is_masked_array = is_masked_data(array) + else: + is_masked_array = np.ma.isMaskedArray(array) + return is_masked_array + + +def _get_data(array): + """Equivalent to :func:`np.ma.getdata`, but works for both lazy AND realised arrays.""" + if is_lazy_data(array): + result = dask.array.ma.getdata(array) + else: + result = np.ma.getdata(array) + return result + + +def _get_mask_array(array): + """Equivalent to func:`numpy.ma.getmaskarray`, but works for both lazy AND realised arrays.""" + if is_lazy_data(array): + result = dask.array.ma.getmaskarray(array) + else: + result = np.ma.getmaskarray(array) + return result diff --git a/lib/iris/analysis/_scipy_interpolate.py b/lib/iris/analysis/_scipy_interpolate.py index 251fb4bf70..afa589d933 100644 --- a/lib/iris/analysis/_scipy_interpolate.py +++ b/lib/iris/analysis/_scipy_interpolate.py @@ -1,7 +1,10 @@ import itertools +import dask.array import numpy as np -from scipy.sparse import csr_matrix +from sparse import GCXS + +from iris._lazy_data import is_lazy_data # ============================================================================ # | Copyright SciPy | @@ -218,7 +221,7 @@ def compute_interp_weights(self, xi, method=None): n_result_values = len(indices[0]) n_non_zero = n_result_values * n_src_values_per_result_value weights = np.ones(n_non_zero, dtype=norm_distances[0].dtype) - col_indices = np.empty(n_non_zero) + col_indices = np.empty(n_non_zero, dtype=int) row_ptrs = np.arange( 0, n_non_zero + n_src_values_per_result_value, @@ -238,11 +241,13 @@ def compute_interp_weights(self, xi, method=None): weights[i::n_src_values_per_result_value] *= cw n_src_values = np.prod(list(map(len, self.grid))) - sparse_matrix = csr_matrix( + sparse_matrix = GCXS( (weights, col_indices, row_ptrs), + compressed_axes=[0], shape=(n_result_values, n_src_values), ) - + if is_lazy_data(self.values): + sparse_matrix = dask.array.from_array(sparse_matrix) prepared = (xi_shape, method, sparse_matrix, None, out_of_bounds) return prepared @@ -289,10 +294,10 @@ def interp_using_pre_computed_weights(self, computed_weights): def _evaluate_linear_sparse(self, sparse_matrix): ndim = len(self.grid) if ndim == self.values.ndim: - result = sparse_matrix * self.values.reshape(-1) + result = sparse_matrix @ self.values.reshape(-1) else: shape = (sparse_matrix.shape[1], -1) - result = sparse_matrix * self.values.reshape(shape) + result = sparse_matrix @ self.values.reshape(shape) return result @@ -300,7 +305,12 @@ def _evaluate_nearest(self, indices, norm_distances, out_of_bounds): idx_res = [] for i, yi in zip(indices, norm_distances): idx_res.append(np.where(yi <= 0.5, i, i + 1)) - return self.values[tuple(idx_res)] + if is_lazy_data(self.values): + # dask arrays do not (yet) support fancy indexing + indexer = self.values.vindex + else: + indexer = self.values + return indexer[tuple(idx_res)] def _find_indices(self, xi): # find relevant edges between which xi are situated diff --git a/requirements/py310.yml b/requirements/py310.yml index 309f8aa9a2..f8b5d56041 100644 --- a/requirements/py310.yml +++ b/requirements/py310.yml @@ -23,6 +23,7 @@ dependencies: - pyproj - scipy - shapely !=1.8.3 + - sparse # Optional dependencies. - esmpy >=7.0 diff --git a/requirements/py311.yml b/requirements/py311.yml index 58d71ddd52..9c902eac9a 100644 --- a/requirements/py311.yml +++ b/requirements/py311.yml @@ -23,6 +23,7 @@ dependencies: - pyproj - scipy - shapely !=1.8.3 + - sparse # Optional dependencies. - esmpy >=7.0 diff --git a/requirements/py312.yml b/requirements/py312.yml index e1e62e52d9..6e25555251 100644 --- a/requirements/py312.yml +++ b/requirements/py312.yml @@ -23,6 +23,7 @@ dependencies: - pyproj - scipy - shapely !=1.8.3 + - sparse # Optional dependencies. - esmpy >=7.0