Skip to content

Commit 24a1ab7

Browse files
authoredAug 9, 2023
Refine typing with NumPy plugin of mypy (#385)
* Make numpy plugin of mypy happy * Finish linting! * Add comment for backward comp * Explicitly define union * Exception for Number before Python 3.9 * Clarify backward behaviour * Linting * Try with union everywhere * Fix dtype union * Fix bool type * Fix bool again * Temporarily ignore boolean inconsistencies with Raster * Remove RGB example in colorbar test * Add comment * Linting
1 parent fac1de0 commit 24a1ab7

13 files changed

+318
-181
lines changed
 

‎.pre-commit-config.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ repos:
4848
hooks:
4949
- id: mypy
5050
args: [
51+
--config-file=mypy.ini,
5152
--strict,
53+
--implicit-optional,
5254
--ignore-missing-imports, # Don't warn about stubs since pre-commit runs in a limited env
5355
--allow-untyped-calls, # Dynamic function/method calls are okay. Untyped function definitions are not okay.
5456
--show-error-codes,
@@ -58,7 +60,7 @@ repos:
5860
--disable-error-code=var-annotated,
5961
--disable-error-code=no-any-return
6062
]
61-
additional_dependencies: [tokenize-rt==3.2.0]
63+
additional_dependencies: [tokenize-rt==3.2.0, numpy==1.22]
6264
files: ^(geoutils|tests)
6365

6466
# Sort imports using isort

‎geoutils/_typing.py

+37-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,46 @@
11
"""Typing aliases for internal use."""
22
from __future__ import annotations
33

4+
import sys
45
from typing import Any, List, Tuple, Union
56

67
import numpy as np
78

8-
# Make an array-like type (since the array-like numpy type only exists in numpy>=1.20)
9-
ArrayLike = Union[np.ndarray, np.ma.masked_array, List[Any], Tuple[Any]]
9+
# Only for Python >= 3.9
10+
if sys.version_info.minor >= 9:
1011

11-
DTypeLike = Union[str, type, np.dtype]
12+
from numpy.typing import ( # this syntax works starting on Python 3.9
13+
ArrayLike,
14+
DTypeLike,
15+
NDArray,
16+
)
1217

13-
# Mypy has issues with the builtin Number type (https://github.com/python/mypy/issues/3186)
14-
AnyNumber = Union[int, float, np.number]
18+
# Mypy has issues with the builtin Number type (https://github.com/python/mypy/issues/3186)
19+
Number = Union[int, float, np.integer[Any], np.floating[Any]]
20+
21+
# Simply define here if they exist
22+
DTypeLike = DTypeLike
23+
ArrayLike = ArrayLike
24+
25+
# Use NDArray wrapper to easily define numerical (float or int) N-D array types, and boolean N-D array types
26+
NDArrayNum = NDArray[Union[np.floating[Any], np.integer[Any]]]
27+
NDArrayBool = NDArray[np.bool_]
28+
# Define numerical (float or int) masked N-D array type
29+
MArrayNum = np.ma.masked_array[Any, np.dtype[Union[np.floating[Any], np.integer[Any]]]]
30+
MArrayBool = np.ma.masked_array[Any, np.dtype[np.bool_]]
31+
32+
# For backward compatibility before Python 3.9
33+
else:
34+
35+
# Mypy has issues with the builtin Number type (https://github.com/python/mypy/issues/3186)
36+
Number = Union[int, float, np.integer, np.floating] # type: ignore
37+
38+
# Make an array-like type (since the array-like numpy type only exists in numpy>=1.20)
39+
DTypeLike = Union[str, type, np.dtype] # type: ignore
40+
ArrayLike = Union[np.ndarray, np.ma.masked_array, List[Any], Tuple[Any]] # type: ignore
41+
42+
# Define generic types for NumPy array and masked-array (behaves as "Any" before 3.9 and plugin)
43+
NDArrayNum = np.ndarray # type: ignore
44+
NDArrayBool = np.ndarray # type: ignore
45+
MArrayNum = np.ma.masked_array # type: ignore
46+
MArrayBool = np.ma.masked_array # type: ignore

‎geoutils/projtools.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from math import ceil, floor
99
from typing import Literal
1010

11-
import geopandas
1211
import geopandas as gpd
1312
import numpy as np
1413
import pyproj
@@ -18,8 +17,10 @@
1817
from shapely.geometry.base import BaseGeometry
1918
from shapely.geometry.polygon import Polygon
2019

20+
from geoutils._typing import NDArrayNum, Number
2121

22-
def latlon_to_utm(lat: float, lon: float) -> str:
22+
23+
def latlon_to_utm(lat: Number, lon: Number) -> str:
2324
"""
2425
Get UTM zone for a given latitude and longitude coordinates.
2526
@@ -108,37 +109,37 @@ def _get_utm_ups_crs(df: gpd.GeoDataFrame, method: Literal["centroid"] | Literal
108109

109110

110111
def bounds2poly(
111-
boundsGeom: list[float] | rio.io.DatasetReader,
112+
bounds_geom: list[float] | rio.io.DatasetReader,
112113
in_crs: CRS | None = None,
113114
out_crs: CRS | None = None,
114115
) -> Polygon:
115116
"""
116117
Converts self's bounds into a shapely Polygon. Optionally, returns it into a different CRS.
117118
118-
:param boundsGeom: A geometry with bounds. Can be either a list of coordinates (xmin, ymin, xmax, ymax),\
119+
:param bounds_geom: A geometry with bounds. Can be either a list of coordinates (xmin, ymin, xmax, ymax),\
119120
a rasterio/Raster object, a geoPandas/Vector object
120121
:param in_crs: Input CRS
121122
:param out_crs: Output CRS
122123
123124
:returns: Output polygon
124125
"""
125126
# If boundsGeom is a GeoPandas or Vector object (warning, has both total_bounds and bounds attributes)
126-
if hasattr(boundsGeom, "total_bounds"):
127-
xmin, ymin, xmax, ymax = boundsGeom.total_bounds # type: ignore
128-
in_crs = boundsGeom.crs # type: ignore
127+
if hasattr(bounds_geom, "total_bounds"):
128+
xmin, ymin, xmax, ymax = bounds_geom.total_bounds # type: ignore
129+
in_crs = bounds_geom.crs # type: ignore
129130
# If boundsGeom is a rasterio or Raster object
130-
elif hasattr(boundsGeom, "bounds"):
131-
xmin, ymin, xmax, ymax = boundsGeom.bounds # type: ignore
132-
in_crs = boundsGeom.crs # type: ignore
131+
elif hasattr(bounds_geom, "bounds"):
132+
xmin, ymin, xmax, ymax = bounds_geom.bounds # type: ignore
133+
in_crs = bounds_geom.crs # type: ignore
133134
# if a list of coordinates
134-
elif isinstance(boundsGeom, (list, tuple)):
135-
xmin, ymin, xmax, ymax = boundsGeom
135+
elif isinstance(bounds_geom, (list, tuple)):
136+
xmin, ymin, xmax, ymax = bounds_geom
136137
else:
137138
raise ValueError(
138139
"boundsGeom must a list/tuple of coordinates or an object with attributes bounds or total_bounds."
139140
)
140141

141-
corners = ((xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax))
142+
corners = np.array([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
142143

143144
if (in_crs is not None) & (out_crs is not None):
144145
corners = np.transpose(reproject_points(np.transpose(corners), in_crs, out_crs))
@@ -176,7 +177,7 @@ def merge_bounds(
176177
"or total_bounds"
177178
)
178179

179-
output_poly = bounds2poly(boundsGeom=bounds_list[0])
180+
output_poly = bounds2poly(bounds_geom=bounds_list[0])
180181

181182
# Compute the merging
182183
for boundsGeom in bounds_list[1:]:
@@ -238,7 +239,9 @@ def align_bounds(
238239
return (left, bottom, right, top)
239240

240241

241-
def reproject_points(pts: list[list[float]] | np.ndarray, in_crs: CRS, out_crs: CRS) -> tuple[list[float], list[float]]:
242+
def reproject_points(
243+
pts: list[list[float]] | tuple[list[float], list[float]] | NDArrayNum, in_crs: CRS, out_crs: CRS
244+
) -> tuple[list[float], list[float]]:
242245
"""
243246
Reproject a set of point from input_crs to output_crs.
244247
@@ -262,7 +265,7 @@ def reproject_points(pts: list[list[float]] | np.ndarray, in_crs: CRS, out_crs:
262265

263266

264267
def reproject_to_latlon(
265-
pts: list[list[float]] | np.ndarray, in_crs: CRS, round_: int = 8
268+
pts: list[list[float]] | NDArrayNum, in_crs: CRS, round_: int = 8
266269
) -> tuple[list[float], list[float]]:
267270
"""
268271
Reproject a set of point from in_crs to lat/lon.
@@ -279,7 +282,7 @@ def reproject_to_latlon(
279282

280283

281284
def reproject_from_latlon(
282-
pts: list[list[float]] | tuple[list[float], list[float]] | np.ndarray, out_crs: CRS, round_: int = 2
285+
pts: list[list[float]] | tuple[list[float], list[float]] | NDArrayNum, out_crs: CRS, round_: int = 2
283286
) -> tuple[list[float], list[float]]:
284287
"""
285288
Reproject a set of point from lat/lon to out_crs.

‎geoutils/raster/array.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
import numpy as np
88

99
import geoutils as gu
10+
from geoutils._typing import MArrayNum, NDArrayNum
1011

1112

12-
def get_mask(array: np.ndarray | np.ma.masked_array) -> np.ndarray:
13+
def get_mask(array: NDArrayNum | MArrayNum) -> NDArrayNum:
1314
"""
1415
Return the mask of invalid values, whether array is a ndarray with NaNs or a np.ma.masked_array.
1516
@@ -22,8 +23,8 @@ def get_mask(array: np.ndarray | np.ma.masked_array) -> np.ndarray:
2223

2324

2425
def get_array_and_mask(
25-
array: np.ndarray | np.ma.masked_array, check_shape: bool = True, copy: bool = True
26-
) -> tuple[np.ndarray, np.ndarray]:
26+
array: NDArrayNum | MArrayNum, check_shape: bool = True, copy: bool = True
27+
) -> tuple[NDArrayNum, NDArrayNum]:
2728
"""
2829
Return array with masked values set to NaN and the associated mask.
2930
Works whether array is a ndarray with NaNs or a np.ma.masked_array.
@@ -65,7 +66,7 @@ def get_array_and_mask(
6566
return array_data, invalid_mask
6667

6768

68-
def get_valid_extent(array: np.ndarray | np.ma.masked_array) -> tuple[int, ...]:
69+
def get_valid_extent(array: NDArrayNum | MArrayNum) -> tuple[int, ...]:
6970
"""
7071
Return (rowmin, rowmax, colmin, colmax), the first/last row/column of array with valid pixels
7172
"""
@@ -78,7 +79,7 @@ def get_valid_extent(array: np.ndarray | np.ma.masked_array) -> tuple[int, ...]:
7879
return rows_nonzero[0], rows_nonzero[-1], cols_nonzero[0], cols_nonzero[-1]
7980

8081

81-
def get_xy_rotated(raster: gu.Raster, along_track_angle: float) -> tuple[np.ndarray, np.ndarray]:
82+
def get_xy_rotated(raster: gu.Raster, along_track_angle: float) -> tuple[NDArrayNum, NDArrayNum]:
8283
"""
8384
Rotate x, y axes of image to get along- and cross-track distances.
8485
:param raster: Raster to get x,y positions from.

‎geoutils/raster/multiraster.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tqdm import tqdm
1111

1212
import geoutils as gu
13+
from geoutils._typing import NDArrayNum
1314
from geoutils.misc import resampling_method_from_str
1415
from geoutils.raster import Raster, RasterType, get_array_and_mask
1516
from geoutils.raster.raster import _default_nodata
@@ -167,7 +168,7 @@ def stack_rasters(
167168
)
168169

169170
# Make a data list and add all of the reprojected rasters into it.
170-
data: list[np.ndarray] = []
171+
data: list[NDArrayNum] = []
171172

172173
for raster in tqdm(rasters, disable=not progress):
173174
# Check that data is loaded, otherwise temporarily load it
@@ -209,7 +210,7 @@ def stack_rasters(
209210
nodata = reference_raster.nodata
210211
else:
211212
nodata = _default_nodata(data.dtype)
212-
data[np.isnan(data)] = nodata
213+
data[np.isnan(data)] = nodata # type: ignore
213214

214215
# Save as gu.Raster - needed as some child classes may not accept multiple bands
215216
r = gu.Raster.from_array(

‎geoutils/raster/raster.py

+177-121
Large diffs are not rendered by default.

‎geoutils/raster/sampling.py

+41-6
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,52 @@
22

33
from __future__ import annotations
44

5+
from typing import Literal, overload
6+
57
import numpy as np
68

9+
from geoutils._typing import MArrayNum, NDArrayNum
710
from geoutils.raster.array import get_mask
811

912

13+
@overload
14+
def subsample_array(
15+
array: NDArrayNum | MArrayNum,
16+
subsample: float | int,
17+
return_indices: Literal[False] = False,
18+
*,
19+
random_state: np.random.RandomState | int | None = None,
20+
) -> NDArrayNum:
21+
...
22+
23+
24+
@overload
25+
def subsample_array(
26+
array: NDArrayNum | MArrayNum,
27+
subsample: float | int,
28+
return_indices: Literal[True],
29+
*,
30+
random_state: np.random.RandomState | int | None = None,
31+
) -> tuple[NDArrayNum, ...]:
32+
...
33+
34+
35+
@overload
36+
def subsample_array(
37+
array: NDArrayNum | MArrayNum,
38+
subsample: float | int,
39+
return_indices: bool = False,
40+
random_state: np.random.RandomState | int | None = None,
41+
) -> NDArrayNum | tuple[NDArrayNum, ...]:
42+
...
43+
44+
1045
def subsample_array(
11-
array: np.ndarray | np.ma.masked_array,
46+
array: NDArrayNum | MArrayNum,
1247
subsample: float | int,
1348
return_indices: bool = False,
14-
random_state: np.random.RandomState | np.random.Generator | int | None = None,
15-
) -> np.ndarray:
49+
random_state: np.random.RandomState | int | None = None,
50+
) -> NDArrayNum | tuple[NDArrayNum, ...]:
1651
"""
1752
Randomly subsample a 1D or 2D array by a subsampling factor, taking only non NaN/masked values.
1853
@@ -26,8 +61,8 @@ def subsample_array(
2661
"""
2762
# Define state for random subsampling (to fix results during testing)
2863
if random_state is None:
29-
rnd = np.random.default_rng()
30-
elif isinstance(random_state, (np.random.RandomState, np.random.Generator)):
64+
rnd: np.random.RandomState | np.random.Generator = np.random.default_rng()
65+
elif isinstance(random_state, np.random.RandomState):
3166
rnd = random_state
3267
else:
3368
rnd = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(random_state)))
@@ -98,7 +133,7 @@ def _get_closest_rectangle(size: int) -> tuple[int, int]:
98133
raise NotImplementedError(f"Function criteria not met for rectangle of size: {size}")
99134

100135

101-
def subdivide_array(shape: tuple[int, ...], count: int) -> np.ndarray:
136+
def subdivide_array(shape: tuple[int, ...], count: int) -> NDArrayNum:
102137
"""
103138
Create indices for subdivison of an array in a number of blocks.
104139

‎geoutils/raster/satimg.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414
import rasterio as rio
1515

16+
from geoutils._typing import NDArrayNum
1617
from geoutils.raster import Raster, RasterType
1718

1819
lsat_sensor = {"C": "OLI/TIRS", "E": "ETM+", "T": "TM", "M": "MSS", "O": "OLI", "TI": "TIRS"}
@@ -427,7 +428,7 @@ def __parse_metadata_from_file(self, fn_meta: str | None) -> None:
427428

428429
return None
429430

430-
def copy(self, new_array: np.ndarray | None = None) -> SatelliteImage:
431+
def copy(self, new_array: NDArrayNum | None = None) -> SatelliteImage:
431432
new_satimg = super().copy(new_array=new_array) # type: ignore
432433
# all objects here are immutable so no need for a copy method (string and datetime)
433434
# satimg_attrs = ['satellite', 'sensor', 'product', 'version', 'tile_name', 'datetime'] #taken outside of class

‎geoutils/vector.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from shapely.geometry.polygon import Polygon
3939

4040
import geoutils as gu
41+
from geoutils._typing import NDArrayBool, NDArrayNum
4142
from geoutils.misc import copy_doc
4243
from geoutils.projtools import (
4344
_get_bounds_projected,
@@ -1092,7 +1093,7 @@ def create_mask(
10921093
xres: float | None = None,
10931094
yres: float | None = None,
10941095
bounds: tuple[float, float, float, float] | None = None,
1095-
buffer: int | float | np.number = 0,
1096+
buffer: int | float | np.integer[Any] | np.floating[Any] = 0,
10961097
*,
10971098
as_array: Literal[False] = False,
10981099
) -> gu.Mask:
@@ -1106,10 +1107,10 @@ def create_mask(
11061107
xres: float | None = None,
11071108
yres: float | None = None,
11081109
bounds: tuple[float, float, float, float] | None = None,
1109-
buffer: int | float | np.number = 0,
1110+
buffer: int | float | np.integer[Any] | np.floating[Any] = 0,
11101111
*,
11111112
as_array: Literal[True],
1112-
) -> np.ndarray:
1113+
) -> NDArrayNum:
11131114
...
11141115

11151116
def create_mask(
@@ -1119,9 +1120,9 @@ def create_mask(
11191120
xres: float | None = None,
11201121
yres: float | None = None,
11211122
bounds: tuple[float, float, float, float] | None = None,
1122-
buffer: int | float | np.number = 0,
1123+
buffer: int | float | np.integer[Any] | np.floating[Any] = 0,
11231124
as_array: bool = False,
1124-
) -> gu.Mask | np.ndarray:
1125+
) -> gu.Mask | NDArrayBool:
11251126
"""
11261127
Create a mask from the vector features.
11271128

‎mypy.ini

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[mypy]
2+
plugins = numpy.typing.mypy_plugin

‎tests/test_projtools.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_latlon_reproject(self, example: str) -> None:
130130
randx = np.random.randint(low=img.bounds.left, high=img.bounds.right, size=(nsample,))
131131
randy = np.random.randint(low=img.bounds.bottom, high=img.bounds.top, size=(nsample,))
132132

133-
lat, lon = pt.reproject_to_latlon([randx, randy], img.crs)
133+
lat, lon = pt.reproject_to_latlon([list(randx), list(randy)], img.crs)
134134
x, y = pt.reproject_from_latlon([lat, lon], img.crs)
135135

136136
assert np.all(x == randx)

‎tests/test_raster.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import geoutils as gu
2121
import geoutils.projtools as pt
2222
from geoutils import examples
23+
from geoutils._typing import MArrayNum, NDArrayNum
2324
from geoutils.misc import resampling_method_from_str
2425
from geoutils.projtools import reproject_to_latlon
2526
from geoutils.raster.raster import _default_nodata, _default_rio_attrs
@@ -29,7 +30,7 @@
2930

3031
def run_gdal_proximity(
3132
input_raster: gu.Raster, target_values: list[float] | None, distunits: str = "GEO"
32-
) -> np.ndarray:
33+
) -> NDArrayNum:
3334
"""Run GDAL's ComputeProximity and return the read numpy array."""
3435
# Rasterio strongly recommends against importing gdal along rio, so this is done here instead.
3536
from osgeo import gdal, gdalconst
@@ -412,8 +413,8 @@ def test_data_setter(self, dtype: str, nodata_init: str | None) -> None:
412413

413414
# Create random values between the lower and upper limit of the data type, max absolute 99999 for floats
414415
if "int" in dtype:
415-
val_min = np.iinfo(int_type=dtype).min
416-
val_max = np.iinfo(int_type=dtype).max
416+
val_min: int = np.iinfo(int_type=dtype).min # type: ignore
417+
val_max: int = np.iinfo(int_type=dtype).max # type: ignore
417418
randint_dtype = dtype
418419
else:
419420
val_min = -99999
@@ -422,7 +423,9 @@ def test_data_setter(self, dtype: str, nodata_init: str | None) -> None:
422423

423424
# Fix the random seed
424425
np.random.seed(42)
425-
arr = np.random.randint(low=val_min, high=val_max, size=(width, height), dtype=randint_dtype).astype(dtype)
426+
arr = np.random.randint(
427+
low=val_min, high=val_max, size=(width, height), dtype=randint_dtype # type: ignore
428+
).astype(dtype)
426429
mask = np.random.randint(0, 2, size=(width, height), dtype=bool)
427430

428431
# Check that we are actually masking stuff
@@ -1675,7 +1678,7 @@ def test_xy2ij_and_interp(self) -> None:
16751678
assert img[int(i), int(j)] == val
16761679

16771680
# Finally, check that interp convert to latlon
1678-
lat, lon = gu.projtools.reproject_to_latlon((x, y), in_crs=r.crs)
1681+
lat, lon = gu.projtools.reproject_to_latlon([[x], [y]], in_crs=r.crs)
16791682
val_latlon = r.interp_points([(lat, lon)], order=1, input_latlon=True)[0]
16801683
assert val == pytest.approx(val_latlon, abs=0.0001)
16811684

@@ -1718,7 +1721,7 @@ def test_value_at_coords(self) -> None:
17181721
# -- Tests 2: check arguments work as intended --
17191722

17201723
# 1/ Lat-lon argument check by getting the coordinates of our last test point
1721-
lat, lon = reproject_to_latlon(pts=[xtest0, ytest0], in_crs=r.crs)
1724+
lat, lon = reproject_to_latlon(pts=[[xtest0], [ytest0]], in_crs=r.crs)
17221725
z_val_2 = r.value_at_coords(lon, lat, latlon=True)
17231726
assert z_val == z_val_2
17241727

@@ -2039,13 +2042,13 @@ def test_default_nodata(self) -> None:
20392042

20402043
# Check it works with most frequent np.dtypes too
20412044
assert _default_nodata(np.dtype("uint8")) == np.iinfo("uint8").max
2042-
for dtype in [np.dtype("int32"), np.dtype("float32"), np.dtype("float64")]:
2043-
assert _default_nodata(dtype) == -99999
2045+
for dtype_obj in [np.dtype("int32"), np.dtype("float32"), np.dtype("float64")]:
2046+
assert _default_nodata(dtype_obj) == -99999 # type: ignore
20442047

20452048
# Check it works with most frequent types too
20462049
assert _default_nodata(np.uint8) == np.iinfo("uint8").max
2047-
for dtype in [np.int32, np.float32, np.float64]:
2048-
assert _default_nodata(dtype) == -99999
2050+
for dtype_obj in [np.int32, np.float32, np.float64]:
2051+
assert _default_nodata(dtype_obj) == -99999
20492052

20502053
# Check that an error is raised for other types
20512054
expected_message = "No default nodata value set for dtype."
@@ -2105,9 +2108,8 @@ def test_astype(self) -> None:
21052108
assert np.dtype(rout.dtypes[0]) == dtype
21062109
assert rout.data.dtype == dtype
21072110

2108-
@pytest.mark.parametrize(
2109-
"example", [landsat_b4_path, landsat_b4_crop_path, landsat_rgb_path, aster_dem_path]
2110-
) # type: ignore
2111+
# The multi-band example will not have a colorbar, so not used in tests
2112+
@pytest.mark.parametrize("example", [landsat_b4_path, landsat_b4_crop_path, aster_dem_path]) # type: ignore
21112113
@pytest.mark.parametrize("figsize", np.arange(2, 20, 2)) # type: ignore
21122114
def test_show_cbar(self, example, figsize) -> None:
21132115
"""
@@ -3216,7 +3218,7 @@ def test_reflectivity(self, ops: list[str]) -> None:
32163218
@classmethod
32173219
def from_array(
32183220
cls: type[TestArithmetic],
3219-
data: np.ndarray | np.ma.masked_array,
3221+
data: NDArrayNum | MArrayNum,
32203222
rst_ref: gu.RasterType,
32213223
nodata: int | float | list[int] | list[float] | None = None,
32223224
) -> gu.Raster:
@@ -3595,7 +3597,7 @@ def test_array_ufunc_1nin_1nout(self, ufunc_str: str, nodata_init: None | str, d
35953597
warnings.filterwarnings("ignore", category=RuntimeWarning)
35963598

35973599
# Check if our input dtype is possible on this ufunc, if yes check that outputs are identical
3598-
if com_dtype in (str(np.dtype(t[0])) for t in ufunc.types):
3600+
if com_dtype in [str(np.dtype(t[0])) for t in ufunc.types]: # noqa
35993601
# For a single output
36003602
if ufunc.nout == 1:
36013603
assert np.ma.allequal(ufunc(rst.data), ufunc(rst).data)
@@ -3680,7 +3682,7 @@ def test_array_ufunc_2nin_1nout(
36803682
warnings.filterwarnings("ignore", category=UserWarning)
36813683

36823684
# Check if both our input dtypes are possible on this ufunc, if yes check that outputs are identical
3683-
if com_dtype_tuple in ((str(np.dtype(t[0])), str(np.dtype(t[1]))) for t in ufunc.types):
3685+
if com_dtype_tuple in [(np.dtype(t[0]), np.dtype(t[1])) for t in ufunc.types]: # noqa
36843686
# For a single output
36853687
if ufunc.nout == 1:
36863688
# There exists a single exception due to negative integers as exponent of integers in "power"

‎tests/test_sampling.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77

88
import geoutils as gu
9+
from geoutils._typing import NDArrayNum
910

1011

1112
class TestSubsampling:
@@ -33,7 +34,7 @@ class TestSubsampling:
3334
assert np.count_nonzero(array3D.mask) > 0
3435

3536
@pytest.mark.parametrize("array", [array1D, array2D, array3D]) # type: ignore
36-
def test_subsample(self, array: np.ndarray) -> None:
37+
def test_subsample(self, array: NDArrayNum) -> None:
3738
"""
3839
Test gu.raster.subsample_array.
3940
"""

0 commit comments

Comments
 (0)
Please sign in to comment.