14
14
from ray .rllib .models .modelv2 import ModelV2
15
15
from ray .rllib .models .torch .torch_action_dist import TorchDistributionWrapper
16
16
from ray .rllib .policy .policy import Policy
17
+ from ray .rllib .policy .policy_template import build_policy_class
17
18
from ray .rllib .policy .sample_batch import SampleBatch
18
19
from ray .rllib .policy .torch_policy import EntropyCoeffSchedule , \
19
20
LearningRateSchedule
20
- from ray .rllib .policy .torch_policy_template import build_torch_policy
21
21
from ray .rllib .utils .framework import try_import_torch
22
22
from ray .rllib .utils .torch_ops import convert_to_torch_tensor , \
23
23
explained_variance , sequence_mask
@@ -111,6 +111,9 @@ def reduce_mean_valid(t):
111
111
policy ._total_loss = total_loss
112
112
policy ._mean_policy_loss = mean_policy_loss
113
113
policy ._mean_vf_loss = mean_vf_loss
114
+ policy ._vf_explained_var = explained_variance (
115
+ train_batch [Postprocessing .VALUE_TARGETS ],
116
+ policy .model .value_function ())
114
117
policy ._mean_entropy = mean_entropy
115
118
policy ._mean_kl = mean_kl
116
119
@@ -134,9 +137,7 @@ def kl_and_loss_stats(policy: Policy,
134
137
"total_loss" : policy ._total_loss ,
135
138
"policy_loss" : policy ._mean_policy_loss ,
136
139
"vf_loss" : policy ._mean_vf_loss ,
137
- "vf_explained_var" : explained_variance (
138
- train_batch [Postprocessing .VALUE_TARGETS ],
139
- policy .model .value_function ()),
140
+ "vf_explained_var" : policy ._vf_explained_var ,
140
141
"kl" : policy ._mean_kl ,
141
142
"entropy" : policy ._mean_entropy ,
142
143
"entropy_coeff" : policy .entropy_coeff ,
@@ -271,8 +272,9 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
271
272
272
273
# Build a child class of `TorchPolicy`, given the custom functions defined
273
274
# above.
274
- PPOTorchPolicy = build_torch_policy (
275
+ PPOTorchPolicy = build_policy_class (
275
276
name = "PPOTorchPolicy" ,
277
+ framework = "torch" ,
276
278
get_default_config = lambda : ray .rllib .agents .ppo .ppo .DEFAULT_CONFIG ,
277
279
loss_fn = ppo_surrogate_loss ,
278
280
stats_fn = kl_and_loss_stats ,
0 commit comments