View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.numbers.examples.jmh.core;
18  
19  import java.math.BigDecimal;
20  import java.math.MathContext;
21  import java.util.HashMap;
22  import java.util.LinkedHashMap;
23  import java.util.Map;
24  import java.util.function.ToDoubleFunction;
25  
26  /**
27   * Class used to evaluate the accuracy of different norm computation
28   * methods.
29   */
30  public class EuclideanNormEvaluator {
31  
32      /** Map of names to norm computation methods. */
33      private final Map<String, ToDoubleFunction<double[]>> methods = new LinkedHashMap<>();
34  
35      /** Add a computation method to be evaluated.
36       * @param name method name
37       * @param method computation method
38       * @return this instance
39       */
40      public EuclideanNormEvaluator addMethod(final String name, final ToDoubleFunction<double[]> method) {
41          methods.put(name, method);
42          return this;
43      }
44  
45      /** Evaluate the configured computation methods against the given array of input vectors.
46       * @param inputs array of input vectors
47       * @return map of evaluation results keyed by method name
48       */
49      public Map<String, Stats> evaluate(final double[][] inputs) {
50  
51          final Map<String, StatsAccumulator> accumulators = new HashMap<>();
52          for (final String name : methods.keySet()) {
53              accumulators.put(name, new StatsAccumulator(inputs.length * 2));
54          }
55  
56          for (int i = 0; i < inputs.length; ++i) {
57              // compute the norm in a forward and reverse directions to include
58              // summation artifacts
59              final double[] vec = inputs[i];
60  
61              final double[] reverseVec = new double[vec.length];
62              for (int j = 0; j < vec.length; ++j) {
63                  reverseVec[vec.length - 1 - j] = vec[j];
64              }
65  
66              final double exact = computeExact(vec);
67  
68              for (final Map.Entry<String, ToDoubleFunction<double[]>> entry : methods.entrySet()) {
69                  final ToDoubleFunction<double[]> fn = entry.getValue();
70  
71                  final StatsAccumulator acc = accumulators.get(entry.getKey());
72  
73                  final double forwardSample = fn.applyAsDouble(vec);
74                  acc.report(exact, forwardSample);
75  
76                  final double reverseSample = fn.applyAsDouble(reverseVec);
77                  acc.report(exact, reverseSample);
78              }
79          }
80  
81          final Map<String, Stats> stats = new LinkedHashMap<>();
82          for (final String name : methods.keySet()) {
83              stats.put(name, accumulators.get(name).computeStats());
84          }
85  
86          return stats;
87      }
88  
89      /** Compute the exact double value of the vector norm using BigDecimals
90       * with a math context of {@link MathContext#DECIMAL128}.
91       * @param vec input vector
92       * @return euclidean norm
93       */
94      private static double computeExact(final double[] vec) {
95          final MathContext ctx = MathContext.DECIMAL128;
96  
97          BigDecimal sum = BigDecimal.ZERO;
98          for (final double v : vec) {
99              sum = sum.add(new BigDecimal(v).pow(2), ctx);
100         }
101 
102         return sum.sqrt(ctx).doubleValue();
103     }
104 
105     /** Compute the ulp difference between two values of the same sign.
106      * @param a first input
107      * @param b second input
108      * @return ulp difference between the arguments
109      */
110     private static int computeUlpDifference(final double a, final double b) {
111         return (int) (Double.doubleToLongBits(a) - Double.doubleToLongBits(b));
112     }
113 
114     /** Class containing evaluation statistics for a single computation method.
115      */
116     public static final class Stats {
117 
118         /** Mean ulp error. */
119         private final double ulpErrorMean;
120 
121         /** Ulp error standard deviation. */
122         private final double ulpErrorStdDev;
123 
124         /** Ulp error minimum value. */
125         private final double ulpErrorMin;
126 
127         /** Ulp error maximum value. */
128         private final double ulpErrorMax;
129 
130         /** Number of failed computations. */
131         private final int failCount;
132 
133         /** Construct a new instance.
134          * @param ulpErrorMean ulp error mean
135          * @param ulpErrorStdDev ulp error standard deviation
136          * @param ulpErrorMin ulp error minimum value
137          * @param ulpErrorMax ulp error maximum value
138          * @param failCount number of failed computations
139          */
140         Stats(final double ulpErrorMean, final double ulpErrorStdDev, final double ulpErrorMin,
141                 final double ulpErrorMax, final int failCount) {
142             this.ulpErrorMean = ulpErrorMean;
143             this.ulpErrorStdDev = ulpErrorStdDev;
144             this.ulpErrorMin = ulpErrorMin;
145             this.ulpErrorMax = ulpErrorMax;
146             this.failCount = failCount;
147         }
148 
149         /** Get the ulp error mean.
150          * @return ulp error mean
151          */
152         public double getUlpErrorMean() {
153             return ulpErrorMean;
154         }
155 
156         /** Get the ulp error standard deviation.
157          * @return ulp error standard deviation
158          */
159         public double getUlpErrorStdDev() {
160             return ulpErrorStdDev;
161         }
162 
163         /** Get the ulp error minimum value.
164          * @return ulp error minimum value
165          */
166         public double getUlpErrorMin() {
167             return ulpErrorMin;
168         }
169 
170         /** Get the ulp error maximum value.
171          * @return ulp error maximum value
172          */
173         public double getUlpErrorMax() {
174             return ulpErrorMax;
175         }
176 
177         /** Get the number of failed computations, meaning the number of
178          * computations that overflowed or underflowed.
179          * @return number of failed computations
180          */
181         public int getFailCount() {
182             return failCount;
183         }
184     }
185 
186     /** Class used to accumulate statistics during a norm evaluation run.
187      */
188     private static final class StatsAccumulator {
189 
190         /** Sample index. */
191         private int sampleIdx;
192 
193         /** Array of ulp errors for each sample. */
194         private final double[] ulpErrors;
195 
196         /** Construct a new instance.
197          * @param count number of samples to be accumulated
198          */
199         StatsAccumulator(final int count) {
200             ulpErrors = new double[count];
201         }
202 
203         /** Report a computation result.
204          * @param expected expected result
205          * @param actual actual result
206          */
207         public void report(final double expected, final double actual) {
208             ulpErrors[sampleIdx++] = Double.isFinite(actual) && actual != 0.0 ?
209                     computeUlpDifference(expected, actual) :
210                     Double.NaN;
211         }
212 
213         /** Compute the final statistics for the run.
214          * @return statistics object
215          */
216         public Stats computeStats() {
217             int successCount = 0;
218             double sum = 0d;
219             double min = Double.POSITIVE_INFINITY;
220             double max = Double.NEGATIVE_INFINITY;
221 
222             for (double ulpError : ulpErrors) {
223                 if (Double.isFinite(ulpError)) {
224                     ++successCount;
225                     min = Math.min(ulpError, min);
226                     max = Math.max(ulpError, max);
227                     sum += ulpError;
228                 }
229             }
230 
231             final double mean = sum / successCount;
232 
233             double diffSumSq = 0d;
234             double diff;
235             for (double ulpError : ulpErrors) {
236                 if (Double.isFinite(ulpError)) {
237                     diff = ulpError - mean;
238                     diffSumSq += diff * diff;
239                 }
240             }
241 
242             final double stdDev = successCount > 1 ?
243                     Math.sqrt(diffSumSq / (successCount - 1)) :
244                     0d;
245 
246             return new Stats(mean, stdDev, min, max, ulpErrors.length - successCount);
247         }
248     }
249 }