View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.statistics.distribution;
19  
20  import java.lang.reflect.Array;
21  import java.text.DecimalFormat;
22  import java.util.function.Supplier;
23  import org.apache.commons.math3.stat.inference.ChiSquareTest;
24  import org.junit.jupiter.api.Assertions;
25  
26  /**
27   * Test utilities.
28   */
29  final class TestUtils {
30      /**
31       * The relative error threshold below which absolute error is reported in ULP.
32       */
33      private static final double ULP_THRESHOLD = 100 * Math.ulp(1.0);
34      /**
35       * The prefix for the formatted expected value.
36       *
37       * <p>This should be followed by the expected value then '>'.
38       */
39      private static final String EXPECTED_FORMAT = "expected: <";
40      /**
41       * The prefix for the formatted actual value.
42       *
43       * <p>It is assumed this will be following the expected value.
44       *
45       * <p>This should be followed by the actual value then '>'.
46       */
47      private static final String ACTUAL_FORMAT = ">, actual: <";
48      /**
49       * The prefix for the formatted relative error value.
50       *
51       * <p>It is assumed this will be following the actual value.
52       *
53       * <p>This should be followed by the relative error value then '>'.
54       */
55      private static final String RELATIVE_ERROR_FORMAT = ">, rel.error: <";
56      /**
57       * The prefix for the formatted absolute error value.
58       *
59       * <p>It is assumed this will be following the relative value.
60       *
61       * <p>This should be followed by the absolute error value then '>'.
62       */
63      private static final String ABSOLUTE_ERROR_FORMAT = ">, abs.error: <";
64      /**
65       * The prefix for the formatted ULP error value.
66       *
67       * <p>It is assumed this will be following the relative value.
68       *
69       * <p>This should be followed by the ULP error value then '>'.
70       */
71      private static final String ULP_ERROR_FORMAT = ">, ulp error: <";
72  
73      /**
74       * Collection of static methods used in math unit tests.
75       */
76      private TestUtils() {}
77  
78      ////////////////////////////////////////////////////////////////////////////////////////////
79      // Custom assertions using a DoubleTolerance
80      ////////////////////////////////////////////////////////////////////////////////////////////
81  
82      /**
83       * <em>Asserts</em> {@code expected} and {@code actual} are considered equal with the
84       * provided tolerance.
85       *
86       * @param expected The expected value.
87       * @param actual The value to tolerance against {@code expected}.
88       * @param tolerance The tolerance.
89       * @throws AssertionError If the values are not considered equal
90       */
91      static void assertEquals(double expected, double actual, DoubleTolerance tolerance) {
92          assertEquals(expected, actual, tolerance, (String) null);
93      }
94  
95      /**
96       * <em>Asserts</em> {@code expected} and {@code actual} are considered equal with the
97       * provided tolerance.
98       *
99       * <p>Fails with the supplied failure {@code message}.
100      *
101      * @param expected The expected value.
102      * @param actual The value to tolerance against {@code expected}.
103      * @param tolerance The tolerance.
104      * @param message The message.
105      * @throws AssertionError If the values are not considered equal
106      */
107     static void assertEquals(double expected, double actual, DoubleTolerance tolerance, String message) {
108         if (!tolerance.test(expected, actual)) {
109             throw new AssertionError(format(expected, actual, tolerance, message));
110         }
111     }
112 
113     /**
114      * <em>Asserts</em> {@code expected} and {@code actual} are considered equal with the
115      * provided tolerance.
116      *
117      * <p>If necessary, the failure message will be retrieved lazily from the supplied
118      * {@code messageSupplier}.
119      *
120      * @param expected The expected value.
121      * @param actual The value to tolerance against {@code expected}.
122      * @param tolerance The tolerance.
123      * @param messageSupplier The message supplier.
124      * @throws AssertionError If the values are not considered equal
125      */
126     static void assertEquals(double expected, double actual, DoubleTolerance tolerance,
127         Supplier<String> messageSupplier) {
128         if (!tolerance.test(expected, actual)) {
129             throw new AssertionError(
130                 format(expected, actual, tolerance, messageSupplier == null ? null : messageSupplier.get()));
131         }
132     }
133 
134     /**
135      * Format the message.
136      *
137      * @param expected The expected value.
138      * @param actual The value to check against <code>expected</code>.
139      * @param tolerance The tolerance.
140      * @param message The message.
141      * @return the formatted message
142      */
143     private static String format(double expected, double actual, DoubleTolerance tolerance, String message) {
144         return buildPrefix(message) + formatValues(expected, actual, tolerance);
145     }
146 
147     /**
148      * Builds the fail message prefix.
149      *
150      * @param message the message
151      * @return the prefix
152      */
153     private static String buildPrefix(String message) {
154         return StringUtils.isNotEmpty(message) ? message + " ==> " : "";
155     }
156 
157     /**
158      * Format the values.
159      *
160      * @param expected The expected value.
161      * @param actual The value to check against <code>expected</code>.
162      * @param tolerance The tolerance.
163      * @return the formatted values
164      */
165     private static String formatValues(double expected, double actual, DoubleTolerance tolerance) {
166         // Add error
167         final double diff = Math.abs(expected - actual);
168         final double rel = diff / Math.max(Math.abs(expected), Math.abs(actual));
169         final StringBuilder msg = new StringBuilder(EXPECTED_FORMAT).append(expected).append(ACTUAL_FORMAT)
170             .append(actual).append(RELATIVE_ERROR_FORMAT).append(rel);
171         if (rel < ULP_THRESHOLD) {
172             final long ulp = Math.abs(Double.doubleToRawLongBits(expected) - Double.doubleToRawLongBits(actual));
173             msg.append(ULP_ERROR_FORMAT).append(ulp);
174         } else {
175             msg.append(ABSOLUTE_ERROR_FORMAT).append(diff);
176         }
177         msg.append('>');
178         appendTolerance(msg, tolerance);
179         return msg.toString();
180     }
181 
182     /**
183      * Append the tolerance to the message.
184      *
185      * @param msg The message
186      * @param tolerance the tolerance
187      */
188     private static void appendTolerance(final StringBuilder msg, final Object tolerance) {
189         final String description = StringUtils.toString(tolerance);
190         if (StringUtils.isNotEmpty(description)) {
191             msg.append(", tolerance: ").append(description);
192         }
193     }
194 
195     ////////////////////////////////////////////////////////////////////////////////////////////
196 
197     /**
198      * Verifies that the relative error in actual vs. expected is less than or
199      * equal to relativeError.  If expected is infinite or NaN, actual must be
200      * the same (NaN or infinity of the same sign).
201      *
202      * @param msg  message to return with failure
203      * @param expected expected value
204      * @param actual  observed value
205      * @param relativeError  maximum allowable relative error
206      */
207     static void assertRelativelyEquals(Supplier<String> msg,
208                                        double expected,
209                                        double actual,
210                                        double relativeError) {
211         if (Double.isNaN(expected)) {
212             Assertions.assertTrue(Double.isNaN(actual), msg);
213         } else if (Double.isNaN(actual)) {
214             Assertions.assertTrue(Double.isNaN(expected), msg);
215         } else if (Double.isInfinite(actual) || Double.isInfinite(expected)) {
216             Assertions.assertEquals(expected, actual, relativeError);
217         } else if (expected == 0.0) {
218             Assertions.assertEquals(actual, expected, relativeError, msg);
219         } else {
220             final double absError = Math.abs(expected) * relativeError;
221             Assertions.assertEquals(expected, actual, absError, msg);
222         }
223     }
224 
225     /**
226      * Asserts the null hypothesis for a ChiSquare test.  Fails and dumps arguments and test
227      * statistics if the null hypothesis can be rejected with confidence 100 * (1 - alpha)%
228      *
229      * @param valueLabels labels for the values of the discrete distribution under test
230      * @param expected expected counts
231      * @param observed observed counts
232      * @param alpha significance level of the test
233      */
234     private static void assertChiSquare(int[] valueLabels,
235                                         double[] expected,
236                                         long[] observed,
237                                         double alpha) {
238         final ChiSquareTest chiSquareTest = new ChiSquareTest();
239 
240         // Fail if we can reject null hypothesis that distributions are the same
241         if (chiSquareTest.chiSquareTest(expected, observed, alpha)) {
242             final StringBuilder msgBuffer = new StringBuilder();
243             final DecimalFormat df = new DecimalFormat("#.##");
244             msgBuffer.append("Chisquare test failed");
245             msgBuffer.append(" p-value = ");
246             msgBuffer.append(chiSquareTest.chiSquareTest(expected, observed));
247             msgBuffer.append(" chisquare statistic = ");
248             msgBuffer.append(chiSquareTest.chiSquare(expected, observed));
249             msgBuffer.append(". \n");
250             msgBuffer.append("value\texpected\tobserved\n");
251             for (int i = 0; i < expected.length; i++) {
252                 msgBuffer.append(valueLabels[i]);
253                 msgBuffer.append('\t');
254                 msgBuffer.append(df.format(expected[i]));
255                 msgBuffer.append("\t\t");
256                 msgBuffer.append(observed[i]);
257                 msgBuffer.append('\n');
258             }
259             msgBuffer.append("This test can fail randomly due to sampling error with probability ");
260             msgBuffer.append(alpha);
261             msgBuffer.append('.');
262             Assertions.fail(msgBuffer.toString());
263         }
264     }
265 
266     /**
267      * Asserts the null hypothesis for a ChiSquare test.  Fails and dumps arguments and test
268      * statistics if the null hypothesis can be rejected with confidence 100 * (1 - alpha)%
269      *
270      * @param values integer values whose observed and expected counts are being compared
271      * @param expected expected counts
272      * @param observed observed counts
273      * @param alpha significance level of the test
274      */
275     static void assertChiSquareAccept(int[] values,
276                                       double[] expected,
277                                       long[] observed,
278                                       double alpha) {
279         assertChiSquare(values, expected, observed, alpha);
280     }
281 
282     /**
283      * Asserts the null hypothesis for a ChiSquare test.  Fails and dumps arguments and test
284      * statistics if the null hypothesis can be rejected with confidence 100 * (1 - alpha)%
285      *
286      * @param expected expected counts
287      * @param observed observed counts
288      * @param alpha significance level of the test
289      */
290     static void assertChiSquareAccept(double[] expected,
291                                       long[] observed,
292                                       double alpha) {
293         final int[] values = new int[expected.length];
294         for (int i = 0; i < values.length; i++) {
295             values[i] = i + 1;
296         }
297         assertChiSquare(values, expected, observed, alpha);
298     }
299 
300     /**
301      * Computes the 25th, 50th and 75th percentiles of the given distribution and returns
302      * these values in an array.
303      *
304      * @param distribution Distribution.
305      * @return the quartiles
306      */
307     static double[] getDistributionQuartiles(ContinuousDistribution distribution) {
308         final double[] quantiles = new double[3];
309         quantiles[0] = distribution.inverseCumulativeProbability(0.25d);
310         quantiles[1] = distribution.inverseCumulativeProbability(0.5d);
311         quantiles[2] = distribution.inverseCumulativeProbability(0.75d);
312         return quantiles;
313     }
314 
315     /**
316      * Computes the 25th, 50th and 75th percentiles of the given distribution and returns
317      * these values in an array.
318      *
319      * @param distribution Distribution.
320      * @return the quartiles
321      */
322     static int[] getDistributionQuartiles(DiscreteDistribution distribution) {
323         final int[] quantiles = new int[3];
324         quantiles[0] = distribution.inverseCumulativeProbability(0.25d);
325         quantiles[1] = distribution.inverseCumulativeProbability(0.5d);
326         quantiles[2] = distribution.inverseCumulativeProbability(0.75d);
327         return quantiles;
328     }
329 
330     /**
331      * Updates observed counts of values in quartiles.
332      * counts[0] <-> 1st quartile ... counts[3] <-> top quartile
333      *
334      * @param value Observed value.
335      * @param counts Counts for each quartile.
336      * @param quartiles Quartiles.
337      */
338     static void updateCounts(double value, long[] counts, double[] quartiles) {
339         if (value > quartiles[1]) {
340             counts[value <= quartiles[2] ? 2 : 3]++;
341         } else {
342             counts[value <= quartiles[0] ? 0 : 1]++;
343         }
344     }
345 
346     /**
347      * Updates observed counts of values in quartiles.
348      * counts[0] <-> 1st quartile ... counts[3] <-> top quartile
349      *
350      * @param value Observed value.
351      * @param counts Counts for each quartile.
352      * @param quartiles Quartiles.
353      */
354     static void updateCounts(double value, long[] counts, int[] quartiles) {
355         if (value > quartiles[1]) {
356             counts[value <= quartiles[2] ? 2 : 3]++;
357         } else {
358             counts[value <= quartiles[0] ? 0 : 1]++;
359         }
360     }
361 
362     /**
363      * Eliminates points with zero mass from densityPoints and densityValues parallel
364      * arrays. Returns the number of positive mass points and collapses the arrays so that
365      * the first <returned value> elements of the input arrays represent the positive mass
366      * points.
367      *
368      * @param densityPoints Density points.
369      * @param densityValues Density values.
370      * @return number of positive mass points
371      */
372     static int eliminateZeroMassPoints(int[] densityPoints, double[] densityValues) {
373         int positiveMassCount = 0;
374         for (int i = 0; i < densityValues.length; i++) {
375             if (densityValues[i] > 0) {
376                 positiveMassCount++;
377             }
378         }
379         if (positiveMassCount < densityValues.length) {
380             final int[] newPoints = new int[positiveMassCount];
381             final double[] newValues = new double[positiveMassCount];
382             int j = 0;
383             for (int i = 0; i < densityValues.length; i++) {
384                 if (densityValues[i] > 0) {
385                     newPoints[j] = densityPoints[i];
386                     newValues[j] = densityValues[i];
387                     j++;
388                 }
389             }
390             System.arraycopy(newPoints, 0, densityPoints, 0, positiveMassCount);
391             System.arraycopy(newValues, 0, densityValues, 0, positiveMassCount);
392         }
393         return positiveMassCount;
394     }
395 
396     /**
397      * Utility function for allocating an array and filling it with {@code n}
398      * samples generated by the given {@code sampler}.
399      *
400      * @param n Number of samples.b
401      * @param sampler Sampler.
402      * @return an array of size {@code n}.
403      */
404     static double[] sample(int n,
405                            ContinuousDistribution.Sampler sampler) {
406         final double[] samples = new double[n];
407         for (int i = 0; i < n; i++) {
408             samples[i] = sampler.sample();
409         }
410         return samples;
411     }
412 
413     /**
414      * Utility function for allocating an array and filling it with {@code n}
415      * samples generated by the given {@code sampler}.
416      *
417      * @param n Number of samples.
418      * @param sampler Sampler.
419      * @return an array of size {@code n}.
420      */
421     static int[] sample(int n,
422                         DiscreteDistribution.Sampler sampler) {
423         final int[] samples = new int[n];
424         for (int i = 0; i < n; i++) {
425             samples[i] = sampler.sample();
426         }
427         return samples;
428     }
429 
430     /**
431      * Gets the length of the array.
432      *
433      * @param array Array
434      * @return the length (or 0 for null array)
435      */
436     static int getLength(double[] array) {
437         return array == null ? 0 : array.length;
438     }
439 
440     /**
441      * Gets the length of the array.
442      *
443      * @param array Array
444      * @return the length (or 0 for null array)
445      */
446     static int getLength(int[] array) {
447         return array == null ? 0 : array.length;
448     }
449 
450     /**
451      * Gets the length of the array.
452      *
453      * @param array Array
454      * @return the length (or 0 for null array)
455      * @throws IllegalArgumentException if the object is not an array
456      */
457     static int getLength(Object array) {
458         return array == null ? 0 : Array.getLength(array);
459     }
460 }