1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;
18
19 import java.util.Arrays;
20 import java.util.List;
21 import java.util.ArrayList;
22 import java.io.PrintWriter;
23 import java.io.IOException;
24 import java.nio.file.Files;
25 import java.nio.file.Paths;
26 import java.nio.file.StandardOpenOption;
27 import org.junit.jupiter.api.Assertions;
28 import org.junit.jupiter.api.Test;
29 import org.junit.jupiter.api.extension.ParameterContext;
30 import org.junit.jupiter.params.ParameterizedTest;
31 import org.junit.jupiter.params.aggregator.ArgumentsAggregator;
32 import org.junit.jupiter.params.aggregator.ArgumentsAccessor;
33 import org.junit.jupiter.params.aggregator.ArgumentsAggregationException;
34 import org.junit.jupiter.params.aggregator.AggregateWith;
35 import org.junit.jupiter.params.provider.CsvFileSource;
36 import org.apache.commons.rng.UniformRandomProvider;
37 import org.apache.commons.rng.simple.RandomSource;
38 import org.apache.commons.rng.sampling.distribution.ContinuousUniformSampler;
39 import org.apache.commons.rng.sampling.UnitSphereSampler;
40 import org.apache.commons.math4.legacy.core.MathArrays;
41 import org.apache.commons.math4.legacy.exception.MathUnsupportedOperationException;
42 import org.apache.commons.math4.legacy.exception.TooManyEvaluationsException;
43 import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
44 import org.apache.commons.math4.legacy.optim.InitialGuess;
45 import org.apache.commons.math4.legacy.optim.MaxEval;
46 import org.apache.commons.math4.legacy.optim.PointValuePair;
47 import org.apache.commons.math4.legacy.optim.SimpleBounds;
48 import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType;
49 import org.apache.commons.math4.legacy.optim.nonlinear.scalar.ObjectiveFunction;
50 import org.apache.commons.math4.legacy.optim.nonlinear.scalar.SimulatedAnnealing;
51 import org.apache.commons.math4.legacy.optim.nonlinear.scalar.TestFunction;
52
53
54
55
56 public class SimplexOptimizerTest {
57 private static final String NELDER_MEAD_INPUT_FILE = "std_test_func.simplex.nelder_mead.csv";
58 private static final String MULTIDIRECTIONAL_INPUT_FILE = "std_test_func.simplex.multidirectional.csv";
59 private static final String HEDAR_FUKUSHIMA_INPUT_FILE = "std_test_func.simplex.hedar_fukushima.csv";
60
61 @Test
62 public void testMaxEvaluations() {
63 Assertions.assertThrows(TooManyEvaluationsException.class, () -> {
64 final int dim = 4;
65 final SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-3);
66 optimizer.optimize(new MaxEval(20),
67 new ObjectiveFunction(TestFunction.PARABOLA.withDimension(dim)),
68 GoalType.MINIMIZE,
69 new InitialGuess(new double[] { 3, -1, -3, 1 }),
70 Simplex.equalSidesAlongAxes(dim, 1d),
71 new NelderMeadTransform());
72 });
73 }
74
75 @Test
76 public void testBoundsUnsupported() {
77 Assertions.assertThrows(MathUnsupportedOperationException.class, () -> {
78 final int dim = 2;
79 final SimplexOptimizer optimizer = new SimplexOptimizer(1e-10, 1e-30);
80 optimizer.optimize(new MaxEval(100),
81 new ObjectiveFunction(TestFunction.PARABOLA.withDimension(dim)),
82 GoalType.MINIMIZE,
83 new InitialGuess(new double[] { -3, 0 }),
84 Simplex.alongAxes(new double[] { 0.2, 0.2 }),
85 new NelderMeadTransform(),
86 new SimpleBounds(new double[] { -5, -1 },
87 new double[] { 5, 1 }));
88 });
89 }
90
91 @ParameterizedTest
92 @CsvFileSource(resources = NELDER_MEAD_INPUT_FILE)
93 void testFunctionWithNelderMead(@AggregateWith(TaskAggregator.class) Task task) {
94
95 task.run(new NelderMeadTransform());
96 }
97
98 @ParameterizedTest
99 @CsvFileSource(resources = MULTIDIRECTIONAL_INPUT_FILE)
100 void testFunctionWithMultiDirectional(@AggregateWith(TaskAggregator.class) Task task) {
101 task.run(new MultiDirectionalTransform());
102 }
103
104 @ParameterizedTest
105 @CsvFileSource(resources = HEDAR_FUKUSHIMA_INPUT_FILE)
106 void testFunctionWithHedarFukushima(@AggregateWith(TaskAggregator.class) Task task) {
107 task.run(new HedarFukushimaTransform());
108 }
109
110
111
112
113 public static class Task {
114
115 private static final int FUNC_EVAL_DEBUG = 500000;
116
117 private static final double CONVERGENCE_CHECK = 1e-9;
118
119 private static final double SA_COOL_FACTOR = 0.7;
120
121 private static final double SA_START_PROB = 0.9;
122
123 private static final double SA_END_PROB = 1e-20;
124
125 private final MultivariateFunction function;
126
127 private final double[] start;
128
129 private final double[] optimum;
130
131 private final double pointTolerance;
132
133 private final int functionEvaluations;
134
135 private final double simplexSideLength;
136
137 private final boolean withSA;
138
139 private final String tracePrefix;
140
141 private final int[] traceIndices;
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158 Task(MultivariateFunction function,
159 double[] start,
160 double[] optimum,
161 double pointTolerance,
162 int functionEvaluations,
163 double simplexSideLength,
164 boolean withSA,
165 String tracePrefix,
166 int[] traceIndices) {
167 this.function = function;
168 this.start = start;
169 this.optimum = optimum;
170 this.pointTolerance = pointTolerance;
171 this.functionEvaluations = functionEvaluations;
172 this.simplexSideLength = simplexSideLength;
173 this.withSA = withSA;
174 this.tracePrefix = tracePrefix;
175 this.traceIndices = traceIndices;
176 }
177
178 @Override
179 public String toString() {
180 return function.toString();
181 }
182
183
184
185
186 void run(Simplex.TransformFactory factory) {
187
188
189
190
191 final int maxEval = Math.max(functionEvaluations, FUNC_EVAL_DEBUG);
192
193 final String name = function.toString();
194 final int dim = start.length;
195
196 final SimulatedAnnealing sa;
197 if (withSA) {
198 final SimulatedAnnealing.CoolingSchedule coolSched =
199 SimulatedAnnealing.CoolingSchedule.decreasingExponential(SA_COOL_FACTOR);
200
201 sa = new SimulatedAnnealing(dim,
202 SA_START_PROB,
203 SA_END_PROB,
204 coolSched,
205 RandomSource.KISS.create());
206 } else {
207 sa = null;
208 }
209
210 final SimplexOptimizer optim = new SimplexOptimizer(-1, CONVERGENCE_CHECK);
211 if (tracePrefix != null) {
212 optim.addObserver(createCallback(factory));
213 }
214
215 final Simplex initialSimplex = Simplex.equalSidesAlongAxes(dim, simplexSideLength);
216 final PointValuePair result =
217 optim.optimize(new MaxEval(maxEval),
218 new ObjectiveFunction(function),
219 GoalType.MINIMIZE,
220 new InitialGuess(start),
221 initialSimplex,
222 factory,
223 sa);
224
225 final double[] endPoint = result.getPoint();
226 final double funcValue = result.getValue();
227 final double dist = MathArrays.distance(optimum, endPoint);
228 Assertions.assertEquals(0d, dist, pointTolerance,
229 () -> name + ": distance to optimum" +
230 " f(" + Arrays.toString(endPoint) + ")=" +
231 funcValue);
232
233 final int nEval = optim.getEvaluations();
234 Assertions.assertTrue(nEval < functionEvaluations,
235 () -> name + ": nEval=" + nEval + " < " + functionEvaluations);
236 }
237
238
239
240
241
242 private SimplexOptimizer.Observer createCallback(Simplex.TransformFactory factory) {
243 if (tracePrefix == null) {
244 throw new IllegalArgumentException("Missing file prefix");
245 }
246
247 final String sep = "__";
248 final String name = tracePrefix + sanitizeBasename(function + sep +
249 Arrays.toString(start) + sep +
250 factory + sep);
251
252
253 try (PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(name)))) {
254 out.println("# Function: " + function);
255 out.println("# Transform: " + factory);
256 out.println("#");
257
258 out.println("# Optimum");
259 for (double c : optimum) {
260 out.print(c + " ");
261 }
262 out.println();
263 out.println();
264
265 out.println("#");
266 out.print("# <1: evaluations> <2: f(x)> <3: |f(x) - f(optimum)|>");
267 for (int i = 0; i < start.length; i++) {
268 out.print(" <" + (i + 4) + ": x[" + i + "]>");
269 }
270 out.println();
271 } catch (IOException e) {
272 Assertions.fail(e.getMessage());
273 }
274
275 final double fAtOptimum = function.value(optimum);
276
277
278 return (simplex, isInit, numEval) -> {
279 try (PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(name),
280 StandardOpenOption.APPEND))) {
281 if (isInit) {
282
283
284 out.println();
285 out.println("# [init]");
286 }
287
288 final String fieldSep = " ";
289
290 final List<PointValuePair> points = simplex.asList();
291 for (int index : traceIndices) {
292 final PointValuePair p = points.get(index);
293 out.print(numEval + fieldSep +
294 p.getValue() + fieldSep +
295 Math.abs(p.getValue() - fAtOptimum) + fieldSep);
296
297 final double[] coord = p.getPoint();
298 for (int i = 0; i < coord.length; i++) {
299 out.print(coord[i] + fieldSep);
300 }
301 out.println();
302 }
303
304 out.println();
305 } catch (IOException e) {
306 Assertions.fail(e.getMessage());
307 }
308 };
309 }
310
311
312
313
314
315
316
317 public void checkAlongLine(int numPoints) {
318 if (tracePrefix != null) {
319 final String name = tracePrefix + createPlotBasename(function, start, optimum);
320 try (PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(name)))) {
321 checkAlongLine(numPoints, out);
322 } catch (IOException e) {
323 Assertions.fail(e.getMessage());
324 }
325 } else {
326 checkAlongLine(numPoints, null);
327 }
328 }
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347 private void checkAlongLine(int numPoints,
348 PrintWriter output) {
349 final double delta = 1d / numPoints;
350
351 final int dim = start.length;
352 final double[] dir = new double[dim];
353 for (int i = 0; i < dim; i++) {
354 dir[i] = optimum[i] - start[i];
355 }
356
357 double[] minPoint = null;
358 double minValue = Double.POSITIVE_INFINITY;
359 int count = 0;
360 while (count <= numPoints) {
361 final double[] p = new double[dim];
362 final double t = count * delta;
363 for (int i = 0; i < dim; i++) {
364 p[i] = start[i] + t * dir[i];
365 }
366
367 final double value = function.value(p);
368 if (value <= minValue) {
369 minValue = value;
370 minPoint = p;
371 }
372
373 if (output != null) {
374 output.println(t + " " + value);
375 }
376
377 ++count;
378 }
379
380 final double tol = 1e-15;
381 final double[] point = minPoint;
382 final double value = minValue;
383 Assertions.assertArrayEquals(optimum, minPoint, tol,
384 () -> "Minimum: f(" + Arrays.toString(point) + ")=" + value);
385 }
386
387
388
389
390
391
392
393
394
395 private static String createPlotBasename(MultivariateFunction f,
396 double[] start,
397 double[] end) {
398 final String s = f.toString() + "__" +
399 Arrays.toString(start) + "__" +
400 Arrays.toString(end);
401
402 return sanitizeBasename(s) + ".dat";
403 }
404
405
406
407
408
409
410
411
412
413 private static String sanitizeBasename(String str) {
414 final String repl = "_";
415 return str
416 .replaceAll("\\(", "")
417 .replaceAll("\\)", "")
418 .replaceAll("\\[", "")
419 .replaceAll("\\]", "")
420 .replaceAll("=", repl)
421 .replaceAll(",\\s+", repl)
422 .replaceAll(",", repl)
423 .replaceAll("\\s", repl)
424 .replaceAll("/", repl)
425 .replaceAll("^_+", "")
426 .replaceAll("_+$", "");
427 }
428 }
429
430
431
432
433 public static class TaskAggregator implements ArgumentsAggregator {
434 @Override
435 public Object aggregateArguments(ArgumentsAccessor a,
436 ParameterContext context)
437 throws ArgumentsAggregationException {
438
439 int index = 0;
440
441 final TestFunction funcGen = a.get(index++, TestFunction.class);
442 final int dim = a.getInteger(index++);
443 final double[] optimum = toArrayOfDoubles(a.getString(index++), dim);
444 final double minRadius = a.getDouble(index++);
445 final double maxRadius = a.getDouble(index++);
446 if (minRadius < 0 ||
447 maxRadius < 0 ||
448 minRadius >= maxRadius) {
449 throw new ArgumentsAggregationException("radii");
450 }
451 final double pointTol = a.getDouble(index++);
452 final int funcEval = a.getInteger(index++);
453 final boolean withSA = a.getBoolean(index++);
454
455
456 final UniformRandomProvider rng = OptimTestUtils.rng();
457 final double radius = ContinuousUniformSampler.of(rng, minRadius, maxRadius).sample();
458 final double[] start = UnitSphereSampler.of(rng, dim).sample();
459 for (int i = 0; i < dim; i++) {
460 start[i] *= radius;
461 start[i] += optimum[i];
462 }
463
464 final double sideLength = 0.5 * (maxRadius - minRadius);
465
466 if (index == a.size()) {
467
468 return new Task(funcGen.withDimension(dim),
469 start,
470 optimum,
471 pointTol,
472 funcEval,
473 sideLength,
474 withSA,
475 null,
476 null);
477 } else {
478
479 final String tracePrefix = a.getString(index++);
480 final int[] spxIndices = tracePrefix == null ?
481 null :
482 toSimplexIndices(a.getString(index++), dim);
483
484 return new Task(funcGen.withDimension(dim),
485 start,
486 optimum,
487 pointTol,
488 funcEval,
489 sideLength,
490 withSA,
491 tracePrefix,
492 spxIndices);
493 }
494 }
495
496
497
498
499
500
501
502
503
504
505
506
507 private static int[] toSimplexIndices(String str,
508 int dim) {
509 final List<Integer> list = new ArrayList<>();
510
511 if (str == null ||
512 str.isEmpty()) {
513 for (int i = 0; i <= dim; i++) {
514 list.add(i);
515 }
516 } else {
517 for (String s : str.split("\\s+")) {
518 if (s.equals("LAST")) {
519 list.add(dim);
520 } else if (str.equals("ALL")) {
521 for (int i = 0; i <= dim; i++) {
522 list.add(i);
523 }
524 } else {
525 final int index = Integer.valueOf(s);
526 if (index < 0 ||
527 index > dim) {
528 throw new IllegalArgumentException("index: " + index +
529 " (dim=" + dim + ")");
530 }
531 list.add(index);
532 }
533 }
534 }
535
536 final int len = list.size();
537 final int[] indices = new int[len];
538 for (int i = 0; i < len; i++) {
539 indices[i] = list.get(i);
540 }
541
542 return indices;
543 }
544
545
546
547
548
549
550
551
552 private static double[] toArrayOfDoubles(String params,
553 int dim) {
554 final String[] s = params.trim().split("\\s+");
555
556 if (s.length != dim) {
557 final String msg = "Expected " + dim + " values: " + Arrays.toString(s);
558 throw new ArgumentsAggregationException(msg);
559 }
560
561 final double[] p = new double[dim];
562 for (int i = 0; i < dim; i++) {
563 p[i] = Double.valueOf(s[i]);
564 }
565
566 return p;
567 }
568 }
569 }