1
+ import warnings
2
+ from typing import Optional , Union , Tuple , Dict , Any , List
3
+
1
4
import numpy as np
2
5
import gym
3
6
import gym .logger as logger
4
7
5
8
from gym import error , spaces
6
9
from gym import utils
7
- from gym .utils import seeding
8
-
9
- from typing import Optional , Union , Tuple , Dict , Any , List
10
10
11
11
import ale_py .roms as roms
12
12
from ale_py ._ale_py import ALEInterface , ALEState , Action , LoggerMode
@@ -74,11 +74,22 @@ def __init__(
74
74
raise error .Error (
75
75
f"Invalid observation type: { obs_type } . Expecting: rgb, grayscale, ram."
76
76
)
77
- if not (
78
- isinstance (frameskip , int )
79
- or (isinstance (frameskip , tuple ) and len (frameskip ) == 2 )
80
- ):
81
- raise error .Error (f"Invalid frameskip type: { frameskip } " )
77
+
78
+ if type (frameskip ) not in (int , tuple ):
79
+ raise error .Error (f"Invalid frameskip type: { type (frameskip )} ." )
80
+ if isinstance (frameskip , int ) and frameskip <= 0 :
81
+ raise error .Error (
82
+ f"Invalid frameskip of { frameskip } , frameskip must be positive." )
83
+ elif isinstance (frameskip , tuple ) and len (frameskip ) != 2 :
84
+ raise error .Error (
85
+ f"Invalid stochastic frameskip length of { len (frameskip )} , expected length 2." )
86
+ elif isinstance (frameskip , tuple ) and frameskip [0 ] > frameskip [1 ]:
87
+ raise error .Error (
88
+ f"Invalid stochastic frameskip, lower bound is greater than upper bound." )
89
+ elif isinstance (frameskip , tuple ) and frameskip [0 ] <= 0 :
90
+ raise error .Error (
91
+ f"Invalid stochastic frameskip lower bound is greater than upper bound." )
92
+
82
93
if render_mode is not None and render_mode not in {"rgb_array" , "human" }:
83
94
raise error .Error (
84
95
f"Render mode { render_mode } not supported (rgb_array, human)."
@@ -98,7 +109,6 @@ def __init__(
98
109
99
110
# Initialize ALE
100
111
self .ale = ALEInterface ()
101
- self .viewer = None
102
112
103
113
self ._game = rom_id_to_name (game )
104
114
@@ -112,7 +122,8 @@ def __init__(
112
122
# Set logger mode to error only
113
123
self .ale .setLoggerMode (LoggerMode .Error )
114
124
# Config sticky action prob.
115
- self .ale .setFloat ("repeat_action_probability" , repeat_action_probability )
125
+ self .ale .setFloat ("repeat_action_probability" ,
126
+ repeat_action_probability )
116
127
117
128
# If render mode is human we can display screen and sound
118
129
if render_mode == "human" :
@@ -146,7 +157,8 @@ def __init__(
146
157
low = 0 , high = 255 , dtype = np .uint8 , shape = image_shape
147
158
)
148
159
else :
149
- raise error .Error (f"Unrecognized observation type: { self ._obs_type } " )
160
+ raise error .Error (
161
+ f"Unrecognized observation type: { self ._obs_type } " )
150
162
151
163
def seed (self , seed : Optional [int ] = None ) -> Tuple [int , int ]:
152
164
"""
@@ -162,10 +174,13 @@ def seed(self, seed: Optional[int] = None) -> Tuple[int, int]:
162
174
Returns:
163
175
tuple[int, int] => (np seed, ALE seed)
164
176
"""
165
- self . np_random , seed1 = seeding . np_random (seed )
166
- seed2 = seeding . hash_seed ( seed1 + 1 ) % 2 ** 31
177
+ ss = np . random . SeedSequence (seed )
178
+ seed1 , seed2 = ss . generate_state ( n_words = 2 )
167
179
168
- self .ale .setInt ("random_seed" , seed2 )
180
+ self .np_random = np .random .default_rng (seed1 )
181
+ # ALE only takes signed integers for `setInt`, it'll get converted back
182
+ # to unsigned in StellaEnvironment.
183
+ self .ale .setInt ("random_seed" , seed2 .astype (np .int32 ))
169
184
170
185
if not hasattr (roms , self ._game ):
171
186
raise error .Error (
@@ -212,7 +227,7 @@ def step(self, action_ind: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any]
212
227
if isinstance (self ._frameskip , int ):
213
228
frameskip = self ._frameskip
214
229
elif isinstance (self ._frameskip , tuple ):
215
- frameskip = self .np_random .randint (* self ._frameskip )
230
+ frameskip = self .np_random .integers (* self ._frameskip )
216
231
else :
217
232
raise error .Error (f"Invalid frameskip type: { self ._frameskip } " )
218
233
@@ -224,7 +239,7 @@ def step(self, action_ind: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any]
224
239
return self ._get_obs (), reward , terminal , self ._get_info ()
225
240
226
241
def reset (
227
- self , * , seed : Optional [int ] = None , return_info : bool = False
242
+ self , * , seed : Optional [int ] = None , return_info : bool = False , options : Optional [ Dict [ str , Any ]] = None
228
243
) -> Union [Tuple [np .ndarray , Dict [str , Any ]], np .ndarray ]:
229
244
"""
230
245
Resets environment and returns initial observation.
@@ -247,7 +262,7 @@ def reset(
247
262
else :
248
263
return obs
249
264
250
- def render (self , mode : str ) -> None :
265
+ def render (self , mode : str ) -> Any :
251
266
"""
252
267
Render is not supported by ALE. We use a paradigm similar to
253
268
Gym3 which allows you to specify `render_mode` during construction.
@@ -261,28 +276,21 @@ def render(self, mode: str) -> None:
261
276
if mode == "rgb_array" :
262
277
return img
263
278
elif mode == "human" :
264
- from gym .envs .classic_control import rendering
265
-
266
- if self .viewer is None :
267
- logger .warn (
268
- (
269
- "We strongly suggest supplying `render_mode` when "
270
- "constructing your environment, e.g., gym.make(ID, render_mode='human'). "
271
- "Using `render_mode` provides access to proper scaling, audio support, "
272
- "and proper framerates."
273
- )
279
+ warnings .warn (
280
+ (
281
+ "render('human') is deprecated. Please supply `render_mode` when "
282
+ "constructing your environment, e.g., gym.make(ID, render_mode='human'). "
283
+ "The new `render_mode` keyword argument supports DPI scaling, "
284
+ "audio support, and native framerates."
274
285
)
275
- self .viewer = rendering .SimpleImageViewer ()
276
- self .viewer .imshow (img )
277
- return self .viewer .isopen
286
+ )
287
+ return False
278
288
279
289
def close (self ) -> None :
280
290
"""
281
291
Cleanup any leftovers by the environment
282
292
"""
283
- if self .viewer is not None :
284
- self .viewer .close ()
285
- self .viewer = None
293
+ pass
286
294
287
295
def _get_obs (self ) -> np .ndarray :
288
296
"""
@@ -296,7 +304,8 @@ def _get_obs(self) -> np.ndarray:
296
304
elif self ._obs_type == "grayscale" :
297
305
return self .ale .getScreenGrayscale ()
298
306
else :
299
- raise error .Error (f"Unrecognized observation type: { self ._obs_type } " )
307
+ raise error .Error (
308
+ f"Unrecognized observation type: { self ._obs_type } " )
300
309
301
310
def _get_info (self ) -> Dict [str , Any ]:
302
311
info = {
0 commit comments