diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 713da24f577..1aad039320a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -67,6 +67,16 @@ jobs: - env: "flaky" python-version: "3.13" os: ubuntu-latest + # The mypy tests must be executed using only 1 process in order to guarantee + # predictable mypy output messages for comparison to expectations. + - env: "mypy" + python-version: "3.10" + numprocesses: 1 + os: ubuntu-latest + - env: "mypy" + python-version: "3.13" + numprocesses: 1 + os: ubuntu-latest steps: - uses: actions/checkout@v4 with: @@ -88,6 +98,10 @@ jobs: then echo "CONDA_ENV_FILE=ci/requirements/environment.yml" >> $GITHUB_ENV echo "PYTEST_ADDOPTS=-m 'flaky or network' --run-flaky --run-network-tests -W default" >> $GITHUB_ENV + elif [[ "${{ matrix.env }}" == "mypy" ]] ; + then + echo "CONDA_ENV_FILE=ci/requirements/environment.yml" >> $GITHUB_ENV + echo "PYTEST_ADDOPTS=-n 1 -m 'mypy' --run-mypy -W default" >> $GITHUB_ENV else echo "CONDA_ENV_FILE=ci/requirements/${{ matrix.env }}.yml" >> $GITHUB_ENV fi @@ -144,7 +158,7 @@ jobs: save-always: true - name: Run tests - run: python -m pytest -n 4 + run: python -m pytest -n ${{ matrix.numprocesses || 4 }} --timeout 180 --cov=xarray --cov-report=xml diff --git a/ci/minimum_versions.py b/ci/minimum_versions.py index c226e304769..cc115789d0f 100644 --- a/ci/minimum_versions.py +++ b/ci/minimum_versions.py @@ -26,8 +26,9 @@ "pytest", "pytest-cov", "pytest-env", - "pytest-xdist", + "pytest-mypy-plugins", "pytest-timeout", + "pytest-xdist", "hypothesis", ] diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index b7bf167188f..ca4943bddb1 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -30,8 +30,9 @@ dependencies: - pytest - pytest-cov - pytest-env - - pytest-xdist + - pytest-mypy-plugins - pytest-timeout + - pytest-xdist - rasterio - scipy - seaborn diff --git a/ci/requirements/all-but-numba.yml b/ci/requirements/all-but-numba.yml index 17a657eb32b..fa7ad81f198 100644 --- a/ci/requirements/all-but-numba.yml +++ b/ci/requirements/all-but-numba.yml @@ -43,8 +43,9 @@ dependencies: - pytest - pytest-cov - pytest-env - - pytest-xdist + - pytest-mypy-plugins - pytest-timeout + - pytest-xdist - rasterio - scipy - seaborn diff --git a/ci/requirements/bare-minimum.yml b/ci/requirements/bare-minimum.yml index d9590d95165..02e99d34af2 100644 --- a/ci/requirements/bare-minimum.yml +++ b/ci/requirements/bare-minimum.yml @@ -9,8 +9,9 @@ dependencies: - pytest - pytest-cov - pytest-env - - pytest-xdist + - pytest-mypy-plugins - pytest-timeout + - pytest-xdist - numpy=1.24 - packaging=23.1 - pandas=2.1 diff --git a/ci/requirements/environment-3.14.yml b/ci/requirements/environment-3.14.yml index cca3a7a746b..cebae38bc83 100644 --- a/ci/requirements/environment-3.14.yml +++ b/ci/requirements/environment-3.14.yml @@ -29,6 +29,7 @@ dependencies: - opt_einsum - packaging - pandas + - pandas-stubs # - pint>=0.22 - pip - pooch @@ -38,14 +39,25 @@ dependencies: - pytest - pytest-cov - pytest-env - - pytest-xdist + - pytest-mypy-plugins - pytest-timeout + - pytest-xdist - rasterio - scipy - seaborn # - sparse - toolz + - types-colorama + - types-docutils + - types-psutil + - types-Pygments + - types-python-dateutil + - types-pytz + - types-PyYAML + - types-setuptools - typing_extensions - zarr - pip: - jax # no way to get cpu-only jaxlib from conda if gpu is present + - types-defusedxml + - types-pexpect diff --git a/ci/requirements/environment-windows-3.14.yml b/ci/requirements/environment-windows-3.14.yml index c7f67d2efac..31c91b24b6d 100644 --- a/ci/requirements/environment-windows-3.14.yml +++ b/ci/requirements/environment-windows-3.14.yml @@ -25,6 +25,7 @@ dependencies: - numpy - packaging - pandas + - pandas-stubs # - pint>=0.22 - pip - pre-commit @@ -33,12 +34,24 @@ dependencies: - pytest - pytest-cov - pytest-env - - pytest-xdist + - pytest-mypy-plugins - pytest-timeout + - pytest-xdist - rasterio - scipy - seaborn # - sparse - toolz + - types-colorama + - types-docutils + - types-psutil + - types-Pygments + - types-python-dateutil + - types-pytz + - types-PyYAML + - types-setuptools - typing_extensions - zarr + - pip: + - types-defusedxml + - types-pexpect diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index a2ecef43d07..f8eb80f6c75 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -25,6 +25,7 @@ dependencies: - numpy - packaging - pandas + - pandas-stubs # - pint>=0.22 - pip - pre-commit @@ -33,12 +34,24 @@ dependencies: - pytest - pytest-cov - pytest-env - - pytest-xdist + - pytest-mypy-plugins - pytest-timeout + - pytest-xdist - rasterio - scipy - seaborn - sparse - toolz + - types-colorama + - types-docutils + - types-psutil + - types-Pygments + - types-python-dateutil + - types-pytz + - types-PyYAML + - types-setuptools - typing_extensions - zarr + - pip: + - types-defusedxml + - types-pexpect diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 321dbe75c38..f1465f5a7e7 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -29,6 +29,7 @@ dependencies: - opt_einsum - packaging - pandas + - pandas-stubs # - pint>=0.22 - pip - pooch @@ -39,14 +40,25 @@ dependencies: - pytest - pytest-cov - pytest-env - - pytest-xdist + - pytest-mypy-plugins - pytest-timeout + - pytest-xdist - rasterio - scipy - seaborn - sparse - toolz + - types-colorama + - types-docutils + - types-psutil + - types-Pygments + - types-python-dateutil + - types-pytz + - types-PyYAML + - types-setuptools - typing_extensions - zarr - pip: - jax # no way to get cpu-only jaxlib from conda if gpu is present + - types-defusedxml + - types-pexpect diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index f3dab2e5bbf..52c7f9b18e3 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -46,8 +46,9 @@ dependencies: - pytest - pytest-cov - pytest-env - - pytest-xdist + - pytest-mypy-plugins - pytest-timeout + - pytest-xdist - rasterio=1.3 - scipy=1.11 - seaborn=0.13 diff --git a/conftest.py b/conftest.py index 24b7530b220..532a7badd91 100644 --- a/conftest.py +++ b/conftest.py @@ -3,7 +3,7 @@ import pytest -def pytest_addoption(parser): +def pytest_addoption(parser: pytest.Parser): """Add command-line flags for pytest.""" parser.addoption("--run-flaky", action="store_true", help="runs flaky tests") parser.addoption( @@ -11,6 +11,7 @@ def pytest_addoption(parser): action="store_true", help="runs tests requiring a network connection", ) + parser.addoption("--run-mypy", action="store_true", help="runs mypy tests") def pytest_runtest_setup(item): @@ -21,6 +22,21 @@ def pytest_runtest_setup(item): pytest.skip( "set --run-network-tests to run test requiring an internet connection" ) + if "mypy" in item.keywords and not item.config.getoption("--run-mypy"): + pytest.skip("set --run-mypy option to run mypy tests") + + +# See https://docs.pytest.org/en/stable/example/markers.html#automatically-adding-markers-based-on-test-names +def pytest_collection_modifyitems(items): + for item in items: + if "mypy" in item.nodeid: + # IMPORTANT: mypy type annotation tests leverage the pytest-mypy-plugins + # plugin, and are thus written in test_*.yml files. As such, there are + # no explicit test functions on which we can apply a pytest.mark.mypy + # decorator. Therefore, we mark them via this name-based, automatic + # marking approach, meaning that each test case must contain "mypy" in the + # name. + item.add_marker(pytest.mark.mypy) @pytest.fixture(autouse=True) diff --git a/pyproject.toml b/pyproject.toml index 32b0bce1322..817fda6c328 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,8 +44,9 @@ dev = [ "pytest", "pytest-cov", "pytest-env", - "pytest-xdist", + "pytest-mypy-plugins", "pytest-timeout", + "pytest-xdist", "ruff>=0.8.0", "sphinx", "sphinx_autosummary_accessors", @@ -304,7 +305,12 @@ known-first-party = ["xarray"] ban-relative-imports = "all" [tool.pytest.ini_options] -addopts = ["--strict-config", "--strict-markers"] +addopts = [ + "--strict-config", + "--strict-markers", + "--mypy-only-local-stub", + "--mypy-pyproject-toml-file=pyproject.toml", +] # We want to forbid warnings from within xarray in our tests — instead we should # fix our own code, or mark the test itself as expecting a warning. So this: @@ -361,6 +367,7 @@ filterwarnings = [ log_cli_level = "INFO" markers = [ "flaky: flaky tests", + "mypy: type annotation tests", "network: tests requiring a network connection", "slow: slow tests", "slow_hypothesis: slow hypothesis tests", diff --git a/xarray/core/common.py b/xarray/core/common.py index 01c02a8d14f..ceaae42356a 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -6,7 +6,7 @@ from contextlib import suppress from html import escape from textwrap import dedent -from typing import TYPE_CHECKING, Any, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, Union, overload import numpy as np import pandas as pd @@ -60,6 +60,7 @@ T_Resample = TypeVar("T_Resample", bound="Resample") C = TypeVar("C") T = TypeVar("T") +P = ParamSpec("P") class ImplementsArrayReduce: @@ -718,11 +719,27 @@ def assign_attrs(self, *args: Any, **kwargs: Any) -> Self: out.attrs.update(*args, **kwargs) return out + @overload + def pipe( + self, + func: Callable[Concatenate[Self, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: ... + + @overload def pipe( self, - func: Callable[..., T] | tuple[Callable[..., T], str], + func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any, + ) -> T: ... + + def pipe( + self, + func: Callable[Concatenate[Self, P], T] | tuple[Callable[P, T], str], + *args: P.args, + **kwargs: P.kwargs, ) -> T: """ Apply ``func(self, *args, **kwargs)`` @@ -840,15 +857,19 @@ def pipe( pandas.DataFrame.pipe """ if isinstance(func, tuple): - func, target = func + # Use different var when unpacking function from tuple because the type + # signature of the unpacked function differs from the expected type + # signature in the case where only a function is given, rather than a tuple. + # This makes type checkers happy at both call sites below. + f, target = func if target in kwargs: raise ValueError( f"{target} is both the pipe target and a keyword argument" ) kwargs[target] = self - return func(*args, **kwargs) - else: - return func(self, *args, **kwargs) + return f(*args, **kwargs) + + return func(self, *args, **kwargs) def rolling_exp( self: T_DataWithCoords, diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 1a388919f0c..61340ac99ad 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -12,7 +12,17 @@ Mapping, ) from html import escape -from typing import TYPE_CHECKING, Any, Literal, NoReturn, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Concatenate, + Literal, + NoReturn, + ParamSpec, + TypeVar, + Union, + overload, +) from xarray.core import utils from xarray.core._aggregations import DataTreeAggregations @@ -80,18 +90,23 @@ # """ # DEVELOPERS' NOTE # ---------------- -# The idea of this module is to create a `DataTree` class which inherits the tree structure from TreeNode, and also copies -# the entire API of `xarray.Dataset`, but with certain methods decorated to instead map the dataset function over every -# node in the tree. As this API is copied without directly subclassing `xarray.Dataset` we instead create various Mixin -# classes (in ops.py) which each define part of `xarray.Dataset`'s extensive API. +# The idea of this module is to create a `DataTree` class which inherits the tree +# structure from TreeNode, and also copies the entire API of `xarray.Dataset`, but with +# certain methods decorated to instead map the dataset function over every node in the +# tree. As this API is copied without directly subclassing `xarray.Dataset` we instead +# create various Mixin classes (in ops.py) which each define part of `xarray.Dataset`'s +# extensive API. # -# Some of these methods must be wrapped to map over all nodes in the subtree. Others are fine to inherit unaltered -# (normally because they (a) only call dataset properties and (b) don't return a dataset that should be nested into a new -# tree) and some will get overridden by the class definition of DataTree. +# Some of these methods must be wrapped to map over all nodes in the subtree. Others are +# fine to inherit unaltered (normally because they (a) only call dataset properties and +# (b) don't return a dataset that should be nested into a new tree) and some will get +# overridden by the class definition of DataTree. # """ T_Path = Union[str, NodePath] +T = TypeVar("T") +P = ParamSpec("P") def _collect_data_and_coord_variables( @@ -1465,9 +1480,28 @@ def map_over_datasets( # TODO fix this typing error return map_over_datasets(func, self, *args, kwargs=kwargs) + @overload + def pipe( + self, + func: Callable[Concatenate[Self, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: ... + + @overload + def pipe( + self, + func: tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, + ) -> T: ... + def pipe( - self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any - ) -> Any: + self, + func: Callable[Concatenate[Self, P], T] | tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, + ) -> T: """Apply ``func(self, *args, **kwargs)`` This method replicates the pandas method of the same name. @@ -1487,7 +1521,7 @@ def pipe( Returns ------- - object : Any + object : T the return type of ``func``. Notes @@ -1515,15 +1549,19 @@ def pipe( """ if isinstance(func, tuple): - func, target = func + # Use different var when unpacking function from tuple because the type + # signature of the unpacked function differs from the expected type + # signature in the case where only a function is given, rather than a tuple. + # This makes type checkers happy at both call sites below. + f, target = func if target in kwargs: raise ValueError( f"{target} is both the pipe target and a keyword argument" ) kwargs[target] = self - else: - args = (self,) + args - return func(*args, **kwargs) + return f(*args, **kwargs) + + return func(self, *args, **kwargs) # TODO some kind of .collapse() or .flatten() method to merge a subtree diff --git a/xarray/tests/test_dataarray_typing.yml b/xarray/tests/test_dataarray_typing.yml new file mode 100644 index 00000000000..ae3356f9d7c --- /dev/null +++ b/xarray/tests/test_dataarray_typing.yml @@ -0,0 +1,190 @@ +- case: test_mypy_pipe_lambda_noarg_return_type + main: | + from xarray import DataArray + + da = DataArray().pipe(lambda data: data) + + reveal_type(da) # N: Revealed type is "xarray.core.dataarray.DataArray" + +- case: test_mypy_pipe_lambda_posarg_return_type + main: | + from xarray import DataArray + + da = DataArray().pipe(lambda data, arg: arg, "foo") + + reveal_type(da) # N: Revealed type is "builtins.str" + +- case: test_mypy_pipe_lambda_chaining_return_type + main: | + from xarray import DataArray + + answer = DataArray().pipe(lambda data, arg: arg, "foo").count("o") + + reveal_type(answer) # N: Revealed type is "builtins.int" + +- case: test_mypy_pipe_lambda_missing_arg + main: | + from xarray import DataArray + + # Call to pipe missing argument for lambda parameter `arg` + da = DataArray().pipe(lambda data, arg: data) + out: | + main:4: error: No overload variant of "pipe" of "DataWithCoords" matches argument type "Callable[[Any, Any], Any]" [call-overload] + main:4: note: Possible overload variants: + main:4: note: def [P`2, T] pipe(self, func: Callable[[DataArray, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:4: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_lambda_extra_arg + main: | + from xarray import DataArray + + # Call to pipe with extra argument for lambda + da = DataArray().pipe(lambda data: data, "oops!") + out: | + main:4: error: No overload variant of "pipe" of "DataWithCoords" matches argument types "Callable[[Any], Any]", "str" [call-overload] + main:4: note: Possible overload variants: + main:4: note: def [P`2, T] pipe(self, func: Callable[[DataArray, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:4: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_missing_posarg + main: | + from xarray import DataArray + + def f(da: DataArray, arg: int) -> DataArray: + return da + + # Call to pipe missing argument for function parameter `arg` + da = DataArray().pipe(f) + out: | + main:7: error: No overload variant of "pipe" of "DataWithCoords" matches argument type "Callable[[DataArray, int], DataArray]" [call-overload] + main:7: note: Possible overload variants: + main:7: note: def [P`2, T] pipe(self, func: Callable[[DataArray, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:7: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_extra_posarg + main: | + from xarray import DataArray + + def f(da: DataArray, arg: int) -> DataArray: + return da + + # Call to pipe missing keyword for kwonly parameter `kwonly` + da = DataArray().pipe(f, 42, "oops!") + out: | + main:7: error: No overload variant of "pipe" of "DataWithCoords" matches argument types "Callable[[DataArray, int], DataArray]", "int", "str" [call-overload] + main:7: note: Possible overload variants: + main:7: note: def [P`2, T] pipe(self, func: Callable[[DataArray, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:7: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_missing_kwarg + main: | + from xarray import DataArray + + def f(da: DataArray, arg: int, *, kwonly: int) -> DataArray: + return da + + # Call to pipe missing argument for kwonly parameter `kwonly` + da = DataArray().pipe(f, 42) + out: | + main:7: error: No overload variant of "pipe" of "DataWithCoords" matches argument types "Callable[[DataArray, int, NamedArg(int, 'kwonly')], DataArray]", "int" [call-overload] + main:7: note: Possible overload variants: + main:7: note: def [P`2, T] pipe(self, func: Callable[[DataArray, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:7: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_missing_keyword + main: | + from xarray import DataArray + + def f(da: DataArray, arg: int, *, kwonly: int) -> DataArray: + return da + + # Call to pipe missing keyword for kwonly parameter `kwonly` + da = DataArray().pipe(f, 42, 99) + out: | + main:7: error: No overload variant of "pipe" of "DataWithCoords" matches argument types "Callable[[DataArray, int, NamedArg(int, 'kwonly')], DataArray]", "int", "int" [call-overload] + main:7: note: Possible overload variants: + main:7: note: def [P`2, T] pipe(self, func: Callable[[DataArray, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:7: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_unexpected_keyword + main: | + from xarray import DataArray + + def f(da: DataArray, arg: int, *, kwonly: int) -> DataArray: + return da + + # Call to pipe using wrong keyword: `kw` instead of `kwonly` + da = DataArray().pipe(f, 42, kw=99) + out: | + main:7: error: Unexpected keyword argument "kw" for "pipe" of "DataWithCoords" [call-arg] + +- case: test_mypy_pipe_tuple_return_type_dataarray + main: | + from xarray import DataArray + + def f(arg: int, da: DataArray) -> DataArray: + return da + + da = DataArray().pipe((f, "da"), 42) + reveal_type(da) # N: Revealed type is "xarray.core.dataarray.DataArray" + +- case: test_mypy_pipe_tuple_return_type_other + main: | + from xarray import DataArray + + def f(arg: int, da: DataArray) -> int: + return arg + + answer = DataArray().pipe((f, "da"), 42) + + reveal_type(answer) # N: Revealed type is "builtins.int" + +- case: test_mypy_pipe_tuple_missing_arg + main: | + from xarray import DataArray + + def f(arg: int, da: DataArray) -> DataArray: + return da + + # Since we cannot provide a precise type annotation when passing a tuple to + # pipe, there's not enough information for type analysis to indicate that + # we are missing an argument for parameter `arg`, so we get no error here. + + da = DataArray().pipe((f, "da")) + reveal_type(da) # N: Revealed type is "xarray.core.dataarray.DataArray" + + # Rather than passing a tuple, passing a lambda that calls `f` with args in + # the correct order allows for proper type analysis, indicating (perhaps + # somewhat cryptically) that we failed to pass an argument for `arg`. + + da = DataArray().pipe(lambda data, arg: f(arg, data)) + out: | + main:17: error: No overload variant of "pipe" of "DataWithCoords" matches argument type "Callable[[Any, Any], DataArray]" [call-overload] + main:17: note: Possible overload variants: + main:17: note: def [P`9, T] pipe(self, func: Callable[[DataArray, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:17: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_tuple_extra_arg + main: | + from xarray import DataArray + + def f(arg: int, da: DataArray) -> DataArray: + return da + + # Since we cannot provide a precise type annotation when passing a tuple to + # pipe, there's not enough information for type analysis to indicate that + # we are providing too many args for `f`, so we get no error here. + + da = DataArray().pipe((f, "da"), 42, "foo") + reveal_type(da) # N: Revealed type is "xarray.core.dataarray.DataArray" + + # Rather than passing a tuple, passing a lambda that calls `f` with args in + # the correct order allows for proper type analysis, indicating (perhaps + # somewhat cryptically) that we passed too many arguments. + + da = DataArray().pipe(lambda data, arg: f(arg, data), 42, "foo") + out: | + main:17: error: No overload variant of "pipe" of "DataWithCoords" matches argument types "Callable[[Any, Any], DataArray]", "int", "str" [call-overload] + main:17: note: Possible overload variants: + main:17: note: def [P`9, T] pipe(self, func: Callable[[DataArray, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:17: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T diff --git a/xarray/tests/test_dataset_typing.yml b/xarray/tests/test_dataset_typing.yml new file mode 100644 index 00000000000..3b62f81d361 --- /dev/null +++ b/xarray/tests/test_dataset_typing.yml @@ -0,0 +1,190 @@ +- case: test_mypy_pipe_lambda_noarg_return_type + main: | + from xarray import Dataset + + ds = Dataset().pipe(lambda data: data) + + reveal_type(ds) # N: Revealed type is "xarray.core.dataset.Dataset" + +- case: test_mypy_pipe_lambda_posarg_return_type + main: | + from xarray import Dataset + + ds = Dataset().pipe(lambda data, arg: arg, "foo") + + reveal_type(ds) # N: Revealed type is "builtins.str" + +- case: test_mypy_pipe_lambda_chaining_return_type + main: | + from xarray import Dataset + + answer = Dataset().pipe(lambda data, arg: arg, "foo").count("o") + + reveal_type(answer) # N: Revealed type is "builtins.int" + +- case: test_mypy_pipe_lambda_missing_arg + main: | + from xarray import Dataset + + # Call to pipe missing argument for lambda parameter `arg` + ds = Dataset().pipe(lambda data, arg: data) + out: | + main:4: error: No overload variant of "pipe" of "DataWithCoords" matches argument type "Callable[[Any, Any], Any]" [call-overload] + main:4: note: Possible overload variants: + main:4: note: def [P`2, T] pipe(self, func: Callable[[Dataset, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:4: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_lambda_extra_arg + main: | + from xarray import Dataset + + # Call to pipe with extra argument for lambda + ds = Dataset().pipe(lambda data: data, "oops!") + out: | + main:4: error: No overload variant of "pipe" of "DataWithCoords" matches argument types "Callable[[Any], Any]", "str" [call-overload] + main:4: note: Possible overload variants: + main:4: note: def [P`2, T] pipe(self, func: Callable[[Dataset, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:4: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_missing_posarg + main: | + from xarray import Dataset + + def f(ds: Dataset, arg: int) -> Dataset: + return ds + + # Call to pipe missing argument for function parameter `arg` + ds = Dataset().pipe(f) + out: | + main:7: error: No overload variant of "pipe" of "DataWithCoords" matches argument type "Callable[[Dataset, int], Dataset]" [call-overload] + main:7: note: Possible overload variants: + main:7: note: def [P`2, T] pipe(self, func: Callable[[Dataset, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:7: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_extra_posarg + main: | + from xarray import Dataset + + def f(ds: Dataset, arg: int) -> Dataset: + return ds + + # Call to pipe missing keyword for kwonly parameter `kwonly` + ds = Dataset().pipe(f, 42, "oops!") + out: | + main:7: error: No overload variant of "pipe" of "DataWithCoords" matches argument types "Callable[[Dataset, int], Dataset]", "int", "str" [call-overload] + main:7: note: Possible overload variants: + main:7: note: def [P`2, T] pipe(self, func: Callable[[Dataset, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:7: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_missing_kwarg + main: | + from xarray import Dataset + + def f(ds: Dataset, arg: int, *, kwonly: int) -> Dataset: + return ds + + # Call to pipe missing argument for kwonly parameter `kwonly` + ds = Dataset().pipe(f, 42) + out: | + main:7: error: No overload variant of "pipe" of "DataWithCoords" matches argument types "Callable[[Dataset, int, NamedArg(int, 'kwonly')], Dataset]", "int" [call-overload] + main:7: note: Possible overload variants: + main:7: note: def [P`2, T] pipe(self, func: Callable[[Dataset, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:7: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_missing_keyword + main: | + from xarray import Dataset + + def f(ds: Dataset, arg: int, *, kwonly: int) -> Dataset: + return ds + + # Call to pipe missing keyword for kwonly parameter `kwonly` + ds = Dataset().pipe(f, 42, 99) + out: | + main:7: error: No overload variant of "pipe" of "DataWithCoords" matches argument types "Callable[[Dataset, int, NamedArg(int, 'kwonly')], Dataset]", "int", "int" [call-overload] + main:7: note: Possible overload variants: + main:7: note: def [P`2, T] pipe(self, func: Callable[[Dataset, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:7: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_unexpected_keyword + main: | + from xarray import Dataset + + def f(ds: Dataset, arg: int, *, kwonly: int) -> Dataset: + return ds + + # Call to pipe using wrong keyword: `kw` instead of `kwonly` + ds = Dataset().pipe(f, 42, kw=99) + out: | + main:7: error: Unexpected keyword argument "kw" for "pipe" of "DataWithCoords" [call-arg] + +- case: test_mypy_pipe_tuple_return_type_dataset + main: | + from xarray import Dataset + + def f(arg: int, ds: Dataset) -> Dataset: + return ds + + ds = Dataset().pipe((f, "ds"), 42) + reveal_type(ds) # N: Revealed type is "xarray.core.dataset.Dataset" + +- case: test_mypy_pipe_tuple_return_type_other + main: | + from xarray import Dataset + + def f(arg: int, ds: Dataset) -> int: + return arg + + answer = Dataset().pipe((f, "ds"), 42) + + reveal_type(answer) # N: Revealed type is "builtins.int" + +- case: test_mypy_pipe_tuple_missing_arg + main: | + from xarray import Dataset + + def f(arg: int, ds: Dataset) -> Dataset: + return ds + + # Since we cannot provide a precise type annotation when passing a tuple to + # pipe, there's not enough information for type analysis to indicate that + # we are missing an argument for parameter `arg`, so we get no error here. + + ds = Dataset().pipe((f, "ds")) + reveal_type(ds) # N: Revealed type is "xarray.core.dataset.Dataset" + + # Rather than passing a tuple, passing a lambda that calls `f` with args in + # the correct order allows for proper type analysis, indicating (perhaps + # somewhat cryptically) that we failed to pass an argument for `arg`. + + ds = Dataset().pipe(lambda data, arg: f(arg, data)) + out: | + main:17: error: No overload variant of "pipe" of "DataWithCoords" matches argument type "Callable[[Any, Any], Dataset]" [call-overload] + main:17: note: Possible overload variants: + main:17: note: def [P`9, T] pipe(self, func: Callable[[Dataset, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:17: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_tuple_extra_arg + main: | + from xarray import Dataset + + def f(arg: int, ds: Dataset) -> Dataset: + return ds + + # Since we cannot provide a precise type annotation when passing a tuple to + # pipe, there's not enough information for type analysis to indicate that + # we are providing too many args for `f`, so we get no error here. + + ds = Dataset().pipe((f, "ds"), 42, "foo") + reveal_type(ds) # N: Revealed type is "xarray.core.dataset.Dataset" + + # Rather than passing a tuple, passing a lambda that calls `f` with args in + # the correct order allows for proper type analysis, indicating (perhaps + # somewhat cryptically) that we passed too many arguments. + + ds = Dataset().pipe(lambda data, arg: f(arg, data), 42, "foo") + out: | + main:17: error: No overload variant of "pipe" of "DataWithCoords" matches argument types "Callable[[Any, Any], Dataset]", "int", "str" [call-overload] + main:17: note: Possible overload variants: + main:17: note: def [P`9, T] pipe(self, func: Callable[[Dataset, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:17: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 715d80e084a..c87a1e1329e 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1,7 +1,7 @@ import re import sys import typing -from collections.abc import Mapping +from collections.abc import Callable, Mapping from copy import copy, deepcopy from textwrap import dedent @@ -1589,27 +1589,53 @@ def test_assign(self) -> None: class TestPipe: - def test_noop(self, create_test_datatree) -> None: + def test_noop(self, create_test_datatree: Callable[[], DataTree]) -> None: dt = create_test_datatree() actual = dt.pipe(lambda tree: tree) assert actual.identical(dt) - def test_params(self, create_test_datatree) -> None: + def test_args(self, create_test_datatree: Callable[[], DataTree]) -> None: dt = create_test_datatree() - def f(tree, **attrs): - return tree.assign(arr_with_attrs=xr.Variable("dim0", [], attrs=attrs)) + def f(tree: DataTree, x: int, y: int) -> DataTree: + return tree.assign( + arr_with_attrs=xr.Variable("dim0", [], attrs=dict(x=x, y=y)) + ) + + actual = dt.pipe(f, 1, 2) + assert actual["arr_with_attrs"].attrs == dict(x=1, y=2) + + def test_kwargs(self, create_test_datatree: Callable[[], DataTree]) -> None: + dt = create_test_datatree() + + def f(tree: DataTree, *, x: int, y: int, z: int) -> DataTree: + return tree.assign( + arr_with_attrs=xr.Variable("dim0", [], attrs=dict(x=x, y=y, z=z)) + ) attrs = {"x": 1, "y": 2, "z": 3} actual = dt.pipe(f, **attrs) assert actual["arr_with_attrs"].attrs == attrs - def test_named_self(self, create_test_datatree) -> None: + def test_args_kwargs(self, create_test_datatree: Callable[[], DataTree]) -> None: + dt = create_test_datatree() + + def f(tree: DataTree, x: int, *, y: int, z: int) -> DataTree: + return tree.assign( + arr_with_attrs=xr.Variable("dim0", [], attrs=dict(x=x, y=y, z=z)) + ) + + attrs = {"x": 1, "y": 2, "z": 3} + + actual = dt.pipe(f, attrs["x"], y=attrs["y"], z=attrs["z"]) + assert actual["arr_with_attrs"].attrs == attrs + + def test_named_self(self, create_test_datatree: Callable[[], DataTree]) -> None: dt = create_test_datatree() - def f(x, tree, y): + def f(x: int, tree: DataTree, y: int): tree.attrs.update({"x": x, "y": y}) return tree diff --git a/xarray/tests/test_datatree_typing.yml b/xarray/tests/test_datatree_typing.yml new file mode 100644 index 00000000000..fac7fe8ab65 --- /dev/null +++ b/xarray/tests/test_datatree_typing.yml @@ -0,0 +1,190 @@ +- case: test_mypy_pipe_lambda_noarg_return_type + main: | + from xarray import DataTree + + dt = DataTree().pipe(lambda data: data) + + reveal_type(dt) # N: Revealed type is "xarray.core.datatree.DataTree" + +- case: test_mypy_pipe_lambda_posarg_return_type + main: | + from xarray import DataTree + + dt = DataTree().pipe(lambda data, arg: arg, "foo") + + reveal_type(dt) # N: Revealed type is "builtins.str" + +- case: test_mypy_pipe_lambda_chaining_return_type + main: | + from xarray import DataTree + + answer = DataTree().pipe(lambda data, arg: arg, "foo").count("o") + + reveal_type(answer) # N: Revealed type is "builtins.int" + +- case: test_mypy_pipe_lambda_missing_arg + main: | + from xarray import DataTree + + # Call to pipe missing argument for lambda parameter `arg` + dt = DataTree().pipe(lambda data, arg: data) + out: | + main:4: error: No overload variant of "pipe" of "DataTree" matches argument type "Callable[[Any, Any], Any]" [call-overload] + main:4: note: Possible overload variants: + main:4: note: def [P`2, T] pipe(self, func: Callable[[DataTree, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:4: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_lambda_extra_arg + main: | + from xarray import DataTree + + # Call to pipe with extra argument for lambda + dt = DataTree().pipe(lambda data: data, "oops!") + out: | + main:4: error: No overload variant of "pipe" of "DataTree" matches argument types "Callable[[Any], Any]", "str" [call-overload] + main:4: note: Possible overload variants: + main:4: note: def [P`2, T] pipe(self, func: Callable[[DataTree, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:4: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_missing_posarg + main: | + from xarray import DataTree + + def f(dt: DataTree, arg: int) -> DataTree: + return dt + + # Call to pipe missing argument for function parameter `arg` + dt = DataTree().pipe(f) + out: | + main:7: error: No overload variant of "pipe" of "DataTree" matches argument type "Callable[[DataTree, int], DataTree]" [call-overload] + main:7: note: Possible overload variants: + main:7: note: def [P`2, T] pipe(self, func: Callable[[DataTree, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:7: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_extra_posarg + main: | + from xarray import DataTree + + def f(dt: DataTree, arg: int) -> DataTree: + return dt + + # Call to pipe missing keyword for kwonly parameter `kwonly` + dt = DataTree().pipe(f, 42, "oops!") + out: | + main:7: error: No overload variant of "pipe" of "DataTree" matches argument types "Callable[[DataTree, int], DataTree]", "int", "str" [call-overload] + main:7: note: Possible overload variants: + main:7: note: def [P`2, T] pipe(self, func: Callable[[DataTree, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:7: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_missing_kwarg + main: | + from xarray import DataTree + + def f(dt: DataTree, arg: int, *, kwonly: int) -> DataTree: + return dt + + # Call to pipe missing argument for kwonly parameter `kwonly` + dt = DataTree().pipe(f, 42) + out: | + main:7: error: No overload variant of "pipe" of "DataTree" matches argument types "Callable[[DataTree, int, NamedArg(int, 'kwonly')], DataTree]", "int" [call-overload] + main:7: note: Possible overload variants: + main:7: note: def [P`2, T] pipe(self, func: Callable[[DataTree, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:7: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_missing_keyword + main: | + from xarray import DataTree + + def f(dt: DataTree, arg: int, *, kwonly: int) -> DataTree: + return dt + + # Call to pipe missing keyword for kwonly parameter `kwonly` + dt = DataTree().pipe(f, 42, 99) + out: | + main:7: error: No overload variant of "pipe" of "DataTree" matches argument types "Callable[[DataTree, int, NamedArg(int, 'kwonly')], DataTree]", "int", "int" [call-overload] + main:7: note: Possible overload variants: + main:7: note: def [P`2, T] pipe(self, func: Callable[[DataTree, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:7: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_function_unexpected_keyword + main: | + from xarray import DataTree + + def f(dt: DataTree, arg: int, *, kwonly: int) -> DataTree: + return dt + + # Call to pipe using wrong keyword: `kw` instead of `kwonly` + dt = DataTree().pipe(f, 42, kw=99) + out: | + main:7: error: Unexpected keyword argument "kw" for "pipe" of "DataTree" [call-arg] + +- case: test_mypy_pipe_tuple_return_type_datatree + main: | + from xarray import DataTree + + def f(arg: int, dt: DataTree) -> DataTree: + return dt + + dt = DataTree().pipe((f, "dt"), 42) + reveal_type(dt) # N: Revealed type is "xarray.core.datatree.DataTree" + +- case: test_mypy_pipe_tuple_return_type_other + main: | + from xarray import DataTree + + def f(arg: int, dt: DataTree) -> int: + return arg + + answer = DataTree().pipe((f, "dt"), 42) + + reveal_type(answer) # N: Revealed type is "builtins.int" + +- case: test_mypy_pipe_tuple_missing_arg + main: | + from xarray import DataTree + + def f(arg: int, dt: DataTree) -> DataTree: + return dt + + # Since we cannot provide a precise type annotation when passing a tuple to + # pipe, there's not enough information for type analysis to indicate that + # we are missing an argument for parameter `arg`, so we get no error here. + + dt = DataTree().pipe((f, "dt")) + reveal_type(dt) # N: Revealed type is "xarray.core.datatree.DataTree" + + # Rather than passing a tuple, passing a lambda that calls `f` with args in + # the correct order allows for proper type analysis, indicating (perhaps + # somewhat cryptically) that we failed to pass an argument for `arg`. + + dt = DataTree().pipe(lambda data, arg: f(arg, data)) + out: | + main:17: error: No overload variant of "pipe" of "DataTree" matches argument type "Callable[[Any, Any], DataTree]" [call-overload] + main:17: note: Possible overload variants: + main:17: note: def [P`9, T] pipe(self, func: Callable[[DataTree, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:17: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T + +- case: test_mypy_pipe_tuple_extra_arg + main: | + from xarray import DataTree + + def f(arg: int, dt: DataTree) -> DataTree: + return dt + + # Since we cannot provide a precise type annotation when passing a tuple to + # pipe, there's not enough information for type analysis to indicate that + # we are providing too many args for `f`, so we get no error here. + + dt = DataTree().pipe((f, "dt"), 42, "foo") + reveal_type(dt) # N: Revealed type is "xarray.core.datatree.DataTree" + + # Rather than passing a tuple, passing a lambda that calls `f` with args in + # the correct order allows for proper type analysis, indicating (perhaps + # somewhat cryptically) that we passed too many arguments. + + dt = DataTree().pipe(lambda data, arg: f(arg, data), 42, "foo") + out: | + main:17: error: No overload variant of "pipe" of "DataTree" matches argument types "Callable[[Any, Any], DataTree]", "int", "str" [call-overload] + main:17: note: Possible overload variants: + main:17: note: def [P`9, T] pipe(self, func: Callable[[DataTree, **P], T], *args: P.args, **kwargs: P.kwargs) -> T + main:17: note: def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T