Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 4dd189e

Browse files
Lukasz KaiserRyan Sepassi
Lukasz Kaiser
authored and
Ryan Sepassi
committed
internal
PiperOrigin-RevId: 187661295
1 parent 0584c15 commit 4dd189e

File tree

6 files changed

+292
-44
lines changed

6 files changed

+292
-44
lines changed

Diff for: tensor2tensor/data_generators/gym.py

+77-9
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,28 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
import os
22+
import functools
2323

2424
# Dependency imports
2525

26+
import gym
27+
import numpy as np
28+
2629
from tensor2tensor.data_generators import generator_utils
2730
from tensor2tensor.data_generators import problem
31+
from tensor2tensor.models.research import rl
32+
from tensor2tensor.rl.envs import atari_wrappers
2833
from tensor2tensor.utils import registry
2934

3035
import tensorflow as tf
3136

3237

3338

34-
def gym_lib():
35-
"""Access to gym to allow for import of this file without a gym install."""
36-
try:
37-
import gym # pylint: disable=g-import-not-at-top
38-
except ImportError:
39-
raise ImportError("pip install gym to use gym-based Problems")
40-
return gym
39+
40+
flags = tf.flags
41+
FLAGS = flags.FLAGS
42+
43+
flags.DEFINE_string("model_path", "", "File with model for pong")
4144

4245

4346
class GymDiscreteProblem(problem.Problem):
@@ -55,7 +58,7 @@ def env_name(self):
5558
@property
5659
def env(self):
5760
if self._env is None:
58-
self._env = gym_lib().make(self.env_name)
61+
self._env = gym.make(self.env_name)
5962
return self._env
6063

6164
@property
@@ -143,3 +146,68 @@ def num_rewards(self):
143146
@property
144147
def num_steps(self):
145148
return 5000
149+
150+
151+
@registry.register_problem
152+
class GymPongTrajectoriesFromPolicy(GymDiscreteProblem):
153+
"""Pong game, loaded actions."""
154+
155+
def __init__(self, event_dir, *args, **kwargs):
156+
super(GymPongTrajectoriesFromPolicy, self).__init__(*args, **kwargs)
157+
self._env = None
158+
self._event_dir = event_dir
159+
env_spec = lambda: atari_wrappers.wrap_atari( # pylint: disable=g-long-lambda
160+
gym.make("PongNoFrameskip-v4"),
161+
warp=False,
162+
frame_skip=4,
163+
frame_stack=False)
164+
hparams = rl.atari_base()
165+
with tf.variable_scope("train"):
166+
policy_lambda = hparams.network
167+
policy_factory = tf.make_template(
168+
"network",
169+
functools.partial(policy_lambda, env_spec().action_space, hparams))
170+
self._max_frame_pl = tf.placeholder(
171+
tf.float32, self.env.observation_space.shape)
172+
actor_critic = policy_factory(tf.expand_dims(tf.expand_dims(
173+
self._max_frame_pl, 0), 0))
174+
policy = actor_critic.policy
175+
self._last_policy_op = policy.mode()
176+
self._last_action = self.env.action_space.sample()
177+
self._skip = 4
178+
self._skip_step = 0
179+
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape,
180+
dtype=np.uint8)
181+
self._sess = tf.Session()
182+
model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*"))
183+
model_saver.restore(self._sess, FLAGS.model_path)
184+
185+
# TODO(blazej0): For training of atari agents wrappers are usually used.
186+
# Below we have a hacky solution which is a workaround to be used together
187+
# with atari_wrappers.MaxAndSkipEnv.
188+
def get_action(self, observation=None):
189+
if self._skip_step == self._skip - 2: self._obs_buffer[0] = observation
190+
if self._skip_step == self._skip - 1: self._obs_buffer[1] = observation
191+
self._skip_step = (self._skip_step + 1) % self._skip
192+
if self._skip_step == 0:
193+
max_frame = self._obs_buffer.max(axis=0)
194+
self._last_action = int(self._sess.run(
195+
self._last_policy_op,
196+
feed_dict={self._max_frame_pl: max_frame})[0, 0])
197+
return self._last_action
198+
199+
@property
200+
def env_name(self):
201+
return "PongNoFrameskip-v4"
202+
203+
@property
204+
def num_actions(self):
205+
return 4
206+
207+
@property
208+
def num_rewards(self):
209+
return 2
210+
211+
@property
212+
def num_steps(self):
213+
return 5000

Diff for: tensor2tensor/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from tensor2tensor.models.research import cycle_gan
4545
from tensor2tensor.models.research import gene_expression
4646
from tensor2tensor.models.research import multimodel
47+
from tensor2tensor.models.research import rl
4748
from tensor2tensor.models.research import super_lm
4849
from tensor2tensor.models.research import transformer_moe
4950
from tensor2tensor.models.research import transformer_revnet

Diff for: tensor2tensor/models/research/rl.py

+49-30
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def ppo_base_v1():
4949
hparams.add_hparam("eval_every_epochs", 10)
5050
hparams.add_hparam("num_eval_agents", 3)
5151
hparams.add_hparam("video_during_eval", True)
52+
hparams.add_hparam("save_models_every_epochs", 30)
5253
return hparams
5354

5455

@@ -66,7 +67,23 @@ def discrete_action_base():
6667
return hparams
6768

6869

69-
# Neural networks for actor-critic algorithms
70+
@registry.register_hparams
71+
def atari_base():
72+
"""Atari base parameters."""
73+
hparams = discrete_action_base()
74+
hparams.learning_rate = 16e-5
75+
hparams.num_agents = 5
76+
hparams.epoch_length = 200
77+
hparams.gae_gamma = 0.985
78+
hparams.gae_lambda = 0.985
79+
hparams.entropy_loss_coef = 0.002
80+
hparams.value_loss_coef = 0.025
81+
hparams.optimization_epochs = 10
82+
hparams.epochs_num = 10000
83+
hparams.num_eval_agents = 1
84+
hparams.network = feed_forward_cnn_small_categorical_fun
85+
return hparams
86+
7087

7188
NetworkOutput = collections.namedtuple(
7289
"NetworkOutput", "policy, value, action_postprocessing")
@@ -85,23 +102,24 @@ def feed_forward_gaussian_fun(action_space, config, observations):
85102
tf.shape(observations)[0], tf.shape(observations)[1],
86103
functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
87104

88-
with tf.variable_scope("policy"):
89-
x = flat_observations
90-
for size in config.policy_layers:
91-
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
92-
mean = tf.contrib.layers.fully_connected(
93-
x, action_space.shape[0], tf.tanh,
94-
weights_initializer=mean_weights_initializer)
95-
logstd = tf.get_variable(
96-
"logstd", mean.shape[2:], tf.float32, logstd_initializer)
97-
logstd = tf.tile(
98-
logstd[None, None],
99-
[tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
100-
with tf.variable_scope("value"):
101-
x = flat_observations
102-
for size in config.value_layers:
103-
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
104-
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
105+
with tf.variable_scope("network_parameters"):
106+
with tf.variable_scope("policy"):
107+
x = flat_observations
108+
for size in config.policy_layers:
109+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
110+
mean = tf.contrib.layers.fully_connected(
111+
x, action_space.shape[0], tf.tanh,
112+
weights_initializer=mean_weights_initializer)
113+
logstd = tf.get_variable(
114+
"logstd", mean.shape[2:], tf.float32, logstd_initializer)
115+
logstd = tf.tile(
116+
logstd[None, None],
117+
[tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
118+
with tf.variable_scope("value"):
119+
x = flat_observations
120+
for size in config.value_layers:
121+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
122+
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
105123
mean = tf.check_numerics(mean, "mean")
106124
logstd = tf.check_numerics(logstd, "logstd")
107125
value = tf.check_numerics(value, "value")
@@ -119,17 +137,18 @@ def feed_forward_categorical_fun(action_space, config, observations):
119137
flat_observations = tf.reshape(observations, [
120138
tf.shape(observations)[0], tf.shape(observations)[1],
121139
functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
122-
with tf.variable_scope("policy"):
123-
x = flat_observations
124-
for size in config.policy_layers:
125-
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
126-
logits = tf.contrib.layers.fully_connected(x, action_space.n,
127-
activation_fn=None)
128-
with tf.variable_scope("value"):
129-
x = flat_observations
130-
for size in config.value_layers:
131-
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
132-
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
140+
with tf.variable_scope("network_parameters"):
141+
with tf.variable_scope("policy"):
142+
x = flat_observations
143+
for size in config.policy_layers:
144+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
145+
logits = tf.contrib.layers.fully_connected(x, action_space.n,
146+
activation_fn=None)
147+
with tf.variable_scope("value"):
148+
x = flat_observations
149+
for size in config.value_layers:
150+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
151+
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
133152
policy = tf.contrib.distributions.Categorical(logits=logits)
134153
return NetworkOutput(policy, value, lambda a: a)
135154

@@ -141,7 +160,7 @@ def feed_forward_cnn_small_categorical_fun(action_space, config, observations):
141160
obs_shape = observations.shape.as_list()
142161
x = tf.reshape(observations, [-1] + obs_shape[2:])
143162

144-
with tf.variable_scope("policy"):
163+
with tf.variable_scope("network_parameters"):
145164
x = tf.to_float(x) / 255.0
146165
x = tf.contrib.layers.conv2d(x, 32, [5, 5], [2, 2],
147166
activation_fn=tf.nn.relu, padding="SAME")

Diff for: tensor2tensor/rl/README.md

+11-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ for now and under heavy development.
77

88
Currently the only supported algorithm is Proximy Policy Optimization - PPO.
99

10-
## Sample usage - training in Pendulum-v0 environment.
10+
## Sample usage - training in the Pendulum-v0 environment.
1111

1212
```python rl/t2t_rl_trainer.py --problems=Pendulum-v0 --hparams_set continuous_action_base [--output_dir dir_location]```
13+
14+
## Sample usage - training in the PongNoFrameskip-v0 environment.
15+
16+
```python tensor2tensor/rl/t2t_rl_trainer.py --problem stacked_pong --hparams_set atari_base --hparams num_agents=5 --output_dir /tmp/pong`date +%Y%m%d_%H%M%S````
17+
18+
## Sample usage - generation of a model
19+
20+
```python tensor2tensor/bin/t2t-trainer --generate_data --data_dir=~/t2t_data --problems=gym_pong_trajectories_from_policy --hparams_set=base_atari --model_path [model]```
21+
22+
```python tensor2tensor/bin/t2t-datagen --data_dir=~/t2t_data --tmp_dir=~/t2t_data/tmp --problem=gym_pong_trajectories_from_policy --model_path [model]```

Diff for: tensor2tensor/rl/envs/atari_wrappers.py

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Various wrappers copied for Gym Baselines."""
17+
18+
from collections import deque
19+
import gym
20+
import numpy as np
21+
22+
23+
# Adapted from the link below.
24+
# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
25+
26+
27+
class WarpFrame(gym.ObservationWrapper):
28+
"""Wrap a frame."""
29+
30+
def __init__(self, env):
31+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
32+
gym.ObservationWrapper.__init__(self, env)
33+
self.width = 84
34+
self.height = 84
35+
self.observation_space = gym.spaces.Box(
36+
low=0, high=255,
37+
shape=(self.height, self.width, 1), dtype=np.uint8)
38+
39+
def observation(self, frame):
40+
import cv2 # pylint: disable=g-import-not-at-top
41+
cv2.ocl.setUseOpenCL(False)
42+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
43+
frame = cv2.resize(frame, (self.width, self.height),
44+
interpolation=cv2.INTER_AREA)
45+
return frame[:, :, None]
46+
47+
48+
class LazyFrames(object):
49+
"""Lazy frame storage."""
50+
51+
def __init__(self, frames):
52+
"""Lazy frame storage.
53+
54+
This object ensures that common frames between the observations
55+
are only stored once. It exists purely to optimize memory usage
56+
which can be huge for DQN's 1M frames replay buffers.
57+
This object should only be converted to numpy array before being passed
58+
to the model.
59+
60+
Args:
61+
frames: the frames.
62+
"""
63+
self._frames = frames
64+
65+
def __array__(self, dtype=None):
66+
out = np.concatenate(self._frames, axis=2)
67+
if dtype is not None:
68+
out = out.astype(dtype)
69+
return out
70+
71+
72+
class FrameStack(gym.Wrapper):
73+
"""Stack frames."""
74+
75+
def __init__(self, env, k):
76+
"""Stack k last frames. Returns lazy array, memory efficient."""
77+
gym.Wrapper.__init__(self, env)
78+
self.k = k
79+
self.frames = deque([], maxlen=k)
80+
shp = env.observation_space.shape
81+
self.observation_space = gym.spaces.Box(
82+
low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8)
83+
84+
def reset(self):
85+
ob = self.env.reset()
86+
for _ in range(self.k):
87+
self.frames.append(ob)
88+
return self._get_ob()
89+
90+
def step(self, action):
91+
ob, reward, done, info = self.env.step(action)
92+
self.frames.append(ob)
93+
return self._get_ob(), reward, done, info
94+
95+
def _get_ob(self):
96+
assert len(self.frames) == self.k
97+
return LazyFrames(list(self.frames))
98+
99+
100+
class MaxAndSkipEnv(gym.Wrapper):
101+
"""Max and skip env."""
102+
103+
def __init__(self, env, skip=4):
104+
"""Return only every `skip`-th frame."""
105+
gym.Wrapper.__init__(self, env)
106+
# Most recent raw observations (for max pooling across time steps).
107+
self._obs_buffer = np.zeros((2,) + env.observation_space.shape,
108+
dtype=np.uint8)
109+
self._skip = skip
110+
111+
def reset(self, **kwargs):
112+
return self.env.reset(**kwargs)
113+
114+
def step(self, action):
115+
"""Repeat action, sum reward, and max over last observations."""
116+
total_reward = 0.0
117+
done = None
118+
for i in range(self._skip):
119+
obs, reward, done, info = self.env.step(action)
120+
if i == self._skip - 2: self._obs_buffer[0] = obs
121+
if i == self._skip - 1: self._obs_buffer[1] = obs
122+
total_reward += reward
123+
if done:
124+
break
125+
# Note that the observation on the done=True frame
126+
# doesn't matter
127+
max_frame = self._obs_buffer.max(axis=0)
128+
129+
return max_frame, total_reward, done, info
130+
131+
132+
def wrap_atari(env, warp=False, frame_skip=False, frame_stack=False):
133+
if warp:
134+
env = WarpFrame(env)
135+
if frame_skip:
136+
env = MaxAndSkipEnv(env, frame_skip)
137+
if frame_stack:
138+
env = FrameStack(env, frame_stack)
139+
return env

0 commit comments

Comments
 (0)