forked from dennisl88/rand_param_envs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwrappers.py
executable file
·151 lines (116 loc) · 5.18 KB
/
wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from gym.envs.registration import load
import gym
import numpy as np
from gym import Env
from gym import spaces
import os
def mujoco_wrapper(entry_point, **kwargs):
# Load the environment from its entry point
env_cls = load(entry_point)
env = env_cls(**kwargs)
return env
class VariBadWrapper(gym.Wrapper):
def __init__(self,
env,
episodes_per_task
):
"""
Wrapper, creates a multi-episode (BA)MDP around a one-episode MDP. Automatically deals with
- horizons H in the MDP vs horizons H+ in the BAMDP,
- resetting the tasks
- normalized actions in case of continuous action space
- adding the timestep / done info to the state (might be needed to make states markov)
"""
super().__init__(env)
# if continuous actions, make sure in [-1, 1]
if isinstance(self.env.action_space, gym.spaces.Box):
self._normalize_actions = True
else:
self._normalize_actions = False
if episodes_per_task > 1:
self.add_done_info = True
else:
self.add_done_info = False
if self.add_done_info:
if isinstance(self.observation_space, spaces.Box):
if len(self.observation_space.shape) > 1:
raise ValueError # can't add additional info for obs of more than 1D
self.observation_space = spaces.Box(low=np.array([*self.observation_space.low, 0]), # shape will be deduced from this
high=np.array([*self.observation_space.high, 1]),
dtype=np.float32)
else:
# TODO: add something simliar for the other possible spaces,
# "Space", "Discrete", "MultiDiscrete", "MultiBinary", "Tuple", "Dict", "flatdim", "flatten", "unflatten"
raise NotImplementedError
# calculate horizon length H^+
self.episodes_per_task = episodes_per_task
# counts the number of episodes
self.episode_count = 0
# count timesteps in BAMDP
self.step_count_bamdp = 0.0
# the horizon in the BAMDP is the one in the MDP times the number of episodes per task,
# and if we train a policy that maximises the return over all episodes
# we add transitions to the reset start in-between episodes
try:
self.horizon_bamdp = self.episodes_per_task * self.env._max_episode_steps
except AttributeError:
self.horizon_bamdp = self.episodes_per_task * self.env.unwrapped._max_episode_steps
# add dummy timesteps in-between episodes for resetting the MDP
self.horizon_bamdp += self.episodes_per_task - 1
# this tells us if we have reached the horizon in the underlying MDP
self.done_mdp = True
# def reset(self, task):
def reset(self, task=None):
# reset task -- this sets goal and state -- sets self.env._goal and self.env._state
self.env.reset_task(task)
self.episode_count = 0
self.step_count_bamdp = 0
# normal reset
try:
state = self.env.reset()
except AttributeError:
state = self.env.unwrapped.reset()
if self.add_done_info:
state = np.concatenate((state, [0.0]))
self.done_mdp = False
return state
def reset_mdp(self):
state = self.env.reset()
# if self.add_timestep:
# state = np.concatenate((state, [self.step_count_bamdp / self.horizon_bamdp]))
if self.add_done_info:
state = np.concatenate((state, [0.0]))
self.done_mdp = False
return state
def step(self, action):
if self._normalize_actions: # from [-1, 1] to [lb, ub]
lb = self.env.action_space.low
ub = self.env.action_space.high
action = lb + (action + 1.) * 0.5 * (ub - lb)
action = np.clip(action, lb, ub)
# do normal environment step in MDP
state, reward, self.done_mdp, info = self.env.step(action)
info['done_mdp'] = self.done_mdp
# if self.add_timestep:
# state = np.concatenate((state, [self.step_count_bamdp / self.horizon_bamdp]))
if self.add_done_info:
state = np.concatenate((state, [float(self.done_mdp)]))
self.step_count_bamdp += 1
# if we want to maximise performance over multiple episodes,
# only say "done" when we collected enough episodes in this task
done_bamdp = False
if self.done_mdp:
self.episode_count += 1
if self.episode_count == self.episodes_per_task:
done_bamdp = True
if self.done_mdp and not done_bamdp:
info['start_state'] = self.reset_mdp()
return state, reward, done_bamdp, info
class TimeLimitMask(gym.Wrapper):
def step(self, action):
obs, rew, done, info = self.env.step(action)
if done and self.env._max_episode_steps == self.env._elapsed_steps:
info['bad_transition'] = True
return obs, rew, done, info
def reset(self, **kwargs):
return self.env.reset(**kwargs)