1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.statistics.distribution;
19
20 import java.lang.reflect.Array;
21 import java.text.DecimalFormat;
22 import java.util.function.Supplier;
23 import org.apache.commons.math3.stat.inference.ChiSquareTest;
24 import org.junit.jupiter.api.Assertions;
25
26
27
28
29 final class TestUtils {
30
31
32
33 private static final double ULP_THRESHOLD = 100 * Math.ulp(1.0);
34
35
36
37
38
39 private static final String EXPECTED_FORMAT = "expected: <";
40
41
42
43
44
45
46
47 private static final String ACTUAL_FORMAT = ">, actual: <";
48
49
50
51
52
53
54
55 private static final String RELATIVE_ERROR_FORMAT = ">, rel.error: <";
56
57
58
59
60
61
62
63 private static final String ABSOLUTE_ERROR_FORMAT = ">, abs.error: <";
64
65
66
67
68
69
70
71 private static final String ULP_ERROR_FORMAT = ">, ulp error: <";
72
73
74
75
76 private TestUtils() {}
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91 static void assertEquals(double expected, double actual, DoubleTolerance tolerance) {
92 assertEquals(expected, actual, tolerance, (String) null);
93 }
94
95
96
97
98
99
100
101
102
103
104
105
106
107 static void assertEquals(double expected, double actual, DoubleTolerance tolerance, String message) {
108 if (!tolerance.test(expected, actual)) {
109 throw new AssertionError(format(expected, actual, tolerance, message));
110 }
111 }
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126 static void assertEquals(double expected, double actual, DoubleTolerance tolerance,
127 Supplier<String> messageSupplier) {
128 if (!tolerance.test(expected, actual)) {
129 throw new AssertionError(
130 format(expected, actual, tolerance, messageSupplier == null ? null : messageSupplier.get()));
131 }
132 }
133
134
135
136
137
138
139
140
141
142
143 private static String format(double expected, double actual, DoubleTolerance tolerance, String message) {
144 return buildPrefix(message) + formatValues(expected, actual, tolerance);
145 }
146
147
148
149
150
151
152
153 private static String buildPrefix(String message) {
154 return StringUtils.isNotEmpty(message) ? message + " ==> " : "";
155 }
156
157
158
159
160
161
162
163
164
165 private static String formatValues(double expected, double actual, DoubleTolerance tolerance) {
166
167 final double diff = Math.abs(expected - actual);
168 final double rel = diff / Math.max(Math.abs(expected), Math.abs(actual));
169 final StringBuilder msg = new StringBuilder(EXPECTED_FORMAT).append(expected).append(ACTUAL_FORMAT)
170 .append(actual).append(RELATIVE_ERROR_FORMAT).append(rel);
171 if (rel < ULP_THRESHOLD) {
172 final long ulp = Math.abs(Double.doubleToRawLongBits(expected) - Double.doubleToRawLongBits(actual));
173 msg.append(ULP_ERROR_FORMAT).append(ulp);
174 } else {
175 msg.append(ABSOLUTE_ERROR_FORMAT).append(diff);
176 }
177 msg.append('>');
178 appendTolerance(msg, tolerance);
179 return msg.toString();
180 }
181
182
183
184
185
186
187
188 private static void appendTolerance(final StringBuilder msg, final Object tolerance) {
189 final String description = StringUtils.toString(tolerance);
190 if (StringUtils.isNotEmpty(description)) {
191 msg.append(", tolerance: ").append(description);
192 }
193 }
194
195
196
197
198
199
200
201
202
203
204
205
206
207 static void assertRelativelyEquals(Supplier<String> msg,
208 double expected,
209 double actual,
210 double relativeError) {
211 if (Double.isNaN(expected)) {
212 Assertions.assertTrue(Double.isNaN(actual), msg);
213 } else if (Double.isNaN(actual)) {
214 Assertions.assertTrue(Double.isNaN(expected), msg);
215 } else if (Double.isInfinite(actual) || Double.isInfinite(expected)) {
216 Assertions.assertEquals(expected, actual, relativeError);
217 } else if (expected == 0.0) {
218 Assertions.assertEquals(actual, expected, relativeError, msg);
219 } else {
220 final double absError = Math.abs(expected) * relativeError;
221 Assertions.assertEquals(expected, actual, absError, msg);
222 }
223 }
224
225
226
227
228
229
230
231
232
233
234 private static void assertChiSquare(int[] valueLabels,
235 double[] expected,
236 long[] observed,
237 double alpha) {
238 final ChiSquareTest chiSquareTest = new ChiSquareTest();
239
240
241 if (chiSquareTest.chiSquareTest(expected, observed, alpha)) {
242 final StringBuilder msgBuffer = new StringBuilder();
243 final DecimalFormat df = new DecimalFormat("#.##");
244 msgBuffer.append("Chisquare test failed");
245 msgBuffer.append(" p-value = ");
246 msgBuffer.append(chiSquareTest.chiSquareTest(expected, observed));
247 msgBuffer.append(" chisquare statistic = ");
248 msgBuffer.append(chiSquareTest.chiSquare(expected, observed));
249 msgBuffer.append(". \n");
250 msgBuffer.append("value\texpected\tobserved\n");
251 for (int i = 0; i < expected.length; i++) {
252 msgBuffer.append(valueLabels[i]);
253 msgBuffer.append('\t');
254 msgBuffer.append(df.format(expected[i]));
255 msgBuffer.append("\t\t");
256 msgBuffer.append(observed[i]);
257 msgBuffer.append('\n');
258 }
259 msgBuffer.append("This test can fail randomly due to sampling error with probability ");
260 msgBuffer.append(alpha);
261 msgBuffer.append('.');
262 Assertions.fail(msgBuffer.toString());
263 }
264 }
265
266
267
268
269
270
271
272
273
274
275 static void assertChiSquareAccept(int[] values,
276 double[] expected,
277 long[] observed,
278 double alpha) {
279 assertChiSquare(values, expected, observed, alpha);
280 }
281
282
283
284
285
286
287
288
289
290 static void assertChiSquareAccept(double[] expected,
291 long[] observed,
292 double alpha) {
293 final int[] values = new int[expected.length];
294 for (int i = 0; i < values.length; i++) {
295 values[i] = i + 1;
296 }
297 assertChiSquare(values, expected, observed, alpha);
298 }
299
300
301
302
303
304
305
306
307 static double[] getDistributionQuartiles(ContinuousDistribution distribution) {
308 final double[] quantiles = new double[3];
309 quantiles[0] = distribution.inverseCumulativeProbability(0.25d);
310 quantiles[1] = distribution.inverseCumulativeProbability(0.5d);
311 quantiles[2] = distribution.inverseCumulativeProbability(0.75d);
312 return quantiles;
313 }
314
315
316
317
318
319
320
321
322 static int[] getDistributionQuartiles(DiscreteDistribution distribution) {
323 final int[] quantiles = new int[3];
324 quantiles[0] = distribution.inverseCumulativeProbability(0.25d);
325 quantiles[1] = distribution.inverseCumulativeProbability(0.5d);
326 quantiles[2] = distribution.inverseCumulativeProbability(0.75d);
327 return quantiles;
328 }
329
330
331
332
333
334
335
336
337
338 static void updateCounts(double value, long[] counts, double[] quartiles) {
339 if (value > quartiles[1]) {
340 counts[value <= quartiles[2] ? 2 : 3]++;
341 } else {
342 counts[value <= quartiles[0] ? 0 : 1]++;
343 }
344 }
345
346
347
348
349
350
351
352
353
354 static void updateCounts(double value, long[] counts, int[] quartiles) {
355 if (value > quartiles[1]) {
356 counts[value <= quartiles[2] ? 2 : 3]++;
357 } else {
358 counts[value <= quartiles[0] ? 0 : 1]++;
359 }
360 }
361
362
363
364
365
366
367
368
369
370
371
372 static int eliminateZeroMassPoints(int[] densityPoints, double[] densityValues) {
373 int positiveMassCount = 0;
374 for (int i = 0; i < densityValues.length; i++) {
375 if (densityValues[i] > 0) {
376 positiveMassCount++;
377 }
378 }
379 if (positiveMassCount < densityValues.length) {
380 final int[] newPoints = new int[positiveMassCount];
381 final double[] newValues = new double[positiveMassCount];
382 int j = 0;
383 for (int i = 0; i < densityValues.length; i++) {
384 if (densityValues[i] > 0) {
385 newPoints[j] = densityPoints[i];
386 newValues[j] = densityValues[i];
387 j++;
388 }
389 }
390 System.arraycopy(newPoints, 0, densityPoints, 0, positiveMassCount);
391 System.arraycopy(newValues, 0, densityValues, 0, positiveMassCount);
392 }
393 return positiveMassCount;
394 }
395
396
397
398
399
400
401
402
403
404 static double[] sample(int n,
405 ContinuousDistribution.Sampler sampler) {
406 final double[] samples = new double[n];
407 for (int i = 0; i < n; i++) {
408 samples[i] = sampler.sample();
409 }
410 return samples;
411 }
412
413
414
415
416
417
418
419
420
421 static int[] sample(int n,
422 DiscreteDistribution.Sampler sampler) {
423 final int[] samples = new int[n];
424 for (int i = 0; i < n; i++) {
425 samples[i] = sampler.sample();
426 }
427 return samples;
428 }
429
430
431
432
433
434
435
436 static int getLength(double[] array) {
437 return array == null ? 0 : array.length;
438 }
439
440
441
442
443
444
445
446 static int getLength(int[] array) {
447 return array == null ? 0 : array.length;
448 }
449
450
451
452
453
454
455
456
457 static int getLength(Object array) {
458 return array == null ? 0 : Array.getLength(array);
459 }
460 }