Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BVPFunction #370

Merged
merged 10 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -655,9 +655,9 @@ function specialization(::Union{ODEFunction{iip, specialize},
ImplicitDiscreteFunction{iip, specialize},
RODEFunction{iip, specialize},
NonlinearFunction{iip, specialize},
OptimizationFunction{iip, specialize}}) where {iip,
OptimizationFunction{iip, specialize},
BVPFunction{iip, specialize}}) where {iip,
specialize}
specialize
end

specialization(f::AbstractSciMLFunction) = FullSpecialize
Expand Down Expand Up @@ -784,7 +784,7 @@ export remake

export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, DAEFunction,
DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction,
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction

export OptimizationFunction

Expand Down
2 changes: 1 addition & 1 deletion src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function __solve(prob::EnsembleProblem{<:AbstractVector{<:AbstractSciMLProblem}}
ensemblealg::BasicEnsembleAlgorithm; kwargs...)
# TODO: @invoke
invoke(__solve, Tuple{AbstractEnsembleProblem, typeof(alg), typeof(ensemblealg)},
prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...)
prob, alg, ensemblealg; trajectories = length(prob.prob), kwargs...)
end

function __solve(prob::AbstractEnsembleProblem,
Expand Down
33 changes: 20 additions & 13 deletions src/ensemble/ensemble_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i]
function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...)
# TODO: @invoke
invoke(EnsembleProblem, Tuple{Any}, prob; prob_func=DEFAULT_VECTOR_PROB_FUNC, kwargs...)
invoke(EnsembleProblem,
Tuple{Any},
prob;
prob_func = DEFAULT_VECTOR_PROB_FUNC,
kwargs...)
end
function EnsembleProblem(prob;
output_func = DEFAULT_OUTPUT_FUNC,
Expand All @@ -36,20 +40,23 @@
EnsembleProblem(prob, prob_func, output_func, reduction, u_init, safetycopy)
end

struct WeightedEnsembleProblem{T1<:AbstractEnsembleProblem, T2<:AbstractVector} <: AbstractEnsembleProblem
ensembleprob::T1
weights::T2
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
AbstractEnsembleProblem
ensembleprob::T1
weights::T2
end
function Base.propertynames(e::WeightedEnsembleProblem)
(Base.propertynames(getfield(e, :ensembleprob))..., :weights)

Check warning on line 49 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L48-L49

Added lines #L48 - L49 were not covered by tests
end
Base.propertynames(e::WeightedEnsembleProblem) = (Base.propertynames(getfield(e, :ensembleprob))..., :weights)
function Base.getproperty(e::WeightedEnsembleProblem, f::Symbol)
f === :weights && return getfield(e, :weights)
f === :ensembleprob && return getfield(e, :ensembleprob)
return getproperty(getfield(e, :ensembleprob), f)
f === :weights && return getfield(e, :weights)
f === :ensembleprob && return getfield(e, :ensembleprob)
return getproperty(getfield(e, :ensembleprob), f)

Check warning on line 54 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L52-L54

Added lines #L52 - L54 were not covered by tests
end
function WeightedEnsembleProblem(args...; weights, kwargs...)
# TODO: allow skipping checks?
@assert sum(weights) ≈ 1
ep = EnsembleProblem(args...; kwargs...)
@assert length(ep.prob) == length(weights)
WeightedEnsembleProblem(ep, weights)
# TODO: allow skipping checks?
@assert sum(weights) ≈ 1
ep = EnsembleProblem(args...; kwargs...)
@assert length(ep.prob) == length(weights)
WeightedEnsembleProblem(ep, weights)

Check warning on line 61 in src/ensemble/ensemble_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_problems.jl#L58-L61

Added lines #L58 - L61 were not covered by tests
end
13 changes: 9 additions & 4 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
converged)
end

struct WeightedEnsembleSolution{T1<:AbstractEnsembleSolution, T2<:Number}
struct WeightedEnsembleSolution{T1 <: AbstractEnsembleSolution, T2 <: Number}
ensol::T1
weights::Vector{T2}
function WeightedEnsembleSolution(ensol, weights)
Expand Down Expand Up @@ -207,13 +207,18 @@
end
end


Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
return [xi[s] for xi in x]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...)
return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...)
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution,

Check warning on line 214 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L214

Added line #L214 was not covered by tests
::Colon,
args::Colon...)
return invoke(getindex,

Check warning on line 217 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L217

Added line #L217 was not covered by tests
Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...},
x,
:,
args...)
end

function (sol::AbstractEnsembleSolution)(args...; kwargs...)
Expand Down
32 changes: 10 additions & 22 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,53 +78,41 @@ every solve call.
* `p`: The parameters for the problem. Defaults to `NullParameters`
* `kwargs`: The keyword arguments passed onto the solves.
"""
struct BVProblem{uType, tType, isinplace, P, F, bF, PT, K} <:
struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <:
AbstractBVProblem{uType, tType, isinplace}
f::F
bc::bF
bc::BF
u0::uType
tspan::tType
p::P
problem_type::PT
kwargs::K
@add_kwonly function BVProblem{iip}(f::AbstractODEFunction, bc, u0, tspan,

@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip}, bc, u0, tspan,
p = NullParameters(),
problem_type = StandardBVProblem();
kwargs...) where {iip}
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan), iip, typeof(p),
typeof(f), typeof(bc),
typeof(problem_type), typeof(kwargs)}(f, bc, u0, _tspan, p,
typeof(f.f), typeof(bc),
typeof(problem_type), typeof(kwargs)}(f.f, bc, u0, _tspan, p,
problem_type, kwargs)
end

function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}
BVProblem(ODEFunction{iip}(f), bc, u0, tspan, p; kwargs...)
BVProblem(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...)
end
end

TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2

function BVProblem(f::AbstractODEFunction, bc, u0, tspan, args...; kwargs...)
BVProblem{isinplace(f, 4)}(f, bc, u0, tspan, args...; kwargs...)
end

function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
BVProblem(ODEFunction(f), bc, u0, tspan, p; kwargs...)
BVProblem(BVPFunction(f, bc), u0, tspan, p; kwargs...)
end

# convenience interfaces:
# Allow any previous timeseries solution
function BVProblem(f::AbstractODEFunction, bc, sol::T, tspan::Tuple, p = NullParameters();
kwargs...) where {T <: AbstractTimeseriesSolution}
BVProblem(f, bc, sol.u, tspan, p)
end
# Allow a function of time for the initial guess
function BVProblem(f::AbstractODEFunction, bc, initialGuess, tspan::AbstractVector,
p = NullParameters(); kwargs...)
u0 = [initialGuess(i) for i in tspan]
BVProblem(f, bc, u0, (tspan[1], tspan[end]), p)
function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...)
BVProblem{isinplace(f)}(f.f, f.bc, u0, tspan, p; kwargs...)
end

"""
Expand Down
Loading
Loading