1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.numbers.examples.jmh.gamma;
19
20 import java.util.SplittableRandom;
21 import java.util.concurrent.ThreadLocalRandom;
22 import java.util.concurrent.TimeUnit;
23 import java.util.function.DoubleSupplier;
24 import java.util.function.DoubleUnaryOperator;
25 import java.util.function.Supplier;
26 import java.util.stream.DoubleStream;
27 import org.apache.commons.numbers.core.Precision;
28 import org.apache.commons.numbers.fraction.ContinuedFraction;
29 import org.apache.commons.numbers.gamma.Erf;
30 import org.apache.commons.numbers.gamma.Erfc;
31 import org.apache.commons.numbers.gamma.InverseErf;
32 import org.apache.commons.numbers.gamma.InverseErfc;
33 import org.apache.commons.numbers.gamma.LogGamma;
34 import org.openjdk.jmh.annotations.Benchmark;
35 import org.openjdk.jmh.annotations.BenchmarkMode;
36 import org.openjdk.jmh.annotations.Fork;
37 import org.openjdk.jmh.annotations.Measurement;
38 import org.openjdk.jmh.annotations.Mode;
39 import org.openjdk.jmh.annotations.OutputTimeUnit;
40 import org.openjdk.jmh.annotations.Param;
41 import org.openjdk.jmh.annotations.Scope;
42 import org.openjdk.jmh.annotations.Setup;
43 import org.openjdk.jmh.annotations.State;
44 import org.openjdk.jmh.annotations.Warmup;
45 import org.openjdk.jmh.infra.Blackhole;
46
47
48
49
50 @BenchmarkMode(Mode.AverageTime)
51 @OutputTimeUnit(TimeUnit.NANOSECONDS)
52 @Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
53 @Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
54 @State(Scope.Benchmark)
55 @Fork(value = 1, jvmArgs = {"-server", "-Xms512M", "-Xmx512M"})
56 public class ErfPerformance {
57
58 private static final double EXTREME_VALUE_BOUND = 40;
59
60 private static final String IMP_NUMBERS_1_0 = "Numbers 1.0";
61
62 private static final String IMP_NUMBERS_1_1 = "Boost";
63
64 private static final String NUM_UNIFORM = "uniform";
65
66 private static final String NUM_INVERSE_UNIFORM = "inverse uniform";
67
68 private static final String UNKNOWN = "unknown parameter: ";
69
70 private static final String ERF_DOMAIN_ERROR = "erf domain error: ";
71
72 private static final String ERFC_DOMAIN_ERROR = "erfc domain error: ";
73
74
75
76
77
78
79
80
81
82
83 private static final long SEED = ThreadLocalRandom.current().nextLong();
84
85
86
87
88 public abstract static class NumberData {
89
90 @Param({"1000"})
91 private int size;
92
93
94 private double[] numbers;
95
96
97
98
99
100
101 public int getSize() {
102 return size;
103 }
104
105
106
107
108
109
110 public double[] getNumbers() {
111 return numbers;
112 }
113
114
115
116
117 @Setup
118 public void setup() {
119 numbers = createNumbers(new SplittableRandom(SEED));
120 }
121
122
123
124
125
126
127
128
129 protected abstract double[] createNumbers(SplittableRandom rng);
130 }
131
132
133
134
135
136 @State(Scope.Benchmark)
137 public static class BaseData extends NumberData {
138
139 @Override
140 protected double[] createNumbers(SplittableRandom rng) {
141 return rng.doubles().limit(getSize()).toArray();
142 }
143 }
144
145
146
147
148 public abstract static class FunctionData extends NumberData {
149
150
151 private DoubleUnaryOperator function;
152
153
154
155
156 @Param({IMP_NUMBERS_1_0, IMP_NUMBERS_1_1})
157 private String implementation;
158
159
160
161
162
163
164 public String getImplementation() {
165 return implementation;
166 }
167
168
169
170
171
172
173 public DoubleUnaryOperator getFunction() {
174 return function;
175 }
176
177
178
179
180 @Override
181 @Setup
182 public void setup() {
183 super.setup();
184 function = createFunction();
185 verify();
186 }
187
188
189
190
191
192
193
194 protected abstract DoubleUnaryOperator createFunction();
195
196
197
198
199
200
201
202
203 protected abstract void verify();
204 }
205
206
207
208
209 @State(Scope.Benchmark)
210 public static class ErfData extends FunctionData {
211
212 @Param({NUM_UNIFORM, NUM_INVERSE_UNIFORM})
213 private String type;
214
215
216 @Override
217 protected double[] createNumbers(SplittableRandom rng) {
218 DoubleSupplier generator;
219 if (NUM_INVERSE_UNIFORM.equals(type)) {
220
221
222 generator = () -> InverseErf.value(makeSignedDouble(rng));
223 } else if (NUM_UNIFORM.equals(type)) {
224
225
226 generator = () -> makeSignedDouble(rng) * 6;
227 } else {
228 throw new IllegalStateException(UNKNOWN + type);
229 }
230 return DoubleStream.generate(generator).limit(getSize()).toArray();
231 }
232
233
234 @Override
235 protected DoubleUnaryOperator createFunction() {
236 final String impl = getImplementation();
237 if (IMP_NUMBERS_1_0.equals(impl)) {
238 return ErfPerformance::erf;
239 } else if (IMP_NUMBERS_1_1.equals(impl)) {
240 return Erf::value;
241 } else {
242 throw new IllegalStateException(UNKNOWN + impl);
243 }
244 }
245
246
247 @Override
248 protected void verify() {
249 final DoubleUnaryOperator function = getFunction();
250 final double relativeEps = 1e-6;
251 for (final double x : getNumbers()) {
252 final double p = function.applyAsDouble(x);
253 assert -1 <= p & p <= 1 : ERF_DOMAIN_ERROR + p;
254
255
256
257
258
259 if (p < 1e-10 || Math.abs(p - 1) < 1e-10) {
260 continue;
261 }
262 assertEquals(x, InverseErf.value(p), Math.abs(x) * relativeEps,
263 () -> getImplementation() + " inverse erf " + p);
264 }
265 }
266 }
267
268
269
270
271 @State(Scope.Benchmark)
272 public static class ErfcData extends FunctionData {
273
274 @Param({NUM_UNIFORM, NUM_INVERSE_UNIFORM})
275 private String type;
276
277
278 @Override
279 protected double[] createNumbers(SplittableRandom rng) {
280 DoubleSupplier generator;
281 if (NUM_INVERSE_UNIFORM.equals(type)) {
282
283
284 generator = () -> InverseErfc.value(rng.nextDouble() * 2);
285 } else if (NUM_UNIFORM.equals(type)) {
286
287
288
289 generator = () -> makeSignedDouble(rng) * 17 + 11;
290 } else {
291 throw new IllegalStateException(UNKNOWN + type);
292 }
293 return DoubleStream.generate(generator).limit(getSize()).toArray();
294 }
295
296
297 @Override
298 protected DoubleUnaryOperator createFunction() {
299 final String impl = getImplementation();
300 if (IMP_NUMBERS_1_0.equals(impl)) {
301 return ErfPerformance::erfc;
302 } else if (IMP_NUMBERS_1_1.equals(impl)) {
303 return Erfc::value;
304 } else {
305 throw new IllegalStateException(UNKNOWN + impl);
306 }
307 }
308
309
310 @Override
311 protected void verify() {
312 final DoubleUnaryOperator function = getFunction();
313 final double relativeEps = 1e-6;
314 for (final double x : getNumbers()) {
315 final double q = function.applyAsDouble(x);
316 assert 0 <= q && q <= 2 : ERFC_DOMAIN_ERROR + q;
317
318
319
320
321
322
323 if (q < 1e-10 || Math.abs(q - 1) < 1e-10 || q > 2 - 1e-10) {
324 continue;
325 }
326 assertEquals(x, InverseErfc.value(q), Math.abs(x) * relativeEps,
327 () -> getImplementation() + " inverse erfc " + q);
328 }
329 }
330 }
331
332
333
334
335 @State(Scope.Benchmark)
336 public static class InverseErfData extends FunctionData {
337
338
339
340 @Param({NUM_UNIFORM})
341 private String type;
342
343
344 @Override
345 protected double[] createNumbers(SplittableRandom rng) {
346 DoubleSupplier generator;
347 if (NUM_UNIFORM.equals(type)) {
348
349 generator = () -> makeSignedDouble(rng);
350 } else {
351 throw new IllegalStateException(UNKNOWN + type);
352 }
353 return DoubleStream.generate(generator).limit(getSize()).toArray();
354 }
355
356
357 @Override
358 protected DoubleUnaryOperator createFunction() {
359 final String impl = getImplementation();
360 if (IMP_NUMBERS_1_0.equals(impl)) {
361 return ErfPerformance::inverseErf;
362 } else if (IMP_NUMBERS_1_1.equals(impl)) {
363 return InverseErf::value;
364 } else {
365 throw new IllegalStateException(UNKNOWN + impl);
366 }
367 }
368
369
370 @Override
371 protected void verify() {
372 final DoubleUnaryOperator function = getFunction();
373 final double relativeEps = 1e-12;
374 for (final double x : getNumbers()) {
375 assert -1 <= x && x <= 1 : ERF_DOMAIN_ERROR + x;
376
377
378
379
380
381 if (x < 1e-10 || Math.abs(x - 1) < 1e-10) {
382 continue;
383 }
384 final double t = function.applyAsDouble(x);
385 assertEquals(x, Erf.value(t), Math.abs(x) * relativeEps,
386 () -> getImplementation() + " erf " + t);
387 }
388 }
389 }
390
391
392
393
394 @State(Scope.Benchmark)
395 public static class InverseErfcData extends FunctionData {
396
397
398
399 @Param({NUM_UNIFORM})
400 private String type;
401
402
403 @Override
404 protected double[] createNumbers(SplittableRandom rng) {
405 DoubleSupplier generator;
406 if (NUM_UNIFORM.equals(type)) {
407
408 generator = () -> rng.nextDouble() * 2;
409 } else {
410 throw new IllegalStateException(UNKNOWN + type);
411 }
412 return DoubleStream.generate(generator).limit(getSize()).toArray();
413 }
414
415
416 @Override
417 protected DoubleUnaryOperator createFunction() {
418 final String impl = getImplementation();
419 if (IMP_NUMBERS_1_0.equals(impl)) {
420 return ErfPerformance::inverseErfc;
421 } else if (IMP_NUMBERS_1_1.equals(impl)) {
422 return InverseErfc::value;
423 } else {
424 throw new IllegalStateException(UNKNOWN + impl);
425 }
426 }
427
428
429 @Override
430 protected void verify() {
431 final DoubleUnaryOperator function = getFunction();
432 final double relativeEps = 1e-12;
433 for (final double x : getNumbers()) {
434 assert 0 <= x && x <= 2 : ERFC_DOMAIN_ERROR + x;
435
436
437
438
439
440
441 if (x < 1e-10 || Math.abs(x - 1) < 1e-10 || x > 2 - 1e-10) {
442 continue;
443 }
444 final double t = function.applyAsDouble(x);
445 assertEquals(x, Erfc.value(t), Math.abs(x) * relativeEps,
446 () -> getImplementation() + " erfc " + t);
447 }
448 }
449 }
450
451
452
453
454
455
456
457 private static double makeSignedDouble(SplittableRandom rng) {
458
459
460
461
462
463 return (rng.nextLong() >> 10) * 0x1.0p-53;
464 }
465
466
467
468
469
470
471
472
473
474 private static double inverseErfc(final double x) {
475 return inverseErf(1 - x);
476 }
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493 private static double inverseErf(final double x) {
494
495
496
497
498 double w = -Math.log((1 - x) * (1 + x));
499 double p;
500
501 if (w < 6.25) {
502 w -= 3.125;
503 p = -3.6444120640178196996e-21;
504 p = -1.685059138182016589e-19 + p * w;
505 p = 1.2858480715256400167e-18 + p * w;
506 p = 1.115787767802518096e-17 + p * w;
507 p = -1.333171662854620906e-16 + p * w;
508 p = 2.0972767875968561637e-17 + p * w;
509 p = 6.6376381343583238325e-15 + p * w;
510 p = -4.0545662729752068639e-14 + p * w;
511 p = -8.1519341976054721522e-14 + p * w;
512 p = 2.6335093153082322977e-12 + p * w;
513 p = -1.2975133253453532498e-11 + p * w;
514 p = -5.4154120542946279317e-11 + p * w;
515 p = 1.051212273321532285e-09 + p * w;
516 p = -4.1126339803469836976e-09 + p * w;
517 p = -2.9070369957882005086e-08 + p * w;
518 p = 4.2347877827932403518e-07 + p * w;
519 p = -1.3654692000834678645e-06 + p * w;
520 p = -1.3882523362786468719e-05 + p * w;
521 p = 0.0001867342080340571352 + p * w;
522 p = -0.00074070253416626697512 + p * w;
523 p = -0.0060336708714301490533 + p * w;
524 p = 0.24015818242558961693 + p * w;
525 p = 1.6536545626831027356 + p * w;
526 } else if (w < 16.0) {
527 w = Math.sqrt(w) - 3.25;
528 p = 2.2137376921775787049e-09;
529 p = 9.0756561938885390979e-08 + p * w;
530 p = -2.7517406297064545428e-07 + p * w;
531 p = 1.8239629214389227755e-08 + p * w;
532 p = 1.5027403968909827627e-06 + p * w;
533 p = -4.013867526981545969e-06 + p * w;
534 p = 2.9234449089955446044e-06 + p * w;
535 p = 1.2475304481671778723e-05 + p * w;
536 p = -4.7318229009055733981e-05 + p * w;
537 p = 6.8284851459573175448e-05 + p * w;
538 p = 2.4031110387097893999e-05 + p * w;
539 p = -0.0003550375203628474796 + p * w;
540 p = 0.00095328937973738049703 + p * w;
541 p = -0.0016882755560235047313 + p * w;
542 p = 0.0024914420961078508066 + p * w;
543 p = -0.0037512085075692412107 + p * w;
544 p = 0.005370914553590063617 + p * w;
545 p = 1.0052589676941592334 + p * w;
546 p = 3.0838856104922207635 + p * w;
547 } else if (w < Double.POSITIVE_INFINITY) {
548 w = Math.sqrt(w) - 5;
549 p = -2.7109920616438573243e-11;
550 p = -2.5556418169965252055e-10 + p * w;
551 p = 1.5076572693500548083e-09 + p * w;
552 p = -3.7894654401267369937e-09 + p * w;
553 p = 7.6157012080783393804e-09 + p * w;
554 p = -1.4960026627149240478e-08 + p * w;
555 p = 2.9147953450901080826e-08 + p * w;
556 p = -6.7711997758452339498e-08 + p * w;
557 p = 2.2900482228026654717e-07 + p * w;
558 p = -9.9298272942317002539e-07 + p * w;
559 p = 4.5260625972231537039e-06 + p * w;
560 p = -1.9681778105531670567e-05 + p * w;
561 p = 7.5995277030017761139e-05 + p * w;
562 p = -0.00021503011930044477347 + p * w;
563 p = -0.00013871931833623122026 + p * w;
564 p = 1.0103004648645343977 + p * w;
565 p = 4.8499064014085844221 + p * w;
566 } else if (w == Double.POSITIVE_INFINITY) {
567
568
569
570
571
572
573
574
575 p = Double.POSITIVE_INFINITY;
576 } else {
577
578
579 return Double.NaN;
580 }
581
582 return p * x;
583 }
584
585
586
587
588
589
590
591
592
593
594
595
596
597 private static double erfc(final double x) {
598 if (Math.abs(x) > EXTREME_VALUE_BOUND) {
599 return x > 0 ? 0 : 2;
600 }
601 final double ret = RegularizedGamma.Q.value(0.5, x * x, 1e-15, 10000);
602 return x < 0 ? 2 - ret : ret;
603 }
604
605
606
607
608
609
610
611
612
613
614
615
616
617 private static double erf(final double x) {
618 if (Math.abs(x) > EXTREME_VALUE_BOUND) {
619 return x > 0 ? 1 : -1;
620 }
621 final double ret = RegularizedGamma.P.value(0.5, x * x, 1e-15, 10000);
622 return x < 0 ? -ret : ret;
623 }
624
625
626
627
628
629
630
631
632
633
634
635
636 private static final class RegularizedGamma {
637
638 private RegularizedGamma() {
639
640 }
641
642
643
644
645
646
647
648 static final class P {
649
650 private P() {}
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678 static double value(double a,
679 double x,
680 double epsilon,
681 int maxIterations) {
682 if (Double.isNaN(a) ||
683 Double.isNaN(x) ||
684 a <= 0 ||
685 x < 0) {
686 return Double.NaN;
687 } else if (x == 0) {
688 return 0;
689 } else if (x >= a + 1) {
690
691 return 1 - RegularizedGamma.Q.value(a, x, epsilon, maxIterations);
692 } else {
693
694 double n = 0;
695 double an = 1 / a;
696 double sum = an;
697 while (Math.abs(an / sum) > epsilon &&
698 n < maxIterations &&
699 sum < Double.POSITIVE_INFINITY) {
700
701 n += 1;
702 an *= x / (a + n);
703
704
705 sum += an;
706 }
707 if (n >= maxIterations) {
708 throw new ArithmeticException("Max iterations exceeded: " + maxIterations);
709 } else if (Double.isInfinite(sum)) {
710 return 1;
711 } else {
712
713 final double result = Math.exp(-x + (a * Math.log(x)) - LogGamma.value(a)) * sum;
714 return result > 1.0 ? 1.0 : result;
715 }
716 }
717 }
718 }
719
720
721
722
723
724
725
726 static final class Q {
727
728 private Q() {}
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753 static double value(final double a,
754 double x,
755 double epsilon,
756 int maxIterations) {
757 if (Double.isNaN(a) ||
758 Double.isNaN(x) ||
759 a <= 0 ||
760 x < 0) {
761 return Double.NaN;
762 } else if (x == 0) {
763 return 1;
764 } else if (x < a + 1) {
765
766 return 1 - RegularizedGamma.P.value(a, x, epsilon, maxIterations);
767 } else {
768 final ContinuedFraction cf = new ContinuedFraction() {
769
770 @Override
771 protected double getA(int n, double x) {
772 return n * (a - n);
773 }
774
775
776 @Override
777 protected double getB(int n, double x) {
778 return ((2 * n) + 1) - a + x;
779 }
780 };
781
782 return Math.exp(-x + (a * Math.log(x)) - LogGamma.value(a)) /
783 cf.evaluate(x, epsilon, maxIterations);
784 }
785 }
786 }
787 }
788
789
790
791
792
793
794
795
796
797 static void assertEquals(double x, double y, double eps, Supplier<String> msg) {
798 if (!Precision.equalsIncludingNaN(x, y, eps)) {
799 throw new AssertionError(msg.get() + ": " + x + " != " + y);
800 }
801 }
802
803
804
805
806
807
808
809
810 private static void apply(double[] numbers, DoubleUnaryOperator fun, Blackhole bh) {
811 for (int i = 0; i < numbers.length; i++) {
812 bh.consume(fun.applyAsDouble(numbers[i]));
813 }
814 }
815
816
817
818
819
820
821
822
823 private static double identity(double z) {
824 return z;
825 }
826
827
828
829
830
831
832
833
834
835
836
837 @Benchmark
838 public void baseline(BaseData numbers, Blackhole bh) {
839 apply(numbers.getNumbers(), ErfPerformance::identity, bh);
840 }
841
842
843
844
845
846
847
848 @Benchmark
849 public void erf(ErfData data, Blackhole bh) {
850 apply(data.getNumbers(), data.getFunction(), bh);
851 }
852
853
854
855
856
857
858
859 @Benchmark
860 public void erfc(ErfcData data, Blackhole bh) {
861 apply(data.getNumbers(), data.getFunction(), bh);
862 }
863
864
865
866
867
868
869
870 @Benchmark
871 public void inverseErf(InverseErfData data, Blackhole bh) {
872 apply(data.getNumbers(), data.getFunction(), bh);
873 }
874
875
876
877
878
879
880
881 @Benchmark
882 public void inverseErfc(InverseErfcData data, Blackhole bh) {
883 apply(data.getNumbers(), data.getFunction(), bh);
884 }
885 }