Skip to content

Commit 1e67fe1

Browse files
committed
Remove render('human'), update seeding API, add frameskip validation
1 parent 80de842 commit 1e67fe1

File tree

3 files changed

+64
-46
lines changed

3 files changed

+64
-46
lines changed

src/gym/envs/atari/environment.py

+43-34
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1+
import warnings
2+
from typing import Optional, Union, Tuple, Dict, Any, List
3+
14
import numpy as np
25
import gym
36
import gym.logger as logger
47

58
from gym import error, spaces
69
from gym import utils
7-
from gym.utils import seeding
8-
9-
from typing import Optional, Union, Tuple, Dict, Any, List
1010

1111
import ale_py.roms as roms
1212
from ale_py._ale_py import ALEInterface, ALEState, Action, LoggerMode
@@ -74,11 +74,22 @@ def __init__(
7474
raise error.Error(
7575
f"Invalid observation type: {obs_type}. Expecting: rgb, grayscale, ram."
7676
)
77-
if not (
78-
isinstance(frameskip, int)
79-
or (isinstance(frameskip, tuple) and len(frameskip) == 2)
80-
):
81-
raise error.Error(f"Invalid frameskip type: {frameskip}")
77+
78+
if type(frameskip) not in (int, tuple):
79+
raise error.Error(f"Invalid frameskip type: {type(frameskip)}.")
80+
if isinstance(frameskip, int) and frameskip <= 0:
81+
raise error.Error(
82+
f"Invalid frameskip of {frameskip}, frameskip must be positive.")
83+
elif isinstance(frameskip, tuple) and len(frameskip) != 2:
84+
raise error.Error(
85+
f"Invalid stochastic frameskip length of {len(frameskip)}, expected length 2.")
86+
elif isinstance(frameskip, tuple) and frameskip[0] > frameskip[1]:
87+
raise error.Error(
88+
f"Invalid stochastic frameskip, lower bound is greater than upper bound.")
89+
elif isinstance(frameskip, tuple) and frameskip[0] <= 0:
90+
raise error.Error(
91+
f"Invalid stochastic frameskip lower bound is greater than upper bound.")
92+
8293
if render_mode is not None and render_mode not in {"rgb_array", "human"}:
8394
raise error.Error(
8495
f"Render mode {render_mode} not supported (rgb_array, human)."
@@ -98,7 +109,6 @@ def __init__(
98109

99110
# Initialize ALE
100111
self.ale = ALEInterface()
101-
self.viewer = None
102112

103113
self._game = rom_id_to_name(game)
104114

@@ -112,7 +122,8 @@ def __init__(
112122
# Set logger mode to error only
113123
self.ale.setLoggerMode(LoggerMode.Error)
114124
# Config sticky action prob.
115-
self.ale.setFloat("repeat_action_probability", repeat_action_probability)
125+
self.ale.setFloat("repeat_action_probability",
126+
repeat_action_probability)
116127

117128
# If render mode is human we can display screen and sound
118129
if render_mode == "human":
@@ -146,7 +157,8 @@ def __init__(
146157
low=0, high=255, dtype=np.uint8, shape=image_shape
147158
)
148159
else:
149-
raise error.Error(f"Unrecognized observation type: {self._obs_type}")
160+
raise error.Error(
161+
f"Unrecognized observation type: {self._obs_type}")
150162

151163
def seed(self, seed: Optional[int] = None) -> Tuple[int, int]:
152164
"""
@@ -162,10 +174,13 @@ def seed(self, seed: Optional[int] = None) -> Tuple[int, int]:
162174
Returns:
163175
tuple[int, int] => (np seed, ALE seed)
164176
"""
165-
self.np_random, seed1 = seeding.np_random(seed)
166-
seed2 = seeding.hash_seed(seed1 + 1) % 2 ** 31
177+
ss = np.random.SeedSequence(seed)
178+
seed1, seed2 = ss.generate_state(n_words=2)
167179

168-
self.ale.setInt("random_seed", seed2)
180+
self.np_random = np.random.default_rng(seed1)
181+
# ALE only takes signed integers for `setInt`, it'll get converted back
182+
# to unsigned in StellaEnvironment.
183+
self.ale.setInt("random_seed", seed2.astype(np.int32))
169184

170185
if not hasattr(roms, self._game):
171186
raise error.Error(
@@ -212,7 +227,7 @@ def step(self, action_ind: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any]
212227
if isinstance(self._frameskip, int):
213228
frameskip = self._frameskip
214229
elif isinstance(self._frameskip, tuple):
215-
frameskip = self.np_random.randint(*self._frameskip)
230+
frameskip = self.np_random.integers(*self._frameskip)
216231
else:
217232
raise error.Error(f"Invalid frameskip type: {self._frameskip}")
218233

@@ -224,7 +239,7 @@ def step(self, action_ind: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any]
224239
return self._get_obs(), reward, terminal, self._get_info()
225240

226241
def reset(
227-
self, *, seed: Optional[int] = None, return_info: bool = False
242+
self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[Dict[str, Any]] = None
228243
) -> Union[Tuple[np.ndarray, Dict[str, Any]], np.ndarray]:
229244
"""
230245
Resets environment and returns initial observation.
@@ -247,7 +262,7 @@ def reset(
247262
else:
248263
return obs
249264

250-
def render(self, mode: str) -> None:
265+
def render(self, mode: str) -> Any:
251266
"""
252267
Render is not supported by ALE. We use a paradigm similar to
253268
Gym3 which allows you to specify `render_mode` during construction.
@@ -261,28 +276,21 @@ def render(self, mode: str) -> None:
261276
if mode == "rgb_array":
262277
return img
263278
elif mode == "human":
264-
from gym.envs.classic_control import rendering
265-
266-
if self.viewer is None:
267-
logger.warn(
268-
(
269-
"We strongly suggest supplying `render_mode` when "
270-
"constructing your environment, e.g., gym.make(ID, render_mode='human'). "
271-
"Using `render_mode` provides access to proper scaling, audio support, "
272-
"and proper framerates."
273-
)
279+
warnings.warn(
280+
(
281+
"render('human') is deprecated. Please supply `render_mode` when "
282+
"constructing your environment, e.g., gym.make(ID, render_mode='human'). "
283+
"The new `render_mode` keyword argument supports DPI scaling, "
284+
"audio support, and native framerates."
274285
)
275-
self.viewer = rendering.SimpleImageViewer()
276-
self.viewer.imshow(img)
277-
return self.viewer.isopen
286+
)
287+
return False
278288

279289
def close(self) -> None:
280290
"""
281291
Cleanup any leftovers by the environment
282292
"""
283-
if self.viewer is not None:
284-
self.viewer.close()
285-
self.viewer = None
293+
pass
286294

287295
def _get_obs(self) -> np.ndarray:
288296
"""
@@ -296,7 +304,8 @@ def _get_obs(self) -> np.ndarray:
296304
elif self._obs_type == "grayscale":
297305
return self.ale.getScreenGrayscale()
298306
else:
299-
raise error.Error(f"Unrecognized observation type: {self._obs_type}")
307+
raise error.Error(
308+
f"Unrecognized observation type: {self._obs_type}")
300309

301310
def _get_info(self) -> Dict[str, Any]:
302311
info = {

tests/python/gym/test_gym_interface.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
1+
# fmt: off
12
import pytest
23

34
pytest.importorskip("gym")
45
pytest.importorskip("gym.envs.atari")
56

6-
import numpy as np
7-
8-
from unittest.mock import patch
9-
from itertools import product
10-
11-
from gym import spaces
12-
from gym.envs.registration import registry
13-
from gym.core import Env
14-
from gym.utils.env_checker import check_env
15-
167
from ale_py.gym import (
178
register_legacy_gym_envs,
189
_register_gym_configs,
1910
register_gym_envs,
2011
)
12+
from gym import error
13+
from gym.utils.env_checker import check_env
14+
from gym.core import Env
15+
from gym.envs.registration import registry
16+
from gym.envs.atari.environment import AtariEnv
17+
from gym import spaces
18+
from itertools import product
19+
from unittest.mock import patch
20+
import numpy as np
21+
# fmt: on
2122

2223

2324
def test_register_legacy_env_id():
@@ -123,7 +124,8 @@ def test_register_gym_envs(test_rom_path):
123124
suffixes = []
124125
versions = ["-v5"]
125126

126-
all_ids = set(map("".join, product(games, obs_types, suffixes, versions)))
127+
all_ids = set(map("".join, product(
128+
games, obs_types, suffixes, versions)))
127129
assert all_ids.issubset(envids)
128130

129131

@@ -331,6 +333,13 @@ def test_gym_reset_with_infos(tetris_gym):
331333
assert "rgb" in info
332334

333335

336+
@pytest.mark.parametrize("frameskip", [0, -1, 4.0, (-1, 5), (0, 5), (5, 2), (1, 2, 3)])
337+
def test_frameskip_warnings(test_rom_path, frameskip):
338+
with patch("ale_py.roms.Tetris", create=True, new_callable=lambda: test_rom_path):
339+
with pytest.raises(error.Error):
340+
AtariEnv('Tetris', frameskip=frameskip)
341+
342+
334343
def test_gym_compliance(tetris_gym):
335344
try:
336345
check_env(tetris_gym)

tests/python/gym/test_legacy_registration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_legacy_env_specs():
9999
"""
100100
for spec in specs:
101101
assert spec in registry.env_specs
102-
kwargs = registry.env_specs[spec]._kwargs
102+
kwargs = registry.env_specs[spec].kwargs
103103
max_episode_steps = registry.env_specs[spec].max_episode_steps
104104

105105
# Assert necessary parameters are set

0 commit comments

Comments
 (0)