Skip to content

Commit cae1c71

Browse files
committed
Fix indexation in D4PG implementations
1 parent 2ea002f commit cae1c71

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

Diff for: Chapter07/lib/common.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,11 @@ def distr_projection(next_distr, rewards, dones, Vmin, Vmax, n_atoms, gamma):
175175
eq_dones = dones.copy()
176176
eq_dones[dones] = eq_mask
177177
if eq_dones.any():
178-
proj_distr[eq_dones, l] = 1.0
178+
proj_distr[eq_dones, l[eq_mask]] = 1.0
179179
ne_mask = u != l
180180
ne_dones = dones.copy()
181181
ne_dones[dones] = ne_mask
182182
if ne_dones.any():
183-
proj_distr[ne_dones, l] = (u - b_j)[ne_mask]
184-
proj_distr[ne_dones, u] = (b_j - l)[ne_mask]
183+
proj_distr[ne_dones, l[ne_mask]] = (u - b_j)[ne_mask]
184+
proj_distr[ne_dones, u[ne_mask]] = (b_j - l)[ne_mask]
185185
return proj_distr

Diff for: Chapter14/06_train_d4pg.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,13 @@ def distr_projection(next_distr_v, rewards_v, dones_mask_t, gamma, device="cpu")
7777
eq_dones = dones_mask.copy()
7878
eq_dones[dones_mask] = eq_mask
7979
if eq_dones.any():
80-
proj_distr[eq_dones, l] = 1.0
80+
proj_distr[eq_dones, l[eq_mask]] = 1.0
8181
ne_mask = u != l
8282
ne_dones = dones_mask.copy()
8383
ne_dones[dones_mask] = ne_mask
8484
if ne_dones.any():
85-
proj_distr[ne_dones, l] = (u - b_j)[ne_mask]
86-
proj_distr[ne_dones, u] = (b_j - l)[ne_mask]
85+
proj_distr[ne_dones, l[ne_mask]] = (u - b_j)[ne_mask]
86+
proj_distr[ne_dones, u[ne_mask]] = (b_j - l)[ne_mask]
8787
return torch.FloatTensor(proj_distr).to(device)
8888

8989

0 commit comments

Comments
 (0)