Skip to content

Commit 02e21ba

Browse files
authored
Merge pull request #814 from pxl-th/pxl-th/eachslice
Unthunk each element in `∇eachslice`
2 parents e055009 + f9754c1 commit 02e21ba

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.72.1"
3+
version = "1.72.2"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/rulesets/Base/indexing.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -262,20 +262,22 @@ end
262262
# Using Val(dim) here is worth a factor of 2 in this, on Julia 1.8-
263263
# @btime rrule(eachcol, $([1 2; 3 4]))[2]($([[10, 20], [30, 40]]))
264264
function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim}
265-
dys = unthunk(dys_raw)
265+
dys = unthunk.(unthunk(dys_raw))
266266
i1 = findfirst(dy -> dy isa AbstractArray, dys)
267267
if i1 === nothing # all slices are Zero!
268268
return _zero_fill!(similar(x, float(eltype(x)), axes(x)))
269269
end
270+
270271
T = Base.promote_eltype(dys...)
271272
# The whole point of this gradient is that we can allocate one `dx` array:
272273
dx = similar(x, T, axes(x))
273274
for i in axes(x, dim)
274275
slice = selectdim(dx, dim, i)
275-
if dys[i] isa AbstractZero
276+
dy = dys[i]
277+
if dy isa AbstractZero
276278
_zero_fill!(slice) # Avoids this: copyto!([1,2,3], ZeroTangent()) == [0,2,3]
277279
else
278-
copyto!(slice, dys[i])
280+
copyto!(slice, dy)
279281
end
280282
end
281283
return ProjectTo(x)(dx)

test/rulesets/Base/indexing.jl

+13
Original file line numberDiff line numberDiff line change
@@ -261,4 +261,17 @@ end
261261
Val(3);
262262
check_inferred=(VERSION >= v"1.7"),
263263
)
264+
265+
# eachslice: Make sure pulling back an array of thunks unthunks them and does not return all zeros.
266+
x = ones(Float32, 3)
267+
Δ = ones(Float32, 1)
268+
_, norm_back = ChainRules.rrule(norm, x)
269+
dx = norm_back(Δ)[2]
270+
@test dx isa AbstractThunk
271+
272+
x = ones(Float32, 3, 1)
273+
_, eachcol_back = ChainRules.rrule(eachcol, x)
274+
Δ2 = [dx]
275+
dx2 = eachcol_back(Δ2)[2]
276+
@test all(dx2 .≉ 0f0)
264277
end

0 commit comments

Comments
 (0)