From 5c619cec318c937c42e7b30ddc4bc9229e8622ab Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 10 Aug 2023 16:47:00 -0400 Subject: [PATCH] fix vector interpolation and remove unnecessary inbounds --- src/interpolation.jl | 44 +++++++++++++++++++------------------- test/solution_interface.jl | 4 +++- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/interpolation.jl b/src/interpolation.jl index dab2f0249..e64d91c35 100644 --- a/src/interpolation.jl +++ b/src/interpolation.jl @@ -83,10 +83,10 @@ end u = id.u typeof(id) <: HermiteInterpolation && (du = id.du) tdir = sign(t[end] - t[1]) - t[end] == t[1] && tval != t[end] && - error("Solution interpolation cannot extrapolate from a single timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.") idx = sortperm(tvals, rev = tdir < 0) i = 2 # Start the search thinking it's between t[1] and t[2] + t[end] == t[1] && (tvals[idx[1]] != t[1] || tvals[idx[end]] != t[1]) && + error("Solution interpolation cannot extrapolate from a single timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.") tdir * tvals[idx[end]] > tdir * t[end] && error("Solution interpolation cannot extrapolate past the final timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.") tdir * tvals[idx[1]] < tdir * t[1] && @@ -98,12 +98,18 @@ end else vals = Vector{eltype(u)}(undef, length(tvals)) end - @inbounds for j in idx + for j in idx tval = tvals[j] i = searchsortedfirst(@view(t[i:end]), tval, rev = tdir < 0) + i - 1 # It's in the interval t[i-1] to t[i] avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual avoid_constant_ends && i == 1 && (i += 1) - if !avoid_constant_ends && t[i] == tval + if !avoid_constant_ends && t[i - 1] == tval # Can happen if it's the first value! + if idxs === nothing + vals[j] = u[i - 1] + else + vals[j] = u[i - 1][idxs] + end + elseif !avoid_constant_ends && t[i] == tval lasti = lastindex(t) k = continuity == :right && i + 1 <= lasti && t[i + 1] == tval ? i + 1 : i if idxs === nothing @@ -111,12 +117,6 @@ end else vals[j] = u[k][idxs] end - elseif !avoid_constant_ends && t[i - 1] == tval # Can happen if it's the first value! - if idxs === nothing - vals[j] = u[i - 1] - else - vals[j] = u[i - 1][idxs] - end else typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) dt = t[i] - t[i - 1] @@ -145,20 +145,26 @@ times t (sorted), with values u and derivatives ks u = id.u typeof(id) <: HermiteInterpolation && (du = id.du) tdir = sign(t[end] - t[1]) - t[end] == t[1] && tval != t[end] && - error("Solution interpolation cannot extrapolate from a single timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.") idx = sortperm(tvals, rev = tdir < 0) i = 2 # Start the search thinking it's between t[1] and t[2] + t[end] == t[1] && (tvals[idx[1]] != t[1] || tvals[idx[end]] != t[1]) && + error("Solution interpolation cannot extrapolate from a single timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.") tdir * tvals[idx[end]] > tdir * t[end] && error("Solution interpolation cannot extrapolate past the final timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.") tdir * tvals[idx[1]] < tdir * t[1] && error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.") - @inbounds for j in idx + for j in idx tval = tvals[j] i = searchsortedfirst(@view(t[i:end]), tval, rev = tdir < 0) + i - 1 # It's in the interval t[i-1] to t[i] avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual avoid_constant_ends && i == 1 && (i += 1) - if !avoid_constant_ends && t[i] == tval + if !avoid_constant_ends && t[i - 1] == tval # Can happen if it's the first value! + if idxs === nothing + vals[j] = u[i - 1] + else + vals[j] = u[i - 1][idxs] + end + elseif !avoid_constant_ends && t[i] == tval lasti = lastindex(t) k = continuity == :right && i + 1 <= lasti && t[i + 1] == tval ? i + 1 : i if idxs === nothing @@ -166,12 +172,6 @@ times t (sorted), with values u and derivatives ks else vals[j] = u[k][idxs] end - elseif !avoid_constant_ends && t[i - 1] == tval # Can happen if it's the first value! - if idxs === nothing - vals[j] = u[i - 1] - else - vals[j] = u[i - 1][idxs] - end else typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) dt = t[i] - t[i - 1] @@ -217,7 +217,7 @@ times t (sorted), with values u and derivatives ks @inbounds i = searchsortedfirst(t, tval, rev = tdir < 0) # It's in the interval t[i-1] to t[i] avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual avoid_constant_ends && i == 1 && (i += 1) - @inbounds if !avoid_constant_ends && t[i] == tval + if !avoid_constant_ends && t[i] == tval lasti = lastindex(t) k = continuity == :right && i + 1 <= lasti && t[i + 1] == tval ? i + 1 : i if idxs === nothing @@ -267,7 +267,7 @@ times t (sorted), with values u and derivatives ks @inbounds i = searchsortedfirst(t, tval, rev = tdir < 0) # It's in the interval t[i-1] to t[i] avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual avoid_constant_ends && i == 1 && (i += 1) - @inbounds if !avoid_constant_ends && t[i] == tval + if !avoid_constant_ends && t[i] == tval lasti = lastindex(t) k = continuity == :right && i + 1 <= lasti && t[i + 1] == tval ? i + 1 : i if idxs === nothing diff --git a/test/solution_interface.jl b/test/solution_interface.jl index d2fa128dd..6d1a1ec5c 100644 --- a/test/solution_interface.jl +++ b/test/solution_interface.jl @@ -34,7 +34,9 @@ end ode = ODEProblem(f, 1.0, (0.0, 1.0)) sol = SciMLBase.build_solution(ode, :NoAlgorithm, [ode.tspan[begin]], [ode.u0]) @test sol(0.0) == 1.0 + @test sol([0.0,0.0]) == [1.0, 1.0] # test that indexing out of bounds doesn't segfault - @test_throws ErrorException sol(1.0) + @test_throws ErrorException sol(1) @test_throws ErrorException sol(-0.5) + @test_throws ErrorException sol([0, -0.5, 0]) end