20
20
21
21
from typing import (
22
22
Any ,
23
+ Callable ,
23
24
Dict ,
24
25
Iterable ,
25
26
Iterator ,
28
29
Set ,
29
30
Sequence ,
30
31
Tuple ,
32
+ TypeVar ,
31
33
Union ,
32
34
cast ,
33
35
overload ,
36
38
# For Mypy
37
39
from .agent import Agent
38
40
from numbers import Real
41
+ import numpy .typing as npt
39
42
40
43
Coordinate = Tuple [int , int ]
41
- GridContent = Union [Optional [Agent ], Set [Agent ]]
42
44
# used in ContinuousSpace
43
- FloatCoordinate = Union [Tuple [float , float ], np . ndarray ]
45
+ FloatCoordinate = Union [Tuple [float , float ], npt . NDArray [ float ] ]
44
46
NetworkCoordinate = int
45
47
46
48
Position = Union [Coordinate , FloatCoordinate , NetworkCoordinate ]
47
49
50
+ GridContent = Optional [Agent ]
51
+ MultiGridContent = List [Agent ]
52
+
53
+ F = TypeVar ("F" , bound = Callable [..., Any ])
54
+
55
+
48
56
def clamp (x : float , lowest : float , highest : float ) -> float :
49
57
# This should be faster than np.clip for a scalar x.
50
58
# TODO: measure how much faster this function is.
51
59
return max (lowest , min (x , highest ))
52
60
53
61
54
- def accept_tuple_argument (wrapped_function ) :
62
+ def accept_tuple_argument (wrapped_function : F ) -> F :
55
63
"""Decorator to allow grid methods that take a list of (x, y) coord tuples
56
64
to also handle a single position, by automatically wrapping tuple in
57
65
single-item list rather than forcing user to do it."""
58
66
59
- def wrapper (* args : Any ):
67
+ def wrapper (* args : Any ) -> Any :
60
68
if isinstance (args [1 ], tuple ) and len (args [1 ]) == 2 :
61
69
return wrapped_function (args [0 ], [args [1 ]])
62
70
else :
63
71
return wrapped_function (* args )
64
72
65
- return wrapper
73
+ return cast ( F , wrapper )
66
74
67
75
68
76
def is_integer (x : Real ) -> bool :
@@ -140,8 +148,7 @@ def __getitem__(
140
148
if isinstance (index , int ):
141
149
# grid[x]
142
150
return self .grid [index ]
143
-
144
- if isinstance (index [0 ], tuple ):
151
+ elif isinstance (index [0 ], tuple ):
145
152
# grid[(x1, y1), (x2, y2)]
146
153
index = cast (Sequence [Coordinate ], index )
147
154
@@ -564,7 +571,7 @@ class MultiGrid(Grid):
564
571
"""
565
572
566
573
@staticmethod
567
- def default_val () -> Set [ Agent ] :
574
+ def default_val () -> MultiGridContent :
568
575
"""Default value for new cell elements."""
569
576
return []
570
577
@@ -585,7 +592,7 @@ def _remove_agent(self, pos: Coordinate, agent: Agent) -> None:
585
592
@accept_tuple_argument
586
593
def iter_cell_list_contents (
587
594
self , cell_list : Iterable [Coordinate ]
588
- ) -> Iterator [GridContent ]:
595
+ ) -> Iterator [MultiGridContent ]:
589
596
"""Returns an iterator of the contents of the
590
597
cells identified in cell_list.
591
598
@@ -786,7 +793,7 @@ def __init__(
786
793
self .size = np .array ((self .width , self .height ))
787
794
self .torus = torus
788
795
789
- self ._agent_points = None
796
+ self ._agent_points : Optional [ npt . NDArray [ FloatCoordinate ]] = None
790
797
self ._index_to_agent : Dict [int , Agent ] = {}
791
798
self ._agent_to_index : Dict [Agent , int ] = {}
792
799
0 commit comments