Skip to content

Commit 237b0c6

Browse files
rhttpike3
authored andcommitted
mypy: Improve space.py annotations
1 parent 6884c9d commit 237b0c6

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

mesa/space.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from typing import (
2222
Any,
23+
Callable,
2324
Dict,
2425
Iterable,
2526
Iterator,
@@ -28,6 +29,7 @@
2829
Set,
2930
Sequence,
3031
Tuple,
32+
TypeVar,
3133
Union,
3234
cast,
3335
overload,
@@ -36,33 +38,39 @@
3638
# For Mypy
3739
from .agent import Agent
3840
from numbers import Real
41+
import numpy.typing as npt
3942

4043
Coordinate = Tuple[int, int]
41-
GridContent = Union[Optional[Agent], Set[Agent]]
4244
# used in ContinuousSpace
43-
FloatCoordinate = Union[Tuple[float, float], np.ndarray]
45+
FloatCoordinate = Union[Tuple[float, float], npt.NDArray[float]]
4446
NetworkCoordinate = int
4547

4648
Position = Union[Coordinate, FloatCoordinate, NetworkCoordinate]
4749

50+
GridContent = Optional[Agent]
51+
MultiGridContent = List[Agent]
52+
53+
F = TypeVar("F", bound=Callable[..., Any])
54+
55+
4856
def clamp(x: float, lowest: float, highest: float) -> float:
4957
# This should be faster than np.clip for a scalar x.
5058
# TODO: measure how much faster this function is.
5159
return max(lowest, min(x, highest))
5260

5361

54-
def accept_tuple_argument(wrapped_function):
62+
def accept_tuple_argument(wrapped_function: F) -> F:
5563
"""Decorator to allow grid methods that take a list of (x, y) coord tuples
5664
to also handle a single position, by automatically wrapping tuple in
5765
single-item list rather than forcing user to do it."""
5866

59-
def wrapper(*args: Any):
67+
def wrapper(*args: Any) -> Any:
6068
if isinstance(args[1], tuple) and len(args[1]) == 2:
6169
return wrapped_function(args[0], [args[1]])
6270
else:
6371
return wrapped_function(*args)
6472

65-
return wrapper
73+
return cast(F, wrapper)
6674

6775

6876
def is_integer(x: Real) -> bool:
@@ -140,8 +148,7 @@ def __getitem__(
140148
if isinstance(index, int):
141149
# grid[x]
142150
return self.grid[index]
143-
144-
if isinstance(index[0], tuple):
151+
elif isinstance(index[0], tuple):
145152
# grid[(x1, y1), (x2, y2)]
146153
index = cast(Sequence[Coordinate], index)
147154

@@ -564,7 +571,7 @@ class MultiGrid(Grid):
564571
"""
565572

566573
@staticmethod
567-
def default_val() -> Set[Agent]:
574+
def default_val() -> MultiGridContent:
568575
"""Default value for new cell elements."""
569576
return []
570577

@@ -585,7 +592,7 @@ def _remove_agent(self, pos: Coordinate, agent: Agent) -> None:
585592
@accept_tuple_argument
586593
def iter_cell_list_contents(
587594
self, cell_list: Iterable[Coordinate]
588-
) -> Iterator[GridContent]:
595+
) -> Iterator[MultiGridContent]:
589596
"""Returns an iterator of the contents of the
590597
cells identified in cell_list.
591598
@@ -786,7 +793,7 @@ def __init__(
786793
self.size = np.array((self.width, self.height))
787794
self.torus = torus
788795

789-
self._agent_points = None
796+
self._agent_points: Optional[npt.NDArray[FloatCoordinate]] = None
790797
self._index_to_agent: Dict[int, Agent] = {}
791798
self._agent_to_index: Dict[Agent, int] = {}
792799

0 commit comments

Comments
 (0)