From 661cafaf68d5e32e97c2a32eb09caa7b660e9bd2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 9 Oct 2024 15:29:47 +0530 Subject: [PATCH] fix: handle empty `idxs` in interpolation --- src/solutions/ode_solutions.jl | 12 ++++++++++++ test/solution_interface.jl | 14 ++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 9454902fe..2d49b6a1e 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -245,6 +245,9 @@ end function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector{<:Integer}, continuity) where {deriv} + if isempty(idxs) + return eltype(eltype(sol.u))[] + end if eltype(sol.u) <: Number idxs = only(idxs) end @@ -259,6 +262,9 @@ end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs::AbstractVector{<:Integer}, continuity) where {deriv} + if isempty(idxs) + return map(_ -> eltype(eltype(sol.u))[], t) + end if eltype(sol.u) <: Number idxs = only(idxs) end @@ -295,6 +301,9 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect any(isequal(NotSymbolic()), symbolic_type.(idxs)) error("Incorrect specification of `idxs`") end + if isempty(idxs) + return eltype(eltype(sol.u))[] + end error_if_observed_derivative(sol, idxs, deriv) ps = parameter_values(sol) if is_parameter_timeseries(sol) == Timeseries() && is_discrete_expression(sol, idxs) @@ -335,6 +344,9 @@ end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs::AbstractVector, continuity) where {deriv} + if isempty(idxs) + return map(_ -> eltype(eltype(sol.u))[], t) + end error_if_observed_derivative(sol, idxs, deriv) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing getter = getu(sol, idxs) diff --git a/test/solution_interface.jl b/test/solution_interface.jl index 1f7773084..c38d02fd4 100644 --- a/test/solution_interface.jl +++ b/test/solution_interface.jl @@ -37,3 +37,17 @@ end @test_throws ErrorException sol(-0.5) @test_throws ErrorException sol([0, -0.5, 0]) end + +@testset "interpolate with empty idxs" begin + f = (u, p, t) -> u + sol1 = SciMLBase.build_solution( + ODEProblem(f, 1.0, (0.0, 1.0)), :NoAlgorithm, 0.0:0.1:1.0, exp.(0.0:0.1:1.0)) + sol2 = SciMLBase.build_solution(ODEProblem(f, [1.0, 2.0], (0.0, 1.0)), :NoAlgorithm, + 0.0:0.1:1.0, vcat.(exp.(0.0:0.1:1.0), 2exp.(0.0:0.1:1.0))) + for sol in [sol1, sol2] + @test sol(0.15; idxs = []) == Float64[] + @test sol(0.15; idxs = Int[]) == Float64[] + @test sol([0.15, 0.25]; idxs = []) == [Float64[], Float64[]] + @test sol([0.15, 0.25]; idxs = Int[]) == [Float64[], Float64[]] + end +end