Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: DLR-RM/stable-baselines3
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: master
Choose a base ref
...
head repository: jwong8314/stable-baselines3
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: support_pyg
Choose a head ref
Can’t automatically merge. Don’t worry, you can still create the pull request.
  • 11 commits
  • 9 files changed
  • 1 contributor

Commits on May 9, 2023

  1. Copy the full SHA
    4552932 View commit details
  2. Copy the full SHA
    ce737be View commit details
  3. Copy the full SHA
    9f5de7c View commit details
  4. Copy the full SHA
    70a699f View commit details
  5. fix batching issue

    jwong8314 committed May 9, 2023
    Copy the full SHA
    b58216b View commit details
  6. refactor clipping

    jwong8314 committed May 9, 2023
    Copy the full SHA
    b127543 View commit details
  7. Copy the full SHA
    eda0d11 View commit details
  8. Copy the full SHA
    6364767 View commit details
  9. minimal edits

    jwong8314 committed May 9, 2023
    Copy the full SHA
    5d58b63 View commit details
  10. minimal patch

    jwong8314 committed May 9, 2023
    Copy the full SHA
    7033a9e View commit details

Commits on May 17, 2023

  1. support extra logging

    jwong8314 committed May 17, 2023
    Copy the full SHA
    6f341b7 View commit details
218 changes: 214 additions & 4 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
@@ -6,14 +6,18 @@
import torch as th
from gymnasium import spaces

from torch_geometric.data import Data, Batch

from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
from stable_baselines3.common.type_aliases import (
DictReplayBufferSamples,
DictRolloutBufferSamples,
ReplayBufferSamples,
RolloutBufferSamples,
GraphRolloutBufferSamples,
)
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env.util import dict_to_obs, graph_copy_obs_dict
from stable_baselines3.common.vec_env import VecNormalize

try:
@@ -439,7 +443,6 @@ def add(

# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

self.observations[self.pos] = np.array(obs).copy()
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
@@ -493,6 +496,214 @@ def _get_samples(
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))


class GraphRolloutBuffer(BaseBuffer):
"""
Rollout buffer used in on-policy algorithms like A2C/PPO.
It corresponds to ``buffer_size`` transitions collected
using the current policy.
This experience will be discarded after the policy update.
In order to use PPO objective, we also store the current value of each state
and the log probability of each taken action.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
Hence, it is only involved in policy and value function training but not action selection.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""

observations: np.ndarray
actions: np.ndarray
rewards: np.ndarray
advantages: np.ndarray
returns: np.ndarray
episode_starts: np.ndarray
log_probs: np.ndarray
values: np.ndarray

def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
assert isinstance(self.observation_space, spaces.Graph), "Graph buffer"

self.gae_lambda = gae_lambda
self.gamma = gamma
self.generator_ready = False
self.observations = {"node": {}, "edge_weight": {}, "edge_index": {}}
self.reset()

def reset(self) -> None:
self.observations = {"node": {}, "edge_weight": {}, "edge_index": {}} # variable size
self.actions = {}
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = {}
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
super().reset()

def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
"""
Post-processing step: compute the lambda-return (TD(lambda) estimate)
and GAE(lambda) advantage.
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
where R is the sum of discounted reward with value bootstrap
(because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
The TD(lambda) estimator has also two special cases:
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
:param last_values: state value estimation for the last step (one for each env)
:param dones: if the last step was a terminal step (one bool for each env).
"""
# Convert to numpy
last_values = last_values.clone().cpu().numpy().flatten()

last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = 1.0 - dones
next_values = last_values
else:
next_non_terminal = 1.0 - self.episode_starts[step + 1]
next_values = self.values[step + 1]
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
self.returns = self.advantages + self.values

def add(
self,
obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action -- assumed to be ndarray by clipping
:param reward:
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
following the current policy.
"""
if isinstance(log_prob, List):
if len(log_prob) == 1:
log_prob = log_prob[0].cpu()
else:
raise NotImplementedError
if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)

# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs, *self.obs_shape))

# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
if isinstance(action, List):
if len(action) == 1:
action = action[0]
else:
# Probably loop trough and add each entry independently into the buffer
raise NotImplementedError
else:
action = action.reshape((self.n_envs, self.action_dim))

assert isinstance(obs, Data)
self.observations["node"][self.pos] = obs.x # should be a pyg Data entry
self.observations["edge_index"][self.pos] = obs.edge_index
self.observations["edge_weight"][self.pos] = obs.w

self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.episode_starts[self.pos] = np.array(episode_start).copy()
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True

def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data

if not self.generator_ready:
self.observations = dict_to_obs(self.observation_space, self.observations)
assert all([isinstance(k, int) for k in self.actions.keys()]), f"Action not indexed correctly"
self.actions_flat = np.stack([self.actions[i] for i in range(len(self.actions.keys()))])
self.actions = self.actions_flat

assert all([isinstance(k, int) for k in self.log_probs.keys()]), f"Action not indexed correctly"
self.log_probs_flat = np.stack([self.log_probs[i] for i in range(len(self.log_probs.keys()))])
self.log_probs = self.log_probs_flat

_tensor_names = [
"values",
"advantages",
"returns",
]

for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True

# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs

start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size

def _get_samples(
self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
data = (
Batch.from_data_list(self.observations[batch_inds]),
self.to_torch(self.actions[batch_inds]),
self.to_torch(self.values[batch_inds].flatten()),
self.to_torch(self.log_probs[batch_inds]),
self.to_torch(self.advantages[batch_inds]),
self.to_torch(self.returns[batch_inds].flatten()),
)

return GraphRolloutBufferSamples(*tuple(data))


class DictReplayBuffer(ReplayBuffer):
"""
Dict Replay buffer used in off-policy algorithms like SAC/TD3.
@@ -681,8 +892,6 @@ class DictRolloutBuffer(RolloutBuffer):
:param n_envs: Number of parallel environments
"""

observations: Dict[str, np.ndarray]

def __init__(
self,
buffer_size: int,
@@ -699,7 +908,8 @@ def __init__(

self.gae_lambda = gae_lambda
self.gamma = gamma

self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
self.generator_ready = False
self.reset()

49 changes: 44 additions & 5 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
from gymnasium import spaces

from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer, GraphRolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
@@ -100,15 +100,20 @@ def __init__(
self.ent_coef = ent_coef
self.vf_coef = vf_coef
self.max_grad_norm = max_grad_norm
self.episodes = 0

if _init_setup_model:
self._setup_model()

def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)

buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer
if isinstance(self.observation_space, spaces.Dict):
buffer_cls = DictRolloutBuffer
elif isinstance(self.observation_space, spaces.Graph):
buffer_cls = GraphRolloutBuffer
else:
buffer_cls = RolloutBuffer

self.rollout_buffer = buffer_cls(
self.n_steps,
@@ -158,6 +163,8 @@ def collect_rollouts(

callback.on_rollout_start()

self.episodes = 0

while n_steps < n_rollout_steps:
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
@@ -167,13 +174,32 @@ def collect_rollouts(
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
actions, values, log_probs = self.policy(obs_tensor)
actions = actions.cpu().numpy()
actions, log_probs = actions, log_probs

# Rescale and perform action
clipped_actions = actions
# Clip the actions to avoid out of bound error
if isinstance(self.action_space, spaces.Box):
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
if isinstance(actions, List):
clipped_actions = [
th.clamp(
a,
min=th.Tensor(self.action_space.low).to(a.device),
max=th.Tensor(self.action_space.high).to(a.device),
)
for a in actions
]
actions = [a.cpu().numpy() for a in actions]
else:
clipped_actions = th.clamp(
actions,
min=th.Tensor(self.action_space.low).to(actions.device),
max=th.Tensor(self.action_space.high).to(actions.device),
)
actions = actions.cpu().numpy()
else:
clipped_actions = actions.cpu().numpy()
actions = actions.cpu().numpy()

new_obs, rewards, dones, infos = env.step(clipped_actions)

@@ -203,6 +229,8 @@ def collect_rollouts(
with th.no_grad():
terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type]
rewards[idx] += self.gamma * terminal_value
if done:
self.episodes += 1

rollout_buffer.add(
self._last_obs, # type: ignore[arg-type]
@@ -273,6 +301,17 @@ def learn(
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))

diff = self.episodes
self.logger.record("result/success", sum([ep_info["success"] for ep_info in self.ep_info_buffer][-diff:]))
self.logger.record("result/failed", sum([ep_info["failed"] for ep_info in self.ep_info_buffer][-diff:]))
self.logger.record(
"result/truncated", sum([ep_info["truncated"] for ep_info in self.ep_info_buffer][-diff:])
)
self.logger.record(
"result/terminated", sum([ep_info["terminated"] for ep_info in self.ep_info_buffer][-diff:])
)

self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
67 changes: 46 additions & 21 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
@@ -624,9 +624,13 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1, *self.action_space.shape))
if isinstance(distribution, List):
actions = [d.get_actions(deterministic=deterministic) for d in distribution]
log_prob = [d.log_prob(a) for d, a in zip(distribution, actions)]
else:
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1, *self.action_space.shape))
return actions, values, log_prob

def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
@@ -650,23 +654,36 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
:param latent_pi: Latent code for the actor
:return: Action distribution
"""
mean_actions = self.action_net(latent_pi)

if isinstance(self.action_dist, DiagGaussianDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std)
elif isinstance(self.action_dist, CategoricalDistribution):
# Here mean_actions are the logits before the softmax
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, MultiCategoricalDistribution):
# Here mean_actions are the flattened logits
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, BernoulliDistribution):
# Here mean_actions are the logits (before rounding to get the binary actions)
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
if isinstance(latent_pi, List): # indicates we're doing one policy to rule them all
mean_actions_graphs = [self.action_net(a_latent_pi) for a_latent_pi in latent_pi]
if isinstance(self.action_dist, DiagGaussianDistribution):
dist = [
self.action_dist.proba_distribution(mean_actions, self.log_std) for mean_actions in mean_actions_graphs
]
else:
raise NotImplementedError
return dist
else:
raise ValueError("Invalid action distribution")
mean_actions = self.action_net(latent_pi)

dist = None
if isinstance(self.action_dist, DiagGaussianDistribution):
dist = self.action_dist.proba_distribution(mean_actions, self.log_std)
elif isinstance(self.action_dist, CategoricalDistribution):
# Here mean_actions are the logits before the softmax
dist = self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, MultiCategoricalDistribution):
# Here mean_actions are the flattened logits
dist = self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, BernoulliDistribution):
# Here mean_actions are the logits (before rounding to get the binary actions)
dist = self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
dist = self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
else:
raise ValueError("Invalid action distribution")

return dist

def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
"""
@@ -697,9 +714,17 @@ def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tenso
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
distribution = self._get_action_dist_from_latent(latent_pi)
log_prob = distribution.log_prob(actions)
if isinstance(distribution, List):
log_prob = [d.log_prob(a) for d, a in zip(distribution, actions)]
log_prob = th.stack(log_prob)
entropy = [d.entropy() for d in distribution]
entropy = th.stack(entropy)
else:
log_prob = distribution.log_prob(actions)

entropy = distribution.entropy()
values = self.value_net(latent_vf)
entropy = distribution.entropy()

return values, log_prob, entropy

def get_distribution(self, obs: th.Tensor) -> Distribution:
6 changes: 4 additions & 2 deletions stable_baselines3/common/preprocessing.py
Original file line number Diff line number Diff line change
@@ -134,7 +134,8 @@ def preprocess_obs(
for key, _obs in obs.items():
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
return preprocessed_obs

elif isinstance(observation_space, spaces.Graph):
return obs
else:
raise NotImplementedError(f"Preprocessing not implemented for {observation_space}")

@@ -161,7 +162,8 @@ def get_obs_shape(
return observation_space.shape
elif isinstance(observation_space, spaces.Dict):
return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc]

elif isinstance(observation_space, spaces.Graph):
return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()}
else:
raise NotImplementedError(f"{observation_space} observation space is not supported")

10 changes: 10 additions & 0 deletions stable_baselines3/common/type_aliases.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import gymnasium as gym
import numpy as np
import torch as th
from torch_geometric.data import Batch

if sys.version_info >= (3, 8):
from typing import Protocol
@@ -39,6 +40,15 @@ class RolloutBufferSamples(NamedTuple):
returns: th.Tensor


class GraphRolloutBufferSamples(NamedTuple):
observations: Batch
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor


class DictRolloutBufferSamples(NamedTuple):
observations: TensorDict
actions: th.Tensor
3 changes: 3 additions & 0 deletions stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
import numpy as np
import torch as th
from gymnasium import spaces
from torch_geometric.data import Data

import stable_baselines3 as sb3

@@ -482,6 +483,8 @@ def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.devi
"""
if isinstance(obs, np.ndarray):
return th.as_tensor(obs, device=device)
elif isinstance(obs, Data):
return obs.to(device)
elif isinstance(obs, dict):
return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()}
else:
2 changes: 1 addition & 1 deletion stable_baselines3/common/vec_env/__init__.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
from typing import Optional, Type, Union

from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv, DummyGraphVecEnv
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
26 changes: 21 additions & 5 deletions stable_baselines3/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@

from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
from stable_baselines3.common.vec_env.patch_gym import _patch_env
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
from stable_baselines3.common.vec_env.util import graph_copy_obs_dict, copy_obs_dict, dict_to_obs, obs_space_info


class DummyVecEnv(VecEnv):
@@ -24,8 +24,6 @@ class DummyVecEnv(VecEnv):
:raises ValueError: If the same environment instance is passed as the output of two or more different env_fn.
"""

actions: np.ndarray

def __init__(self, env_fns: List[Callable[[], gym.Env]]):
self.envs = [_patch_env(fn()) for fn in env_fns]
if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs):
@@ -40,13 +38,14 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]]):
)
env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space, env.render_mode)
obs_space = env.observation_space
self.keys, shapes, dtypes = obs_space_info(obs_space)
self.obs_space = env.observation_space
self.keys, shapes, dtypes = obs_space_info(self.obs_space)

self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys])
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos: List[Dict[str, Any]] = [{} for _ in range(self.num_envs)]
self.actions = None
self.metadata = env.metadata

def step_async(self, actions: np.ndarray) -> None:
@@ -146,3 +145,20 @@ def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndice
def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
indices = self._get_indices(indices)
return [self.envs[i] for i in indices]


class DummyGraphVecEnv(DummyVecEnv):
def __init__(self, env_fns: List[Callable[[], gym.Env]]):
super().__init__(env_fns)
assert isinstance(self.obs_space, gym.spaces.Graph)
self.buf_obs = OrderedDict([(k, {}) for k in self.keys])

def _obs_from_buf(self) -> VecEnvObs:
assert isinstance(self.observation_space, gym.spaces.Graph)
return dict_to_obs(self.observation_space, graph_copy_obs_dict(self.buf_obs))

def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
assert isinstance(self.envs[env_idx].observation_space, gym.spaces.Graph)
self.buf_obs["node"][env_idx] = obs.x
self.buf_obs["edge_weight"][env_idx] = obs.w
self.buf_obs["edge_index"][env_idx] = obs.edge_index
20 changes: 20 additions & 0 deletions stable_baselines3/common/vec_env/util.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
"""
from collections import OrderedDict
from typing import Any, Dict, List, Tuple
from torch_geometric.data import Data, Batch

import numpy as np
from gymnasium import spaces
@@ -22,6 +23,17 @@ def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])


def graph_copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""
Deep-copy a dict of numpy arrays.
:param obs: a dict of numpy arrays.
:return: a dict of copied numpy arrays.
"""
assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'"
return OrderedDict([(k, {i: env.clone() for i, env in v.items()}) for k, v in obs.items()])


def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
"""
Convert an internal representation raw_obs into the appropriate type
@@ -38,6 +50,12 @@ def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> Vec
elif isinstance(obs_space, spaces.Tuple):
assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space"
return tuple(obs_dict[i] for i in range(len(obs_space.spaces)))
elif isinstance(obs_space, spaces.Graph):
indexes = obs_dict["node"].keys()
list_of_graphs = [
Data(x=obs_dict["node"][i], edge_index=obs_dict["edge_index"][i], w=obs_dict["edge_weight"][i]) for i in indexes
]
return Batch.from_data_list(list_of_graphs)
else:
assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
return obs_dict[None]
@@ -63,6 +81,8 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[
subspaces = obs_space.spaces
elif isinstance(obs_space, spaces.Tuple):
subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment]
elif isinstance(obs_space, spaces.Graph):
subspaces = obs_space.spaces
else:
assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
subspaces = {None: obs_space} # type: ignore[assignment]