Skip to content

Commit 6dbf398

Browse files
initial commit
1 parent 71e7133 commit 6dbf398

14 files changed

+379
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__

README.md

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Overview
2+
3+
In this tutorial we are interested in reproducible reinforcement learning research. The experiments in this repository aim to reproduce some deep reinforcement learning results from the paper [Learning Value Functions in Deep Policy Gradients using Residual Variance](https://arxiv.org/pdf/2010.04440). To do so we use specific [emprirical protocole](https://arxiv.org/abs/2304.01315) and open-source libraries that we introduce next.
4+
5+
### Empirical reinforcement learning research.
6+
![protocole](imgs/ExpFlowChart.png "Empirical protocole from Andrew Patterson, Samuel Neumann, Martha White and Adam White")
7+
### Stable deep reinforcement learning agents and study tools (seeding, plotting, agents comparison).
8+
#### Stable-baselines3
9+
#### rlberry
10+
- seeding
11+
- agent manager
12+
- hyperparams optimization
13+
#### Adastop
14+
- statistically significant comparisons
15+
16+
17+
### Usage
18+
Tested on Python 3.10
19+
```bash
20+
python3 -m venv .venv
21+
source .venv/bin/activate
22+
pip install -r requirements.txt
23+
cd empirical_rl
24+
```
25+
26+
##### Training
27+
```bash
28+
python3 training.py
29+
```
30+
##### Plotting
31+
```bash
32+
python3 plotting.py
33+
```
34+
#### Evaluating
35+
```bash
36+
python3 evaluating.py
37+
```
38+
#### Adastop (long)
39+
```bash
40+
python3 statistical_comparing.py
41+
```
42+
43+
44+
## Expected Results
45+
![rewards](imgs/rewards.png)
46+
![value_loss](imgs/value_loss.png)
47+
![var](imgs/explained_variance.png)
48+
![eval](imgs/evaluations.png)
49+
#### Adastop expected results
50+
```bash
51+
[INFO] 13:10: Test finished
52+
[INFO] 13:10: Results are
53+
Agent1 vs Agent2 mean Agent1 mean Agent2 mean diff decisions
54+
0 default_ppo vs avec_ppo -86.636 -118.6952 32.0592 equal
55+
```
56+
57+
58+
# TODOs
59+
- Ant-v4
60+
- Loop over hyperparams and expand boundaris (hyperparam optim as per Patterson 2023)
61+
- Docstrings ?
62+
- Fix bug data loading for plotting data.

empirical_rl/avec_ppo/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .avec_ppo import AVECPPO

empirical_rl/avec_ppo/avec_ppo.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from typing import TypeVar
2+
3+
import numpy as np
4+
import torch as th
5+
from gymnasium import spaces
6+
7+
from stable_baselines3 import PPO
8+
from stable_baselines3.common.utils import explained_variance
9+
10+
SelfAVECPPO = TypeVar("SelfAVECPPO", bound="AVECPPO")
11+
12+
class AVECPPO(PPO):
13+
"""
14+
PPO version of LEARNING VALUE FUNCTIONS IN DEEP POLICY GRADIENTS USING RESIDUAL VARIANCE.
15+
Paper: https://arxiv.org/abs/2010.04440
16+
17+
Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html
18+
Full PPO documentation: https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html
19+
"""
20+
21+
def train(self) -> None:
22+
"""
23+
Update policy using the currently gathered rollout buffer.
24+
"""
25+
# Switch to train mode (this affects batch norm / dropout)
26+
self.policy.set_training_mode(True)
27+
# Update optimizer learning rate
28+
self._update_learning_rate(self.policy.optimizer)
29+
# Compute current clip range
30+
clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator]
31+
# Optional: clip range for the value function
32+
if self.clip_range_vf is not None:
33+
clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator]
34+
35+
entropy_losses = []
36+
pg_losses, value_losses = [], []
37+
clip_fractions = []
38+
39+
continue_training = True
40+
# train for n_epochs epochs
41+
for epoch in range(self.n_epochs):
42+
approx_kl_divs = []
43+
# Do a complete pass on the rollout buffer
44+
for rollout_data in self.rollout_buffer.get(self.batch_size):
45+
actions = rollout_data.actions
46+
if isinstance(self.action_space, spaces.Discrete):
47+
# Convert discrete action from float to long
48+
actions = rollout_data.actions.long().flatten()
49+
50+
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
51+
values = values.flatten()
52+
# Normalize advantage
53+
advantages = rollout_data.advantages
54+
# Normalization does not make sense if mini batchsize == 1, see GH issue #325
55+
if self.normalize_advantage and len(advantages) > 1:
56+
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
57+
58+
# ratio between old and new policy, should be one at the first iteration
59+
ratio = th.exp(log_prob - rollout_data.old_log_prob)
60+
61+
# clipped surrogate loss
62+
policy_loss_1 = advantages * ratio
63+
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
64+
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
65+
66+
# Logging
67+
pg_losses.append(policy_loss.item())
68+
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
69+
clip_fractions.append(clip_fraction)
70+
71+
if self.clip_range_vf is None:
72+
# No clipping
73+
values_pred = values
74+
else:
75+
# Clip the difference between old and new value
76+
# NOTE: this depends on the reward scaling
77+
values_pred = rollout_data.old_values + th.clamp(
78+
values - rollout_data.old_values, -clip_range_vf, clip_range_vf
79+
)
80+
# Value loss using the TD(gae_lambda) target
81+
# value_loss = F.mse_loss(rollout_data.returns, values_pred)
82+
83+
# NOTE here is the variance loss:
84+
value_loss = th.var(rollout_data.returns - values_pred)
85+
value_losses.append(value_loss.item())
86+
87+
# Entropy loss favor exploration
88+
if entropy is None:
89+
# Approximate entropy when no analytical form
90+
entropy_loss = -th.mean(-log_prob)
91+
else:
92+
entropy_loss = -th.mean(entropy)
93+
94+
entropy_losses.append(entropy_loss.item())
95+
96+
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
97+
98+
# Calculate approximate form of reverse KL Divergence for early stopping
99+
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
100+
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
101+
# and Schulman blog: http://joschu.net/blog/kl-approx.html
102+
with th.no_grad():
103+
log_ratio = log_prob - rollout_data.old_log_prob
104+
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
105+
approx_kl_divs.append(approx_kl_div)
106+
107+
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
108+
continue_training = False
109+
if self.verbose >= 1:
110+
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
111+
break
112+
113+
# Optimization step
114+
self.policy.optimizer.zero_grad()
115+
loss.backward()
116+
# Clip grad norm
117+
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
118+
self.policy.optimizer.step()
119+
120+
self._n_updates += 1
121+
if not continue_training:
122+
break
123+
124+
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
125+
126+
# Logs
127+
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
128+
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
129+
self.logger.record("train/value_loss", np.mean(value_losses))
130+
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
131+
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
132+
self.logger.record("train/loss", loss.item())
133+
self.logger.record("train/explained_variance", explained_var)
134+
if hasattr(self.policy, "log_std"):
135+
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
136+
137+
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
138+
self.logger.record("train/clip_range", clip_range)
139+
if self.clip_range_vf is not None:
140+
self.logger.record("train/clip_range_vf", clip_range_vf)
141+
142+
143+

empirical_rl/evaluating.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from rlberry.manager import evaluate_agents
2+
import matplotlib.pyplot as plt
3+
4+
5+
# aliases for exp results
6+
avec_ppo = "rlberry_data/temp/manager_data/avec-ppo_2024-07-03_11-09-17_06970bd6"
7+
default_ppo = "rlberry_data/temp/manager_data/avec-ppo_2024-07-03_11-09-17_06970bd6"
8+
_ = evaluate_agents(
9+
[default_ppo, avec_ppo], n_simulations=50, show=False,
10+
) # Evaluate the trained agent on
11+
plt.savefig("evaluations")
12+

empirical_rl/plotting.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from rlberry.manager import plot_writer_data
2+
3+
4+
# aliases for exp results
5+
default_ppo = "/data_training_default_ppo/manager_data/ppo-default_2024-07-03_11-26-14_b0045e63/agent_handlers/"
6+
avec_ppo = "data_training_avec_ppo/manager_data/avec-ppo_2024-07-03_11-26-14_9e4b15e4/"
7+
8+
_ = plot_writer_data([default_ppo, avec_ppo],
9+
tag="rollout/ep_rew_mean",
10+
title="Training Episode Cumulative Rewards",
11+
show=False,
12+
savefig_fname="rewards"
13+
)
14+
15+
_ = plot_writer_data([default_ppo, avec_ppo],
16+
tag="train/explained_variance",
17+
title="Training Explained Variance",
18+
show=False,
19+
savefig_fname="explained_variance"
20+
)
21+
22+
23+
_ = plot_writer_data([default_ppo, avec_ppo],
24+
tag="train/value_loss",
25+
title="Training Value Loss",
26+
show=False,
27+
savefig_fname="value_loss"
28+
)
29+

empirical_rl/statistical_comparing.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from rlberry.manager import AdastopComparator
2+
from rlberry.agents.stable_baselines import StableBaselinesAgent
3+
from rlberry.envs import gym_make
4+
from rlberry.seeding import Seeder
5+
6+
from stable_baselines3 import PPO
7+
from avec_ppo import AVECPPO
8+
9+
10+
seed = Seeder(42)
11+
12+
managers = [
13+
dict(
14+
agent_class=StableBaselinesAgent, # The Agent class.
15+
train_env=(gym_make, dict(id="Acrobot-v1")), # The Environment to solve.
16+
fit_budget=5e4, # The number of interactions
17+
# between the agent and the
18+
# environment during training.
19+
init_kwargs=dict(algo_cls=PPO), # Init value for StableBaselinesAgent
20+
eval_kwargs=dict(eval_horizon=500), # The number of interactions
21+
# between the agent and the
22+
# environment during evaluations.
23+
agent_name="default_ppo", # The agent's name.
24+
),
25+
dict(
26+
agent_class = StableBaselinesAgent, # The Agent class.
27+
train_env=(gym_make, dict(id="Acrobot-v1")), # The Environment to solve.
28+
fit_budget=5e4, # The number of interactions
29+
# between the agent and the
30+
# environment during training.
31+
init_kwargs=dict(algo_cls=AVECPPO), # Init value for StableBaselinesAgent
32+
eval_kwargs=dict(eval_horizon=500), # The number of interactions
33+
# between the agent and the
34+
# environment during evaluations.
35+
agent_name="avec_ppo", # The agent's name.
36+
) ]
37+
# # Comparing distributions
38+
comparator = AdastopComparator(seed=42)
39+
comparator.compare(managers)
40+
print(comparator.managers_paths)

empirical_rl/training.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from rlberry.manager import ExperimentManager
2+
from rlberry.envs import gym_make
3+
from rlberry.agents.stable_baselines import StableBaselinesAgent
4+
from rlberry.seeding import Seeder
5+
6+
from stable_baselines3 import PPO
7+
from avec_ppo import AVECPPO
8+
9+
seeder = Seeder(42)
10+
11+
# The ExperimentManager class is a compact way of experimenting with a deepRL agent.
12+
default_xp = ExperimentManager(
13+
StableBaselinesAgent, # The Agent class.
14+
(gym_make, dict(id="Acrobot-v1")), # The Environment to solve.
15+
fit_budget=5e4, # The number of interactions
16+
# between the agent and the
17+
# environment during training.
18+
init_kwargs=dict(algo_cls=PPO), # Init value for StableBaselinesAgent
19+
eval_kwargs=dict(eval_horizon=500), # The number of interactions
20+
# between the agent and the
21+
# environment during evaluations.
22+
n_fit=5, # The number of agents to train.
23+
# Usually, it is good to do more
24+
# than 1 because the training is
25+
# stochastic.
26+
seed=seeder,
27+
agent_name="default_ppo", # The agent's name.
28+
output_dir="data_training_default_ppo"
29+
)
30+
31+
avec_xp = ExperimentManager(
32+
StableBaselinesAgent, # The Agent class.
33+
(gym_make, dict(id="Acrobot-v1")), # The Environment to solve.
34+
fit_budget=5e4, # The number of interactions
35+
# between the agent and the
36+
# environment during training.
37+
init_kwargs=dict(algo_cls=AVECPPO), # Init value for StableBaselinesAgent
38+
eval_kwargs=dict(eval_horizon=500), # The number of interactions
39+
# between the agent and the
40+
# environment during evaluations.
41+
n_fit=5, # The number of agents to train.
42+
# Usually, it is good to do more
43+
# than 1 because the training is
44+
# stochastic.
45+
seed=seeder,
46+
agent_name="avec_ppo", # The agent's name.
47+
output_dir="data_training_avec_ppo"
48+
)
49+
50+
default_xp.fit(), avec_xp.fit()
51+
52+
53+
54+
# FOR TESTING PURPOSES
55+
from rlberry.manager import plot_writer_data
56+
57+
_ = plot_writer_data([default_xp, avec_xp],
58+
tag="rollout/ep_rew_mean",
59+
title="Training Episode Cumulative Rewards",
60+
show=False,
61+
savefig_fname="rewards"
62+
)
63+
64+
_ = plot_writer_data([default_xp, avec_xp],
65+
tag="train/explained_variance",
66+
title="Training Explained Variance",
67+
show=False,
68+
savefig_fname="explained_variance"
69+
)
70+
71+
_ = plot_writer_data([default_xp, avec_xp],
72+
tag="train/value_loss",
73+
title="Training Value Loss",
74+
show=False,
75+
savefig_fname="value_loss"
76+
)
77+
78+
from rlberry.manager import evaluate_agents
79+
import matplotlib.pyplot as plt
80+
81+
# Comparing means
82+
_ = evaluate_agents(
83+
[default_xp, avec_xp], n_simulations=50,show=False,
84+
) # Evaluate the trained agent on
85+
plt.savefig("evaluations")

imgs/ExpFlowChart.png

77.5 KB
Loading

imgs/evaluations.png

17.4 KB
Loading

imgs/explained_variance.png

51.4 KB
Loading

imgs/rewards.png

50.4 KB
Loading

imgs/value_loss.png

39.9 KB
Loading

requirements.txt

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
rlberry==0.7
2+
numpy==1.25.2
3+
torch==2.3.1
4+
tensorboard==2.17.0
5+
stable-baselines3==2.2.1
6+
gymnasium[mujoco]==0.29.1

0 commit comments

Comments
 (0)