Skip to content
This repository was archived by the owner on Jun 4, 2024. It is now read-only.

Commit 90c6b10

Browse files
authored
[RLlib] Decentralized multi-agent learning; PR #1 (#21421)
1 parent d392f97 commit 90c6b10

File tree

7 files changed

+472
-70
lines changed

7 files changed

+472
-70
lines changed

rllib/agents/ddpg/apex.py

+10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@
2323
"buffer_size": 2000000,
2424
# TODO(jungong) : update once Apex supports replay_buffer_config.
2525
"replay_buffer_config": None,
26+
# Whether all shards of the replay buffer must be co-located
27+
# with the learner process (running the execution plan).
28+
# This is preferred b/c the learner process should have quick
29+
# access to the data from the buffer shards, avoiding network
30+
# traffic each time samples from the buffer(s) are drawn.
31+
# Set this to False for relaxing this constraint and allowing
32+
# replay shards to be created on node(s) other than the one
33+
# on which the learner is located.
34+
"replay_buffer_shards_colocated_with_driver": True,
2635
"learning_starts": 50000,
2736
"train_batch_size": 512,
2837
"rollout_fragment_length": 50,
@@ -31,6 +40,7 @@
3140
"worker_side_prioritization": True,
3241
"min_iter_time_s": 30,
3342
},
43+
_allow_unknown_configs=True,
3444
)
3545

3646

rllib/agents/dqn/apex.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import collections
1616
import copy
17+
import platform
1718
from typing import Tuple
1819

1920
import ray
@@ -32,7 +33,7 @@
3233
from ray.rllib.execution.rollout_ops import ParallelRollouts
3334
from ray.rllib.execution.train_ops import UpdateTargetNetwork
3435
from ray.rllib.utils import merge_dicts
35-
from ray.rllib.utils.actors import create_colocated
36+
from ray.rllib.utils.actors import create_colocated_actors
3637
from ray.rllib.utils.annotations import override
3738
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
3839
from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict
@@ -55,10 +56,21 @@
5556
"n_step": 3,
5657
"num_gpus": 1,
5758
"num_workers": 32,
59+
5860
"buffer_size": 2000000,
5961
# TODO(jungong) : add proper replay_buffer_config after
6062
# DistributedReplayBuffer type is supported.
6163
"replay_buffer_config": None,
64+
# Whether all shards of the replay buffer must be co-located
65+
# with the learner process (running the execution plan).
66+
# This is preferred b/c the learner process should have quick
67+
# access to the data from the buffer shards, avoiding network
68+
# traffic each time samples from the buffer(s) are drawn.
69+
# Set this to False for relaxing this constraint and allowing
70+
# replay shards to be created on node(s) other than the one
71+
# on which the learner is located.
72+
"replay_buffer_shards_colocated_with_driver": True,
73+
6274
"learning_starts": 50000,
6375
"train_batch_size": 512,
6476
"rollout_fragment_length": 50,
@@ -129,7 +141,8 @@ def execution_plan(workers: WorkerSet, config: dict,
129141
# Create a number of replay buffer actors.
130142
num_replay_buffer_shards = config["optimizer"][
131143
"num_replay_buffer_shards"]
132-
replay_actors = create_colocated(ReplayActor, [
144+
145+
replay_actor_args = [
133146
num_replay_buffer_shards,
134147
config["learning_starts"],
135148
config["buffer_size"],
@@ -139,7 +152,24 @@ def execution_plan(workers: WorkerSet, config: dict,
139152
config["prioritized_replay_eps"],
140153
config["multiagent"]["replay_mode"],
141154
config.get("replay_sequence_length", 1),
142-
], num_replay_buffer_shards)
155+
]
156+
# Place all replay buffer shards on the same node as the learner
157+
# (driver process that runs this execution plan).
158+
if config["replay_buffer_shards_colocated_with_driver"]:
159+
replay_actors = create_colocated_actors(
160+
actor_specs=[
161+
# (class, args, kwargs={}, count)
162+
(ReplayActor, replay_actor_args, {},
163+
num_replay_buffer_shards)
164+
],
165+
node=platform.node(), # localhost
166+
)[0] # [0]=only one item in `actor_specs`.
167+
# Place replay buffer shards on any node(s).
168+
else:
169+
replay_actors = [
170+
ReplayActor(*replay_actor_args)
171+
for _ in range(num_replay_buffer_shards)
172+
]
143173

144174
# Start the learner thread.
145175
learner_thread = LearnerThread(workers.local_worker())

rllib/agents/trainer.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
import pickle
1212
import tempfile
1313
import time
14-
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
14+
from typing import Callable, DefaultDict, Dict, List, Optional, Set, Tuple, \
15+
Type, Union
1516

1617
import ray
18+
from ray.actor import ActorHandle
1719
from ray.exceptions import RayError
1820
from ray.rllib.agents.callbacks import DefaultCallbacks
1921
from ray.rllib.env.env_context import EnvContext
@@ -722,8 +724,9 @@ def default_logger_creator(config):
722724
self._episode_history = []
723725
self._episodes_to_be_collected = []
724726

725-
# Evaluation WorkerSet and metrics last returned by `self.evaluate()`.
726-
self.evaluation_workers = None
727+
# Evaluation WorkerSet.
728+
self.evaluation_workers: Optional[WorkerSet] = None
729+
# Metrics most recently returned by `self.evaluate()`.
727730
self.evaluation_metrics = {}
728731

729732
super().__init__(config, logger_creator, remote_checkpoint_dir,
@@ -798,12 +801,19 @@ def env_creator_from_classpath(env_context):
798801
self.local_replay_buffer = (
799802
self._create_local_replay_buffer_if_necessary(self.config))
800803

804+
# Create a dict, mapping ActorHandles to sets of open remote
805+
# requests (object refs). This way, we keep track, of which actors
806+
# inside this Trainer (e.g. a remote RolloutWorker) have
807+
# already been sent how many (e.g. `sample()`) requests.
808+
self.remote_requests_in_flight: \
809+
DefaultDict[ActorHandle, Set[ray.ObjectRef]] = defaultdict(set)
810+
801811
# Deprecated way of implementing Trainer sub-classes (or "templates"
802812
# via the soon-to-be deprecated `build_trainer` utility function).
803813
# Instead, sub-classes should override the Trainable's `setup()`
804814
# method and call super().setup() from within that override at some
805815
# point.
806-
self.workers = None
816+
self.workers: Optional[WorkerSet] = None
807817
self.train_exec_impl = None
808818

809819
# Old design: Override `Trainer._init` (or use `build_trainer()`, which
@@ -845,13 +855,10 @@ def env_creator_from_classpath(env_context):
845855
self.workers, self.config,
846856
**self._kwargs_for_execution_plan())
847857

848-
# TODO: Now that workers have been created, update our policy
849-
# specs in the config[multiagent] dict with the correct spaces.
850-
# However, this leads to a problem with the evaluation
851-
# workers' observation one-hot preprocessor in
852-
# `examples/documentation/rllib_in_6sec.py` script.
853-
# self.config["multiagent"]["policies"] = \
854-
# self.workers.local_worker().policy_map.policy_specs
858+
# Now that workers have been created, update our policy
859+
# specs in the config[multiagent] dict with the correct spaces.
860+
self.config["multiagent"]["policies"] = \
861+
self.workers.local_worker().policy_dict
855862

856863
# Evaluation WorkerSet setup.
857864
# User would like to setup a separate evaluation worker set.
@@ -912,7 +919,7 @@ def env_creator_from_classpath(env_context):
912919
# If evaluation_num_workers=0, use the evaluation set's local
913920
# worker for evaluation, otherwise, use its remote workers
914921
# (parallelized evaluation).
915-
self.evaluation_workers = self._make_workers(
922+
self.evaluation_workers: WorkerSet = self._make_workers(
916923
env_creator=self.env_creator,
917924
validate_env=None,
918925
policy_class=self.get_default_policy_class(self.config),

rllib/evaluation/rollout_worker.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
TYPE_CHECKING, Union
1111

1212
import ray
13+
from ray import ObjectRef
1314
from ray import cloudpickle as pickle
1415
from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
1516
from ray.rllib.env.env_context import EnvContext
@@ -537,16 +538,16 @@ def make_sub_env(vector_index):
537538
self.make_sub_env_fn = make_sub_env
538539
self.spaces = spaces
539540

540-
policy_dict = _determine_spaces_for_multi_agent_dict(
541+
self.policy_dict = _determine_spaces_for_multi_agent_dict(
541542
policy_spec,
542543
self.env,
543544
spaces=self.spaces,
544545
policy_config=policy_config)
545546

546547
# List of IDs of those policies, which should be trained.
547-
# By default, these are all policies found in the policy_dict.
548+
# By default, these are all policies found in `self.policy_dict`.
548549
self.policies_to_train: List[PolicyID] = policies_to_train or list(
549-
policy_dict.keys())
550+
self.policy_dict.keys())
550551
self.set_policies_to_train(self.policies_to_train)
551552

552553
self.policy_map: PolicyMap = None
@@ -583,7 +584,7 @@ def make_sub_env(vector_index):
583584
f"is ignored.")
584585

585586
self._build_policy_map(
586-
policy_dict,
587+
self.policy_dict,
587588
policy_config,
588589
session_creator=tf_session_creator,
589590
seed=seed)
@@ -1111,7 +1112,7 @@ def add_policy(
11111112
"""
11121113
if policy_id in self.policy_map:
11131114
raise ValueError(f"Policy ID '{policy_id}' already in policy map!")
1114-
policy_dict = _determine_spaces_for_multi_agent_dict(
1115+
policy_dict_to_add = _determine_spaces_for_multi_agent_dict(
11151116
{
11161117
policy_id: PolicySpec(policy_cls, observation_space,
11171118
action_space, config or {})
@@ -1120,8 +1121,9 @@ def add_policy(
11201121
spaces=self.spaces,
11211122
policy_config=self.policy_config,
11221123
)
1124+
self.policy_dict.update(policy_dict_to_add)
11231125
self._build_policy_map(
1124-
policy_dict,
1126+
policy_dict_to_add,
11251127
self.policy_config,
11261128
seed=self.policy_config.get("seed"))
11271129
new_policy = self.policy_map[policy_id]
@@ -1386,6 +1388,14 @@ def set_weights(self,
13861388
>>> # Set `global_vars` (timestep) as well.
13871389
>>> worker.set_weights(weights, {"timestep": 42})
13881390
"""
1391+
# If per-policy weights are object refs, `ray.get()` them first.
1392+
if weights and isinstance(next(iter(weights.values())), ObjectRef):
1393+
actual_weights = ray.get(list(weights.values()))
1394+
weights = {
1395+
pid: actual_weights[i]
1396+
for i, pid in enumerate(weights.keys())
1397+
}
1398+
13891399
for pid, w in weights.items():
13901400
self.policy_map[pid].set_weights(w)
13911401
if global_vars:

0 commit comments

Comments
 (0)