Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 11f1ae4

Browse files
authored
Merge pull request #624 from deepsense-ai/pong_ext
Training on Pong environment and generation of frames from a gameplay - with cv2 fixes
2 parents 276a6fb + 62797a0 commit 11f1ae4

File tree

5 files changed

+254
-35
lines changed

5 files changed

+254
-35
lines changed

tensor2tensor/data_generators/gym.py

+72
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,25 @@
2323

2424
# Dependency imports
2525

26+
import numpy as np
27+
import functools
28+
import gym
29+
30+
from tensor2tensor.rl import rl_trainer_lib
31+
from tensor2tensor.rl.envs import atari_wrappers
32+
from tensor2tensor.models.research import rl
2633
from tensor2tensor.data_generators import generator_utils
2734
from tensor2tensor.data_generators import problem
2835
from tensor2tensor.utils import registry
2936

3037
import tensorflow as tf
3138

3239

40+
flags = tf.flags
41+
FLAGS = flags.FLAGS
42+
43+
flags.DEFINE_string("model_path", "", "File with model for pong")
44+
3345

3446
def gym_lib():
3547
"""Access to gym to allow for import of this file without a gym install."""
@@ -143,3 +155,63 @@ def num_rewards(self):
143155
@property
144156
def num_steps(self):
145157
return 5000
158+
159+
160+
161+
162+
@registry.register_problem
163+
class GymPongTrajectoriesFromPolicy(GymDiscreteProblem):
164+
"""Pong game, loaded actions."""
165+
166+
def __init__(self, event_dir, *args, **kwargs):
167+
super(GymPongTrajectoriesFromPolicy, self).__init__(*args, **kwargs)
168+
self._env = None
169+
self._event_dir = event_dir
170+
env_spec = lambda: atari_wrappers.wrap_atari(
171+
gym.make("PongNoFrameskip-v4"), warp=False, frame_skip=4, frame_stack=False)
172+
hparams = rl.atari_base()
173+
with tf.variable_scope("train"):
174+
policy_lambda = hparams.network
175+
policy_factory = tf.make_template(
176+
"network",
177+
functools.partial(policy_lambda, env_spec().action_space, hparams))
178+
self._max_frame_pl = tf.placeholder(tf.float32, self.env.observation_space.shape)
179+
actor_critic = policy_factory(tf.expand_dims(tf.expand_dims(self._max_frame_pl, 0), 0))
180+
policy = actor_critic.policy
181+
self._last_policy_op = policy.mode()
182+
self._last_action = self.env.action_space.sample()
183+
self._skip = 4
184+
self._skip_step = 0
185+
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape, dtype=np.uint8)
186+
self._sess = tf.Session()
187+
model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*"))
188+
model_saver.restore(self._sess, FLAGS.model_path)
189+
190+
# TODO(blazej0): For training of atari agents wrappers are usually used.
191+
# Below we have a hacky solution which is a temporary workaround to be used together
192+
# with atari_wrappers.MaxAndSkipEnv.
193+
def get_action(self, observation=None):
194+
if self._skip_step == self._skip - 2: self._obs_buffer[0] = observation
195+
if self._skip_step == self._skip - 1: self._obs_buffer[1] = observation
196+
self._skip_step = (self._skip_step + 1) % self._skip
197+
if self._skip_step == 0:
198+
max_frame = self._obs_buffer.max(axis=0)
199+
self._last_action = int(self._sess.run(
200+
self._last_policy_op, feed_dict={self._max_frame_pl: max_frame})[0, 0])
201+
return self._last_action
202+
203+
@property
204+
def env_name(self):
205+
return "PongNoFrameskip-v4"
206+
207+
@property
208+
def num_actions(self):
209+
return 4
210+
211+
@property
212+
def num_rewards(self):
213+
return 2
214+
215+
@property
216+
def num_steps(self):
217+
return 5000

tensor2tensor/models/research/rl.py

+48-30
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def ppo_base_v1():
4949
hparams.add_hparam("eval_every_epochs", 10)
5050
hparams.add_hparam("num_eval_agents", 3)
5151
hparams.add_hparam("video_during_eval", True)
52+
hparams.add_hparam("save_models_every_epochs", 30)
5253
return hparams
5354

5455

@@ -66,7 +67,22 @@ def discrete_action_base():
6667
return hparams
6768

6869

69-
# Neural networks for actor-critic algorithms
70+
@registry.register_hparams
71+
def atari_base():
72+
hparams = discrete_action_base()
73+
hparams.learning_rate = 16e-5
74+
hparams.num_agents = 5
75+
hparams.epoch_length = 200
76+
hparams.gae_gamma = 0.985
77+
hparams.gae_lambda = 0.985
78+
hparams.entropy_loss_coef = 0.002
79+
hparams.value_loss_coef = 0.025
80+
hparams.optimization_epochs = 10
81+
hparams.epochs_num = 10000
82+
hparams.num_eval_agents = 1
83+
hparams.network = feed_forward_cnn_small_categorical_fun
84+
return hparams
85+
7086

7187
NetworkOutput = collections.namedtuple(
7288
"NetworkOutput", "policy, value, action_postprocessing")
@@ -85,23 +101,24 @@ def feed_forward_gaussian_fun(action_space, config, observations):
85101
tf.shape(observations)[0], tf.shape(observations)[1],
86102
functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
87103

88-
with tf.variable_scope("policy"):
89-
x = flat_observations
90-
for size in config.policy_layers:
91-
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
92-
mean = tf.contrib.layers.fully_connected(
93-
x, action_space.shape[0], tf.tanh,
94-
weights_initializer=mean_weights_initializer)
95-
logstd = tf.get_variable(
96-
"logstd", mean.shape[2:], tf.float32, logstd_initializer)
97-
logstd = tf.tile(
98-
logstd[None, None],
99-
[tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
100-
with tf.variable_scope("value"):
101-
x = flat_observations
102-
for size in config.value_layers:
103-
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
104-
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
104+
with tf.variable_scope("network_parameters"):
105+
with tf.variable_scope("policy"):
106+
x = flat_observations
107+
for size in config.policy_layers:
108+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
109+
mean = tf.contrib.layers.fully_connected(
110+
x, action_space.shape[0], tf.tanh,
111+
weights_initializer=mean_weights_initializer)
112+
logstd = tf.get_variable(
113+
"logstd", mean.shape[2:], tf.float32, logstd_initializer)
114+
logstd = tf.tile(
115+
logstd[None, None],
116+
[tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
117+
with tf.variable_scope("value"):
118+
x = flat_observations
119+
for size in config.value_layers:
120+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
121+
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
105122
mean = tf.check_numerics(mean, "mean")
106123
logstd = tf.check_numerics(logstd, "logstd")
107124
value = tf.check_numerics(value, "value")
@@ -119,17 +136,18 @@ def feed_forward_categorical_fun(action_space, config, observations):
119136
flat_observations = tf.reshape(observations, [
120137
tf.shape(observations)[0], tf.shape(observations)[1],
121138
functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
122-
with tf.variable_scope("policy"):
123-
x = flat_observations
124-
for size in config.policy_layers:
125-
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
126-
logits = tf.contrib.layers.fully_connected(x, action_space.n,
127-
activation_fn=None)
128-
with tf.variable_scope("value"):
129-
x = flat_observations
130-
for size in config.value_layers:
131-
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
132-
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
139+
with tf.variable_scope("network_parameters"):
140+
with tf.variable_scope("policy"):
141+
x = flat_observations
142+
for size in config.policy_layers:
143+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
144+
logits = tf.contrib.layers.fully_connected(x, action_space.n,
145+
activation_fn=None)
146+
with tf.variable_scope("value"):
147+
x = flat_observations
148+
for size in config.value_layers:
149+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
150+
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
133151
policy = tf.contrib.distributions.Categorical(logits=logits)
134152
return NetworkOutput(policy, value, lambda a: a)
135153

@@ -141,7 +159,7 @@ def feed_forward_cnn_small_categorical_fun(action_space, config, observations):
141159
obs_shape = observations.shape.as_list()
142160
x = tf.reshape(observations, [-1] + obs_shape[2:])
143161

144-
with tf.variable_scope("policy"):
162+
with tf.variable_scope("network_parameters"):
145163
x = tf.to_float(x) / 255.0
146164
x = tf.contrib.layers.conv2d(x, 32, [5, 5], [2, 2],
147165
activation_fn=tf.nn.relu, padding="SAME")

tensor2tensor/rl/README.md

+12-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@ for now and under heavy development.
77

88
Currently the only supported algorithm is Proximy Policy Optimization - PPO.
99

10-
## Sample usage - training in Pendulum-v0 environment.
10+
## Sample usage - training in the Pendulum-v0 environment.
1111

1212
```python rl/t2t_rl_trainer.py --problems=Pendulum-v0 --hparams_set continuous_action_base [--output_dir dir_location]```
13+
14+
## Sample usage - training in the PongNoFrameskip-v0 environment.
15+
16+
```python tensor2tensor/rl/t2t_rl_trainer.py --problem stacked_pong --hparams_set atari_base --hparams num_agents=5 --output_dir /tmp/pong`date +%Y%m%d_%H%M%S````
17+
18+
## Sample usage - generation of a model
19+
20+
```python tensor2tensor/bin/t2t-trainer --generate_data --data_dir=~/t2t_data --problems=gym_pong_trajectories_from_policy --hparams_set=base_atari --model_path [model]```
21+
22+
```python tensor2tensor/bin/t2t-datagen --data_dir=~/t2t_data --tmp_dir=~/t2t_data/tmp --problem=gym_pong_trajectories_from_policy --model_path [model]```
23+
+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copied from baselines
2+
# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
3+
4+
# Various wrappers copied from Baselines
5+
import gym
6+
import numpy as np
7+
from collections import deque
8+
import gym
9+
from gym import spaces
10+
11+
12+
class WarpFrame(gym.ObservationWrapper):
13+
def __init__(self, env):
14+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
15+
gym.ObservationWrapper.__init__(self, env)
16+
self.width = 84
17+
self.height = 84
18+
self.observation_space = spaces.Box(low=0, high=255,
19+
shape=(self.height, self.width, 1), dtype=np.uint8)
20+
21+
def observation(self, frame):
22+
import cv2
23+
cv2.ocl.setUseOpenCL(False)
24+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
25+
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
26+
return frame[:, :, None]
27+
28+
class LazyFrames(object):
29+
def __init__(self, frames):
30+
"""This object ensures that common frames between the observations are only stored once.
31+
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
32+
buffers.
33+
This object should only be converted to numpy array before being passed to the model.
34+
You'd not believe how complex the previous solution was."""
35+
self._frames = frames
36+
37+
def __array__(self, dtype=None):
38+
out = np.concatenate(self._frames, axis=2)
39+
if dtype is not None:
40+
out = out.astype(dtype)
41+
return out
42+
43+
class FrameStack(gym.Wrapper):
44+
def __init__(self, env, k):
45+
"""Stack k last frames.
46+
Returns lazy array, which is much more memory efficient.
47+
See Also
48+
--------
49+
baselines.common.atari_wrappers.LazyFrames
50+
"""
51+
gym.Wrapper.__init__(self, env)
52+
self.k = k
53+
self.frames = deque([], maxlen=k)
54+
shp = env.observation_space.shape
55+
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8)
56+
57+
def reset(self):
58+
ob = self.env.reset()
59+
for _ in range(self.k):
60+
self.frames.append(ob)
61+
return self._get_ob()
62+
63+
def step(self, action):
64+
ob, reward, done, info = self.env.step(action)
65+
self.frames.append(ob)
66+
return self._get_ob(), reward, done, info
67+
68+
def _get_ob(self):
69+
assert len(self.frames) == self.k
70+
return LazyFrames(list(self.frames))
71+
72+
class MaxAndSkipEnv(gym.Wrapper):
73+
def __init__(self, env, skip=4):
74+
"""Return only every `skip`-th frame"""
75+
gym.Wrapper.__init__(self, env)
76+
# most recent raw observations (for max pooling across time steps)
77+
self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)
78+
self._skip = skip
79+
80+
def reset(self):
81+
return self.env.reset()
82+
83+
def step(self, action):
84+
"""Repeat action, sum reward, and max over last observations."""
85+
total_reward = 0.0
86+
done = None
87+
for i in range(self._skip):
88+
obs, reward, done, info = self.env.step(action)
89+
if i == self._skip - 2: self._obs_buffer[0] = obs
90+
if i == self._skip - 1: self._obs_buffer[1] = obs
91+
total_reward += reward
92+
if done:
93+
break
94+
# Note that the observation on the done=True frame
95+
# doesn't matter
96+
max_frame = self._obs_buffer.max(axis=0)
97+
98+
return max_frame, total_reward, done, info
99+
100+
def reset(self, **kwargs):
101+
return self.env.reset(**kwargs)
102+
103+
def wrap_atari(env, warp=False, frame_skip=False, frame_stack=False):
104+
if warp:
105+
env = WarpFrame(env)
106+
if frame_skip:
107+
env = MaxAndSkipEnv(env, frame_skip)
108+
if frame_stack:
109+
env = FrameStack(env, frame_stack)
110+
return env

tensor2tensor/rl/rl_trainer_lib.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import absolute_import
1919

2020
import functools
21+
import os
2122

2223
# Dependency imports
2324

@@ -28,6 +29,7 @@
2829
from tensor2tensor.models.research import rl # pylint: disable=unused-import
2930
from tensor2tensor.rl import collect
3031
from tensor2tensor.rl import ppo
32+
from tensor2tensor.rl.envs import atari_wrappers
3133
from tensor2tensor.rl.envs import utils
3234

3335
import tensorflow as tf
@@ -69,19 +71,23 @@ def define_train(hparams, environment_spec, event_dir):
6971
utils.define_batch_env(wrapped_eval_env_lambda, hparams.num_eval_agents,
7072
xvfb=hparams.video_during_eval),
7173
hparams, eval_phase=True)
72-
return summary, eval_summary
74+
return summary, eval_summary, policy_factory
7375

7476

7577
def train(hparams, environment_spec, event_dir=None):
7678
"""Train."""
77-
train_summary_op, eval_summary_op = define_train(hparams, environment_spec,
78-
event_dir)
79-
79+
if environment_spec == "stacked_pong":
80+
environment_spec = lambda: atari_wrappers.wrap_atari(
81+
gym.make("PongNoFrameskip-v4"), warp=False, frame_skip=4, frame_stack=False)
82+
train_summary_op, eval_summary_op, _ = define_train(hparams, environment_spec,
83+
event_dir)
8084
if event_dir:
8185
summary_writer = tf.summary.FileWriter(
8286
event_dir, graph=tf.get_default_graph(), flush_secs=60)
87+
model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*"))
8388
else:
8489
summary_writer = None
90+
model_saver = None
8591

8692
with tf.Session() as sess:
8793
sess.run(tf.global_variables_initializer())
@@ -94,3 +100,5 @@ def train(hparams, environment_spec, event_dir=None):
94100
summary = sess.run(eval_summary_op)
95101
if summary_writer:
96102
summary_writer.add_summary(summary, epoch_index)
103+
if model_saver and hparams.save_models_every_epochs and epoch_index % hparams.save_models_every_epochs == 0:
104+
model_saver.save(sess, os.path.join(event_dir, "model{}.ckpt".format(epoch_index)))

0 commit comments

Comments
 (0)