Skip to content

Commit adad7a2

Browse files
authored
Add create_agent factory method to Agent (#2351)
1 parent b0a36ec commit adad7a2

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

mesa/agent.py

+48
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,54 @@ def step(self) -> None:
8585
def advance(self) -> None: # noqa: D102
8686
pass
8787

88+
@classmethod
89+
def create_agents(cls, model: Model, n: int, *args, **kwargs) -> AgentSet[Agent]:
90+
"""Create N agents.
91+
92+
Args:
93+
model: the model to which the agents belong
94+
args: arguments to pass onto agent instances
95+
each arg is either a single object or a sequence of length n
96+
n: the number of agents to create
97+
kwargs: keyword arguments to pass onto agent instances
98+
each keyword arg is either a single object or a sequence of length n
99+
100+
Returns:
101+
AgentSet containing the agents created.
102+
103+
"""
104+
105+
class ListLike:
106+
"""Helper class to make default arguments act as if they are in a list of length N."""
107+
108+
def __init__(self, value):
109+
self.value = value
110+
111+
def __getitem__(self, i):
112+
return self.value
113+
114+
listlike_args = []
115+
for arg in args:
116+
if isinstance(arg, (list | np.ndarray | tuple)) and len(arg) == n:
117+
listlike_args.append(arg)
118+
else:
119+
listlike_args.append(ListLike(arg))
120+
121+
listlike_kwargs = {}
122+
for k, v in kwargs.items():
123+
if isinstance(v, (list | np.ndarray | tuple)) and len(v) == n:
124+
listlike_kwargs[k] = v
125+
else:
126+
listlike_kwargs[k] = ListLike(v)
127+
128+
agents = []
129+
for i in range(n):
130+
instance_args = [arg[i] for arg in listlike_args]
131+
instance_kwargs = {k: v[i] for k, v in listlike_kwargs.items()}
132+
agent = cls(model, *instance_args, **instance_kwargs)
133+
agents.append(agent)
134+
return AgentSet(agents, random=model.random)
135+
88136
@property
89137
def random(self) -> Random:
90138
"""Return a seeded stdlib rng."""

tests/test_agent.py

+24
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,30 @@ def test_agent_rng():
170170
assert agent.rng is model.rng
171171

172172

173+
def test_agent_create():
174+
"""Test create agent factory method."""
175+
176+
class TestAgent(Agent):
177+
def __init__(self, model, attr, def_attr, a=0, b=0):
178+
super().__init__(model)
179+
self.some_attribute = attr
180+
self.some_default_value = def_attr
181+
self.a = a
182+
self.b = b
183+
184+
model = Model(seed=42)
185+
n = 10
186+
some_attribute = model.rng.random(n)
187+
a = tuple([model.random.random() for _ in range(n)])
188+
TestAgent.create_agents(model, n, some_attribute, 5, a=a, b=7)
189+
190+
for agent, value, a_i in zip(model.agents, some_attribute, a):
191+
assert agent.some_attribute == value
192+
assert agent.some_default_value == 5
193+
assert agent.a == a_i
194+
assert agent.b == 7
195+
196+
173197
def test_agent_add_remove_discard():
174198
"""Test adding, removing and discarding agents from AgentSet."""
175199
model = Model()

0 commit comments

Comments
 (0)