1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;
18
19 import java.util.Arrays;
20 import java.util.List;
21 import java.util.ArrayList;
22 import java.util.Comparator;
23 import java.util.Collections;
24 import java.util.Objects;
25 import java.util.function.UnaryOperator;
26 import java.util.function.IntSupplier;
27 import java.util.concurrent.CopyOnWriteArrayList;
28
29 import org.apache.commons.math4.legacy.core.MathArrays;
30 import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
31 import org.apache.commons.math4.legacy.exception.MathUnsupportedOperationException;
32 import org.apache.commons.math4.legacy.exception.MathInternalError;
33 import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
34 import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
35 import org.apache.commons.math4.legacy.optim.OptimizationData;
36 import org.apache.commons.math4.legacy.optim.PointValuePair;
37 import org.apache.commons.math4.legacy.optim.SimpleValueChecker;
38 import org.apache.commons.math4.legacy.optim.InitialGuess;
39 import org.apache.commons.math4.legacy.optim.MaxEval;
40 import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType;
41 import org.apache.commons.math4.legacy.optim.nonlinear.scalar.MultivariateOptimizer;
42 import org.apache.commons.math4.legacy.optim.nonlinear.scalar.SimulatedAnnealing;
43 import org.apache.commons.math4.legacy.optim.nonlinear.scalar.PopulationSize;
44 import org.apache.commons.math4.legacy.optim.nonlinear.scalar.ObjectiveFunction;
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110 public class SimplexOptimizer extends MultivariateOptimizer {
111
112 private static final double SIMPLEX_SIDE_RATIO = 1e-1;
113
114 private Simplex.TransformFactory updateRule;
115
116 private Simplex initialSimplex;
117
118 private SimulatedAnnealing simulatedAnnealing = null;
119
120 private int populationSize = 0;
121
122 private int additionalSearch = 0;
123
124 private final List<Observer> callbacks = new CopyOnWriteArrayList<>();
125
126
127
128
129 public SimplexOptimizer(ConvergenceChecker<PointValuePair> checker) {
130 super(checker);
131 }
132
133
134
135
136
137 public SimplexOptimizer(double rel,
138 double abs) {
139 this(new SimpleValueChecker(rel, abs));
140 }
141
142
143
144
145
146 @FunctionalInterface
147 public interface Observer {
148
149
150
151
152
153
154
155
156
157 void update(Simplex simplex,
158 boolean isInit,
159 int numEval);
160 }
161
162
163
164
165
166
167 public void addObserver(Observer cb) {
168 Objects.requireNonNull(cb, "Callback");
169 callbacks.add(cb);
170 }
171
172
173 @Override
174 protected PointValuePair doOptimize() {
175 checkParameters();
176
177 final MultivariateFunction evalFunc = getObjectiveFunction();
178
179 final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
180 final Comparator<PointValuePair> comparator = (o1, o2) -> {
181 final double v1 = o1.getValue();
182 final double v2 = o2.getValue();
183 return isMinim ? Double.compare(v1, v2) : Double.compare(v2, v1);
184 };
185
186
187 final List<PointValuePair> bestList = new ArrayList<>();
188
189 Simplex currentSimplex = initialSimplex.translate(getStartPoint()).evaluate(evalFunc, comparator);
190 notifyObservers(currentSimplex, true);
191 double temperature = Double.NaN;
192 Simplex previousSimplex = null;
193
194 if (simulatedAnnealing != null) {
195 temperature =
196 temperature(currentSimplex.get(0),
197 currentSimplex.get(currentSimplex.getDimension()),
198 simulatedAnnealing.getStartProbability());
199 }
200
201 while (true) {
202 if (previousSimplex != null) {
203 if (hasConverged(previousSimplex, currentSimplex)) {
204 return currentSimplex.get(0);
205 }
206 }
207
208
209 previousSimplex = currentSimplex;
210
211 if (simulatedAnnealing != null) {
212
213 temperature =
214 simulatedAnnealing.getCoolingSchedule().apply(temperature,
215 currentSimplex);
216
217 final double endTemperature =
218 temperature(currentSimplex.get(0),
219 currentSimplex.get(currentSimplex.getDimension()),
220 simulatedAnnealing.getEndProbability());
221
222 if (temperature < endTemperature) {
223 break;
224 }
225
226 final UnaryOperator<Simplex> update =
227 updateRule.create(evalFunc,
228 comparator,
229 simulatedAnnealing.metropolis(temperature));
230
231 for (int i = 0; i < simulatedAnnealing.getEpochDuration(); i++) {
232
233 currentSimplex = applyUpdate(update,
234 currentSimplex,
235 evalFunc,
236 comparator);
237 }
238 } else {
239
240 final UnaryOperator<Simplex> update =
241 updateRule.create(evalFunc, comparator, null);
242
243
244 currentSimplex = applyUpdate(update,
245 currentSimplex,
246 evalFunc,
247 comparator);
248 }
249
250 if (additionalSearch != 0) {
251
252
253
254 final int max = Math.max(additionalSearch, 2);
255
256
257 for (int i = 0; i < currentSimplex.getSize(); i++) {
258 keepIfBetter(currentSimplex.get(i),
259 comparator,
260 bestList,
261 max);
262 }
263 }
264
265 incrementIterationCount();
266 }
267
268
269
270 if (additionalSearch > 0) {
271
272
273
274 final IntSupplier evalCount = () -> getEvaluations();
275
276 return bestListSearch(evalFunc,
277 comparator,
278 bestList,
279 evalCount);
280 }
281
282 throw new MathInternalError();
283 }
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298 @Override
299 protected void parseOptimizationData(OptimizationData... optData) {
300
301 super.parseOptimizationData(optData);
302
303
304
305 for (OptimizationData data : optData) {
306 if (data instanceof Simplex) {
307 initialSimplex = (Simplex) data;
308 } else if (data instanceof Simplex.TransformFactory) {
309 updateRule = (Simplex.TransformFactory) data;
310 } else if (data instanceof SimulatedAnnealing) {
311 simulatedAnnealing = (SimulatedAnnealing) data;
312 } else if (data instanceof PopulationSize) {
313 populationSize = ((PopulationSize) data).getPopulationSize();
314 }
315 }
316 }
317
318
319
320
321
322
323
324
325
326 private boolean hasConverged(Simplex previous,
327 Simplex current) {
328 final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
329
330 for (int i = 0; i < current.getSize(); i++) {
331 final PointValuePair prev = previous.get(i);
332 final PointValuePair curr = current.get(i);
333
334 if (!checker.converged(getIterations(), prev, curr)) {
335 return false;
336 }
337 }
338
339 return true;
340 }
341
342
343
344
345
346
347
348
349 private void checkParameters() {
350 Objects.requireNonNull(updateRule, "Update rule");
351 Objects.requireNonNull(initialSimplex, "Initial simplex");
352
353 if (getLowerBound() != null ||
354 getUpperBound() != null) {
355 throw new MathUnsupportedOperationException(LocalizedFormats.CONSTRAINT);
356 }
357
358 if (populationSize < 0) {
359 throw new IllegalArgumentException("Population size");
360 }
361
362 additionalSearch = simulatedAnnealing == null ?
363 Math.max(0, populationSize) :
364 Math.max(1, populationSize);
365 }
366
367
368
369
370
371
372
373
374
375
376
377 private double temperature(PointValuePair p1,
378 PointValuePair p2,
379 double prob) {
380 return -Math.abs(p1.getValue() - p2.getValue()) / Math.log(prob);
381 }
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396 private static void keepIfBetter(PointValuePair candidate,
397 Comparator<PointValuePair> comp,
398 List<PointValuePair> list,
399 int max) {
400 final int listSize = list.size();
401 final double[] candidatePoint = candidate.getPoint();
402 if (listSize == 0) {
403 list.add(candidate);
404 } else if (listSize < max) {
405
406 for (PointValuePair p : list) {
407 final double[] pPoint = p.getPoint();
408 if (Arrays.equals(pPoint, candidatePoint)) {
409
410 return;
411 }
412 }
413
414 list.add(candidate);
415
416 if (list.size() == max) {
417 Collections.sort(list, comp);
418 }
419 } else {
420 final int last = max - 1;
421 if (comp.compare(candidate, list.get(last)) < 0) {
422 for (PointValuePair p : list) {
423 final double[] pPoint = p.getPoint();
424 if (Arrays.equals(pPoint, candidatePoint)) {
425
426 return;
427 }
428 }
429
430
431 list.set(last, candidate);
432 Collections.sort(list, comp);
433 }
434 }
435 }
436
437
438
439
440
441
442
443
444
445 private static double shortestDistance(PointValuePair point,
446 List<PointValuePair> list) {
447 double minDist = Double.POSITIVE_INFINITY;
448
449 final double[] p = point.getPoint();
450 for (PointValuePair other : list) {
451 final double[] pOther = other.getPoint();
452 if (!Arrays.equals(p, pOther)) {
453 final double dist = MathArrays.distance(p, pOther);
454 if (dist < minDist) {
455 minDist = dist;
456 }
457 }
458 }
459
460 return minDist;
461 }
462
463
464
465
466
467
468
469
470
471
472
473 private PointValuePair bestListSearch(MultivariateFunction evalFunc,
474 Comparator<PointValuePair> comp,
475 List<PointValuePair> bestList,
476 IntSupplier evalCount) {
477 PointValuePair best = bestList.get(0);
478
479
480
481 for (int i = 0; i < additionalSearch; i++) {
482 final PointValuePair start = bestList.get(i);
483
484 final double dist = shortestDistance(start, bestList);
485 final double[] init = start.getPoint();
486
487 final Simplex simplex = Simplex.equalSidesAlongAxes(init.length,
488 SIMPLEX_SIDE_RATIO * dist);
489
490 final PointValuePair r = directSearch(init,
491 simplex,
492 evalFunc,
493 getConvergenceChecker(),
494 getGoalType(),
495 callbacks,
496 evalCount);
497 if (comp.compare(r, best) < 0) {
498 best = r;
499 }
500 }
501
502 return best;
503 }
504
505
506
507
508
509
510
511
512
513
514
515
516
517 private static PointValuePair directSearch(double[] init,
518 Simplex simplex,
519 MultivariateFunction eval,
520 ConvergenceChecker<PointValuePair> checker,
521 GoalType goalType,
522 List<Observer> cbList,
523 final IntSupplier evalCount) {
524 final SimplexOptimizer optim = new SimplexOptimizer(checker);
525
526 for (Observer cOrig : cbList) {
527 final SimplexOptimizer.Observer cNew = (spx, isInit, numEval) ->
528 cOrig.update(spx, isInit, evalCount.getAsInt());
529
530 optim.addObserver(cNew);
531 }
532
533 return optim.optimize(MaxEval.unlimited(),
534 new ObjectiveFunction(eval),
535 goalType,
536 new InitialGuess(init),
537 simplex,
538 new MultiDirectionalTransform());
539 }
540
541
542
543
544
545
546
547 private void notifyObservers(Simplex simplex,
548 boolean isInit) {
549 for (Observer cb : callbacks) {
550 cb.update(simplex,
551 isInit,
552 getEvaluations());
553 }
554 }
555
556
557
558
559
560
561
562
563
564
565
566 private Simplex applyUpdate(UnaryOperator<Simplex> update,
567 Simplex simplex,
568 MultivariateFunction eval,
569 Comparator<PointValuePair> comp) {
570 final Simplex transformed = update.apply(simplex).evaluate(eval, comp);
571
572 notifyObservers(transformed, false);
573
574 return transformed;
575 }
576 }