Skip to content

Commit

Permalink
Merge pull request #478 from oscardssmith/fix-interpolation-bounds-fo…
Browse files Browse the repository at this point in the history
…llowup

fix vector interpolation and remove unnecessary `@inbounds`
  • Loading branch information
ChrisRackauckas authored Aug 10, 2023
2 parents 0bffe3e + 5c619ce commit cce1ce0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 23 deletions.
44 changes: 22 additions & 22 deletions src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ end
u = id.u
typeof(id) <: HermiteInterpolation && (du = id.du)
tdir = sign(t[end] - t[1])
t[end] == t[1] && tval != t[end] &&
error("Solution interpolation cannot extrapolate from a single timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.")
idx = sortperm(tvals, rev = tdir < 0)
i = 2 # Start the search thinking it's between t[1] and t[2]
t[end] == t[1] && (tvals[idx[1]] != t[1] || tvals[idx[end]] != t[1]) &&
error("Solution interpolation cannot extrapolate from a single timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.")
tdir * tvals[idx[end]] > tdir * t[end] &&
error("Solution interpolation cannot extrapolate past the final timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.")
tdir * tvals[idx[1]] < tdir * t[1] &&
Expand All @@ -98,25 +98,25 @@ end
else
vals = Vector{eltype(u)}(undef, length(tvals))
end
@inbounds for j in idx
for j in idx
tval = tvals[j]
i = searchsortedfirst(@view(t[i:end]), tval, rev = tdir < 0) + i - 1 # It's in the interval t[i-1] to t[i]
avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual
avoid_constant_ends && i == 1 && (i += 1)
if !avoid_constant_ends && t[i] == tval
if !avoid_constant_ends && t[i - 1] == tval # Can happen if it's the first value!
if idxs === nothing
vals[j] = u[i - 1]
else
vals[j] = u[i - 1][idxs]
end
elseif !avoid_constant_ends && t[i] == tval
lasti = lastindex(t)
k = continuity == :right && i + 1 <= lasti && t[i + 1] == tval ? i + 1 : i
if idxs === nothing
vals[j] = u[k]
else
vals[j] = u[k][idxs]
end
elseif !avoid_constant_ends && t[i - 1] == tval # Can happen if it's the first value!
if idxs === nothing
vals[j] = u[i - 1]
else
vals[j] = u[i - 1][idxs]
end
else
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
dt = t[i] - t[i - 1]
Expand Down Expand Up @@ -145,33 +145,33 @@ times t (sorted), with values u and derivatives ks
u = id.u
typeof(id) <: HermiteInterpolation && (du = id.du)
tdir = sign(t[end] - t[1])
t[end] == t[1] && tval != t[end] &&
error("Solution interpolation cannot extrapolate from a single timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.")
idx = sortperm(tvals, rev = tdir < 0)
i = 2 # Start the search thinking it's between t[1] and t[2]
t[end] == t[1] && (tvals[idx[1]] != t[1] || tvals[idx[end]] != t[1]) &&
error("Solution interpolation cannot extrapolate from a single timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.")
tdir * tvals[idx[end]] > tdir * t[end] &&
error("Solution interpolation cannot extrapolate past the final timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.")
tdir * tvals[idx[1]] < tdir * t[1] &&
error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.")
@inbounds for j in idx
for j in idx
tval = tvals[j]
i = searchsortedfirst(@view(t[i:end]), tval, rev = tdir < 0) + i - 1 # It's in the interval t[i-1] to t[i]
avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual
avoid_constant_ends && i == 1 && (i += 1)
if !avoid_constant_ends && t[i] == tval
if !avoid_constant_ends && t[i - 1] == tval # Can happen if it's the first value!
if idxs === nothing
vals[j] = u[i - 1]
else
vals[j] = u[i - 1][idxs]
end
elseif !avoid_constant_ends && t[i] == tval
lasti = lastindex(t)
k = continuity == :right && i + 1 <= lasti && t[i + 1] == tval ? i + 1 : i
if idxs === nothing
vals[j] = u[k]
else
vals[j] = u[k][idxs]
end
elseif !avoid_constant_ends && t[i - 1] == tval # Can happen if it's the first value!
if idxs === nothing
vals[j] = u[i - 1]
else
vals[j] = u[i - 1][idxs]
end
else
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
dt = t[i] - t[i - 1]
Expand Down Expand Up @@ -217,7 +217,7 @@ times t (sorted), with values u and derivatives ks
@inbounds i = searchsortedfirst(t, tval, rev = tdir < 0) # It's in the interval t[i-1] to t[i]
avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual
avoid_constant_ends && i == 1 && (i += 1)
@inbounds if !avoid_constant_ends && t[i] == tval
if !avoid_constant_ends && t[i] == tval
lasti = lastindex(t)
k = continuity == :right && i + 1 <= lasti && t[i + 1] == tval ? i + 1 : i
if idxs === nothing
Expand Down Expand Up @@ -267,7 +267,7 @@ times t (sorted), with values u and derivatives ks
@inbounds i = searchsortedfirst(t, tval, rev = tdir < 0) # It's in the interval t[i-1] to t[i]
avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual
avoid_constant_ends && i == 1 && (i += 1)
@inbounds if !avoid_constant_ends && t[i] == tval
if !avoid_constant_ends && t[i] == tval
lasti = lastindex(t)
k = continuity == :right && i + 1 <= lasti && t[i + 1] == tval ? i + 1 : i
if idxs === nothing
Expand Down
4 changes: 3 additions & 1 deletion test/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ end
ode = ODEProblem(f, 1.0, (0.0, 1.0))
sol = SciMLBase.build_solution(ode, :NoAlgorithm, [ode.tspan[begin]], [ode.u0])
@test sol(0.0) == 1.0
@test sol([0.0,0.0]) == [1.0, 1.0]
# test that indexing out of bounds doesn't segfault
@test_throws ErrorException sol(1.0)
@test_throws ErrorException sol(1)
@test_throws ErrorException sol(-0.5)
@test_throws ErrorException sol([0, -0.5, 0])
end

0 comments on commit cce1ce0

Please sign in to comment.