Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BVPFunction #370

Merged
merged 10 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -655,9 +655,9 @@ function specialization(::Union{ODEFunction{iip, specialize},
ImplicitDiscreteFunction{iip, specialize},
RODEFunction{iip, specialize},
NonlinearFunction{iip, specialize},
OptimizationFunction{iip, specialize}}) where {iip,
OptimizationFunction{iip, specialize},
BVPFunction{iip, specialize}}) where {iip,
specialize}
specialize
end

specialization(f::AbstractSciMLFunction) = FullSpecialize
Expand Down Expand Up @@ -784,7 +784,7 @@ export remake

export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, DAEFunction,
DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction,
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction

export OptimizationFunction

Expand Down
2 changes: 1 addition & 1 deletion src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function __solve(prob::EnsembleProblem{<:AbstractVector{<:AbstractSciMLProblem}}
ensemblealg::BasicEnsembleAlgorithm; kwargs...)
# TODO: @invoke
invoke(__solve, Tuple{AbstractEnsembleProblem, typeof(alg), typeof(ensemblealg)},
prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...)
prob, alg, ensemblealg; trajectories = length(prob.prob), kwargs...)
end

function __solve(prob::AbstractEnsembleProblem,
Expand Down
33 changes: 20 additions & 13 deletions src/ensemble/ensemble_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ DEFAULT_REDUCTION(u, data, I) = append!(u, data), false
DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i]
function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...)
# TODO: @invoke
invoke(EnsembleProblem, Tuple{Any}, prob; prob_func=DEFAULT_VECTOR_PROB_FUNC, kwargs...)
invoke(EnsembleProblem,
Tuple{Any},
prob;
prob_func = DEFAULT_VECTOR_PROB_FUNC,
kwargs...)
end
function EnsembleProblem(prob;
output_func = DEFAULT_OUTPUT_FUNC,
Expand All @@ -36,20 +40,23 @@ function EnsembleProblem(; prob,
EnsembleProblem(prob, prob_func, output_func, reduction, u_init, safetycopy)
end

struct WeightedEnsembleProblem{T1<:AbstractEnsembleProblem, T2<:AbstractVector} <: AbstractEnsembleProblem
ensembleprob::T1
weights::T2
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
AbstractEnsembleProblem
ensembleprob::T1
weights::T2
end
function Base.propertynames(e::WeightedEnsembleProblem)
(Base.propertynames(getfield(e, :ensembleprob))..., :weights)
end
Base.propertynames(e::WeightedEnsembleProblem) = (Base.propertynames(getfield(e, :ensembleprob))..., :weights)
function Base.getproperty(e::WeightedEnsembleProblem, f::Symbol)
f === :weights && return getfield(e, :weights)
f === :ensembleprob && return getfield(e, :ensembleprob)
return getproperty(getfield(e, :ensembleprob), f)
f === :weights && return getfield(e, :weights)
f === :ensembleprob && return getfield(e, :ensembleprob)
return getproperty(getfield(e, :ensembleprob), f)
end
function WeightedEnsembleProblem(args...; weights, kwargs...)
# TODO: allow skipping checks?
@assert sum(weights) ≈ 1
ep = EnsembleProblem(args...; kwargs...)
@assert length(ep.prob) == length(weights)
WeightedEnsembleProblem(ep, weights)
# TODO: allow skipping checks?
@assert sum(weights) ≈ 1
ep = EnsembleProblem(args...; kwargs...)
@assert length(ep.prob) == length(weights)
WeightedEnsembleProblem(ep, weights)
end
13 changes: 9 additions & 4 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function EnsembleSolution(sim::T, elapsedTime,
converged)
end

struct WeightedEnsembleSolution{T1<:AbstractEnsembleSolution, T2<:Number}
struct WeightedEnsembleSolution{T1 <: AbstractEnsembleSolution, T2 <: Number}
ensol::T1
weights::Vector{T2}
function WeightedEnsembleSolution(ensol, weights)
Expand Down Expand Up @@ -207,13 +207,18 @@ end
end
end


Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
return [xi[s] for xi in x]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...)
return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...)
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution,
::Colon,
args::Colon...)
return invoke(getindex,
Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...},
x,
:,
args...)
end

function (sol::AbstractEnsembleSolution)(args...; kwargs...)
Expand Down
32 changes: 10 additions & 22 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,53 +78,41 @@ every solve call.
* `p`: The parameters for the problem. Defaults to `NullParameters`
* `kwargs`: The keyword arguments passed onto the solves.
"""
struct BVProblem{uType, tType, isinplace, P, F, bF, PT, K} <:
struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <:
AbstractBVProblem{uType, tType, isinplace}
f::F
bc::bF
bc::BF
u0::uType
tspan::tType
p::P
problem_type::PT
kwargs::K
@add_kwonly function BVProblem{iip}(f::AbstractODEFunction, bc, u0, tspan,

@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip}, bc, u0, tspan,
p = NullParameters(),
problem_type = StandardBVProblem();
kwargs...) where {iip}
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan), iip, typeof(p),
typeof(f), typeof(bc),
typeof(problem_type), typeof(kwargs)}(f, bc, u0, _tspan, p,
typeof(f.f), typeof(bc),
typeof(problem_type), typeof(kwargs)}(f.f, bc, u0, _tspan, p,
problem_type, kwargs)
end

function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}
BVProblem(ODEFunction{iip}(f), bc, u0, tspan, p; kwargs...)
BVProblem(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...)
end
end

TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2

function BVProblem(f::AbstractODEFunction, bc, u0, tspan, args...; kwargs...)
BVProblem{isinplace(f, 4)}(f, bc, u0, tspan, args...; kwargs...)
end

function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
BVProblem(ODEFunction(f), bc, u0, tspan, p; kwargs...)
BVProblem(BVPFunction(f, bc), u0, tspan, p; kwargs...)
end

# convenience interfaces:
# Allow any previous timeseries solution
function BVProblem(f::AbstractODEFunction, bc, sol::T, tspan::Tuple, p = NullParameters();
kwargs...) where {T <: AbstractTimeseriesSolution}
BVProblem(f, bc, sol.u, tspan, p)
end
# Allow a function of time for the initial guess
function BVProblem(f::AbstractODEFunction, bc, initialGuess, tspan::AbstractVector,
p = NullParameters(); kwargs...)
u0 = [initialGuess(i) for i in tspan]
BVProblem(f, bc, u0, (tspan[1], tspan[end]), p)
function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...)
BVProblem{isinplace(f)}(f.f, f.bc, u0, tspan, p; kwargs...)
end

"""
Expand Down
Loading