Skip to content
This repository was archived by the owner on Jun 4, 2024. It is now read-only.

Commit 99ae7ba

Browse files
authored
[RLlib] JAXPolicy prep. PR #1. (#13077)
1 parent 25f9f0d commit 99ae7ba

28 files changed

+501
-359
lines changed

rllib/agents/a3c/a3c_torch_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import ray
22
from ray.rllib.evaluation.postprocessing import compute_advantages, \
33
Postprocessing
4+
from ray.rllib.policy.policy_template import build_policy_class
45
from ray.rllib.policy.sample_batch import SampleBatch
5-
from ray.rllib.policy.torch_policy_template import build_torch_policy
66
from ray.rllib.utils.framework import try_import_torch
77

88
torch, nn = try_import_torch()
@@ -84,8 +84,9 @@ def _value(self, obs):
8484
return self.model.value_function()[0]
8585

8686

87-
A3CTorchPolicy = build_torch_policy(
87+
A3CTorchPolicy = build_policy_class(
8888
name="A3CTorchPolicy",
89+
framework="torch",
8990
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
9091
loss_fn=actor_critic_loss,
9192
stats_fn=loss_and_entropy_stats,

rllib/agents/ars/ars_torch_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import ray
55
from ray.rllib.agents.es.es_torch_policy import after_init, before_init, \
66
make_model_and_action_dist
7-
from ray.rllib.policy.torch_policy_template import build_torch_policy
7+
from ray.rllib.policy.policy_template import build_policy_class
88

9-
ARSTorchPolicy = build_torch_policy(
9+
ARSTorchPolicy = build_policy_class(
1010
name="ARSTorchPolicy",
11+
framework="torch",
1112
loss_fn=None,
1213
get_default_config=lambda: ray.rllib.agents.ars.ars.DEFAULT_CONFIG,
1314
before_init=before_init,

rllib/agents/ddpg/ddpg_torch_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from ray.rllib.models.torch.misc import SlimFC
44
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
5-
from ray.rllib.utils.framework import try_import_torch, get_activation_fn
5+
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
66

77
torch, nn = try_import_torch()
88

rllib/agents/ddpg/ddpg_torch_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
88
PRIO_WEIGHTS
99
from ray.rllib.models.torch.torch_action_dist import TorchDeterministic
10+
from ray.rllib.policy.policy_template import build_policy_class
1011
from ray.rllib.policy.sample_batch import SampleBatch
11-
from ray.rllib.policy.torch_policy_template import build_torch_policy
1212
from ray.rllib.utils.framework import try_import_torch
1313
from ray.rllib.utils.torch_ops import huber_loss, l2_loss
1414

@@ -264,8 +264,9 @@ def setup_late_mixins(policy, obs_space, action_space, config):
264264
TargetNetworkMixin.__init__(policy)
265265

266266

267-
DDPGTorchPolicy = build_torch_policy(
267+
DDPGTorchPolicy = build_policy_class(
268268
name="DDPGTorchPolicy",
269+
framework="torch",
269270
loss_fn=ddpg_actor_critic_loss,
270271
get_default_config=lambda: ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG,
271272
stats_fn=build_ddpg_stats,

rllib/agents/dqn/dqn_torch_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
from ray.rllib.models.torch.torch_action_dist import (TorchCategorical,
1515
TorchDistributionWrapper)
1616
from ray.rllib.policy.policy import Policy
17+
from ray.rllib.policy.policy_template import build_policy_class
1718
from ray.rllib.policy.sample_batch import SampleBatch
1819
from ray.rllib.policy.torch_policy import LearningRateSchedule
19-
from ray.rllib.policy.torch_policy_template import build_torch_policy
2020
from ray.rllib.utils.error import UnsupportedSpaceException
2121
from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
2222
from ray.rllib.utils.framework import try_import_torch
@@ -384,8 +384,9 @@ def extra_action_out_fn(policy: Policy, input_dict, state_batches, model,
384384
return {"q_values": policy.q_values}
385385

386386

387-
DQNTorchPolicy = build_torch_policy(
387+
DQNTorchPolicy = build_policy_class(
388388
name="DQNTorchPolicy",
389+
framework="torch",
389390
loss_fn=build_q_losses,
390391
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
391392
make_model_and_action_dist=build_q_model_and_distribution,

rllib/agents/dqn/simple_q_torch_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
1212
TorchDistributionWrapper
1313
from ray.rllib.policy import Policy
14+
from ray.rllib.policy.policy_template import build_policy_class
1415
from ray.rllib.policy.sample_batch import SampleBatch
15-
from ray.rllib.policy.torch_policy_template import build_torch_policy
1616
from ray.rllib.utils.framework import try_import_torch
1717
from ray.rllib.utils.torch_ops import huber_loss
1818
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
@@ -127,8 +127,9 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
127127
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
128128

129129

130-
SimpleQTorchPolicy = build_torch_policy(
130+
SimpleQTorchPolicy = build_policy_class(
131131
name="SimpleQPolicy",
132+
framework="torch",
132133
loss_fn=build_q_losses,
133134
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
134135
extra_action_out_fn=extra_action_out_fn,

rllib/agents/dreamer/dreamer_torch_policy.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import logging
22

33
import ray
4-
from ray.rllib.policy.torch_policy_template import build_torch_policy
54
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
6-
from ray.rllib.utils.framework import try_import_torch
7-
from ray.rllib.models.catalog import ModelCatalog
85
from ray.rllib.agents.dreamer.utils import FreezeParameters
6+
from ray.rllib.models.catalog import ModelCatalog
7+
from ray.rllib.policy.policy_template import build_policy_class
8+
from ray.rllib.utils.framework import try_import_torch
99

1010
torch, nn = try_import_torch()
1111
if torch:
@@ -236,8 +236,9 @@ def dreamer_optimizer_fn(policy, config):
236236
return (model_opt, actor_opt, critic_opt)
237237

238238

239-
DreamerTorchPolicy = build_torch_policy(
239+
DreamerTorchPolicy = build_policy_class(
240240
name="DreamerTorchPolicy",
241+
framework="torch",
241242
get_default_config=lambda: ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG,
242243
action_sampler_fn=action_sampler_fn,
243244
loss_fn=dreamer_loss,

rllib/agents/es/es_torch_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
import ray
99
from ray.rllib.models import ModelCatalog
10+
from ray.rllib.policy.policy_template import build_policy_class
1011
from ray.rllib.policy.sample_batch import SampleBatch
11-
from ray.rllib.policy.torch_policy_template import build_torch_policy
1212
from ray.rllib.utils.filter import get_filter
1313
from ray.rllib.utils.framework import try_import_torch
1414
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
@@ -126,8 +126,9 @@ def make_model_and_action_dist(policy, observation_space, action_space,
126126
return model, dist_class
127127

128128

129-
ESTorchPolicy = build_torch_policy(
129+
ESTorchPolicy = build_policy_class(
130130
name="ESTorchPolicy",
131+
framework="torch",
131132
loss_fn=None,
132133
get_default_config=lambda: ray.rllib.agents.es.es.DEFAULT_CONFIG,
133134
before_init=before_init,

rllib/agents/impala/vtrace_torch_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
77
import ray.rllib.agents.impala.vtrace_torch as vtrace
88
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
9+
from ray.rllib.policy.policy_template import build_policy_class
910
from ray.rllib.policy.sample_batch import SampleBatch
1011
from ray.rllib.policy.torch_policy import LearningRateSchedule, \
1112
EntropyCoeffSchedule
12-
from ray.rllib.policy.torch_policy_template import build_torch_policy
1313
from ray.rllib.utils.framework import try_import_torch
1414
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \
1515
sequence_mask
@@ -260,8 +260,9 @@ def setup_mixins(policy, obs_space, action_space, config):
260260
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
261261

262262

263-
VTraceTorchPolicy = build_torch_policy(
263+
VTraceTorchPolicy = build_policy_class(
264264
name="VTraceTorchPolicy",
265+
framework="torch",
265266
loss_fn=build_vtrace_loss,
266267
get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG,
267268
stats_fn=stats,

rllib/agents/maml/maml_tf_policy.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import logging
22

33
import ray
4+
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
5+
vf_preds_fetches, compute_and_clip_gradients, setup_config, \
6+
ValueNetworkMixin
47
from ray.rllib.evaluation.postprocessing import Postprocessing
58
from ray.rllib.policy.sample_batch import SampleBatch
69
from ray.rllib.policy.tf_policy_template import build_tf_policy
710
from ray.rllib.utils import try_import_tf
8-
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
9-
vf_preds_fetches, compute_and_clip_gradients, setup_config, \
10-
ValueNetworkMixin
1111
from ray.rllib.utils.framework import get_activation_fn
1212

1313
tf1, tf, tfv = try_import_tf()

rllib/agents/maml/maml_torch_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import ray
44
from ray.rllib.evaluation.postprocessing import Postprocessing
5+
from ray.rllib.policy.policy_template import build_policy_class
56
from ray.rllib.policy.sample_batch import SampleBatch
6-
from ray.rllib.policy.torch_policy_template import build_torch_policy
77
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
88
setup_config
99
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \
@@ -347,8 +347,9 @@ def setup_mixins(policy, obs_space, action_space, config):
347347
KLCoeffMixin.__init__(policy, config)
348348

349349

350-
MAMLTorchPolicy = build_torch_policy(
350+
MAMLTorchPolicy = build_policy_class(
351351
name="MAMLTorchPolicy",
352+
framework="torch",
352353
get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG,
353354
loss_fn=maml_loss,
354355
stats_fn=maml_stats,

rllib/agents/marwil/marwil_torch_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import ray
22
from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages
33
from ray.rllib.evaluation.postprocessing import Postprocessing
4+
from ray.rllib.policy.policy_template import build_policy_class
45
from ray.rllib.policy.sample_batch import SampleBatch
5-
from ray.rllib.policy.torch_policy_template import build_torch_policy
66
from ray.rllib.utils.framework import try_import_torch
77
from ray.rllib.utils.torch_ops import explained_variance
88

@@ -75,8 +75,9 @@ def setup_mixins(policy, obs_space, action_space, config):
7575
ValueNetworkMixin.__init__(policy)
7676

7777

78-
MARWILTorchPolicy = build_torch_policy(
78+
MARWILTorchPolicy = build_policy_class(
7979
name="MARWILTorchPolicy",
80+
framework="torch",
8081
loss_fn=marwil_loss,
8182
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
8283
stats_fn=stats,

rllib/agents/mbmpo/mbmpo_torch_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ray.rllib.models.modelv2 import ModelV2
1414
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
1515
from ray.rllib.policy.policy import Policy
16-
from ray.rllib.policy.torch_policy_template import build_torch_policy
16+
from ray.rllib.policy.policy_template import build_policy_class
1717
from ray.rllib.utils.framework import try_import_torch
1818
from ray.rllib.utils.typing import TrainerConfigDict
1919

@@ -76,8 +76,9 @@ def make_model_and_action_dist(
7676

7777
# Build a child class of `TorchPolicy`, given the custom functions defined
7878
# above.
79-
MBMPOTorchPolicy = build_torch_policy(
79+
MBMPOTorchPolicy = build_policy_class(
8080
name="MBMPOTorchPolicy",
81+
framework="torch",
8182
get_default_config=lambda: ray.rllib.agents.mbmpo.mbmpo.DEFAULT_CONFIG,
8283
make_model_and_action_dist=make_model_and_action_dist,
8384
loss_fn=maml_loss,

rllib/agents/pg/pg_torch_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
1111
from ray.rllib.models.modelv2 import ModelV2
1212
from ray.rllib.policy import Policy
13+
from ray.rllib.policy.policy_template import build_policy_class
1314
from ray.rllib.policy.sample_batch import SampleBatch
14-
from ray.rllib.policy.torch_policy_template import build_torch_policy
1515
from ray.rllib.utils.framework import try_import_torch
1616
from ray.rllib.utils.typing import TensorType
1717

@@ -72,8 +72,9 @@ def pg_loss_stats(policy: Policy,
7272
# Build a child class of `TFPolicy`, given the extra options:
7373
# - trajectory post-processing function (to calculate advantages)
7474
# - PG loss function
75-
PGTorchPolicy = build_torch_policy(
75+
PGTorchPolicy = build_policy_class(
7676
name="PGTorchPolicy",
77+
framework="torch",
7778
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
7879
loss_fn=pg_torch_loss,
7980
stats_fn=pg_loss_stats,

rllib/agents/ppo/appo_torch_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
from ray.rllib.models.torch.torch_action_dist import \
2424
TorchDistributionWrapper, TorchCategorical
2525
from ray.rllib.policy.policy import Policy
26+
from ray.rllib.policy.policy_template import build_policy_class
2627
from ray.rllib.policy.sample_batch import SampleBatch
2728
from ray.rllib.policy.torch_policy import LearningRateSchedule
28-
from ray.rllib.policy.torch_policy_template import build_torch_policy
2929
from ray.rllib.utils.framework import try_import_torch
3030
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \
3131
sequence_mask
@@ -322,8 +322,9 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
322322

323323
# Build a child class of `TorchPolicy`, given the custom functions defined
324324
# above.
325-
AsyncPPOTorchPolicy = build_torch_policy(
325+
AsyncPPOTorchPolicy = build_policy_class(
326326
name="AsyncPPOTorchPolicy",
327+
framework="torch",
327328
loss_fn=appo_surrogate_loss,
328329
stats_fn=stats,
329330
postprocess_fn=postprocess_trajectory,

rllib/agents/ppo/ppo_torch_policy.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
from ray.rllib.models.modelv2 import ModelV2
1515
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
1616
from ray.rllib.policy.policy import Policy
17+
from ray.rllib.policy.policy_template import build_policy_class
1718
from ray.rllib.policy.sample_batch import SampleBatch
1819
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
1920
LearningRateSchedule
20-
from ray.rllib.policy.torch_policy_template import build_torch_policy
2121
from ray.rllib.utils.framework import try_import_torch
2222
from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \
2323
explained_variance, sequence_mask
@@ -111,6 +111,9 @@ def reduce_mean_valid(t):
111111
policy._total_loss = total_loss
112112
policy._mean_policy_loss = mean_policy_loss
113113
policy._mean_vf_loss = mean_vf_loss
114+
policy._vf_explained_var = explained_variance(
115+
train_batch[Postprocessing.VALUE_TARGETS],
116+
policy.model.value_function())
114117
policy._mean_entropy = mean_entropy
115118
policy._mean_kl = mean_kl
116119

@@ -134,9 +137,7 @@ def kl_and_loss_stats(policy: Policy,
134137
"total_loss": policy._total_loss,
135138
"policy_loss": policy._mean_policy_loss,
136139
"vf_loss": policy._mean_vf_loss,
137-
"vf_explained_var": explained_variance(
138-
train_batch[Postprocessing.VALUE_TARGETS],
139-
policy.model.value_function()),
140+
"vf_explained_var": policy._vf_explained_var,
140141
"kl": policy._mean_kl,
141142
"entropy": policy._mean_entropy,
142143
"entropy_coeff": policy.entropy_coeff,
@@ -271,8 +272,9 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
271272

272273
# Build a child class of `TorchPolicy`, given the custom functions defined
273274
# above.
274-
PPOTorchPolicy = build_torch_policy(
275+
PPOTorchPolicy = build_policy_class(
275276
name="PPOTorchPolicy",
277+
framework="torch",
276278
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
277279
loss_fn=ppo_surrogate_loss,
278280
stats_fn=kl_and_loss_stats,

rllib/agents/qmix/qmix_policy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def forward(self,
143143
return loss, mask, masked_td_error, chosen_action_qvals, targets
144144

145145

146-
# TODO(sven): Make this a TorchPolicy child via `build_torch_policy`.
146+
# TODO(sven): Make this a TorchPolicy child via `build_policy_class`.
147147
class QMixTorchPolicy(Policy):
148148
"""QMix impl. Assumes homogeneous agents for now.
149149

rllib/agents/sac/sac_torch_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from ray.rllib.models.torch.torch_action_dist import \
1818
TorchDistributionWrapper, TorchDirichlet
1919
from ray.rllib.policy.policy import Policy
20+
from ray.rllib.policy.policy_template import build_policy_class
2021
from ray.rllib.policy.sample_batch import SampleBatch
21-
from ray.rllib.policy.torch_policy_template import build_torch_policy
2222
from ray.rllib.models.torch.torch_action_dist import (
2323
TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta)
2424
from ray.rllib.utils.framework import try_import_torch
@@ -480,8 +480,9 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
480480

481481
# Build a child class of `TorchPolicy`, given the custom functions defined
482482
# above.
483-
SACTorchPolicy = build_torch_policy(
483+
SACTorchPolicy = build_policy_class(
484484
name="SACTorchPolicy",
485+
framework="torch",
485486
loss_fn=actor_critic_loss,
486487
get_default_config=lambda: ray.rllib.agents.sac.sac.DEFAULT_CONFIG,
487488
stats_fn=stats,

rllib/agents/slateq/slateq_torch_policy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
TorchDistributionWrapper)
1212
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
1313
from ray.rllib.policy.policy import Policy
14+
from ray.rllib.policy.policy_template import build_policy_class
1415
from ray.rllib.policy.sample_batch import SampleBatch
15-
from ray.rllib.policy.torch_policy_template import build_torch_policy
1616
from ray.rllib.utils.framework import try_import_torch
1717
from ray.rllib.utils.typing import (ModelConfigDict, TensorType,
1818
TrainerConfigDict)
@@ -403,8 +403,9 @@ def postprocess_fn_add_next_actions_for_sarsa(policy: Policy,
403403
return batch
404404

405405

406-
SlateQTorchPolicy = build_torch_policy(
406+
SlateQTorchPolicy = build_policy_class(
407407
name="SlateQTorchPolicy",
408+
framework="torch",
408409
get_default_config=lambda: ray.rllib.agents.slateq.slateq.DEFAULT_CONFIG,
409410

410411
# build model, loss functions, and optimizers

0 commit comments

Comments
 (0)