Skip to content

Commit 3d2cd41

Browse files
authoredOct 4, 2022
Migrate imitation envs to seals (#58)
1 parent 7def17c commit 3d2cd41

10 files changed

+753
-56
lines changed
 

‎ci/code_checks.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ set -e # quit immediately on error
88

99
echo "Source format checking"
1010
flake8 ${SRC_FILES[@]}
11-
black --check ${SRC_FILES}
11+
black --check ${SRC_FILES[@]}
1212
codespell -I .codespell.skip --skip='*.pyc' ${SRC_FILES[@]}
1313

1414
if [ -x "`which circleci`" ]; then

‎mypy.ini

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[mypy]
2+
ignore_missing_imports = true

‎setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def get_readme() -> str:
107107
"flake8-docstrings",
108108
"flake8-isort",
109109
"isort",
110+
"matplotlib",
110111
"mypy",
111112
"pydocstyle",
112113
"pytest",
@@ -137,7 +138,7 @@ def get_readme() -> str:
137138
packages=find_packages("src"),
138139
package_dir={"": "src"},
139140
package_data={"seals": ["py.typed"]},
140-
install_requires=["gym"],
141+
install_requires=["gym", "numpy"],
141142
tests_require=TESTS_REQUIRE,
142143
extras_require={
143144
# recommended packages for development

‎src/seals/base_envs.py

+230-31
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ class ResettablePOMDP(gym.Env, abc.ABC, Generic[State, Observation, Action]):
2323
meet these criteria.
2424
"""
2525

26+
_state_space: gym.Space
27+
_observation_space: gym.Space
28+
_action_space: gym.Space
29+
_cur_state: Optional[State]
30+
_n_actions_taken: Optional[int]
31+
2632
def __init__(
2733
self,
2834
*,
@@ -41,8 +47,8 @@ def __init__(
4147
self._observation_space = observation_space
4248
self._action_space = action_space
4349

44-
self.cur_state: Optional[State] = None
45-
self._n_actions_taken: Optional[int] = None
50+
self._cur_state = None
51+
self._n_actions_taken = None
4652
self.seed()
4753

4854
@abc.abstractmethod
@@ -86,6 +92,19 @@ def n_actions_taken(self) -> int:
8692
assert self._n_actions_taken is not None
8793
return self._n_actions_taken
8894

95+
@property
96+
def state(self) -> State:
97+
"""Current state."""
98+
assert self._cur_state is not None
99+
return self._cur_state
100+
101+
@state.setter
102+
def state(self, state: State):
103+
"""Set current state."""
104+
if state not in self.state_space:
105+
raise ValueError(f"{state} not in {self.state_space}")
106+
self._cur_state = state
107+
89108
def seed(self, seed=None) -> Sequence[int]:
90109
"""Set random seed."""
91110
if seed is None:
@@ -97,32 +116,57 @@ def seed(self, seed=None) -> Sequence[int]:
97116

98117
def reset(self) -> Observation:
99118
"""Reset episode and return initial observation."""
100-
self.cur_state = self.initial_state()
101-
assert self.cur_state in self.state_space, f"unexpected state {self.cur_state}"
119+
self.state = self.initial_state()
102120
self._n_actions_taken = 0
103-
return self.obs_from_state(self.cur_state)
121+
return self.obs_from_state(self.state)
104122

105123
def step(self, action: Action) -> Tuple[Observation, float, bool, dict]:
106124
"""Transition state using given action."""
107-
if self.cur_state is None or self._n_actions_taken is None:
125+
if self._cur_state is None or self._n_actions_taken is None:
108126
raise ValueError("Need to call reset() before first step()")
109127
if action not in self.action_space:
110128
raise ValueError(f"{action} not in {self.action_space}")
111129

112-
old_state = self.cur_state
113-
self.cur_state = self.transition(self.cur_state, action)
114-
assert self.cur_state in self.state_space, f"unexpected state {self.cur_state}"
115-
obs = self.obs_from_state(self.cur_state)
116-
assert obs in self.observation_space, f"{obs} not in {self.observation_space}"
117-
rew = self.reward(old_state, action, self.cur_state)
118-
done = self.terminal(self.cur_state, self._n_actions_taken)
130+
old_state = self.state
131+
self.state = self.transition(self.state, action)
132+
obs = self.obs_from_state(self.state)
133+
assert obs in self.observation_space
134+
reward = self.reward(old_state, action, self.state)
119135
self._n_actions_taken += 1
136+
done = self.terminal(self.state, self.n_actions_taken)
137+
138+
infos = {"old_state": old_state, "new_state": self._cur_state}
139+
return obs, reward, done, infos
140+
141+
142+
class ExposePOMDPStateWrapper(gym.Wrapper, Generic[State, Observation, Action]):
143+
"""A wrapper that exposes the current state of the POMDP as the observation."""
144+
145+
def __init__(self, env: ResettablePOMDP[State, Observation, Action]) -> None:
146+
"""Build wrapper.
147+
148+
Args:
149+
env: POMDP to wrap.
150+
"""
151+
super().__init__(env)
152+
self._observation_space = env.state_space
153+
154+
def reset(self) -> State:
155+
"""Reset environment and return initial state."""
156+
self.env.reset()
157+
return self.env.state
120158

121-
infos = {"old_state": old_state, "new_state": self.cur_state}
122-
return obs, rew, done, infos
159+
def step(self, action) -> Tuple[State, float, bool, dict]:
160+
"""Transition state using given action."""
161+
obs, reward, done, info = self.env.step(action)
162+
return self.env.state, reward, done, info
123163

124164

125-
class ResettableMDP(ResettablePOMDP[State, State, Action], Generic[State, Action]):
165+
class ResettableMDP(
166+
ResettablePOMDP[State, State, Action],
167+
abc.ABC,
168+
Generic[State, Action],
169+
):
126170
"""ABC for MDPs that are resettable."""
127171

128172
def __init__(
@@ -148,8 +192,20 @@ def obs_from_state(self, state: State) -> State:
148192
return state
149193

150194

151-
class TabularModelMDP(ResettableMDP[int, int]):
152-
"""Base class for tabular environments with known dynamics."""
195+
# TODO(juan) this does not implement the .render() method,
196+
# so in theory it should not be instantiated directly.
197+
# Not sure why this is not raising an error?
198+
class BaseTabularModelPOMDP(ResettablePOMDP[int, Observation, int]):
199+
"""Base class for tabular environments with known dynamics.
200+
201+
This is the general class that also allows subclassing for creating
202+
MDP (where observation == state) or POMDP (where observation != state).
203+
"""
204+
205+
transition_matrix: np.ndarray
206+
reward_matrix: np.ndarray
207+
208+
state_space: spaces.Discrete
153209

154210
def __init__(
155211
self,
@@ -179,14 +235,28 @@ def __init__(
179235
ValueError: `transition_matrix`, `reward_matrix` or
180236
`initial_state_dist` have shapes different to specified above.
181237
"""
182-
n_states, n_actions, n_next_states = transition_matrix.shape
183-
if n_states != n_next_states:
238+
# The following matrices should conform to the shapes below:
239+
240+
# transition matrix: n_states x n_actions x n_states
241+
n_states = transition_matrix.shape[0]
242+
if n_states != transition_matrix.shape[2]:
184243
raise ValueError(
185244
"Malformed transition_matrix:\n"
186245
f"transition_matrix.shape: {transition_matrix.shape}\n"
187-
f"{n_states} != {n_next_states}",
246+
f"{n_states} != {transition_matrix.shape[2]}",
247+
)
248+
249+
# reward matrix: n_states x n_actions x n_states
250+
# OR n_states x n_actions
251+
# OR n_states
252+
if reward_matrix.shape != transition_matrix.shape[: len(reward_matrix.shape)]:
253+
raise ValueError(
254+
"transition_matrix and reward_matrix are not compatible:\n"
255+
f"transition_matrix.shape: {transition_matrix.shape}\n"
256+
f"reward_matrix.shape: {reward_matrix.shape}",
188257
)
189258

259+
# initial state dist: n_states
190260
if initial_state_dist is None:
191261
initial_state_dist = util.one_hot_encoding(0, n_states)
192262
if initial_state_dist.ndim != 1:
@@ -197,28 +267,32 @@ def __init__(
197267
if initial_state_dist.shape[0] != n_states:
198268
raise ValueError(
199269
"transition_matrix and initial_state_dist are not compatible:\n"
200-
f"n_states = {n_states}\n"
270+
f"number of states = {n_states}\n"
201271
f"len(initial_state_dist) = {len(initial_state_dist)}",
202272
)
203273

204-
if reward_matrix.shape != transition_matrix.shape[: len(reward_matrix.shape)]:
205-
raise ValueError(
206-
"transition_matrix and reward_matrix are not compatible:\n"
207-
f"transition_matrix.shape: {transition_matrix.shape}\n"
208-
f"reward_matrix.shape: {reward_matrix.shape}",
209-
)
210-
211274
self.transition_matrix = transition_matrix
212275
self.reward_matrix = reward_matrix
213276
self._feature_matrix = None
214277
self.horizon = horizon
215278
self.initial_state_dist = initial_state_dist
216279

217280
super().__init__(
218-
state_space=spaces.Discrete(n_states),
219-
action_space=spaces.Discrete(n_actions),
281+
state_space=self._construct_state_space(),
282+
action_space=self._construct_action_space(),
283+
observation_space=self._construct_observation_space(),
220284
)
221285

286+
def _construct_state_space(self) -> gym.Space:
287+
return spaces.Discrete(self.state_dim)
288+
289+
def _construct_action_space(self) -> gym.Space:
290+
return spaces.Discrete(self.action_dim)
291+
292+
@abc.abstractmethod
293+
def _construct_observation_space(self) -> gym.Space:
294+
pass # pragma: no cover
295+
222296
def initial_state(self) -> int:
223297
"""Samples from the initial state distribution."""
224298
return util.sample_distribution(
@@ -250,3 +324,128 @@ def feature_matrix(self):
250324
n_states = self.state_space.n
251325
self._feature_matrix = np.eye(n_states)
252326
return self._feature_matrix
327+
328+
@property
329+
def state_dim(self):
330+
"""Number of states in this MDP (int)."""
331+
return self.transition_matrix.shape[0]
332+
333+
@property
334+
def action_dim(self) -> int:
335+
"""Number of action vectors (int)."""
336+
return self.transition_matrix.shape[1]
337+
338+
339+
class TabularModelPOMDP(BaseTabularModelPOMDP[np.ndarray]):
340+
"""Tabular model POMDP.
341+
342+
This class is specifically for environments where observation != state,
343+
from both a typing perspective but also by defining the method that
344+
draws observations from the state.
345+
346+
The tabular model is deterministic in drawing observations from the state,
347+
in that given a certain state, the observation is always the same;
348+
a vector with self.obs_dim entries.
349+
"""
350+
351+
observation_matrix: np.ndarray
352+
353+
def __init__(
354+
self,
355+
*,
356+
transition_matrix: np.ndarray,
357+
observation_matrix: np.ndarray,
358+
reward_matrix: np.ndarray,
359+
horizon: float = np.inf,
360+
initial_state_dist: Optional[np.ndarray] = None,
361+
):
362+
"""Initializes a tabular model POMDP."""
363+
self.observation_matrix = observation_matrix
364+
super().__init__(
365+
transition_matrix=transition_matrix,
366+
reward_matrix=reward_matrix,
367+
horizon=horizon,
368+
initial_state_dist=initial_state_dist,
369+
)
370+
371+
# observation matrix: n_states x n_observations
372+
if observation_matrix.shape[0] != self.state_dim:
373+
raise ValueError(
374+
"transition_matrix and observation_matrix are not compatible:\n"
375+
f"transition_matrix.shape[0]: {self.state_dim}\n"
376+
f"observation_matrix.shape[0]: {observation_matrix.shape[0]}",
377+
)
378+
379+
def _construct_observation_space(self) -> gym.Space:
380+
min_val: float
381+
max_val: float
382+
try:
383+
dtype_iinfo = np.iinfo(self.obs_dtype)
384+
min_val, max_val = dtype_iinfo.min, dtype_iinfo.max
385+
except ValueError:
386+
min_val = -np.inf
387+
max_val = np.inf
388+
return spaces.Box(
389+
low=min_val,
390+
high=max_val,
391+
shape=(self.obs_dim,),
392+
dtype=self.obs_dtype,
393+
)
394+
395+
def obs_from_state(self, state: int) -> np.ndarray:
396+
"""Computes observation from state."""
397+
# Copy so it can't be mutated in-place (updates will be reflected in
398+
# self.observation_matrix!)
399+
obs = self.observation_matrix[state].copy()
400+
assert obs.ndim == 1, obs.shape
401+
return obs
402+
403+
@property
404+
def obs_dim(self) -> int:
405+
"""Size of observation vectors for this MDP."""
406+
return self.observation_matrix.shape[1]
407+
408+
@property
409+
def obs_dtype(self) -> int:
410+
"""Data type of observation vectors (e.g. np.float32)."""
411+
return self.observation_matrix.dtype
412+
413+
414+
class TabularModelMDP(BaseTabularModelPOMDP[int]):
415+
"""Tabular model MDP.
416+
417+
A tabular model MDP is a tabular MDP where the transition and reward
418+
matrices are constant.
419+
"""
420+
421+
def __init__(
422+
self,
423+
*,
424+
transition_matrix: np.ndarray,
425+
reward_matrix: np.ndarray,
426+
horizon: float = np.inf,
427+
initial_state_dist: Optional[np.ndarray] = None,
428+
):
429+
"""Initializes a tabular model MDP.
430+
431+
Args:
432+
transition_matrix: Matrix of shape `(n_states, n_actions, n_states)`
433+
containing transition probabilities.
434+
reward_matrix: Matrix of shape `(n_states, n_actions, n_states)`
435+
containing reward values.
436+
initial_state_dist: Distribution over initial states. Shape `(n_states,)`.
437+
horizon: Maximum number of steps to take in an episode.
438+
"""
439+
super().__init__(
440+
transition_matrix=transition_matrix,
441+
reward_matrix=reward_matrix,
442+
horizon=horizon,
443+
initial_state_dist=initial_state_dist,
444+
)
445+
446+
def obs_from_state(self, state: int) -> int:
447+
"""Identity since observation == state in an MDP."""
448+
return state
449+
450+
def _construct_observation_space(self) -> gym.Space:
451+
return self._construct_state_space()

0 commit comments

Comments
 (0)
Please sign in to comment.