1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.fitting.leastsquares;
18
19 import org.apache.commons.geometry.euclidean.twod.Vector2D;
20 import org.apache.commons.math4.legacy.analysis.MultivariateMatrixFunction;
21 import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction;
22 import org.apache.commons.math4.legacy.exception.ConvergenceException;
23 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
24 import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
25 import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem.Evaluation;
26 import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
27 import org.apache.commons.math4.legacy.linear.ArrayRealVector;
28 import org.apache.commons.math4.legacy.linear.BlockRealMatrix;
29 import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
30 import org.apache.commons.math4.legacy.linear.RealMatrix;
31 import org.apache.commons.math4.legacy.linear.RealVector;
32 import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
33 import org.apache.commons.math4.legacy.optim.SimpleVectorValueChecker;
34 import org.apache.commons.math4.core.jdkmath.JdkMath;
35 import org.apache.commons.math4.legacy.core.Pair;
36 import org.junit.Assert;
37 import org.junit.Test;
38
39 import java.io.IOException;
40 import java.util.Arrays;
41
42
43
44
45
46
47
48
49
50
51 public abstract class AbstractLeastSquaresOptimizerAbstractTest {
52
53
54 public static final double TOL = 1e-10;
55
56 public LeastSquaresBuilder base() {
57 return new LeastSquaresBuilder()
58 .checkerPair(new SimpleVectorValueChecker(1e-6, 1e-6))
59 .maxEvaluations(100)
60 .maxIterations(getMaxIterations());
61 }
62
63 public LeastSquaresBuilder builder(CircleVectorial c) {
64 final double[] weights = new double[c.getN()];
65 Arrays.fill(weights, 1.0);
66 return base()
67 .model(c.getModelFunction(), c.getModelFunctionJacobian())
68 .target(new double[c.getN()])
69 .weight(new DiagonalMatrix(weights));
70 }
71
72 public LeastSquaresBuilder builder(StatisticalReferenceDataset dataset) {
73 StatisticalReferenceDataset.LeastSquaresProblem problem
74 = dataset.getLeastSquaresProblem();
75 final double[] weights = new double[dataset.getNumObservations()];
76 Arrays.fill(weights, 1.0);
77 return base()
78 .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
79 .target(dataset.getData()[1])
80 .weight(new DiagonalMatrix(weights))
81 .start(dataset.getStartingPoint(0));
82 }
83
84 public void fail(LeastSquaresOptimizer optimizer) {
85 Assert.fail("Expected Exception from: " + optimizer.toString());
86 }
87
88
89
90
91
92
93
94 public void assertEquals(double tolerance, RealVector actual, double... expected){
95 for (int i = 0; i < expected.length; i++) {
96 Assert.assertEquals(expected[i], actual.getEntry(i), tolerance);
97 }
98 Assert.assertEquals(expected.length, actual.getDimension());
99 }
100
101
102
103
104
105 public abstract int getMaxIterations();
106
107
108
109
110
111
112 public abstract LeastSquaresOptimizer getOptimizer();
113
114
115
116
117 protected final LeastSquaresOptimizer optimizer = this.getOptimizer();
118
119 @Test
120 public void testGetIterations() {
121 LeastSquaresProblem lsp = base()
122 .target(new double[]{1})
123 .weight(new DiagonalMatrix(new double[]{1}))
124 .start(new double[]{3})
125 .model(new MultivariateJacobianFunction() {
126 @Override
127 public Pair<RealVector, RealMatrix> value(final RealVector point) {
128 return new Pair<>(
129 new ArrayRealVector(
130 new double[]{
131 JdkMath.pow(point.getEntry(0), 4)
132 },
133 false),
134 new Array2DRowRealMatrix(
135 new double[][]{
136 {0.25 * JdkMath.pow(point.getEntry(0), 3)}
137 },
138 false)
139 );
140 }
141 })
142 .build();
143
144 Optimum optimum = optimizer.optimize(lsp);
145
146
147 Assert.assertTrue(optimum.getIterations() > 0);
148 }
149
150 @Test
151 public void testTrivial() {
152 LinearProblem problem
153 = new LinearProblem(new double[][]{{2}},
154 new double[]{3});
155 LeastSquaresProblem ls = problem.getBuilder().build();
156
157 Optimum optimum = optimizer.optimize(ls);
158
159 Assert.assertEquals(0, optimum.getRMS(), TOL);
160 assertEquals(TOL, optimum.getPoint(), 1.5);
161 Assert.assertEquals(0.0, optimum.getResiduals().getEntry(0), TOL);
162 }
163
164 @Test
165 public void testQRColumnsPermutation() {
166 LinearProblem problem
167 = new LinearProblem(new double[][]{{1, -1}, {0, 2}, {1, -2}},
168 new double[]{4, 6, 1});
169
170 Optimum optimum = optimizer.optimize(problem.getBuilder().build());
171
172 Assert.assertEquals(0, optimum.getRMS(), TOL);
173 assertEquals(TOL, optimum.getPoint(), 7, 3);
174 assertEquals(TOL, optimum.getResiduals(), 0, 0, 0);
175 }
176
177 @Test
178 public void testNoDependency() {
179 LinearProblem problem = new LinearProblem(new double[][]{
180 {2, 0, 0, 0, 0, 0},
181 {0, 2, 0, 0, 0, 0},
182 {0, 0, 2, 0, 0, 0},
183 {0, 0, 0, 2, 0, 0},
184 {0, 0, 0, 0, 2, 0},
185 {0, 0, 0, 0, 0, 2}
186 }, new double[]{0, 1.1, 2.2, 3.3, 4.4, 5.5});
187
188 Optimum optimum = optimizer.optimize(problem.getBuilder().build());
189
190 Assert.assertEquals(0, optimum.getRMS(), TOL);
191 for (int i = 0; i < problem.target.length; ++i) {
192 Assert.assertEquals(0.55 * i, optimum.getPoint().getEntry(i), TOL);
193 }
194 }
195
196 @Test
197 public void testOneSet() {
198 LinearProblem problem = new LinearProblem(new double[][]{
199 {1, 0, 0},
200 {-1, 1, 0},
201 {0, -1, 1}
202 }, new double[]{1, 1, 1});
203
204 Optimum optimum = optimizer.optimize(problem.getBuilder().build());
205
206 Assert.assertEquals(0, optimum.getRMS(), TOL);
207 assertEquals(TOL, optimum.getPoint(), 1, 2, 3);
208 }
209
210 @Test
211 public void testTwoSets() {
212 double epsilon = 1e-7;
213 LinearProblem problem = new LinearProblem(new double[][]{
214 {2, 1, 0, 4, 0, 0},
215 {-4, -2, 3, -7, 0, 0},
216 {4, 1, -2, 8, 0, 0},
217 {0, -3, -12, -1, 0, 0},
218 {0, 0, 0, 0, epsilon, 1},
219 {0, 0, 0, 0, 1, 1}
220 }, new double[]{2, -9, 2, 2, 1 + epsilon * epsilon, 2});
221
222 Optimum optimum = optimizer.optimize(problem.getBuilder().build());
223
224 Assert.assertEquals(0, optimum.getRMS(), TOL);
225 assertEquals(TOL, optimum.getPoint(), 3, 4, -1, -2, 1 + epsilon, 1 - epsilon);
226 }
227
228 @Test
229 public void testNonInvertible() throws Exception {
230 try {
231 LinearProblem problem = new LinearProblem(new double[][]{
232 {1, 2, -3},
233 {2, 1, 3},
234 {-3, 0, -9}
235 }, new double[]{1, 1, 1});
236
237 optimizer.optimize(problem.getBuilder().build());
238
239 fail(optimizer);
240 } catch (ConvergenceException e) {
241
242 }
243 }
244
245 @Test
246 public void testIllConditioned() {
247 LinearProblem problem1 = new LinearProblem(new double[][]{
248 {10, 7, 8, 7},
249 {7, 5, 6, 5},
250 {8, 6, 10, 9},
251 {7, 5, 9, 10}
252 }, new double[]{32, 23, 33, 31});
253 final double[] start = {0, 1, 2, 3};
254
255 Optimum optimum = optimizer
256 .optimize(problem1.getBuilder().start(start).build());
257
258 Assert.assertEquals(0, optimum.getRMS(), TOL);
259 assertEquals(TOL, optimum.getPoint(), 1, 1, 1, 1);
260
261 LinearProblem problem2 = new LinearProblem(new double[][]{
262 {10.00, 7.00, 8.10, 7.20},
263 {7.08, 5.04, 6.00, 5.00},
264 {8.00, 5.98, 9.89, 9.00},
265 {6.99, 4.99, 9.00, 9.98}
266 }, new double[]{32, 23, 33, 31});
267
268 optimum = optimizer.optimize(problem2.getBuilder().start(start).build());
269
270 Assert.assertEquals(0, optimum.getRMS(), TOL);
271 assertEquals(1e-8, optimum.getPoint(), -81, 137, -34, 22);
272 }
273
274 @Test
275 public void testMoreEstimatedParametersSimple() {
276 LinearProblem problem = new LinearProblem(new double[][]{
277 {3, 2, 0, 0},
278 {0, 1, -1, 1},
279 {2, 0, 1, 0}
280 }, new double[]{7, 3, 5});
281
282 Optimum optimum = optimizer
283 .optimize(problem.getBuilder().start(new double[]{7, 6, 5, 4}).build());
284
285 Assert.assertEquals(0, optimum.getRMS(), TOL);
286 }
287
288 @Test
289 public void testMoreEstimatedParametersUnsorted() {
290 LinearProblem problem = new LinearProblem(new double[][]{
291 {1, 1, 0, 0, 0, 0},
292 {0, 0, 1, 1, 1, 0},
293 {0, 0, 0, 0, 1, -1},
294 {0, 0, -1, 1, 0, 1},
295 {0, 0, 0, -1, 1, 0}
296 }, new double[]{3, 12, -1, 7, 1});
297
298 Optimum optimum = optimizer.optimize(
299 problem.getBuilder().start(new double[]{2, 2, 2, 2, 2, 2}).build());
300
301 Assert.assertEquals(0, optimum.getRMS(), TOL);
302 RealVector point = optimum.getPoint();
303
304
305 Assert.assertEquals(3, point.getEntry(0) + point.getEntry(1), TOL);
306
307 assertEquals(TOL, point.getSubVector(2, 4), 3, 4, 5, 6);
308 }
309
310 @Test
311 public void testRedundantEquations() {
312 LinearProblem problem = new LinearProblem(new double[][]{
313 {1, 1},
314 {1, -1},
315 {1, 3}
316 }, new double[]{3, 1, 5});
317
318 Optimum optimum = optimizer
319 .optimize(problem.getBuilder().start(new double[]{1, 1}).build());
320
321 Assert.assertEquals(0, optimum.getRMS(), TOL);
322 assertEquals(TOL, optimum.getPoint(), 2, 1);
323 }
324
325 @Test
326 public void testInconsistentEquations() {
327 LinearProblem problem = new LinearProblem(new double[][]{
328 {1, 1},
329 {1, -1},
330 {1, 3}
331 }, new double[]{3, 1, 4});
332
333 Optimum optimum = optimizer
334 .optimize(problem.getBuilder().start(new double[]{1, 1}).build());
335
336
337 Assert.assertTrue(optimum.getRMS() > 0.1);
338 }
339
340 @Test
341 public void testInconsistentSizes1() {
342 try {
343 LinearProblem problem
344 = new LinearProblem(new double[][]{{1, 0},
345 {0, 1}},
346 new double[]{-1, 1});
347
348
349 Optimum optimum = optimizer.optimize(problem.getBuilder().build());
350
351 Assert.assertEquals(0, optimum.getRMS(), TOL);
352 assertEquals(TOL, optimum.getPoint(), -1, 1);
353
354
355 optimizer.optimize(
356 problem.getBuilder().weight(new DiagonalMatrix(new double[]{1})).build());
357
358 fail(optimizer);
359 } catch (DimensionMismatchException e) {
360
361 }
362 }
363
364 @Test
365 public void testInconsistentSizes2() {
366 try {
367 LinearProblem problem
368 = new LinearProblem(new double[][]{{1, 0}, {0, 1}},
369 new double[]{-1, 1});
370
371 Optimum optimum = optimizer.optimize(problem.getBuilder().build());
372
373 Assert.assertEquals(0, optimum.getRMS(), TOL);
374 assertEquals(TOL, optimum.getPoint(), -1, 1);
375
376
377 optimizer.optimize(
378 problem.getBuilder()
379 .target(new double[]{1})
380 .weight(new DiagonalMatrix(new double[]{1}))
381 .build()
382 );
383
384 fail(optimizer);
385 } catch (DimensionMismatchException e) {
386
387 }
388 }
389
390 @Test
391 public void testCircleFitting() {
392 CircleVectorial circle = new CircleVectorial();
393 circle.addPoint(30, 68);
394 circle.addPoint(50, -6);
395 circle.addPoint(110, -20);
396 circle.addPoint(35, 15);
397 circle.addPoint(45, 97);
398 final double[] start = {98.680, 47.345};
399
400 Optimum optimum = optimizer.optimize(builder(circle).start(start).build());
401
402 Assert.assertTrue(optimum.getEvaluations() < 10);
403
404 double rms = optimum.getRMS();
405 Assert.assertEquals(1.768262623567235, JdkMath.sqrt(circle.getN()) * rms, TOL);
406
407 Vector2D center = Vector2D.of(optimum.getPoint().getEntry(0), optimum.getPoint().getEntry(1));
408 Assert.assertEquals(69.96016176931406, circle.getRadius(center), 1e-6);
409 Assert.assertEquals(96.07590211815305, center.getX(), 1e-6);
410 Assert.assertEquals(48.13516790438953, center.getY(), 1e-6);
411
412 double[][] cov = optimum.getCovariances(1e-14).getData();
413 Assert.assertEquals(1.839, cov[0][0], 0.001);
414 Assert.assertEquals(0.731, cov[0][1], 0.001);
415 Assert.assertEquals(cov[0][1], cov[1][0], 1e-14);
416 Assert.assertEquals(0.786, cov[1][1], 0.001);
417
418
419 double r = circle.getRadius(center);
420 for (double d = 0; d < 2 * JdkMath.PI; d += 0.01) {
421 circle.addPoint(center.getX() + r * JdkMath.cos(d), center.getY() + r * JdkMath.sin(d));
422 }
423
424 double[] weights = new double[circle.getN()];
425 Arrays.fill(weights, 2);
426
427 optimum = optimizer.optimize(
428 builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
429
430 cov = optimum.getCovariances(1e-14).getData();
431 Assert.assertEquals(0.0016, cov[0][0], 0.001);
432 Assert.assertEquals(3.2e-7, cov[0][1], 1e-9);
433 Assert.assertEquals(cov[0][1], cov[1][0], 1e-14);
434 Assert.assertEquals(0.0016, cov[1][1], 0.001);
435 }
436
437 @Test
438 public void testCircleFittingBadInit() {
439 CircleVectorial circle = new CircleVectorial();
440 double[][] points = circlePoints;
441 double[] weights = new double[points.length];
442 final double[] start = {-12, -12};
443 Arrays.fill(weights, 2);
444 for (int i = 0; i < points.length; ++i) {
445 circle.addPoint(points[i][0], points[i][1]);
446 }
447
448 Optimum optimum = optimizer.optimize(builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
449
450 Vector2D center = Vector2D.of(optimum.getPoint().getEntry(0), optimum.getPoint().getEntry(1));
451 Assert.assertTrue(optimum.getEvaluations() < 25);
452 Assert.assertEquals(0.043, optimum.getRMS(), 1e-3);
453 Assert.assertEquals(0.292235, circle.getRadius(center), 1e-6);
454 Assert.assertEquals(-0.151738, center.getX(), 1e-6);
455 Assert.assertEquals(0.2075001, center.getY(), 1e-6);
456 }
457
458 @Test
459 public void testCircleFittingGoodInit() {
460 CircleVectorial circle = new CircleVectorial();
461 double[][] points = circlePoints;
462 double[] weights = new double[points.length];
463 Arrays.fill(weights, 2);
464 for (int i = 0; i < points.length; ++i) {
465 circle.addPoint(points[i][0], points[i][1]);
466 }
467 final double[] start = {0, 0};
468
469 Optimum optimum = optimizer.optimize(
470 builder(circle).weight(new DiagonalMatrix(weights)).start(start).build());
471
472 assertEquals(1e-6, optimum.getPoint(), -0.1517383071957963, 0.2074999736353867);
473 Assert.assertEquals(0.04268731682389561, optimum.getRMS(), 1e-8);
474 }
475
476 private final double[][] circlePoints = new double[][]{
477 {-0.312967, 0.072366}, {-0.339248, 0.132965}, {-0.379780, 0.202724},
478 {-0.390426, 0.260487}, {-0.361212, 0.328325}, {-0.346039, 0.392619},
479 {-0.280579, 0.444306}, {-0.216035, 0.470009}, {-0.149127, 0.493832},
480 {-0.075133, 0.483271}, {-0.007759, 0.452680}, {0.060071, 0.410235},
481 {0.103037, 0.341076}, {0.118438, 0.273884}, {0.131293, 0.192201},
482 {0.115869, 0.129797}, {0.072223, 0.058396}, {0.022884, 0.000718},
483 {-0.053355, -0.020405}, {-0.123584, -0.032451}, {-0.216248, -0.032862},
484 {-0.278592, -0.005008}, {-0.337655, 0.056658}, {-0.385899, 0.112526},
485 {-0.405517, 0.186957}, {-0.415374, 0.262071}, {-0.387482, 0.343398},
486 {-0.347322, 0.397943}, {-0.287623, 0.458425}, {-0.223502, 0.475513},
487 {-0.135352, 0.478186}, {-0.061221, 0.483371}, {0.003711, 0.422737},
488 {0.065054, 0.375830}, {0.108108, 0.297099}, {0.123882, 0.222850},
489 {0.117729, 0.134382}, {0.085195, 0.056820}, {0.029800, -0.019138},
490 {-0.027520, -0.072374}, {-0.102268, -0.091555}, {-0.200299, -0.106578},
491 {-0.292731, -0.091473}, {-0.356288, -0.051108}, {-0.420561, 0.014926},
492 {-0.471036, 0.074716}, {-0.488638, 0.182508}, {-0.485990, 0.254068},
493 {-0.463943, 0.338438}, {-0.406453, 0.404704}, {-0.334287, 0.466119},
494 {-0.254244, 0.503188}, {-0.161548, 0.495769}, {-0.075733, 0.495560},
495 {0.001375, 0.434937}, {0.082787, 0.385806}, {0.115490, 0.323807},
496 {0.141089, 0.223450}, {0.138693, 0.131703}, {0.126415, 0.049174},
497 {0.066518, -0.010217}, {-0.005184, -0.070647}, {-0.080985, -0.103635},
498 {-0.177377, -0.116887}, {-0.260628, -0.100258}, {-0.335756, -0.056251},
499 {-0.405195, -0.000895}, {-0.444937, 0.085456}, {-0.484357, 0.175597},
500 {-0.472453, 0.248681}, {-0.438580, 0.347463}, {-0.402304, 0.422428},
501 {-0.326777, 0.479438}, {-0.247797, 0.505581}, {-0.152676, 0.519380},
502 {-0.071754, 0.516264}, {0.015942, 0.472802}, {0.076608, 0.419077},
503 {0.127673, 0.330264}, {0.159951, 0.262150}, {0.153530, 0.172681},
504 {0.140653, 0.089229}, {0.078666, 0.024981}, {0.023807, -0.037022},
505 {-0.048837, -0.077056}, {-0.127729, -0.075338}, {-0.221271, -0.067526}
506 };
507
508 public void doTestStRD(final StatisticalReferenceDataset dataset,
509 final LeastSquaresOptimizer optimizer,
510 final double errParams,
511 final double errParamsSd) {
512
513 final Optimum optimum = optimizer.optimize(builder(dataset).build());
514
515 final RealVector actual = optimum.getPoint();
516 for (int i = 0; i < actual.getDimension(); i++) {
517 double expected = dataset.getParameter(i);
518 double delta = JdkMath.abs(errParams * expected);
519 Assert.assertEquals(dataset.getName() + ", param #" + i,
520 expected, actual.getEntry(i), delta);
521 }
522 }
523
524 @Test
525 public void testKirby2() throws IOException {
526 doTestStRD(StatisticalReferenceDatasetFactory.createKirby2(), optimizer, 1E-7, 1E-7);
527 }
528
529 @Test
530 public void testHahn1() throws IOException {
531 doTestStRD(StatisticalReferenceDatasetFactory.createHahn1(), optimizer, 1E-7, 1E-4);
532 }
533
534 @Test
535 public void testPointCopy() {
536 LinearProblem problem = new LinearProblem(new double[][]{
537 {1, 0, 0},
538 {-1, 1, 0},
539 {0, -1, 1}
540 }, new double[]{1, 1, 1});
541
542 final boolean[] checked = {false};
543
544 final LeastSquaresBuilder builder = problem.getBuilder()
545 .checker(new ConvergenceChecker<Evaluation>() {
546 @Override
547 public boolean converged(int iteration, Evaluation previous, Evaluation current) {
548 Assert.assertFalse(previous.getPoint().equals(current.getPoint()));
549 Assert.assertArrayEquals(new double[3], previous.getPoint().toArray(), 0);
550 Assert.assertArrayEquals(new double[] {1, 2, 3}, current.getPoint().toArray(), TOL);
551 checked[0] = true;
552 return true;
553 }
554 });
555 optimizer.optimize(builder.build());
556
557 Assert.assertTrue(checked[0]);
558 }
559
560 class LinearProblem {
561 private final RealMatrix factors;
562 private final double[] target;
563
564 LinearProblem(double[][] factors, double[] target) {
565 this.factors = new BlockRealMatrix(factors);
566 this.target = target;
567 }
568
569 public double[] getTarget() {
570 return target;
571 }
572
573 public MultivariateVectorFunction getModelFunction() {
574 return new MultivariateVectorFunction() {
575 @Override
576 public double[] value(double[] params) {
577 return factors.operate(params);
578 }
579 };
580 }
581
582 public MultivariateMatrixFunction getModelFunctionJacobian() {
583 return new MultivariateMatrixFunction() {
584 @Override
585 public double[][] value(double[] params) {
586 return factors.getData();
587 }
588 };
589 }
590
591 public LeastSquaresBuilder getBuilder() {
592 final double[] weights = new double[target.length];
593 Arrays.fill(weights, 1.0);
594 return base()
595 .model(getModelFunction(), getModelFunctionJacobian())
596 .target(target)
597 .weight(new DiagonalMatrix(weights))
598 .start(new double[factors.getColumnDimension()]);
599 }
600 }
601 }