Skip to content

Commit fd0cd82

Browse files
araffinqgallouedec
andauthored
Update outdated custom env doc (DLR-RM#1490)
* Update outdated custom env doc * fix render_mode and term/trunc/reset_info * gym -> gymnasium --------- Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 9cebedc commit fd0cd82

14 files changed

+62
-68
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ for i in range(1000):
139139
env.close()
140140
```
141141

142-
Or just train a model with a one liner if [the environment is registered in Gym](https://github.com/openai/gym/wiki/Environments) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
142+
Or just train a model with a one liner if [the environment is registered in Gymnasium](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#registering-envs) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
143143

144144
```python
145145
from stable_baselines3 import PPO

docs/guide/custom_env.rst

+15-14
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
Using Custom Environments
44
==========================
55

6-
To use the RL baselines with custom environments, they just need to follow the *gym* interface.
7-
That is to say, your environment must implement the following methods (and inherits from OpenAI Gym Class):
6+
To use the RL baselines with custom environments, they just need to follow the *gymnasium* `interface <https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#sphx-glr-tutorials-gymnasium-basics-environment-creation-py>`_.
7+
That is to say, your environment must implement the following methods (and inherits from Gym Class):
88

99

1010
.. note::
11-
If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255].
12-
By default, the observation is normalized by SB3 pre-processing (dividing by 255 to have values in [0, 1]) when using CNN policies.
13-
Images can be either channel-first or channel-last.
11+
12+
If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255].
13+
By default, the observation is normalized by SB3 pre-processing (dividing by 255 to have values in [0, 1]) when using CNN policies.
14+
Images can be either channel-first or channel-last.
1415

1516
If you want to use ``CnnPolicy`` or ``MultiInputPolicy`` with image-like observation (3D tensor) that are already normalized, you must pass ``normalize_images=False``
16-
to the policy (using ``policy_kwargs`` parameter, ``policy_kwargs=dict(normalize_images=False)``)
17-
and make sure your image is in the **channel-first** format.
17+
to the policy (using ``policy_kwargs`` parameter, ``policy_kwargs=dict(normalize_images=False)``)
18+
and make sure your image is in the **channel-first** format.
1819

1920

2021
.. note::
@@ -34,7 +35,7 @@ That is to say, your environment must implement the following methods (and inher
3435
class CustomEnv(gym.Env):
3536
"""Custom Environment that follows gym interface."""
3637
37-
metadata = {"render.modes": ["human"]}
38+
metadata = {"render_modes": ["human"], "render_fps": 30}
3839
3940
def __init__(self, arg1, arg2, ...):
4041
super().__init__()
@@ -48,11 +49,11 @@ That is to say, your environment must implement the following methods (and inher
4849
4950
def step(self, action):
5051
...
51-
return observation, reward, done, info
52+
return observation, reward, terminated, truncated, info
5253
53-
def reset(self):
54+
def reset(self, seed=None, options=None):
5455
...
55-
return observation # reward, done, info can't be included
56+
return observation, info
5657
5758
def render(self):
5859
...
@@ -81,11 +82,11 @@ To check that your environment follows the Gym interface that SB3 supports, plea
8182
# It will check your custom environment and output additional warnings if needed
8283
check_env(env)
8384
84-
Gym also have its own `env checker <https://www.gymlibrary.ml/content/api/#checking-api-conformity>`_ but it checks a superset of what SB3 supports (SB3 does not support all Gym features).
85+
Gymnasium also have its own `env checker <https://gymnasium.farama.org/api/utils/#gymnasium.utils.env_checker.check_env>`_ but it checks a superset of what SB3 supports (SB3 does not support all Gym features).
8586

86-
We have created a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/5_custom_gym_env.ipynb>`_ for a concrete example on creating a custom environment along with an example of using it with Stable-Baselines3 interface.
87+
We have created a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/5_custom_gym_env.ipynb>`_ for a concrete example on creating a custom environment along with an example of using it with Stable-Baselines3 interface.
8788

88-
Alternatively, you may look at OpenAI Gym `built-in environments <https://www.gymlibrary.ml/>`_. However, the readers are cautioned as per OpenAI Gym `official wiki <https://github.com/openai/gym/wiki/FAQ>`_, its advised not to customize their built-in environments. It is better to copy and create new ones if you need to modify them.
89+
Alternatively, you may look at Gymnasium `built-in environments <https://gymnasium.farama.org>`_.
8990

9091
Optionally, you can also register the environment with gym, that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env):
9192

docs/guide/examples.rst

+3-4
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ In the following example, we will train, save and load a DQN model on the Lunar
7171
7272
7373
# Create environment
74-
env = gym.make("LunarLander-v2")
74+
env = gym.make("LunarLander-v2", render_mode="rgb_array")
7575
7676
# Instantiate the agent
7777
model = DQN("MlpPolicy", env, verbose=1)
@@ -99,7 +99,7 @@ In the following example, we will train, save and load a DQN model on the Lunar
9999
for i in range(1000):
100100
action, _states = model.predict(obs, deterministic=True)
101101
obs, rewards, dones, info = vec_env.step(action)
102-
vec_env.render()
102+
vec_env.render("human")
103103
104104
105105
Multiprocessing: Unleashing the Power of Vectorized Environments
@@ -116,7 +116,6 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
116116
.. code-block:: python
117117
118118
import gymnasium as gym
119-
import numpy as np
120119
121120
from stable_baselines3 import PPO
122121
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
@@ -512,6 +511,7 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
512511
# Load saved model
513512
# Because it needs access to `env.compute_reward()`
514513
# HER must be loaded with the env
514+
env = gym.make("parking-v0", render_mode="human") # Change the render mode
515515
model = SAC.load("her_sac_highway", env=env)
516516
517517
obs, info = env.reset()
@@ -521,7 +521,6 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
521521
for _ in range(100):
522522
action, _ = model.predict(obs, deterministic=True)
523523
obs, reward, terminated, truncated, info = env.step(action)
524-
env.render()
525524
episode_reward += reward
526525
if terminated or truncated or info.get("is_success", False):
527526
print("Reward:", episode_reward, "Success?", info.get("is_success", False))

docs/guide/quickstart.rst

+4-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
2020
2121
from stable_baselines3 import A2C
2222
23-
env = gym.make("CartPole-v1")
23+
env = gym.make("CartPole-v1", render_mode="rgb_array")
2424
2525
model = A2C("MlpPolicy", env, verbose=1)
2626
model.learn(total_timesteps=10_000)
@@ -30,7 +30,7 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
3030
for i in range(1000):
3131
action, _state = model.predict(obs, deterministic=True)
3232
obs, reward, done, info = vec_env.step(action)
33-
vec_env.render()
33+
vec_env.render("human")
3434
# VecEnv resets automatically
3535
# if done:
3636
# obs = vec_env.reset()
@@ -40,8 +40,8 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
4040
You can find explanations about the logger output and names in the :ref:`Logger <logger>` section.
4141

4242

43-
Or just train a model with a one liner if
44-
`the environment is registered in Gym <https://github.com/openai/gym/wiki/Environments>`_ and if
43+
Or just train a model with a one line if
44+
`the environment is registered in Gymnasium <https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/#registering-envs>`_ and if
4545
the policy is registered:
4646

4747
.. code-block:: python

docs/guide/rl_tips.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -210,14 +210,14 @@ If you want to quickly try a random agent on your environment, you can also do:
210210
.. code-block:: python
211211
212212
env = YourEnv()
213-
obs = env.reset()
213+
obs, info = env.reset()
214214
n_steps = 10
215215
for _ in range(n_steps):
216216
# Random action
217217
action = env.action_space.sample()
218-
obs, reward, done, info = env.step(action)
218+
obs, reward, terminated, truncated, info = env.step(action)
219219
if done:
220-
obs = env.reset()
220+
obs, info = env.reset()
221221
222222
223223
**Why should I normalize the action space?**

docs/misc/changelog.rst

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ Documentation:
7070
- Make it more explicit when using ``VecEnv`` vs Gym env
7171
- Added UAV_Navigation_DRL_AirSim to the project page (@heleidsn)
7272
- Added ``EvalCallback`` example (@sidney-tio)
73+
- Update custom env documentation
7374

7475

7576
Release 1.8.0 (2023-04-07)

docs/modules/a2c.rst

+5-7
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,25 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
5353

5454
.. code-block:: python
5555
56-
import gymnasium as gym
57-
5856
from stable_baselines3 import A2C
5957
from stable_baselines3.common.env_util import make_vec_env
6058
6159
# Parallel environments
62-
env = make_vec_env("CartPole-v1", n_envs=4)
60+
vec_env = make_vec_env("CartPole-v1", n_envs=4)
6361
64-
model = A2C("MlpPolicy", env, verbose=1)
62+
model = A2C("MlpPolicy", vec_env, verbose=1)
6563
model.learn(total_timesteps=25000)
6664
model.save("a2c_cartpole")
6765
6866
del model # remove to demonstrate saving and loading
6967
7068
model = A2C.load("a2c_cartpole")
7169
72-
obs = env.reset()
70+
obs = vec_env.reset()
7371
while True:
7472
action, _states = model.predict(obs)
75-
obs, rewards, dones, info = env.step(action)
76-
env.render()
73+
obs, rewards, dones, info = vec_env.step(action)
74+
vec_env.render("human")
7775
7876
7977
.. note::

docs/modules/ddpg.rst

+5-5
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ This example is only to demonstrate the use of the library and its functions, an
6767
from stable_baselines3 import DDPG
6868
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
6969
70-
env = gym.make("Pendulum-v1")
70+
env = gym.make("Pendulum-v1", render_mode="rgb_array")
7171
7272
# The noise objects for DDPG
7373
n_actions = env.action_space.shape[-1]
@@ -76,17 +76,17 @@ This example is only to demonstrate the use of the library and its functions, an
7676
model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=1)
7777
model.learn(total_timesteps=10000, log_interval=10)
7878
model.save("ddpg_pendulum")
79-
env = model.get_env()
79+
vec_env = model.get_env()
8080
8181
del model # remove to demonstrate saving and loading
8282
8383
model = DDPG.load("ddpg_pendulum")
8484
85-
obs = env.reset()
85+
obs = vec_env.reset()
8686
while True:
8787
action, _states = model.predict(obs)
88-
obs, rewards, dones, info = env.step(action)
89-
env.render()
88+
obs, rewards, dones, info = vec_env.step(action)
89+
env.render("human")
9090
9191
Results
9292
-------

docs/modules/dqn.rst

+5-6
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ This example is only to demonstrate the use of the library and its functions, an
6060
6161
from stable_baselines3 import DQN
6262
63-
env = gym.make("CartPole-v1")
63+
env = gym.make("CartPole-v1", render_mode="human")
6464
6565
model = DQN("MlpPolicy", env, verbose=1)
6666
model.learn(total_timesteps=10000, log_interval=4)
@@ -70,13 +70,12 @@ This example is only to demonstrate the use of the library and its functions, an
7070
7171
model = DQN.load("dqn_cartpole")
7272
73-
obs = env.reset()
73+
obs, info = env.reset()
7474
while True:
7575
action, _states = model.predict(obs, deterministic=True)
76-
obs, reward, done, info = env.step(action)
77-
env.render()
78-
if done:
79-
obs = env.reset()
76+
obs, reward, terminated, truncated, info = env.step(action)
77+
if terminated or truncated:
78+
obs, info = env.reset()
8079
8180
8281
Results

docs/modules/her.rst

+4-6
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ This example is only to demonstrate the use of the library and its functions, an
6565
from stable_baselines3 import HerReplayBuffer, DDPG, DQN, SAC, TD3
6666
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
6767
from stable_baselines3.common.envs import BitFlippingEnv
68-
from stable_baselines3.common.vec_env import DummyVecEnv
6968
7069
model_class = DQN # works also with SAC, DDPG and TD3
7170
N_BITS = 15
@@ -96,13 +95,12 @@ This example is only to demonstrate the use of the library and its functions, an
9695
# HER must be loaded with the env
9796
model = model_class.load("./her_bit_env", env=env)
9897
99-
obs = env.reset()
98+
obs, info = env.reset()
10099
for _ in range(100):
101100
action, _ = model.predict(obs, deterministic=True)
102-
obs, reward, done, _ = env.step(action)
103-
104-
if done:
105-
obs = env.reset()
101+
obs, reward, terminated, truncated, _ = env.step(action)
102+
if terminated or truncated:
103+
obs, info = env.reset()
106104
107105
108106
Results

docs/modules/ppo.rst

+5-5
Original file line numberDiff line numberDiff line change
@@ -71,21 +71,21 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
7171
from stable_baselines3.common.env_util import make_vec_env
7272
7373
# Parallel environments
74-
env = make_vec_env("CartPole-v1", n_envs=4)
74+
vec_env = make_vec_env("CartPole-v1", n_envs=4)
7575
76-
model = PPO("MlpPolicy", env, verbose=1)
76+
model = PPO("MlpPolicy", vec_env, verbose=1)
7777
model.learn(total_timesteps=25000)
7878
model.save("ppo_cartpole")
7979
8080
del model # remove to demonstrate saving and loading
8181
8282
model = PPO.load("ppo_cartpole")
8383
84-
obs = env.reset()
84+
obs = vec_env.reset()
8585
while True:
8686
action, _states = model.predict(obs)
87-
obs, rewards, dones, info = env.step(action)
88-
env.render()
87+
obs, rewards, dones, info = vec_env.step(action)
88+
vec_env.render("human")
8989
9090
9191
Results

docs/modules/sac.rst

+5-7
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,10 @@ This example is only to demonstrate the use of the library and its functions, an
6969
.. code-block:: python
7070
7171
import gymnasium as gym
72-
import numpy as np
7372
7473
from stable_baselines3 import SAC
7574
76-
env = gym.make("Pendulum-v1")
75+
env = gym.make("Pendulum-v1", render_mode="human")
7776
7877
model = SAC("MlpPolicy", env, verbose=1)
7978
model.learn(total_timesteps=10000, log_interval=4)
@@ -83,13 +82,12 @@ This example is only to demonstrate the use of the library and its functions, an
8382
8483
model = SAC.load("sac_pendulum")
8584
86-
obs = env.reset()
85+
obs, info = env.reset()
8786
while True:
8887
action, _states = model.predict(obs, deterministic=True)
89-
obs, reward, done, info = env.step(action)
90-
env.render()
91-
if done:
92-
obs = env.reset()
88+
obs, reward, terminated, truncated, info = env.step(action)
89+
if terminated or truncated:
90+
obs, info = env.reset()
9391
9492
9593
Results

docs/modules/td3.rst

+5-5
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ This example is only to demonstrate the use of the library and its functions, an
6767
from stable_baselines3 import TD3
6868
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
6969
70-
env = gym.make("Pendulum-v1")
70+
env = gym.make("Pendulum-v1", render_mode="rgb_array")
7171
7272
# The noise objects for TD3
7373
n_actions = env.action_space.shape[-1]
@@ -76,17 +76,17 @@ This example is only to demonstrate the use of the library and its functions, an
7676
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=1)
7777
model.learn(total_timesteps=10000, log_interval=10)
7878
model.save("td3_pendulum")
79-
env = model.get_env()
79+
vec_env = model.get_env()
8080
8181
del model # remove to demonstrate saving and loading
8282
8383
model = TD3.load("td3_pendulum")
8484
85-
obs = env.reset()
85+
obs = vec_env.reset()
8686
while True:
8787
action, _states = model.predict(obs)
88-
obs, rewards, dones, info = env.step(action)
89-
env.render()
88+
obs, rewards, dones, info = vec_env.step(action)
89+
vec_env.render("human")
9090
9191
Results
9292
-------

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@
149149
url="https://github.com/DLR-RM/stable-baselines3",
150150
author_email="[email protected]",
151151
keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning "
152-
"gym openai stable baselines toolbox python data-science",
152+
"gymnasium gym openai stable baselines toolbox python data-science",
153153
license="MIT",
154154
long_description=long_description,
155155
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)