Skip to content

Commit f6ed7de

Browse files
authored
Merge pull request #417 from JuliaDiff/ox/lkf1
Don't try and convert to FloatX except if Integer or AbstractFloat (option 1)
2 parents 9c0f08e + 127bae5 commit f6ed7de

File tree

3 files changed

+48
-7
lines changed

3 files changed

+48
-7
lines changed

Diff for: Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.0.1"
3+
version = "1.0.2"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

Diff for: src/projection.jl

+23-5
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,32 @@ ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::Abst
134134
ProjectTo(::Real) = ProjectTo{Real}()
135135
ProjectTo(::Complex) = ProjectTo{Complex}()
136136
ProjectTo(::Number) = ProjectTo{Number}()
137+
138+
ProjectTo(x::Integer) = ProjectTo(float(x))
139+
ProjectTo(x::Complex{<:Integer}) = ProjectTo(float(x))
140+
141+
# Preserve low-precision floats as accidental promotion is a common performance bug
137142
for T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64)
138-
# Preserve low-precision floats as accidental promotion is a common perforance bug
139143
@eval ProjectTo(::$T) = ProjectTo{$T}()
140144
end
141-
ProjectTo(x::Integer) = ProjectTo(float(x))
142-
ProjectTo(x::Complex{<:Integer}) = ProjectTo(float(x))
143-
(::ProjectTo{T})(dx::Number) where {T<:Number} = convert(T, dx)
144-
(::ProjectTo{T})(dx::Number) where {T<:Real} = convert(T, real(dx))
145+
146+
# In these cases we can just `convert` as we know we are dealing with plain and simple types
147+
(::ProjectTo{T})(dx::AbstractFloat) where T<:AbstractFloat = convert(T, dx)
148+
(::ProjectTo{T})(dx::Integer) where T<:AbstractFloat = convert(T, dx) #needed to avoid ambiguity
149+
# simple Complex{<:AbstractFloat}} cases
150+
(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx)
151+
(::ProjectTo{T})(dx::AbstractFloat) where {T<:Complex{<:AbstractFloat}} = convert(T, dx)
152+
(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx)
153+
(::ProjectTo{T})(dx::Integer) where {T<:Complex{<:AbstractFloat}} = convert(T, dx)
154+
155+
# Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through.
156+
# We assume (lacking evidence to the contrary) that it is the right subspace of numebers
157+
# The (::ProjectTo{T})(::T) method doesn't work because we are allowing a different
158+
# Number type that might not be a subtype of the `project_type`.
159+
(::ProjectTo{<:Number})(dx::Number) = dx
160+
161+
(project::ProjectTo{<:Real})(dx::Complex) = project(real(dx))
162+
(project::ProjectTo{<:Complex})(dx::Real) = project(complex(dx))
145163

146164
# Arrays
147165
# If we don't have a more specialized `ProjectTo` rule, we just assume that there is

Diff for: test/projection.jl

+24-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@ using ChainRulesCore, Test
22
using LinearAlgebra, SparseArrays
33
using OffsetArrays, BenchmarkTools
44

5+
# Like ForwardDiff.jl's Dual
6+
struct Dual{T<:Real} <: Real
7+
value::T
8+
partial::T
9+
end
10+
Base.real(x::Dual) = x
11+
Base.float(x::Dual) = Dual(float(x.value), float(x.partial))
12+
Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
13+
514
@testset "projection" begin
615

716
#####
@@ -12,14 +21,28 @@ using OffsetArrays, BenchmarkTools
1221
# real / complex
1322
@test ProjectTo(1.0)(2.0 + 3im) === 2.0
1423
@test ProjectTo(1.0 + 2.0im)(3.0) === 3.0 + 0.0im
24+
@test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im
25+
@test ProjectTo(2.0)(1+1im) === 1.0
26+
1527

1628
# storage
17-
@test ProjectTo(1)(pi) === Float64(pi)
29+
@test ProjectTo(1)(pi) === pi
1830
@test ProjectTo(1 + im)(pi) === ComplexF64(pi)
1931
@test ProjectTo(1//2)(3//4) === 3//4
2032
@test ProjectTo(1.0f0)(1 / 2) === 0.5f0
2133
@test ProjectTo(1.0f0 + 2im)(3) === 3.0f0 + 0im
2234
@test ProjectTo(big(1.0))(2) === 2
35+
@test ProjectTo(1.0)(2) === 2.0
36+
end
37+
38+
@testset "Dual" begin # some weird Real subtype that we should basically leave alone
39+
@test ProjectTo(1.0)(Dual(1.0, 2.0)) isa Dual
40+
@test ProjectTo(1.0)(Dual(1, 2)) isa Dual
41+
@test ProjectTo(1.0 + 1im)(Dual(1.0, 2.0)) isa Complex{<:Dual}
42+
@test ProjectTo(1.0 + 1im)(
43+
Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))
44+
) isa Complex{<:Dual}
45+
@test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual
2346
end
2447

2548
@testset "Base: arrays of numbers" begin

0 commit comments

Comments
 (0)