Skip to content

Commit 6d92f53

Browse files
Add PRIMA wrapper
1 parent 6d625af commit 6d92f53

File tree

4 files changed

+209
-0
lines changed

4 files changed

+209
-0
lines changed

lib/OptimizationPRIMA/LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2023 Vaibhav Dixit <[email protected]> and contributors
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

lib/OptimizationPRIMA/Project.toml

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
name = "OptimizationPRIMA"
2+
uuid = "72f8369c-a2ea-4298-9126-56167ce9cbc2"
3+
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
4+
version = "1.0.0-DEV"
5+
6+
[deps]
7+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
8+
PRIMA = "0a7d04aa-8ac2-47b3-b7a7-9dbd6ad661ed"
9+
10+
[compat]
11+
julia = "1"
12+
13+
[extras]
14+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
15+
16+
[targets]
17+
test = ["Test"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
module OptimizationPRIMA
2+
3+
using PRIMA, Optimization, Optimization.SciMLBase
4+
5+
abstract type PRIMASolvers end
6+
7+
struct UOBYQA <: PRIMASolvers end
8+
struct NEWUOA <: PRIMASolvers end
9+
struct BOBYQA <: PRIMASolvers end
10+
struct LINCOA <: PRIMASolvers end
11+
struct COBYLA <: PRIMASolvers end
12+
13+
SciMLBase.supports_opt_cache_interface(::PRIMASolvers) = true
14+
SciMLBase.allowsconstraints(::Union{LINCOA, COBYLA}) = true
15+
SciMLBase.allowsbounds(opt::Union{BOBYQA, LINCOA, COBYLA}) = true
16+
SciMLBase.requiresconstraints(opt::COBYLA) = true
17+
18+
function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt::PRIMASolvers,
19+
data = Optimization.DEFAULT_DATA;
20+
callback = (args...) -> (false),
21+
progress = false, kwargs...)
22+
return OptimizationCache(prob, opt, data; callback, progress,
23+
kwargs...)
24+
end
25+
26+
function get_solve_func(opt::PRIMASolvers)
27+
if opt isa UOBYQA
28+
return PRIMA.uobyqa
29+
elseif opt isa NEWUOA
30+
return PRIMA.newuoa
31+
elseif opt isa BOBYQA
32+
return PRIMA.bobyqa
33+
elseif opt isa LINCOA
34+
return PRIMA.lincoa
35+
elseif opt isa COBYLA
36+
return PRIMA.cobyla
37+
end
38+
end
39+
40+
function __map_optimizer_args!(cache::OptimizationCache, opt::PRIMASolvers;
41+
callback = nothing,
42+
maxiters::Union{Number, Nothing} = nothing,
43+
maxtime::Union{Number, Nothing} = nothing,
44+
abstol::Union{Number, Nothing} = nothing,
45+
reltol::Union{Number, Nothing} = nothing,
46+
kwargs...)
47+
kws = (; kwargs...)
48+
49+
if !isnothing(maxiters)
50+
kws = (; kws..., maxfun = maxiters)
51+
end
52+
53+
if cache.ub !== nothing
54+
kws = (; kws..., xu = cache.ub, xl = cache.lb)
55+
end
56+
57+
if !isnothing(maxtime) || !isnothing(abstol) || !isnothing(reltol)
58+
error("maxtime, abstol and reltol kwargs not supported in $opt")
59+
end
60+
61+
return kws
62+
end
63+
64+
function sciml_prima_retcode(rc::AbstractString)
65+
if rc in ["SMALL_TR_RADIUS", "TRSUBP_FAILED","NAN_INF_X"
66+
,"NAN_INF_F"
67+
,"NAN_INF_MODEL"
68+
,"DAMAGING_ROUNDING"
69+
,"ZERO_LINEAR_CONSTRAINT"
70+
,"INVALID_INPUT"
71+
,"ASSERTION_FAILS"
72+
,"VALIDATION_FAILS"
73+
,"MEMORY_ALLOCATION_FAILS"]
74+
return ReturnCode.Failure
75+
else rc in [
76+
"FTARGET_ACHIEVED"
77+
"MAXFUN_REACHED"
78+
"MAXTR_REACHED"
79+
"NO_SPACE_BETWEEN_BOUNDS"
80+
]
81+
return ReturnCode.Success
82+
end
83+
end
84+
85+
function SciMLBase.__solve(cache::OptimizationCache{
86+
F,
87+
RC,
88+
LB,
89+
UB,
90+
LC,
91+
UC,
92+
S,
93+
O,
94+
D,
95+
P,
96+
C,
97+
}) where {
98+
F,
99+
RC,
100+
LB,
101+
UB,
102+
LC,
103+
UC,
104+
S,
105+
O <: PRIMASolvers,
106+
D,
107+
P,
108+
C,
109+
}
110+
_loss = function (θ)
111+
x = cache.f(θ, cache.p)
112+
if cache.callback(θ, x...)
113+
error("Optimization halted by callback.")
114+
end
115+
return x[1]
116+
end
117+
118+
optfunc = get_solve_func(cache.opt)
119+
120+
121+
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
122+
maxtime = Optimization._check_and_convert_maxtime(cache.solver_args.maxtime)
123+
124+
kws = __map_optimizer_args!(cache, cache.opt; callback = cache.callback, maxiters = maxiters,
125+
maxtime = maxtime,
126+
cache.solver_args...)
127+
128+
t0 = time()
129+
if cache.opt isa COBYLA
130+
function fwcons(θ, res)
131+
cache.f.cons(res, θ, cache.p)
132+
return _loss(θ)
133+
end
134+
(minx, minf, nf, rc, cstrv) = optfunc(fwcons, cache.u0; kws...)
135+
elseif cache.opt isa LINCOA
136+
(minx, minf, nf, rc, cstrv) = optfunc(_loss, cache.u0; kws...)
137+
else
138+
(minx, minf, nf, rc) = optfunc(_loss, cache.u0; kws...)
139+
end
140+
t1 = time()
141+
142+
retcode = sciml_prima_retcode(PRIMA.reason(rc))
143+
144+
SciMLBase.build_solution(cache, cache.opt, minx,
145+
minf; retcode = retcode,
146+
solve_time = t1 - t0)
147+
end
148+
149+
export UOBYQA, NEWUOA, BOBYQA, LINCOA, COBYLA
150+
end
+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using OptimizationPRIMA, Optimization
2+
using Test
3+
4+
@testset "OptimizationPRIMA.jl" begin
5+
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
6+
x0 = zeros(2)
7+
_p = [1.0, 100.0]
8+
l1 = rosenbrock(x0, _p)
9+
10+
prob = OptimizationProblem(rosenbrock, x0, _p)
11+
sol = Optimization.solve(prob, UOBYQA(), maxiters = 1000)
12+
@test 10 * sol.objective < l1
13+
sol = Optimization.solve(prob, NEWUOA(), maxiters = 1000)
14+
@test 10 * sol.objective < l1
15+
sol = Optimization.solve(prob, BOBYQA(), maxiters = 1000)
16+
@test 10 * sol.objective < l1
17+
sol = Optimization.solve(prob, LINCOA(), maxiters = 1000)
18+
@test 10 * sol.objective < l1
19+
@test_throws SciMLBase.IncompatibleOptimizerError Optimization.solve(prob, COBYLA(), maxiters = 1000)
20+
21+
end

0 commit comments

Comments
 (0)