1
1
import gym
2
- from typing import Callable , Dict , List , Tuple , Type , Optional , Union
2
+ import logging
3
+ from typing import Callable , Dict , List , Tuple , Type , Optional , Union , Set
3
4
4
5
from ray .rllib .env .base_env import BaseEnv
5
6
from ray .rllib .env .env_context import EnvContext
6
- from ray .rllib .utils .annotations import ExperimentalAPI , override , PublicAPI
7
+ from ray .rllib .utils .annotations import ExperimentalAPI , override , PublicAPI , \
8
+ DeveloperAPI
7
9
from ray .rllib .utils .typing import AgentID , EnvID , EnvType , MultiAgentDict , \
8
10
MultiEnvDict
9
11
10
12
# If the obs space is Dict type, look for the global state under this key.
11
13
ENV_STATE = "state"
12
14
15
+ logger = logging .getLogger (__name__ )
16
+
13
17
14
18
@PublicAPI
15
19
class MultiAgentEnv (gym .Env ):
@@ -20,6 +24,15 @@ class MultiAgentEnv(gym.Env):
20
24
referred to as "agents" or "RL agents".
21
25
"""
22
26
27
+ def __init__ (self ):
28
+ self .observation_space = None
29
+ self .action_space = None
30
+ self ._agent_ids = {}
31
+
32
+ # do the action and observation spaces map from agent ids to spaces
33
+ # for the individual agents?
34
+ self ._spaces_in_preferred_format = None
35
+
23
36
@PublicAPI
24
37
def reset (self ) -> MultiAgentDict :
25
38
"""Resets the env and returns observations from ready agents.
@@ -81,20 +94,127 @@ def step(
81
94
"""
82
95
raise NotImplementedError
83
96
97
+ @ExperimentalAPI
98
+ def observation_space_contains (self , x : MultiAgentDict ) -> bool :
99
+ """Checks if the observation space contains the given key.
100
+
101
+ Args:
102
+ x: Observations to check.
103
+
104
+ Returns:
105
+ True if the observation space contains the given all observations
106
+ in x.
107
+ """
108
+ if not hasattr (self , "_spaces_in_preferred_format" ) or \
109
+ self ._spaces_in_preferred_format is None :
110
+ self ._spaces_in_preferred_format = \
111
+ self ._check_if_space_maps_agent_id_to_sub_space ()
112
+ if self ._spaces_in_preferred_format :
113
+ return self .observation_space .contains (x )
114
+
115
+ logger .warning ("observation_space_contains() has not been implemented" )
116
+ return True
117
+
118
+ @ExperimentalAPI
119
+ def action_space_contains (self , x : MultiAgentDict ) -> bool :
120
+ """Checks if the action space contains the given action.
121
+
122
+ Args:
123
+ x: Actions to check.
124
+
125
+ Returns:
126
+ True if the action space contains all actions in x.
127
+ """
128
+ if not hasattr (self , "_spaces_in_preferred_format" ) or \
129
+ self ._spaces_in_preferred_format is None :
130
+ self ._spaces_in_preferred_format = \
131
+ self ._check_if_space_maps_agent_id_to_sub_space ()
132
+ if self ._spaces_in_preferred_format :
133
+ return self .action_space .contains (x )
134
+
135
+ logger .warning ("action_space_contains() has not been implemented" )
136
+ return True
137
+
138
+ @ExperimentalAPI
139
+ def action_space_sample (self , agent_ids : list = None ) -> MultiAgentDict :
140
+ """Returns a random action for each environment, and potentially each
141
+ agent in that environment.
142
+
143
+ Args:
144
+ agent_ids: List of agent ids to sample actions for. If None or
145
+ empty list, sample actions for all agents in the
146
+ environment.
147
+
148
+ Returns:
149
+ A random action for each environment.
150
+ """
151
+ if not hasattr (self , "_spaces_in_preferred_format" ) or \
152
+ self ._spaces_in_preferred_format is None :
153
+ self ._spaces_in_preferred_format = \
154
+ self ._check_if_space_maps_agent_id_to_sub_space ()
155
+ if self ._spaces_in_preferred_format :
156
+ if agent_ids is None :
157
+ agent_ids = self .get_agent_ids ()
158
+ samples = self .action_space .sample ()
159
+ return {agent_id : samples [agent_id ] for agent_id in agent_ids }
160
+ logger .warning ("action_space_sample() has not been implemented" )
161
+ del agent_ids
162
+ return {}
163
+
164
+ @ExperimentalAPI
165
+ def observation_space_sample (self , agent_ids : list = None ) -> MultiEnvDict :
166
+ """Returns a random observation from the observation space for each
167
+ agent if agent_ids is None, otherwise returns a random observation for
168
+ the agents in agent_ids.
169
+
170
+ Args:
171
+ agent_ids: List of agent ids to sample actions for. If None or
172
+ empty list, sample actions for all agents in the
173
+ environment.
174
+
175
+ Returns:
176
+ A random action for each environment.
177
+ """
178
+
179
+ if not hasattr (self , "_spaces_in_preferred_format" ) or \
180
+ self ._spaces_in_preferred_format is None :
181
+ self ._spaces_in_preferred_format = \
182
+ self ._check_if_space_maps_agent_id_to_sub_space ()
183
+ if self ._spaces_in_preferred_format :
184
+ if agent_ids is None :
185
+ agent_ids = self .get_agent_ids ()
186
+ samples = self .observation_space .sample ()
187
+ samples = {agent_id : samples [agent_id ] for agent_id in agent_ids }
188
+ return samples
189
+ logger .warning ("observation_space_sample() has not been implemented" )
190
+ del agent_ids
191
+ return {}
192
+
193
+ @PublicAPI
194
+ def get_agent_ids (self ) -> Set [AgentID ]:
195
+ """Returns a set of agent ids in the environment.
196
+
197
+ Returns:
198
+ set of agent ids.
199
+ """
200
+ if not isinstance (self ._agent_ids , set ):
201
+ self ._agent_ids = set (self ._agent_ids )
202
+ return self ._agent_ids
203
+
84
204
@PublicAPI
85
205
def render (self , mode = None ) -> None :
86
206
"""Tries to render the environment."""
87
207
88
208
# By default, do nothing.
89
209
pass
90
210
91
- # yapf: disable
92
- # __grouping_doc_begin__
211
+ # yapf: disable
212
+ # __grouping_doc_begin__
93
213
@ExperimentalAPI
94
214
def with_agent_groups (
95
- self ,
96
- groups : Dict [str , List [AgentID ]],
97
- obs_space : gym .Space = None ,
215
+ self ,
216
+ groups : Dict [str , List [AgentID ]],
217
+ obs_space : gym .Space = None ,
98
218
act_space : gym .Space = None ) -> "MultiAgentEnv" :
99
219
"""Convenience method for grouping together agents in this env.
100
220
@@ -132,8 +252,9 @@ def with_agent_groups(
132
252
from ray .rllib .env .wrappers .group_agents_wrapper import \
133
253
GroupAgentsWrapper
134
254
return GroupAgentsWrapper (self , groups , obs_space , act_space )
135
- # __grouping_doc_end__
136
- # yapf: enable
255
+
256
+ # __grouping_doc_end__
257
+ # yapf: enable
137
258
138
259
@PublicAPI
139
260
def to_base_env (
@@ -182,6 +303,20 @@ def to_base_env(
182
303
183
304
return env
184
305
306
+ @DeveloperAPI
307
+ def _check_if_space_maps_agent_id_to_sub_space (self ) -> bool :
308
+ # do the action and observation spaces map from agent ids to spaces
309
+ # for the individual agents?
310
+ obs_space_check = (
311
+ hasattr (self , "observation_space" )
312
+ and isinstance (self .observation_space , gym .spaces .Dict )
313
+ and set (self .observation_space .keys ()) == self .get_agent_ids ())
314
+ action_space_check = (
315
+ hasattr (self , "action_space" )
316
+ and isinstance (self .action_space , gym .spaces .Dict )
317
+ and set (self .action_space .keys ()) == self .get_agent_ids ())
318
+ return obs_space_check and action_space_check
319
+
185
320
186
321
def make_multi_agent (
187
322
env_name_or_creator : Union [str , Callable [[EnvContext ], EnvType ]],
@@ -242,6 +377,40 @@ def __init__(self, config=None):
242
377
self .dones = set ()
243
378
self .observation_space = self .agents [0 ].observation_space
244
379
self .action_space = self .agents [0 ].action_space
380
+ self ._agent_ids = set (range (num ))
381
+
382
+ @override (MultiAgentEnv )
383
+ def observation_space_sample (self ,
384
+ agent_ids : list = None ) -> MultiAgentDict :
385
+ if agent_ids is None :
386
+ agent_ids = list (range (len (self .agents )))
387
+ obs = {
388
+ agent_id : self .observation_space .sample ()
389
+ for agent_id in agent_ids
390
+ }
391
+
392
+ return obs
393
+
394
+ @override (MultiAgentEnv )
395
+ def action_space_sample (self ,
396
+ agent_ids : list = None ) -> MultiAgentDict :
397
+ if agent_ids is None :
398
+ agent_ids = list (range (len (self .agents )))
399
+ actions = {
400
+ agent_id : self .action_space .sample ()
401
+ for agent_id in agent_ids
402
+ }
403
+
404
+ return actions
405
+
406
+ @override (MultiAgentEnv )
407
+ def action_space_contains (self , x : MultiAgentDict ) -> bool :
408
+ return all (self .action_space .contains (val ) for val in x .values ())
409
+
410
+ @override (MultiAgentEnv )
411
+ def observation_space_contains (self , x : MultiAgentDict ) -> bool :
412
+ return all (
413
+ self .observation_space .contains (val ) for val in x .values ())
245
414
246
415
@override (MultiAgentEnv )
247
416
def reset (self ):
@@ -277,7 +446,7 @@ def __init__(self, make_env: Callable[[int], EnvType],
277
446
278
447
Args:
279
448
make_env (Callable[[int], EnvType]): Factory that produces a new
280
- MultiAgentEnv intance . Must be defined, if the number of
449
+ MultiAgentEnv instance . Must be defined, if the number of
281
450
existing envs is less than num_envs.
282
451
existing_envs (List[MultiAgentEnv]): List of already existing
283
452
multi-agent envs.
@@ -355,18 +524,31 @@ def try_render(self, env_id: Optional[EnvID] = None) -> None:
355
524
@override (BaseEnv )
356
525
@PublicAPI
357
526
def observation_space (self ) -> gym .spaces .Dict :
358
- space = {
359
- _id : env .observation_space
360
- for _id , env in enumerate (self .envs )
361
- }
362
- return gym .spaces .Dict (space )
527
+ self .envs [0 ].observation_space
363
528
364
529
@property
365
530
@override (BaseEnv )
366
531
@PublicAPI
367
532
def action_space (self ) -> gym .Space :
368
- space = {_id : env .action_space for _id , env in enumerate (self .envs )}
369
- return gym .spaces .Dict (space )
533
+ return self .envs [0 ].action_space
534
+
535
+ @override (BaseEnv )
536
+ def observation_space_contains (self , x : MultiEnvDict ) -> bool :
537
+ return all (
538
+ self .envs [0 ].observation_space_contains (val ) for val in x .values ())
539
+
540
+ @override (BaseEnv )
541
+ def action_space_contains (self , x : MultiEnvDict ) -> bool :
542
+ return all (
543
+ self .envs [0 ].action_space_contains (val ) for val in x .values ())
544
+
545
+ @override (BaseEnv )
546
+ def observation_space_sample (self , agent_ids : list = None ) -> MultiEnvDict :
547
+ return self .envs [0 ].observation_space_sample (agent_ids )
548
+
549
+ @override (BaseEnv )
550
+ def action_space_sample (self , agent_ids : list = None ) -> MultiEnvDict :
551
+ return self .envs [0 ].action_space_sample (agent_ids )
370
552
371
553
372
554
class _MultiAgentEnvState :
0 commit comments