@@ -23,6 +23,12 @@ class ResettablePOMDP(gym.Env, abc.ABC, Generic[State, Observation, Action]):
23
23
meet these criteria.
24
24
"""
25
25
26
+ _state_space : gym .Space
27
+ _observation_space : gym .Space
28
+ _action_space : gym .Space
29
+ _cur_state : Optional [State ]
30
+ _n_actions_taken : Optional [int ]
31
+
26
32
def __init__ (
27
33
self ,
28
34
* ,
@@ -41,8 +47,8 @@ def __init__(
41
47
self ._observation_space = observation_space
42
48
self ._action_space = action_space
43
49
44
- self .cur_state : Optional [ State ] = None
45
- self ._n_actions_taken : Optional [ int ] = None
50
+ self ._cur_state = None
51
+ self ._n_actions_taken = None
46
52
self .seed ()
47
53
48
54
@abc .abstractmethod
@@ -86,6 +92,19 @@ def n_actions_taken(self) -> int:
86
92
assert self ._n_actions_taken is not None
87
93
return self ._n_actions_taken
88
94
95
+ @property
96
+ def state (self ) -> State :
97
+ """Current state."""
98
+ assert self ._cur_state is not None
99
+ return self ._cur_state
100
+
101
+ @state .setter
102
+ def state (self , state : State ):
103
+ """Set current state."""
104
+ if state not in self .state_space :
105
+ raise ValueError (f"{ state } not in { self .state_space } " )
106
+ self ._cur_state = state
107
+
89
108
def seed (self , seed = None ) -> Sequence [int ]:
90
109
"""Set random seed."""
91
110
if seed is None :
@@ -97,32 +116,57 @@ def seed(self, seed=None) -> Sequence[int]:
97
116
98
117
def reset (self ) -> Observation :
99
118
"""Reset episode and return initial observation."""
100
- self .cur_state = self .initial_state ()
101
- assert self .cur_state in self .state_space , f"unexpected state { self .cur_state } "
119
+ self .state = self .initial_state ()
102
120
self ._n_actions_taken = 0
103
- return self .obs_from_state (self .cur_state )
121
+ return self .obs_from_state (self .state )
104
122
105
123
def step (self , action : Action ) -> Tuple [Observation , float , bool , dict ]:
106
124
"""Transition state using given action."""
107
- if self .cur_state is None or self ._n_actions_taken is None :
125
+ if self ._cur_state is None or self ._n_actions_taken is None :
108
126
raise ValueError ("Need to call reset() before first step()" )
109
127
if action not in self .action_space :
110
128
raise ValueError (f"{ action } not in { self .action_space } " )
111
129
112
- old_state = self .cur_state
113
- self .cur_state = self .transition (self .cur_state , action )
114
- assert self .cur_state in self .state_space , f"unexpected state { self .cur_state } "
115
- obs = self .obs_from_state (self .cur_state )
116
- assert obs in self .observation_space , f"{ obs } not in { self .observation_space } "
117
- rew = self .reward (old_state , action , self .cur_state )
118
- done = self .terminal (self .cur_state , self ._n_actions_taken )
130
+ old_state = self .state
131
+ self .state = self .transition (self .state , action )
132
+ obs = self .obs_from_state (self .state )
133
+ assert obs in self .observation_space
134
+ reward = self .reward (old_state , action , self .state )
119
135
self ._n_actions_taken += 1
136
+ done = self .terminal (self .state , self .n_actions_taken )
137
+
138
+ infos = {"old_state" : old_state , "new_state" : self ._cur_state }
139
+ return obs , reward , done , infos
140
+
141
+
142
+ class ExposePOMDPStateWrapper (gym .Wrapper , Generic [State , Observation , Action ]):
143
+ """A wrapper that exposes the current state of the POMDP as the observation."""
144
+
145
+ def __init__ (self , env : ResettablePOMDP [State , Observation , Action ]) -> None :
146
+ """Build wrapper.
147
+
148
+ Args:
149
+ env: POMDP to wrap.
150
+ """
151
+ super ().__init__ (env )
152
+ self ._observation_space = env .state_space
153
+
154
+ def reset (self ) -> State :
155
+ """Reset environment and return initial state."""
156
+ self .env .reset ()
157
+ return self .env .state
120
158
121
- infos = {"old_state" : old_state , "new_state" : self .cur_state }
122
- return obs , rew , done , infos
159
+ def step (self , action ) -> Tuple [State , float , bool , dict ]:
160
+ """Transition state using given action."""
161
+ obs , reward , done , info = self .env .step (action )
162
+ return self .env .state , reward , done , info
123
163
124
164
125
- class ResettableMDP (ResettablePOMDP [State , State , Action ], Generic [State , Action ]):
165
+ class ResettableMDP (
166
+ ResettablePOMDP [State , State , Action ],
167
+ abc .ABC ,
168
+ Generic [State , Action ],
169
+ ):
126
170
"""ABC for MDPs that are resettable."""
127
171
128
172
def __init__ (
@@ -148,8 +192,20 @@ def obs_from_state(self, state: State) -> State:
148
192
return state
149
193
150
194
151
- class TabularModelMDP (ResettableMDP [int , int ]):
152
- """Base class for tabular environments with known dynamics."""
195
+ # TODO(juan) this does not implement the .render() method,
196
+ # so in theory it should not be instantiated directly.
197
+ # Not sure why this is not raising an error?
198
+ class BaseTabularModelPOMDP (ResettablePOMDP [int , Observation , int ]):
199
+ """Base class for tabular environments with known dynamics.
200
+
201
+ This is the general class that also allows subclassing for creating
202
+ MDP (where observation == state) or POMDP (where observation != state).
203
+ """
204
+
205
+ transition_matrix : np .ndarray
206
+ reward_matrix : np .ndarray
207
+
208
+ state_space : spaces .Discrete
153
209
154
210
def __init__ (
155
211
self ,
@@ -179,14 +235,28 @@ def __init__(
179
235
ValueError: `transition_matrix`, `reward_matrix` or
180
236
`initial_state_dist` have shapes different to specified above.
181
237
"""
182
- n_states , n_actions , n_next_states = transition_matrix .shape
183
- if n_states != n_next_states :
238
+ # The following matrices should conform to the shapes below:
239
+
240
+ # transition matrix: n_states x n_actions x n_states
241
+ n_states = transition_matrix .shape [0 ]
242
+ if n_states != transition_matrix .shape [2 ]:
184
243
raise ValueError (
185
244
"Malformed transition_matrix:\n "
186
245
f"transition_matrix.shape: { transition_matrix .shape } \n "
187
- f"{ n_states } != { n_next_states } " ,
246
+ f"{ n_states } != { transition_matrix .shape [2 ]} " ,
247
+ )
248
+
249
+ # reward matrix: n_states x n_actions x n_states
250
+ # OR n_states x n_actions
251
+ # OR n_states
252
+ if reward_matrix .shape != transition_matrix .shape [: len (reward_matrix .shape )]:
253
+ raise ValueError (
254
+ "transition_matrix and reward_matrix are not compatible:\n "
255
+ f"transition_matrix.shape: { transition_matrix .shape } \n "
256
+ f"reward_matrix.shape: { reward_matrix .shape } " ,
188
257
)
189
258
259
+ # initial state dist: n_states
190
260
if initial_state_dist is None :
191
261
initial_state_dist = util .one_hot_encoding (0 , n_states )
192
262
if initial_state_dist .ndim != 1 :
@@ -197,28 +267,32 @@ def __init__(
197
267
if initial_state_dist .shape [0 ] != n_states :
198
268
raise ValueError (
199
269
"transition_matrix and initial_state_dist are not compatible:\n "
200
- f"n_states = { n_states } \n "
270
+ f"number of states = { n_states } \n "
201
271
f"len(initial_state_dist) = { len (initial_state_dist )} " ,
202
272
)
203
273
204
- if reward_matrix .shape != transition_matrix .shape [: len (reward_matrix .shape )]:
205
- raise ValueError (
206
- "transition_matrix and reward_matrix are not compatible:\n "
207
- f"transition_matrix.shape: { transition_matrix .shape } \n "
208
- f"reward_matrix.shape: { reward_matrix .shape } " ,
209
- )
210
-
211
274
self .transition_matrix = transition_matrix
212
275
self .reward_matrix = reward_matrix
213
276
self ._feature_matrix = None
214
277
self .horizon = horizon
215
278
self .initial_state_dist = initial_state_dist
216
279
217
280
super ().__init__ (
218
- state_space = spaces .Discrete (n_states ),
219
- action_space = spaces .Discrete (n_actions ),
281
+ state_space = self ._construct_state_space (),
282
+ action_space = self ._construct_action_space (),
283
+ observation_space = self ._construct_observation_space (),
220
284
)
221
285
286
+ def _construct_state_space (self ) -> gym .Space :
287
+ return spaces .Discrete (self .state_dim )
288
+
289
+ def _construct_action_space (self ) -> gym .Space :
290
+ return spaces .Discrete (self .action_dim )
291
+
292
+ @abc .abstractmethod
293
+ def _construct_observation_space (self ) -> gym .Space :
294
+ pass # pragma: no cover
295
+
222
296
def initial_state (self ) -> int :
223
297
"""Samples from the initial state distribution."""
224
298
return util .sample_distribution (
@@ -250,3 +324,128 @@ def feature_matrix(self):
250
324
n_states = self .state_space .n
251
325
self ._feature_matrix = np .eye (n_states )
252
326
return self ._feature_matrix
327
+
328
+ @property
329
+ def state_dim (self ):
330
+ """Number of states in this MDP (int)."""
331
+ return self .transition_matrix .shape [0 ]
332
+
333
+ @property
334
+ def action_dim (self ) -> int :
335
+ """Number of action vectors (int)."""
336
+ return self .transition_matrix .shape [1 ]
337
+
338
+
339
+ class TabularModelPOMDP (BaseTabularModelPOMDP [np .ndarray ]):
340
+ """Tabular model POMDP.
341
+
342
+ This class is specifically for environments where observation != state,
343
+ from both a typing perspective but also by defining the method that
344
+ draws observations from the state.
345
+
346
+ The tabular model is deterministic in drawing observations from the state,
347
+ in that given a certain state, the observation is always the same;
348
+ a vector with self.obs_dim entries.
349
+ """
350
+
351
+ observation_matrix : np .ndarray
352
+
353
+ def __init__ (
354
+ self ,
355
+ * ,
356
+ transition_matrix : np .ndarray ,
357
+ observation_matrix : np .ndarray ,
358
+ reward_matrix : np .ndarray ,
359
+ horizon : float = np .inf ,
360
+ initial_state_dist : Optional [np .ndarray ] = None ,
361
+ ):
362
+ """Initializes a tabular model POMDP."""
363
+ self .observation_matrix = observation_matrix
364
+ super ().__init__ (
365
+ transition_matrix = transition_matrix ,
366
+ reward_matrix = reward_matrix ,
367
+ horizon = horizon ,
368
+ initial_state_dist = initial_state_dist ,
369
+ )
370
+
371
+ # observation matrix: n_states x n_observations
372
+ if observation_matrix .shape [0 ] != self .state_dim :
373
+ raise ValueError (
374
+ "transition_matrix and observation_matrix are not compatible:\n "
375
+ f"transition_matrix.shape[0]: { self .state_dim } \n "
376
+ f"observation_matrix.shape[0]: { observation_matrix .shape [0 ]} " ,
377
+ )
378
+
379
+ def _construct_observation_space (self ) -> gym .Space :
380
+ min_val : float
381
+ max_val : float
382
+ try :
383
+ dtype_iinfo = np .iinfo (self .obs_dtype )
384
+ min_val , max_val = dtype_iinfo .min , dtype_iinfo .max
385
+ except ValueError :
386
+ min_val = - np .inf
387
+ max_val = np .inf
388
+ return spaces .Box (
389
+ low = min_val ,
390
+ high = max_val ,
391
+ shape = (self .obs_dim ,),
392
+ dtype = self .obs_dtype ,
393
+ )
394
+
395
+ def obs_from_state (self , state : int ) -> np .ndarray :
396
+ """Computes observation from state."""
397
+ # Copy so it can't be mutated in-place (updates will be reflected in
398
+ # self.observation_matrix!)
399
+ obs = self .observation_matrix [state ].copy ()
400
+ assert obs .ndim == 1 , obs .shape
401
+ return obs
402
+
403
+ @property
404
+ def obs_dim (self ) -> int :
405
+ """Size of observation vectors for this MDP."""
406
+ return self .observation_matrix .shape [1 ]
407
+
408
+ @property
409
+ def obs_dtype (self ) -> int :
410
+ """Data type of observation vectors (e.g. np.float32)."""
411
+ return self .observation_matrix .dtype
412
+
413
+
414
+ class TabularModelMDP (BaseTabularModelPOMDP [int ]):
415
+ """Tabular model MDP.
416
+
417
+ A tabular model MDP is a tabular MDP where the transition and reward
418
+ matrices are constant.
419
+ """
420
+
421
+ def __init__ (
422
+ self ,
423
+ * ,
424
+ transition_matrix : np .ndarray ,
425
+ reward_matrix : np .ndarray ,
426
+ horizon : float = np .inf ,
427
+ initial_state_dist : Optional [np .ndarray ] = None ,
428
+ ):
429
+ """Initializes a tabular model MDP.
430
+
431
+ Args:
432
+ transition_matrix: Matrix of shape `(n_states, n_actions, n_states)`
433
+ containing transition probabilities.
434
+ reward_matrix: Matrix of shape `(n_states, n_actions, n_states)`
435
+ containing reward values.
436
+ initial_state_dist: Distribution over initial states. Shape `(n_states,)`.
437
+ horizon: Maximum number of steps to take in an episode.
438
+ """
439
+ super ().__init__ (
440
+ transition_matrix = transition_matrix ,
441
+ reward_matrix = reward_matrix ,
442
+ horizon = horizon ,
443
+ initial_state_dist = initial_state_dist ,
444
+ )
445
+
446
+ def obs_from_state (self , state : int ) -> int :
447
+ """Identity since observation == state in an MDP."""
448
+ return state
449
+
450
+ def _construct_observation_space (self ) -> gym .Space :
451
+ return self ._construct_state_space ()
0 commit comments