Skip to content

Commit d542732

Browse files
committed
Rename to stable-baselines3
1 parent 4a2c247 commit d542732

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+164
-164
lines changed

.coveragerc

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ omit =
44
tests/*
55
setup.py
66
# Require graphical interface
7-
torchy_baselines/common/results_plotter.py
7+
stable_baselines3/common/results_plotter.py
88

99
[report]
1010
exclude_lines =

.github/ISSUE_TEMPLATE/issue-template.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ If you are submitting a bug report, please fill in the following details.
1313
If your issue is related to a custom gym environment, please check it first using:
1414

1515
```python
16-
from torchy_baselines.common.env_checker import check_env
16+
from stable_baselines3.common.env_checker import check_env
1717

1818
env = CustomEnv(arg1, ...)
1919
# It will check your custom environment and output additional warnings if needed
@@ -30,7 +30,7 @@ Please use the [markdown code blocks](https://help.github.com/en/articles/creati
3030
for both code and stack traces.
3131

3232
```python
33-
from torchy_baselines import ...
33+
from stable_baselines3 import ...
3434

3535
```
3636

NOTICE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Large portion of the code of Torchy-Baselines (in `common/`) were ported from Stable-Baselines, a fork of OpenAI Baselines,
1+
Large portion of the code of Stable-Baselines3 (in `common/`) were ported from Stable-Baselines, a fork of OpenAI Baselines,
22
both licensed under the MIT License:
33

44
before the fork (June 2018):

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
[![Build Status](https://travis-ci.com/hill-a/stable-baselines.svg?branch=master)](https://travis-ci.com/hill-a/stable-baselines) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines.readthedocs.io/en/master/?badge=master)
44

5-
# Torchy Baselines
5+
# Stable Baselines3
66

77
PyTorch version of [Stable Baselines](https://github.com/hill-a/stable-baselines), a set of improved implementations of reinforcement learning algorithms.
88

@@ -58,7 +58,7 @@ To cite this repository in publications:
5858
```
5959
@misc{torchy-baselines,
6060
author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
61-
title = {Torchy Baselines},
61+
title = {Stable Baselines3},
6262
year = {2019},
6363
publisher = {GitHub},
6464
journal = {GitHub repository},

docs/conf.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,19 @@ def __getattr__(cls, name):
4444
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
4545

4646

47-
import torchy_baselines
47+
import stable_baselines3
4848

4949

5050
# -- Project information -----------------------------------------------------
5151

52-
project = 'Torchy Baselines'
53-
copyright = '2020, Torchy Baselines'
54-
author = 'Torchy Baselines Contributors'
52+
project = 'Stable Baselines3'
53+
copyright = '2020, Stable Baselines3'
54+
author = 'Stable Baselines3 Contributors'
5555

5656
# The short X.Y version
57-
version = 'master (' + torchy_baselines.__version__ + ' )'
57+
version = 'master (' + stable_baselines3.__version__ + ' )'
5858
# The full version, including alpha/beta/rc tags
59-
release = torchy_baselines.__version__
59+
release = stable_baselines3.__version__
6060

6161

6262
# -- General configuration ---------------------------------------------------
@@ -179,8 +179,8 @@ def setup(app):
179179
# (source start file, target name, title,
180180
# author, documentclass [howto, manual, or own class]).
181181
latex_documents = [
182-
(master_doc, 'TorchyBaselines.tex', 'Torchy Baselines Documentation',
183-
'Torchy Baselines Contributors', 'manual'),
182+
(master_doc, 'TorchyBaselines.tex', 'Stable Baselines3 Documentation',
183+
'Stable Baselines3 Contributors', 'manual'),
184184
]
185185

186186

@@ -189,7 +189,7 @@ def setup(app):
189189
# One entry per manual page. List of tuples
190190
# (source start file, name, description, authors, manual section).
191191
man_pages = [
192-
(master_doc, 'torchybaselines', 'Torchy Baselines Documentation',
192+
(master_doc, 'torchybaselines', 'Stable Baselines3 Documentation',
193193
[author], 1)
194194
]
195195

@@ -200,7 +200,7 @@ def setup(app):
200200
# (source start file, target name, title, author,
201201
# dir menu entry, description, category)
202202
texinfo_documents = [
203-
(master_doc, 'TorchyBaselines', 'Torchy Baselines Documentation',
203+
(master_doc, 'TorchyBaselines', 'Stable Baselines3 Documentation',
204204
author, 'TorchyBaselines', 'One line description of project.',
205205
'Miscellaneous'),
206206
]

docs/guide/quickstart.rst

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ Here is a quick example of how to train and run SAC on a Pendulum environment:
1212
1313
import gym
1414
15-
from torchy_baselines.sac.policies import MlpPolicy
16-
from torchy_baselines.common.vec_env import DummyVecEnv
17-
from torchy_baselines import SAC
15+
from stable_baselines3.sac.policies import MlpPolicy
16+
from stable_baselines3.common.vec_env import DummyVecEnv
17+
from stable_baselines3 import SAC
1818
1919
env = gym.make('Pendulum-v0')
2020
@@ -34,6 +34,6 @@ the policy is registered:
3434

3535
.. code-block:: python
3636
37-
from torchy_baselines import SAC
37+
from stable_baselines3 import SAC
3838
3939
model = SAC('MlpPolicy', 'Pendulum-v0').learn(10000)

docs/guide/vec_envs.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
.. _vec_env:
22

3-
.. automodule:: torchy_baselines.common.vec_env
3+
.. automodule:: stable_baselines3.common.vec_env
44

55
Vectorized Environments
66
=======================

docs/index.rst

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
You can adapt this file completely to your liking, but it should at least
44
contain the root `toctree` directive.
55
6-
Welcome to Torchy Baselines docs! - Pytorch RL Baselines
6+
Welcome to Stable Baselines3 docs! - Pytorch RL Baselines
77
========================================================
88

9-
`Torchy Baselines <https://github.com/hill-a/stable-baselines>`_ is the PyTorch version of `Stable Baselines <https://github.com/hill-a/stable-baselines>`_,
9+
`Stable Baselines3 <https://github.com/hill-a/stable-baselines>`_ is the PyTorch version of `Stable Baselines <https://github.com/hill-a/stable-baselines>`_,
1010
a set of improved implementations of reinforcement learning algorithms.
1111

1212
RL Baselines Zoo (collection of pre-trained agents): https://github.com/araffin/rl-baselines-zoo
@@ -41,15 +41,15 @@ RL Baselines zoo also offers a simple interface to train, evaluate agents and do
4141
misc/changelog
4242

4343

44-
Citing Torchy Baselines
44+
Citing Stable Baselines3
4545
-----------------------
4646
To cite this project in publications:
4747

4848
.. code-block:: bibtex
4949
5050
@misc{torchy-baselines,
5151
author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
52-
title = {Torchy Baselines},
52+
title = {Stable Baselines3},
5353
year = {2019},
5454
publisher = {GitHub},
5555
journal = {GitHub repository},

docs/misc/changelog.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ Pre-Release 0.2.0 (2020-02-14)
111111

112112
Breaking Changes:
113113
^^^^^^^^^^^^^^^^^
114-
- Python 2 support was dropped, Torchy Baselines now requires Python 3.6 or above
114+
- Python 2 support was dropped, Stable Baselines3 now requires Python 3.6 or above
115115
- Return type of ``evaluation.evaluate_policy()`` has been changed
116116
- Refactored the replay buffer to avoid transformation between PyTorch and NumPy
117117
- Created `OffPolicyRLModel` base class
@@ -160,7 +160,7 @@ New Features:
160160
Maintainers
161161
-----------
162162

163-
Torchy-Baselines is currently maintained by `Antonin Raffin`_ (aka `@araffin`_).
163+
Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_).
164164

165165
.. _Antonin Raffin: https://araffin.github.io/
166166
.. _@araffin: https://github.com/araffin

docs/modules/a2c.rst

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
.. _a2c:
22

3-
.. automodule:: torchy_baselines.a2c
3+
.. automodule:: stable_baselines3.a2c
44

55

66
A2C
@@ -44,9 +44,9 @@ Train a A2C agent on `CartPole-v1` using 4 processes.
4444
4545
import gym
4646
47-
from torchy_baselines.common.policies import MlpPolicy
48-
from torchy_baselines.common import make_vec_env
49-
from torchy_baselines import A2C
47+
from stable_baselines3.common.policies import MlpPolicy
48+
from stable_baselines3.common import make_vec_env
49+
from stable_baselines3 import A2C
5050
5151
# Parallel environments
5252
env = make_vec_env('CartPole-v1', n_envs=4)

docs/modules/base.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
.. _base_algo:
22

3-
.. automodule:: torchy_baselines.common.base_class
3+
.. automodule:: stable_baselines3.common.base_class
44

55

66
Base RL Class

docs/modules/ppo.rst

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
.. _ppo2:
22

3-
.. automodule:: torchy_baselines.ppo
3+
.. automodule:: stable_baselines3.ppo
44

55
PPO
66
===
@@ -53,9 +53,9 @@ Train a PPO agent on `Pendulum-v0` using 4 processes.
5353
5454
import gym
5555
56-
from torchy_baselines.ppo.policies import MlpPolicy
57-
from torchy_baselines.common.vec_env import SubprocVecEnv
58-
from torchy_baselines import PPO
56+
from stable_baselines3.ppo.policies import MlpPolicy
57+
from stable_baselines3.common.vec_env import SubprocVecEnv
58+
from stable_baselines3 import PPO
5959
6060
# multiprocess environment
6161
n_cpu = 4

docs/modules/sac.rst

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
.. _sac:
22

3-
.. automodule:: torchy_baselines.sac
3+
.. automodule:: stable_baselines3.sac
44

55

66
SAC
@@ -14,7 +14,7 @@ A key feature of SAC, and a major difference with common RL algorithms, is that
1414

1515
.. warning::
1616

17-
The SAC model does not support ``torchy_baselines.common.policies`` because it uses double q-values
17+
The SAC model does not support ``stable_baselines3.common.policies`` because it uses double q-values
1818
and value estimation, as a result it must use its own policy models (see :ref:`sac_policies`).
1919

2020

@@ -72,9 +72,9 @@ Example
7272
import gym
7373
import numpy as np
7474
75-
from torchy_baselines.sac.policies import MlpPolicy
76-
from torchy_baselines.common.vec_env import DummyVecEnv
77-
from torchy_baselines import SAC
75+
from stable_baselines3.sac.policies import MlpPolicy
76+
from stable_baselines3.common.vec_env import DummyVecEnv
77+
from stable_baselines3 import SAC
7878
7979
env = gym.make('Pendulum-v0')
8080
env = DummyVecEnv([lambda: env])

docs/modules/td3.rst

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
.. _td3:
22

3-
.. automodule:: torchy_baselines.td3
3+
.. automodule:: stable_baselines3.td3
44

55

66
TD3
@@ -14,7 +14,7 @@ We recommend reading `OpenAI Spinning guide on TD3 <https://spinningup.openai.co
1414

1515
.. warning::
1616

17-
The TD3 model does not support ``torchy_baselines.common.policies`` because it uses double q-values
17+
The TD3 model does not support ``stable_baselines3.common.policies`` because it uses double q-values
1818
estimation, as a result it must use its own policy models (see :ref:`td3_policies`).
1919

2020

@@ -64,9 +64,9 @@ Example
6464
6565
import numpy as np
6666
67-
from torchy_baselines import TD3
68-
from torchy_baselines.td3.policies import MlpPolicy
69-
from torchy_baselines.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
67+
from stable_baselines3 import TD3
68+
from stable_baselines3.td3.policies import MlpPolicy
69+
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
7070
7171
# The noise objects for TD3
7272
n_actions = env.action_space.shape[-1]

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ filterwarnings =
1818
ignore::UserWarning:gym
1919

2020
[pytype]
21-
inputs = torchy_baselines
21+
inputs = stable_baselines3

setup.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import subprocess
44
from setuptools import setup, find_packages
55

6-
with open(os.path.join('torchy_baselines', 'version.txt'), 'r') as file_handler:
6+
with open(os.path.join('stable_baselines3', 'version.txt'), 'r') as file_handler:
77
__version__ = file_handler.read()
88

99

10-
setup(name='torchy_baselines',
10+
setup(name='stable_baselines3',
1111
packages=[package for package in find_packages()
12-
if package.startswith('torchy_baselines')],
12+
if package.startswith('stable_baselines3')],
1313
install_requires=[
1414
'gym[classic_control]>=0.11',
1515
'numpy',

torchy_baselines/__init__.py stable_baselines3/__init__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
22

3-
from torchy_baselines.a2c import A2C
4-
from torchy_baselines.ppo import PPO
5-
from torchy_baselines.sac import SAC
6-
from torchy_baselines.td3 import TD3
3+
from stable_baselines3.a2c import A2C
4+
from stable_baselines3.ppo import PPO
5+
from stable_baselines3.sac import SAC
6+
from stable_baselines3.td3 import TD3
77

88
# Read version from file
99
version_file = os.path.join(os.path.dirname(__file__), 'version.txt')

stable_baselines3/a2c/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from stable_baselines3.a2c.a2c import A2C
2+
from stable_baselines3.ppo.policies import MlpPolicy

torchy_baselines/a2c/a2c.py stable_baselines3/a2c/a2c.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from gym import spaces
44
from typing import Type, Union, Callable, Optional, Dict, Any
55

6-
from torchy_baselines.common import logger
7-
from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback
8-
from torchy_baselines.common.utils import explained_variance
9-
from torchy_baselines.ppo.policies import PPOPolicy
10-
from torchy_baselines.ppo.ppo import PPO
6+
from stable_baselines3.common import logger
7+
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
8+
from stable_baselines3.common.utils import explained_variance
9+
from stable_baselines3.ppo.policies import PPOPolicy
10+
from stable_baselines3.ppo.ppo import PPO
1111

1212

1313
class A2C(PPO):
File renamed without changes.

torchy_baselines/common/base_class.py stable_baselines3/common/base_class.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@
1111
import torch as th
1212
import numpy as np
1313

14-
from torchy_baselines.common import logger
15-
from torchy_baselines.common.policies import BasePolicy, get_policy_from_name
16-
from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, update_learning_rate, get_device
17-
from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize, VecTransposeImage
18-
from torchy_baselines.common.preprocessing import is_image_space
19-
from torchy_baselines.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr
20-
from torchy_baselines.common.type_aliases import GymEnv, TensorDict, RolloutReturn, MaybeCallback
21-
from torchy_baselines.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
22-
from torchy_baselines.common.monitor import Monitor
23-
from torchy_baselines.common.noise import ActionNoise
24-
from torchy_baselines.common.buffers import ReplayBuffer
14+
from stable_baselines3.common import logger
15+
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
16+
from stable_baselines3.common.utils import set_random_seed, get_schedule_fn, update_learning_rate, get_device
17+
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize, VecTransposeImage
18+
from stable_baselines3.common.preprocessing import is_image_space
19+
from stable_baselines3.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr
20+
from stable_baselines3.common.type_aliases import GymEnv, TensorDict, RolloutReturn, MaybeCallback
21+
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
22+
from stable_baselines3.common.monitor import Monitor
23+
from stable_baselines3.common.noise import ActionNoise
24+
from stable_baselines3.common.buffers import ReplayBuffer
2525

2626

2727
class BaseRLModel(ABC):

torchy_baselines/common/buffers.py stable_baselines3/common/buffers.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import torch as th
55
from gym import spaces
66

7-
from torchy_baselines.common.vec_env import VecNormalize
8-
from torchy_baselines.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples
9-
from torchy_baselines.common.preprocessing import get_action_dim, get_obs_shape
7+
from stable_baselines3.common.vec_env import VecNormalize
8+
from stable_baselines3.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples
9+
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
1010

1111

1212
class BaseBuffer(object):

0 commit comments

Comments
 (0)