Skip to content
This repository was archived by the owner on Jun 4, 2024. It is now read-only.

Commit 39f8072

Browse files
authored
[RLlib] [MultiAgentEnv Refactor #2] Change space types for BaseEnvs and MultiAgentEnvs (#21063)
1 parent 8b4cb45 commit 39f8072

5 files changed

+277
-60
lines changed

rllib/env/base_env.py

+19-22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
3-
Union
3+
Union, Set
44

55
import gym
66
import ray
@@ -198,14 +198,13 @@ def get_sub_environments(
198198
return []
199199

200200
@PublicAPI
201-
def get_agent_ids(self) -> Dict[EnvID, List[AgentID]]:
202-
"""Return the agent ids for each sub-environment.
201+
def get_agent_ids(self) -> Set[AgentID]:
202+
"""Return the agent ids for the sub_environment.
203203
204204
Returns:
205-
A dict mapping from env_id to a list of agent_ids.
205+
All agent ids for each the environment.
206206
"""
207-
logger.warning("get_agent_ids() has not been implemented")
208-
return {}
207+
return {_DUMMY_AGENT_ID}
209208

210209
@PublicAPI
211210
def try_render(self, env_id: Optional[EnvID] = None) -> None:
@@ -234,8 +233,8 @@ def get_unwrapped(self) -> List[EnvType]:
234233

235234
@PublicAPI
236235
@property
237-
def observation_space(self) -> gym.spaces.Dict:
238-
"""Returns the observation space for each environment.
236+
def observation_space(self) -> gym.Space:
237+
"""Returns the observation space for each agent.
239238
240239
Note: samples from the observation space need to be preprocessed into a
241240
`MultiEnvDict` before being used by a policy.
@@ -248,7 +247,7 @@ def observation_space(self) -> gym.spaces.Dict:
248247
@PublicAPI
249248
@property
250249
def action_space(self) -> gym.Space:
251-
"""Returns the action space for each environment.
250+
"""Returns the action space for each agent.
252251
253252
Note: samples from the action space need to be preprocessed into a
254253
`MultiEnvDict` before being passed to `send_actions`.
@@ -270,6 +269,7 @@ def action_space_sample(self, agent_id: list = None) -> MultiEnvDict:
270269
Returns:
271270
A random action for each environment.
272271
"""
272+
logger.warning("action_space_sample() has not been implemented")
273273
del agent_id
274274
return {}
275275

@@ -286,6 +286,7 @@ def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict:
286286
A random action for each environment.
287287
"""
288288
logger.warning("observation_space_sample() has not been implemented")
289+
del agent_id
289290
return {}
290291

291292
@PublicAPI
@@ -326,8 +327,7 @@ def action_space_contains(self, x: MultiEnvDict) -> bool:
326327
"""
327328
return self._space_contains(self.action_space, x)
328329

329-
@staticmethod
330-
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
330+
def _space_contains(self, space: gym.Space, x: MultiEnvDict) -> bool:
331331
"""Check if the given space contains the observations of x.
332332
333333
Args:
@@ -337,17 +337,14 @@ def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
337337
Returns:
338338
True if the observations of x are contained in space.
339339
"""
340-
# this removes the agent_id key and inner dicts
341-
# in MultiEnvDicts
342-
flattened_obs = {
343-
env_id: list(obs.values())
344-
for env_id, obs in x.items()
345-
}
346-
ret = True
347-
for env_id in flattened_obs:
348-
for obs in flattened_obs[env_id]:
349-
ret = ret and space[env_id].contains(obs)
350-
return ret
340+
agents = set(self.get_agent_ids())
341+
for multi_agent_dict in x.values():
342+
for agent_id, obs in multi_agent_dict:
343+
if (agent_id not in agents) or (
344+
not space[agent_id].contains(obs)):
345+
return False
346+
347+
return True
351348

352349

353350
# Fixed agent identifier when there is only the single agent in the env

rllib/env/multi_agent_env.py

+199-17
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import gym
2-
from typing import Callable, Dict, List, Tuple, Type, Optional, Union
2+
import logging
3+
from typing import Callable, Dict, List, Tuple, Type, Optional, Union, Set
34

45
from ray.rllib.env.base_env import BaseEnv
56
from ray.rllib.env.env_context import EnvContext
6-
from ray.rllib.utils.annotations import ExperimentalAPI, override, PublicAPI
7+
from ray.rllib.utils.annotations import ExperimentalAPI, override, PublicAPI, \
8+
DeveloperAPI
79
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \
810
MultiEnvDict
911

1012
# If the obs space is Dict type, look for the global state under this key.
1113
ENV_STATE = "state"
1214

15+
logger = logging.getLogger(__name__)
16+
1317

1418
@PublicAPI
1519
class MultiAgentEnv(gym.Env):
@@ -20,6 +24,15 @@ class MultiAgentEnv(gym.Env):
2024
referred to as "agents" or "RL agents".
2125
"""
2226

27+
def __init__(self):
28+
self.observation_space = None
29+
self.action_space = None
30+
self._agent_ids = {}
31+
32+
# do the action and observation spaces map from agent ids to spaces
33+
# for the individual agents?
34+
self._spaces_in_preferred_format = None
35+
2336
@PublicAPI
2437
def reset(self) -> MultiAgentDict:
2538
"""Resets the env and returns observations from ready agents.
@@ -81,20 +94,127 @@ def step(
8194
"""
8295
raise NotImplementedError
8396

97+
@ExperimentalAPI
98+
def observation_space_contains(self, x: MultiAgentDict) -> bool:
99+
"""Checks if the observation space contains the given key.
100+
101+
Args:
102+
x: Observations to check.
103+
104+
Returns:
105+
True if the observation space contains the given all observations
106+
in x.
107+
"""
108+
if not hasattr(self, "_spaces_in_preferred_format") or \
109+
self._spaces_in_preferred_format is None:
110+
self._spaces_in_preferred_format = \
111+
self._check_if_space_maps_agent_id_to_sub_space()
112+
if self._spaces_in_preferred_format:
113+
return self.observation_space.contains(x)
114+
115+
logger.warning("observation_space_contains() has not been implemented")
116+
return True
117+
118+
@ExperimentalAPI
119+
def action_space_contains(self, x: MultiAgentDict) -> bool:
120+
"""Checks if the action space contains the given action.
121+
122+
Args:
123+
x: Actions to check.
124+
125+
Returns:
126+
True if the action space contains all actions in x.
127+
"""
128+
if not hasattr(self, "_spaces_in_preferred_format") or \
129+
self._spaces_in_preferred_format is None:
130+
self._spaces_in_preferred_format = \
131+
self._check_if_space_maps_agent_id_to_sub_space()
132+
if self._spaces_in_preferred_format:
133+
return self.action_space.contains(x)
134+
135+
logger.warning("action_space_contains() has not been implemented")
136+
return True
137+
138+
@ExperimentalAPI
139+
def action_space_sample(self, agent_ids: list = None) -> MultiAgentDict:
140+
"""Returns a random action for each environment, and potentially each
141+
agent in that environment.
142+
143+
Args:
144+
agent_ids: List of agent ids to sample actions for. If None or
145+
empty list, sample actions for all agents in the
146+
environment.
147+
148+
Returns:
149+
A random action for each environment.
150+
"""
151+
if not hasattr(self, "_spaces_in_preferred_format") or \
152+
self._spaces_in_preferred_format is None:
153+
self._spaces_in_preferred_format = \
154+
self._check_if_space_maps_agent_id_to_sub_space()
155+
if self._spaces_in_preferred_format:
156+
if agent_ids is None:
157+
agent_ids = self.get_agent_ids()
158+
samples = self.action_space.sample()
159+
return {agent_id: samples[agent_id] for agent_id in agent_ids}
160+
logger.warning("action_space_sample() has not been implemented")
161+
del agent_ids
162+
return {}
163+
164+
@ExperimentalAPI
165+
def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
166+
"""Returns a random observation from the observation space for each
167+
agent if agent_ids is None, otherwise returns a random observation for
168+
the agents in agent_ids.
169+
170+
Args:
171+
agent_ids: List of agent ids to sample actions for. If None or
172+
empty list, sample actions for all agents in the
173+
environment.
174+
175+
Returns:
176+
A random action for each environment.
177+
"""
178+
179+
if not hasattr(self, "_spaces_in_preferred_format") or \
180+
self._spaces_in_preferred_format is None:
181+
self._spaces_in_preferred_format = \
182+
self._check_if_space_maps_agent_id_to_sub_space()
183+
if self._spaces_in_preferred_format:
184+
if agent_ids is None:
185+
agent_ids = self.get_agent_ids()
186+
samples = self.observation_space.sample()
187+
samples = {agent_id: samples[agent_id] for agent_id in agent_ids}
188+
return samples
189+
logger.warning("observation_space_sample() has not been implemented")
190+
del agent_ids
191+
return {}
192+
193+
@PublicAPI
194+
def get_agent_ids(self) -> Set[AgentID]:
195+
"""Returns a set of agent ids in the environment.
196+
197+
Returns:
198+
set of agent ids.
199+
"""
200+
if not isinstance(self._agent_ids, set):
201+
self._agent_ids = set(self._agent_ids)
202+
return self._agent_ids
203+
84204
@PublicAPI
85205
def render(self, mode=None) -> None:
86206
"""Tries to render the environment."""
87207

88208
# By default, do nothing.
89209
pass
90210

91-
# yapf: disable
92-
# __grouping_doc_begin__
211+
# yapf: disable
212+
# __grouping_doc_begin__
93213
@ExperimentalAPI
94214
def with_agent_groups(
95-
self,
96-
groups: Dict[str, List[AgentID]],
97-
obs_space: gym.Space = None,
215+
self,
216+
groups: Dict[str, List[AgentID]],
217+
obs_space: gym.Space = None,
98218
act_space: gym.Space = None) -> "MultiAgentEnv":
99219
"""Convenience method for grouping together agents in this env.
100220
@@ -132,8 +252,9 @@ def with_agent_groups(
132252
from ray.rllib.env.wrappers.group_agents_wrapper import \
133253
GroupAgentsWrapper
134254
return GroupAgentsWrapper(self, groups, obs_space, act_space)
135-
# __grouping_doc_end__
136-
# yapf: enable
255+
256+
# __grouping_doc_end__
257+
# yapf: enable
137258

138259
@PublicAPI
139260
def to_base_env(
@@ -182,6 +303,20 @@ def to_base_env(
182303

183304
return env
184305

306+
@DeveloperAPI
307+
def _check_if_space_maps_agent_id_to_sub_space(self) -> bool:
308+
# do the action and observation spaces map from agent ids to spaces
309+
# for the individual agents?
310+
obs_space_check = (
311+
hasattr(self, "observation_space")
312+
and isinstance(self.observation_space, gym.spaces.Dict)
313+
and set(self.observation_space.keys()) == self.get_agent_ids())
314+
action_space_check = (
315+
hasattr(self, "action_space")
316+
and isinstance(self.action_space, gym.spaces.Dict)
317+
and set(self.action_space.keys()) == self.get_agent_ids())
318+
return obs_space_check and action_space_check
319+
185320

186321
def make_multi_agent(
187322
env_name_or_creator: Union[str, Callable[[EnvContext], EnvType]],
@@ -242,6 +377,40 @@ def __init__(self, config=None):
242377
self.dones = set()
243378
self.observation_space = self.agents[0].observation_space
244379
self.action_space = self.agents[0].action_space
380+
self._agent_ids = set(range(num))
381+
382+
@override(MultiAgentEnv)
383+
def observation_space_sample(self,
384+
agent_ids: list = None) -> MultiAgentDict:
385+
if agent_ids is None:
386+
agent_ids = list(range(len(self.agents)))
387+
obs = {
388+
agent_id: self.observation_space.sample()
389+
for agent_id in agent_ids
390+
}
391+
392+
return obs
393+
394+
@override(MultiAgentEnv)
395+
def action_space_sample(self,
396+
agent_ids: list = None) -> MultiAgentDict:
397+
if agent_ids is None:
398+
agent_ids = list(range(len(self.agents)))
399+
actions = {
400+
agent_id: self.action_space.sample()
401+
for agent_id in agent_ids
402+
}
403+
404+
return actions
405+
406+
@override(MultiAgentEnv)
407+
def action_space_contains(self, x: MultiAgentDict) -> bool:
408+
return all(self.action_space.contains(val) for val in x.values())
409+
410+
@override(MultiAgentEnv)
411+
def observation_space_contains(self, x: MultiAgentDict) -> bool:
412+
return all(
413+
self.observation_space.contains(val) for val in x.values())
245414

246415
@override(MultiAgentEnv)
247416
def reset(self):
@@ -277,7 +446,7 @@ def __init__(self, make_env: Callable[[int], EnvType],
277446
278447
Args:
279448
make_env (Callable[[int], EnvType]): Factory that produces a new
280-
MultiAgentEnv intance. Must be defined, if the number of
449+
MultiAgentEnv instance. Must be defined, if the number of
281450
existing envs is less than num_envs.
282451
existing_envs (List[MultiAgentEnv]): List of already existing
283452
multi-agent envs.
@@ -355,18 +524,31 @@ def try_render(self, env_id: Optional[EnvID] = None) -> None:
355524
@override(BaseEnv)
356525
@PublicAPI
357526
def observation_space(self) -> gym.spaces.Dict:
358-
space = {
359-
_id: env.observation_space
360-
for _id, env in enumerate(self.envs)
361-
}
362-
return gym.spaces.Dict(space)
527+
self.envs[0].observation_space
363528

364529
@property
365530
@override(BaseEnv)
366531
@PublicAPI
367532
def action_space(self) -> gym.Space:
368-
space = {_id: env.action_space for _id, env in enumerate(self.envs)}
369-
return gym.spaces.Dict(space)
533+
return self.envs[0].action_space
534+
535+
@override(BaseEnv)
536+
def observation_space_contains(self, x: MultiEnvDict) -> bool:
537+
return all(
538+
self.envs[0].observation_space_contains(val) for val in x.values())
539+
540+
@override(BaseEnv)
541+
def action_space_contains(self, x: MultiEnvDict) -> bool:
542+
return all(
543+
self.envs[0].action_space_contains(val) for val in x.values())
544+
545+
@override(BaseEnv)
546+
def observation_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
547+
return self.envs[0].observation_space_sample(agent_ids)
548+
549+
@override(BaseEnv)
550+
def action_space_sample(self, agent_ids: list = None) -> MultiEnvDict:
551+
return self.envs[0].action_space_sample(agent_ids)
370552

371553

372554
class _MultiAgentEnvState:

0 commit comments

Comments
 (0)