Skip to content
This repository was archived by the owner on Nov 23, 2018. It is now read-only.

Commit df1c7d6

Browse files
committed
Added NelderMead method and settings for derivative-free optimization
1 parent d702f88 commit df1c7d6

6 files changed

+217
-57
lines changed

functionconvergence.go

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright ©2015 The gonum Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package optimize
6+
7+
import "math"
8+
9+
// FunctionConverge tests for the convergence of function values. See comment
10+
// in Settings.
11+
type FunctionConverge struct {
12+
Absolute float64
13+
Relative float64
14+
Iterations int
15+
16+
best float64
17+
iter int
18+
}
19+
20+
func (fc *FunctionConverge) Init(f float64) {
21+
fc.best = f
22+
fc.iter = 0
23+
}
24+
25+
func (fc *FunctionConverge) FunctionConverged(f float64) Status {
26+
if fc.Iterations == 0 {
27+
return NotTerminated
28+
}
29+
maxAbs := math.Max(math.Abs(f), math.Abs(fc.best))
30+
if f < fc.best && fc.best-f > fc.Relative*maxAbs+fc.Absolute {
31+
fc.best = f
32+
fc.iter = 0
33+
return NotTerminated
34+
}
35+
fc.iter++
36+
if fc.iter < fc.Iterations {
37+
return NotTerminated
38+
}
39+
return FunctionConvergence
40+
}

gradfree_test.go

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// Copyright ©2015 The gonum Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package optimize
6+
7+
import (
8+
"testing"
9+
10+
"github.com/gonum/floats"
11+
"github.com/gonum/optimize/functions"
12+
)
13+
14+
type gradFreeTest struct {
15+
// f is the function that is being minimized.
16+
f Function
17+
// x is the initial guess.
18+
x []float64
19+
// absTol is the absolute function convergence for the test. If absTol == 0,
20+
// the default value of 1e-6 will be used
21+
absTol float64
22+
// absIter is the number of iterations for function convergence. If iter == 0,
23+
// the default value of 5 will be used
24+
absIter int
25+
// long indicates that the test takes long time to finish and will be
26+
// excluded if testing.Short() is true.
27+
long bool
28+
}
29+
30+
// gradFree ensures that the function is gradient free
31+
type gradFree struct {
32+
f Function
33+
}
34+
35+
func (g gradFree) Func(x []float64) float64 {
36+
return g.f.Func(x)
37+
}
38+
39+
// makeGradFree ensures that a function contains no gradient method
40+
func makeGradFree(f Function) gradFree {
41+
return gradFree{f}
42+
}
43+
44+
// TODO(btracey): The gradient is still evaluated and tested if available
45+
// even if a gradient-free method is being used. This should be fixed. When that
46+
// is fixed, this should include Functions that also have gradients.
47+
var gradFreeTests = []gradFreeTest{
48+
{
49+
f: makeGradFree(functions.ExtendedRosenbrock{}),
50+
x: []float64{-10, 10},
51+
},
52+
{
53+
f: makeGradFree(functions.ExtendedRosenbrock{}),
54+
x: []float64{-5, 4, 16, 3},
55+
},
56+
}
57+
58+
func TestLocalGradFree(t *testing.T) {
59+
testLocalGradFree(t, gradFreeTests, nil)
60+
}
61+
62+
func TestNelderMead(t *testing.T) {
63+
testLocalGradFree(t, gradFreeTests, &NelderMead{})
64+
}
65+
66+
func testLocalGradFree(t *testing.T, tests []gradFreeTest, method Method) {
67+
for _, test := range tests {
68+
if test.long && testing.Short() {
69+
continue
70+
}
71+
settings := DefaultSettings()
72+
settings.Recorder = nil
73+
if test.absIter == 0 {
74+
test.absIter = 5
75+
}
76+
if test.absTol == 0 {
77+
test.absTol = 1e-6
78+
}
79+
result, err := Local(test.f, test.x, settings, method)
80+
if err != nil {
81+
t.Errorf("error finding minimum (%v) for \n%v", err, test)
82+
}
83+
if result == nil {
84+
t.Errorf("nil result without error for:\n%v", test)
85+
continue
86+
}
87+
if result.Status != FunctionConvergence {
88+
t.Errorf("Status not %v, %v instead", FunctionConvergence, result.Status)
89+
}
90+
91+
result2, err := Local(test.f, test.x, settings, method)
92+
if err != nil {
93+
t.Errorf("error finding minimum (%v) when reusing Method for \n%v", err, test)
94+
}
95+
if result.FuncEvaluations != result2.FuncEvaluations ||
96+
result.F != result2.F || !floats.Equal(result.X, result2.X) {
97+
t.Errorf("Different result when reuse method")
98+
}
99+
}
100+
}

local.go

+16-6
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ func Local(f Function, initX []float64, settings *Settings, method Method) (*Res
8989
return nil, err
9090
}
9191

92+
if settings.FunctionConverge != nil {
93+
settings.FunctionConverge.Init(optLoc.F)
94+
}
95+
9296
// Runtime is the only Stats field that needs to be updated here.
9397
stats.Runtime = time.Since(startTime)
9498
// Send optLoc to Recorder before checking it for convergence.
@@ -198,8 +202,7 @@ func getDefaultMethod(funcInfo *functionInfo) Method {
198202
if funcInfo.IsGradient || funcInfo.IsFunctionGradient {
199203
return &BFGS{}
200204
}
201-
// TODO: Implement a gradient-free method
202-
panic("optimize: gradient-free methods not yet coded")
205+
return &NelderMead{}
203206
}
204207

205208
// getStartingLocation allocates and initializes the starting location for the minimization.
@@ -256,12 +259,19 @@ func checkConvergence(loc *Location, iterType IterationType, stats *Stats, setti
256259
if iterType == MajorIteration || iterType == InitIteration {
257260
if loc.Gradient != nil {
258261
norm := floats.Norm(loc.Gradient, math.Inf(1))
259-
if norm < settings.GradientAbsTol {
260-
return GradientThreshhold
262+
if norm < settings.GradientThreshold {
263+
return GradientThreshold
261264
}
262265
}
263-
if loc.F < settings.FunctionAbsTol {
264-
return FunctionThreshhold
266+
if loc.F < settings.FunctionThreshold {
267+
return FunctionThreshold
268+
}
269+
}
270+
271+
if iterType == MajorIteration && settings.FunctionConverge != nil {
272+
status := settings.FunctionConverge.FunctionConverged(loc.F)
273+
if status != NotTerminated {
274+
return status
265275
}
266276
}
267277

neldermead.go

+46-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright ©2014 The gonum Authors. All rights reserved.
1+
// Copyright ©2015 The gonum Authors. All rights reserved.
22
// Use of this source code is governed by a BSD-style
33
// license that can be found in the LICENSE file.
44

@@ -49,7 +49,10 @@ func (n nmVertexSorter) Swap(i, j int) {
4949
//
5050
// If an initial simplex is provided, it is used and initLoc is ignored. If
5151
// InitialVertices and InitialValues are both nil, an initial simplex will be
52-
// generated automatically. If the simplex update parameters (Reflection, etc.)
52+
// generated automatically using the initial location as one vertex, and each
53+
// additional vertex as SimplexSize away in one dimension.
54+
//
55+
// If the simplex update parameters (Reflection, etc.)
5356
// are zero, they will be set automatically based on the dimension according to
5457
// the recommendations in
5558
//
@@ -61,6 +64,12 @@ type NelderMead struct {
6164
Expansion float64 // Expansion parameter (>1)
6265
Contraction float64 // Contraction parameter (>0, <1)
6366
Shrink float64 // Shrink parameter (>0, <1)
67+
SimplexSize float64 // size of auto-constructed initial simplex
68+
69+
reflection float64
70+
expansion float64
71+
contraction float64
72+
shrink float64
6473

6574
vertices [][]float64 // location of the vertices sorted in ascending f
6675
values []float64 // function values at the vertices sorted in ascending f
@@ -85,19 +94,27 @@ func (n *NelderMead) Init(loc *Location, f *FunctionInfo, xNext []float64) (Eval
8594
n.centroid = resize(n.centroid, dim)
8695
n.reflectedPoint = resize(n.reflectedPoint, dim)
8796

97+
if n.SimplexSize == 0 {
98+
n.SimplexSize = 0.05
99+
}
100+
88101
// Default parameter choices are chosen in a dimension-dependent way
89102
// from http://www.webpages.uidaho.edu/~fuchang/res/ANMS.pdf
90-
if n.Reflection == 0 {
91-
n.Reflection = 1
103+
n.reflection = n.Reflection
104+
if n.reflection == 0 {
105+
n.reflection = 1
92106
}
93-
if n.Expansion == 0 {
94-
n.Expansion = 1 + 2.0/float64(dim)
107+
n.expansion = n.Expansion
108+
if n.expansion == 0 {
109+
n.expansion = 1 + 2/float64(dim)
95110
}
96-
if n.Contraction == 0 {
97-
n.Contraction = 0.75 - 1.0/(2.0*float64(dim))
111+
n.contraction = n.Contraction
112+
if n.contraction == 0 {
113+
n.contraction = 0.75 - 1/(2*float64(dim))
98114
}
99-
if n.Shrink == 0 {
100-
n.Shrink = 1 - 1.0/float64(dim)
115+
n.shrink = n.Shrink
116+
if n.shrink == 0 {
117+
n.shrink = 1 - 1/float64(dim)
101118
}
102119

103120
if n.InitialVertices != nil {
@@ -126,8 +143,7 @@ func (n *NelderMead) Init(loc *Location, f *FunctionInfo, xNext []float64) (Eval
126143
n.values[dim] = loc.F
127144
n.fillIdx = 0
128145
copy(xNext, loc.X)
129-
xNext[0] += 1
130-
copy(n.vertices[0], xNext)
146+
xNext[0] += n.SimplexSize
131147
n.lastIter = nmInitialize
132148
return FuncEvaluation, InitIteration, nil
133149
}
@@ -155,17 +171,16 @@ func (n *NelderMead) Iterate(loc *Location, xNext []float64) (EvaluationType, It
155171
switch n.lastIter {
156172
case nmInitialize:
157173
n.values[n.fillIdx] = loc.F
174+
copy(n.vertices[n.fillIdx], loc.X)
158175
n.fillIdx++
159176
if n.fillIdx == dim {
160177
// Successfully finished building initial simplex.
161178
sort.Sort(nmVertexSorter{n.vertices, n.values})
162179
computeCentroid(n.vertices, n.centroid)
163180
return n.returnNext(nmReflected, xNext)
164181
}
165-
copy(xNext, loc.X)
166-
xNext[n.fillIdx-1] -= 1
167-
xNext[n.fillIdx] += 1
168-
copy(n.vertices[n.fillIdx], xNext)
182+
copy(xNext, n.vertices[dim])
183+
xNext[n.fillIdx] += n.SimplexSize
169184
return FuncEvaluation, InitIteration, nil
170185
case nmReflected:
171186
n.reflectedValue = loc.F
@@ -228,13 +243,13 @@ func (n *NelderMead) returnNext(iter nmIterType, xNext []float64) (EvaluationTyp
228243
var scale float64
229244
switch iter {
230245
case nmReflected:
231-
scale = n.Reflection
246+
scale = n.reflection
232247
case nmExpanded:
233-
scale = n.Reflection * n.Expansion
248+
scale = n.reflection * n.expansion
234249
case nmContractedOutside:
235-
scale = n.Reflection * n.Contraction
250+
scale = n.reflection * n.contraction
236251
case nmContractedInside:
237-
scale = -n.Contraction
252+
scale = -n.contraction
238253
}
239254
floats.SubTo(xNext, n.centroid, n.vertices[dim])
240255
floats.Scale(scale, xNext)
@@ -248,7 +263,7 @@ func (n *NelderMead) returnNext(iter nmIterType, xNext []float64) (EvaluationTyp
248263
case nmShrink:
249264
// x_shrink = x_best + delta * (x_i + x_best)
250265
floats.SubTo(xNext, n.vertices[n.fillIdx], n.vertices[0])
251-
floats.Scale(n.Shrink, xNext)
266+
floats.Scale(n.shrink, xNext)
252267
floats.Add(xNext, n.vertices[0])
253268
return FuncEvaluation, SubIteration, nil
254269
default:
@@ -280,3 +295,13 @@ func (n *NelderMead) replaceWorst(x []float64, f float64) {
280295
floats.AddScaled(n.centroid, -1/float64(dim), n.vertices[dim])
281296
floats.AddScaled(n.centroid, 1/float64(dim), x)
282297
}
298+
299+
func (*NelderMead) Needs() struct {
300+
Gradient bool
301+
Hessian bool
302+
} {
303+
return struct {
304+
Gradient bool
305+
Hessian bool
306+
}{false, false}
307+
}

termination.go

+8-6
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@ type Status int
1313
const (
1414
NotTerminated Status = iota
1515
Success
16-
FunctionAbsoluteConvergence
17-
FunctionRelativeConvergence
18-
FunctionThreshhold
19-
GradientThreshhold
16+
FunctionThreshold
17+
FunctionConvergence
18+
GradientThreshold
2019
StepConvergence
2120
FunctionNegativeInfinity
2221
Failure
@@ -57,10 +56,13 @@ var statuses = []struct {
5756
name: "Success",
5857
},
5958
{
60-
name: "FunctionAbsoluteConvergence",
59+
name: "FunctionThreshold",
6160
},
6261
{
63-
name: "GradientAbsoluteConvergence",
62+
name: "FunctionConvergence",
63+
},
64+
{
65+
name: "GradientThreshold",
6466
},
6567
{
6668
name: "StepConvergence",

0 commit comments

Comments
 (0)