diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 715d80e084a..c87a1e1329e 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1,7 +1,7 @@ import re import sys import typing -from collections.abc import Mapping +from collections.abc import Callable, Mapping from copy import copy, deepcopy from textwrap import dedent @@ -1589,27 +1589,53 @@ def test_assign(self) -> None: class TestPipe: - def test_noop(self, create_test_datatree) -> None: + def test_noop(self, create_test_datatree: Callable[[], DataTree]) -> None: dt = create_test_datatree() actual = dt.pipe(lambda tree: tree) assert actual.identical(dt) - def test_params(self, create_test_datatree) -> None: + def test_args(self, create_test_datatree: Callable[[], DataTree]) -> None: dt = create_test_datatree() - def f(tree, **attrs): - return tree.assign(arr_with_attrs=xr.Variable("dim0", [], attrs=attrs)) + def f(tree: DataTree, x: int, y: int) -> DataTree: + return tree.assign( + arr_with_attrs=xr.Variable("dim0", [], attrs=dict(x=x, y=y)) + ) + + actual = dt.pipe(f, 1, 2) + assert actual["arr_with_attrs"].attrs == dict(x=1, y=2) + + def test_kwargs(self, create_test_datatree: Callable[[], DataTree]) -> None: + dt = create_test_datatree() + + def f(tree: DataTree, *, x: int, y: int, z: int) -> DataTree: + return tree.assign( + arr_with_attrs=xr.Variable("dim0", [], attrs=dict(x=x, y=y, z=z)) + ) attrs = {"x": 1, "y": 2, "z": 3} actual = dt.pipe(f, **attrs) assert actual["arr_with_attrs"].attrs == attrs - def test_named_self(self, create_test_datatree) -> None: + def test_args_kwargs(self, create_test_datatree: Callable[[], DataTree]) -> None: + dt = create_test_datatree() + + def f(tree: DataTree, x: int, *, y: int, z: int) -> DataTree: + return tree.assign( + arr_with_attrs=xr.Variable("dim0", [], attrs=dict(x=x, y=y, z=z)) + ) + + attrs = {"x": 1, "y": 2, "z": 3} + + actual = dt.pipe(f, attrs["x"], y=attrs["y"], z=attrs["z"]) + assert actual["arr_with_attrs"].attrs == attrs + + def test_named_self(self, create_test_datatree: Callable[[], DataTree]) -> None: dt = create_test_datatree() - def f(x, tree, y): + def f(x: int, tree: DataTree, y: int): tree.attrs.update({"x": x, "y": y}) return tree