Skip to content

Commit fe63195

Browse files
committed
Fix unsafe @pures via separate methods for BitInts
The current usages of `Base.@pure` are _unsafe if called with user-defined types_, since their definitions may change arbitrarily. We fix this by _only providing `@pure` methods for the built in BitInteger types,_ and non-`@pure` methods for any types. Consider the following demonstration of the broken behavior: ```julia (@v1.4) pkg> add FixedPointDecimals#master Updating git-repo `https://github.com/JuliaMath/FixedPointDecimals.jl.git` julia> struct MyInt <: Integer x::Int end julia> Base.typemax(::Type{MyInt}) = 10000000 julia> Base.widen(::Type{MyInt}) = Int128 julia> using FixedPointDecimals julia> a = reinterpret(FixedPointDecimals.FD{MyInt,2}, MyInt(2)); julia> Base.typemax(::Type{MyInt}) = 10 julia> a = reinterpret(FixedPointDecimals.FD{MyInt,2}, MyInt(2)); # SHOULD BE AN ERROR! julia> ``` This is fixed, after this commit: ```julia julia> struct MyInt <: Integer x::Int end julia> Base.typemax(::Type{MyInt}) = 10000000 julia> Base.widen(::Type{MyInt}) = Int128 julia> using FixedPointDecimals [ Info: Precompiling FixedPointDecimals [fb4d412d-6eee-574d-9565-ede6634db7b0] julia> a = reinterpret(FixedPointDecimals.FD{MyInt,2}, MyInt(2)); julia> Base.typemax(::Type{MyInt}) = 10 julia> a = reinterpret(FixedPointDecimals.FD{MyInt,2}, MyInt(2)); # SHOULD BE AN ERROR! ERROR: ArgumentError: Requested number of decimal places 2 exceeds the max allowed for the storage type MyInt: [0, 0] ``` --------------------------------------------------------------- In order to reduce code-duplication, we add a small macro that duplicates the function definition to provide two identical macros, where one is restricted to BitIntegers.
1 parent 317c82b commit fe63195

File tree

1 file changed

+36
-11
lines changed

1 file changed

+36
-11
lines changed

Diff for: src/FixedPointDecimals.jl

+36-11
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,35 @@ for fn in [:trunc, :floor, :ceil]
6969
end
7070
end
7171

72+
"""
73+
@_pure_for_BitInts T function foo(::FD{T,f}) where {T,f} ... end
74+
@_pure_for_BitInts {T,U} bar(::FD{T,f}, ::FD{U,g}) where {T,f,U,g} = ...
75+
76+
Defines the provided function _twice_, once as written, and once as a _@pure method_,
77+
restricted to only Base.BitInteger types. This is because `@pure` is not safe for calling
78+
generic user-defined methods on generic user-defined types.
79+
"""
80+
macro _pure_for_BitInts(int_type_params, f)
81+
if int_type_params isa Symbol
82+
int_type_params = [int_type_params]
83+
else
84+
int_type_params = int_type_params.args
85+
end
86+
@assert length(int_type_params) >= 1 && f.args[1].head == :where "Usage: @_pure_for_BitInts T function foo(::FD{T,f}) where {T,f} ... end"
87+
bitint_f = deepcopy(f)
88+
for param_name in int_type_params
89+
type_param_idx = findfirst(isequal(param_name), bitint_f.args[1].args)
90+
@assert type_param_idx !== nothing "Unmatched int_type param name `$param_name` in where clause params $(f.args[1].args[2:end])"
91+
new_type_restriction = :($param_name <: $Base.BitInteger)
92+
bitint_f.args[1].args[type_param_idx] = new_type_restriction
93+
end
94+
95+
esc(quote
96+
Base.@__doc__ $f
97+
Base.@pure $bitint_f
98+
end)
99+
end
100+
72101
"""
73102
FixedDecimal{T <: Integer, f::Int}
74103
@@ -80,7 +109,7 @@ struct FixedDecimal{T <: Integer, f} <: Real
80109

81110
# inner constructor
82111
# This function is marked as `Base.@pure`. It does not have or depend on any side-effects.
83-
Base.@pure function Base.reinterpret(::Type{FixedDecimal{T, f}}, i::Integer) where {T, f}
112+
@_pure_for_BitInts T function Base.reinterpret(::Type{FixedDecimal{T, f}}, i::Integer) where {T, f}
84113
n = max_exp10(T)
85114
if f >= 0 && (n < 0 || f <= n)
86115
new{T, f}(i % T)
@@ -115,15 +144,15 @@ Base.:+(x::FD{T, f}, y::FD{T, f}) where {T, f} = reinterpret(FD{T, f}, x.i+y.i)
115144
Base.:-(x::FD{T, f}, y::FD{T, f}) where {T, f} = reinterpret(FD{T, f}, x.i-y.i)
116145

117146
# wide multiplication
118-
Base.@pure function Base.widemul(x::FD{<:Any, f}, y::FD{<:Any, g}) where {f, g}
147+
@_pure_for_BitInts {T,U} function Base.widemul(x::FD{<:T, f}, y::FD{<:U, g}) where {T, f, U, g}
119148
i = widemul(x.i, y.i)
120149
reinterpret(FD{typeof(i), f + g}, i)
121150
end
122-
Base.@pure function Base.widemul(x::FD{T, f}, y::Integer) where {T, f}
151+
@_pure_for_BitInts T function Base.widemul(x::FD{T, f}, y::Integer) where {T, f}
123152
i = widemul(x.i, y)
124153
reinterpret(FD{typeof(i), f}, i)
125154
end
126-
Base.@pure Base.widemul(x::Integer, y::FD) = widemul(y, x)
155+
@_pure_for_BitInts T Base.widemul(x::Integer, y::FD{T}) where T = widemul(y, x)
127156

128157
"""
129158
_round_to_even(quotient, remainder, divisor)
@@ -333,7 +362,7 @@ Base.promote_rule(::Type{<:FD}, ::Type{Rational{TR}}) where {TR} = Rational{TR}
333362

334363
# TODO: decide if these are the right semantics;
335364
# right now we pick the bigger int type and the bigger decimal point
336-
Base.@pure function Base.promote_rule(::Type{FD{T, f}}, ::Type{FD{U, g}}) where {T, f, U, g}
365+
@_pure_for_BitInts T,U function Base.promote_rule(::Type{FD{T, f}}, ::Type{FD{U, g}}) where {T, f, U, g}
337366
FD{promote_type(T, U), max(f, g)}
338367
end
339368

@@ -480,10 +509,6 @@ NOTE: This function is expensive, since it contains a while-loop, but it is actu
480509
This function does not have or depend on any side-effects.
481510
"""
482511
function max_exp10(::Type{T}) where {T <: Integer}
483-
# This function is marked as `Base.@pure`. Even though it does call some generic
484-
# functions, they are all simple methods that should be able to be evaluated as
485-
# constants. This function does not have or depend on any side-effects.
486-
487512
W = widen(T)
488513
type_max = W(typemax(T))
489514

@@ -515,8 +540,8 @@ end
515540
Compute `10^f` as an Integer without overflow. Note that overflow will not occur for any
516541
constructable `FD{T, f}`.
517542
"""
518-
Base.@pure coefficient(::Type{FD{T, f}}) where {T, f} = T(10)^f
519-
Base.@pure coefficient(fd::FD{T, f}) where {T, f} = coefficient(FD{T, f})
543+
@_pure_for_BitInts T coefficient(::Type{FD{T, f}}) where {T, f} = T(10)^f
544+
@_pure_for_BitInts T coefficient(fd::FD{T, f}) where {T, f} = coefficient(FD{T, f})
520545
value(fd::FD) = fd.i
521546

522547
# for generic hashing

0 commit comments

Comments
 (0)