Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add integrand interface #497

Merged
merged 28 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
57ba954
add integrand interface
lxvm Sep 16, 2023
21b7895
add InplaceBatchIntegrand
lxvm Sep 16, 2023
6a2038a
format and include
lxvm Sep 17, 2023
3f77759
make the IntegralFunctions
lxvm Sep 19, 2023
f56d654
canonicalize
ChrisRackauckas Sep 19, 2023
e3ee453
Remove error checking on function definition of batch integral
ChrisRackauckas Sep 19, 2023
318e79f
add error test on incorrect integral function dispatches
ChrisRackauckas Sep 19, 2023
783b88e
argument amounts testing
ChrisRackauckas Sep 19, 2023
fcc7edb
some better utils checks
ChrisRackauckas Sep 19, 2023
5a37040
apply format
lxvm Sep 19, 2023
b02e470
fix integralfunction iip
lxvm Sep 19, 2023
95bdb1d
rename integrand_prototype to integral_prototype
lxvm Sep 19, 2023
5675e6f
Update test/function_building_error_messages.jl
ChrisRackauckas Sep 21, 2023
774e4be
fix typos
ChrisRackauckas Sep 21, 2023
bbe691b
revert naming to integrand_prototype
lxvm Sep 21, 2023
c0f4062
wrap integrand with IntegralFunction in IntegralProblem
lxvm Sep 21, 2023
740576a
make integral functions callable
lxvm Sep 21, 2023
3f7d1fb
simplify IntegralProblem definition
lxvm Sep 21, 2023
0deeefb
update docstrings
lxvm Sep 21, 2023
e27965d
apply format
lxvm Sep 21, 2023
5be7d7a
remove output_prototype
lxvm Sep 21, 2023
8ebfe42
add deprecation method
lxvm Sep 21, 2023
e6a0547
Update src/problems/basic_problems.jl
ChrisRackauckas Sep 21, 2023
a3a09d4
Merge branch 'master' into integrands
ChrisRackauckas Sep 21, 2023
619ac07
Update test/function_building_error_messages.jl
ChrisRackauckas Sep 21, 2023
83f933d
Update function_building_error_messages.jl
ChrisRackauckas Sep 21, 2023
7740dd4
fix default batch and dispatch
lxvm Sep 21, 2023
a6fd63a
Change version just to run tests
ChrisRackauckas Sep 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,14 @@ abstract type AbstractDiffEqFunction{iip} <:
"""
$(TYPEDEF)

Base for types defining integrand functions.
"""
abstract type AbstractIntegralFunction{iip} <:
AbstractSciMLFunction{iip} end

"""
$(TYPEDEF)

Base for types defining optimization functions.
"""
abstract type AbstractOptimizationFunction{iip} <: AbstractSciMLFunction{iip} end
Expand Down Expand Up @@ -659,7 +667,9 @@ function specialization(::Union{ODEFunction{iip, specialize},
RODEFunction{iip, specialize},
NonlinearFunction{iip, specialize},
OptimizationFunction{iip, specialize},
BVPFunction{iip, specialize}}) where {iip,
BVPFunction{iip, specialize},
IntegralFunction{iip, specialize},
BatchIntegralFunction{iip, specialize}}) where {iip,
specialize}
specialize
end
Expand Down Expand Up @@ -787,7 +797,8 @@ export remake

export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, DAEFunction,
DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction,
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction,
IntegralFunction, BatchIntegralFunction

export OptimizationFunction

Expand Down
89 changes: 54 additions & 35 deletions src/problems/basic_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,26 +335,16 @@
### Constructors

```
IntegralProblem{iip}(f,lb,ub,p=NullParameters();
nout=1, batch = 0, kwargs...)
IntegralProblem(f,domain,p=NullParameters(); kwargs...)
IntegralProblem(f,lb,ub,p=NullParameters(); kwargs...)
```

- f: the integrand, callable function `y = f(u,p)` for out-of-place or `f(y,u,p)` for in-place.
- f: the integrand, callable function `y = f(u,p)` for out-of-place (default) or an
`IntegralFunction` or `BatchIntegralFunction` for inplace and batching optimizations.
- domain: an object representing an integration domain, i.e. the tuple `(lb, ub)`.
- lb: Either a number or vector of lower bounds.
- ub: Either a number or vector of upper bounds.
- p: The parameters associated with the problem.
- nout: The output size of the function f. Defaults to 1, i.e., a scalar valued function.
If `nout > 1` f is a vector valued function .
- batch: The preferred number of points to batch. This allows user-side parallelization
of the integrand. If `batch == 0` no batching is performed.
If `batch > 0` both `u` and `y` get an additional dimension added to it.
This means that:
if `f` is a multi variable function each `u[:,i]` is a different point to evaluate `f` at,
if `f` is a single variable function each `u[i]` is a different point to evaluate `f` at,
if `f` is a vector valued function each `y[:,i]` is the evaluation of `f` at a different point,
if `f` is a scalar valued function `y[i]` is the evaluation of `f` at a different point.
Note that batch is a suggestion for the number of points,
and it is not necessarily true that batch is the same as batchsize in all algorithms.
- kwargs: Keyword arguments copied to the solvers.

Additionally, we can supply iip like IntegralProblem{iip}(...) as true or false to declare at
Expand All @@ -364,30 +354,59 @@

The fields match the names of the constructor arguments.
"""
struct IntegralProblem{isinplace, P, F, B, K} <: AbstractIntegralProblem{isinplace}
struct IntegralProblem{isinplace, P, F, T, K} <: AbstractIntegralProblem{isinplace}
f::F
lb::B
ub::B
nout::Int
domain::T
p::P
batch::Int
kwargs::K
@add_kwonly function IntegralProblem{iip}(f, lb, ub, p = NullParameters();
nout = 1,
batch = 0, kwargs...) where {iip}
@assert typeof(lb)==typeof(ub) "Type of lower and upper bound must match"
@add_kwonly function IntegralProblem{iip}(f::AbstractIntegralFunction{iip}, domain,

Check warning on line 362 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L362

Added line #L362 was not covered by tests
p = NullParameters();
kwargs...) where {iip}
warn_paramtype(p)
new{iip, typeof(p), typeof(f), typeof(lb), typeof(kwargs)}(f,
lb, ub, nout, p,
batch, kwargs)
Comment on lines -381 to -382
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have a deprecation path where if nout and batch are supplied we throw a warning and just define the prototypes as Arrays appropriately sized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for bringing this up. While implementing it I realized that the BatchIntegralFunction can let the user pass a two-argument out-of-place form and a 3-argument in-place form, which was allowed before and would match IntegralFunction. Then I'll remove the output_prototype field, since we can still query the output type of an out-of-place BatchIntegralFunction by calling the function on an empty vector of input points. The details of allocating an output_prototype may differ across libraries, but we have a mechanism to get the output type for both iip and oop forms, so this buffer can be correctly allocated by the solver, and solves our previous issue.

new{iip, typeof(p), typeof(f), typeof(domain), typeof(kwargs)}(f,

Check warning on line 366 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L366

Added line #L366 was not covered by tests
domain, p, kwargs)
end
end

TruncatedStacktraces.@truncate_stacktrace IntegralProblem 1 4

function IntegralProblem(f, lb, ub, args...;
function IntegralProblem(f::AbstractIntegralFunction,

Check warning on line 373 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L373

Added line #L373 was not covered by tests
domain,
p = NullParameters();
kwargs...)
IntegralProblem{isinplace(f, 3)}(f, lb, ub, args...; kwargs...)
IntegralProblem{isinplace(f)}(f, domain, p; kwargs...)

Check warning on line 377 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L377

Added line #L377 was not covered by tests
end

function IntegralProblem(f::AbstractIntegralFunction,

Check warning on line 380 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L380

Added line #L380 was not covered by tests
lb::B,
ub::B,
p = NullParameters();
kwargs...) where {B}
IntegralProblem(f, (lb, ub), p; kwargs...)

Check warning on line 385 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L385

Added line #L385 was not covered by tests
end

function IntegralProblem(f, args...; nout = nothing, batch = nothing, kwargs...)
if nout !== nothing || batch !== nothing
@warn "`nout` and `batch` keywords are deprecated in favor of inplace `IntegralFunction`s or `BatchIntegralFunction`s. See the updated Integrals.jl documentation for details."

Check warning on line 390 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L388-L390

Added lines #L388 - L390 were not covered by tests
end

iip = isinplace(f, 3)
g = if iip
nout = nout === nothing ? 1 : nout
output_prototype = Vector{Float64}(undef, nout)
if batch == 0
IntegralFunction(f, output_prototype)

Check warning on line 398 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L393-L398

Added lines #L393 - L398 were not covered by tests
else
BatchIntegralFunction(f, output_prototype, max_batch=batch)

Check warning on line 400 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L400

Added line #L400 was not covered by tests
end
else
if batch == 0
IntegralFunction(f)

Check warning on line 404 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L403-L404

Added lines #L403 - L404 were not covered by tests
else
BatchIntegralFunction(f, max_batch=batch)

Check warning on line 406 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L406

Added line #L406 was not covered by tests
end
end
IntegralProblem{iip}(g, args...; kwargs...)

Check warning on line 409 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L409

Added line #L409 was not covered by tests
end

struct QuadratureProblem end
Expand All @@ -405,8 +424,8 @@
```math
\sum_i w_i y_i
```
where `y_i` are sampled values of the integrand, and `w_i` are weights
assigned by a quadrature rule, which depend on sampling points `x`.
where `y_i` are sampled values of the integrand, and `w_i` are weights
assigned by a quadrature rule, which depend on sampling points `x`.

## Problem Type

Expand All @@ -415,10 +434,10 @@
```
SampledIntegralProblem(y::AbstractArray, x::AbstractVector; dim=ndims(y), kwargs...)
```
- y: The sampled integrand, must be a subtype of `AbstractArray`.
It is assumed that the values of `y` along dimension `dim`
- y: The sampled integrand, must be a subtype of `AbstractArray`.
It is assumed that the values of `y` along dimension `dim`
correspond to the integrand evaluated at sampling points `x`
- x: Sampling points, must be a subtype of `AbstractVector`.
- x: Sampling points, must be a subtype of `AbstractVector`.
- dim: Dimension along which to integrate. Defaults to the last dimension of `y`.
- kwargs: Keyword arguments copied to the solvers.

Expand All @@ -434,7 +453,7 @@
@add_kwonly function SampledIntegralProblem(y::AbstractArray, x::AbstractVector;
dim = ndims(y),
kwargs...)
@assert dim <= ndims(y) "The integration dimension `dim` is larger than the number of dimensions of the integrand `y`"
@assert dim<=ndims(y) "The integration dimension `dim` is larger than the number of dimensions of the integrand `y`"

Check warning on line 456 in src/problems/basic_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/basic_problems.jl#L456

Added line #L456 was not covered by tests
@assert length(x)==size(y, dim) "The integrand `y` must have the same length as the sampling points `x` along the integrated dimension."
@assert axes(x, 1)==axes(y, dim) "The integrand `y` must obey the same indexing as the sampling points `x` along the integrated dimension."
new{typeof(y), typeof(x), typeof(kwargs)}(y, x, dim, kwargs)
Expand Down
Loading
Loading