View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.statistics.distribution;
18  
19  import java.util.stream.Stream;
20  import org.junit.jupiter.params.ParameterizedTest;
21  import org.junit.jupiter.params.provider.Arguments;
22  import org.junit.jupiter.params.provider.CsvSource;
23  import org.junit.jupiter.params.provider.MethodSource;
24  
25  /**
26   * Test cases for {@link NakagamiDistribution}.
27   * Extends {@link BaseContinuousDistributionTest}. See javadoc of that class for details.
28   */
29  class NakagamiDistributionTest extends BaseContinuousDistributionTest {
30      @Override
31      ContinuousDistribution makeDistribution(Object... parameters) {
32          final double mu = (Double) parameters[0];
33          final double omega = (Double) parameters[1];
34          return NakagamiDistribution.of(mu, omega);
35      }
36  
37      @Override
38      Object[][] makeInvalidParameters() {
39          return new Object[][] {
40              {0.0, 1.0},
41              {-0.1, 1.0},
42              {0.5, 0.0},
43              {0.5, -0.1}
44          };
45      }
46  
47      @Override
48      String[] getParameterNames() {
49          return new String[] {"Shape", "Scale"};
50      }
51  
52      @Override
53      protected double getRelativeTolerance() {
54          return 5e-15;
55      }
56  
57      //-------------------- Additional test cases -------------------------------
58  
59      /**
60       * Test additional moments.
61       * Includes cases where {@code gamma(mu + 0.5) / gamma(mu)} is not computable
62       * directly due to overflow of the gamma function.
63       */
64      @ParameterizedTest
65      @CsvSource({
66          // Generated using matlab
67          "175, 0.75, 0.86540703592357171, 0.0010706621739778321",
68          "175, 1, 0.99928597029814059, 0.0014275495653037762",
69          "175, 1.25, 1.1172356792742391, 0.0017844369566297202",
70          "175, 3.75, 1.9351089605317091, 0.0053533108698891607",
71          "205.25, 0.75, 0.86549814380218737, 0.00091296307496802065",
72          "205.25, 1, 0.99939117261462862, 0.0012172840999573609",
73          "205.25, 1.25, 1.1173532990397681, 0.0015216051249467011",
74          "205.25, 3.75, 1.9353126839415795, 0.0045648153748401032",
75          "305.25, 0.75, 0.865670838787722, 0.00061399887256183283",
76          "305.25, 1.75, 1.32233404855355, 0.0014326640359776099",
77          "305.25, 3.75, 1.9356988416686078, 0.0030699943628091642",
78          "305.25, 12.75, 3.5692523053388152, 0.010437980833551158",
79          "305.25, 25.25, 5.0228805186490098, 0.020671295376248372",
80      })
81      void testAdditionalMoments(double mu, double omega, double mean, double variance) {
82          // Note:
83          // The relative error of the variance is much greater than the mean.
84          //   variance = omega - mean^2; omega > 0; x > 0; mean > 0
85          // This computation is subject to cancellation due to subtraction of two large
86          // values to approach a result of zero.
87          // Use a moderate threshold.
88          final DoubleTolerance tolerance = createRelTolerance(2e-10);
89          final NakagamiDistribution dist = NakagamiDistribution.of(mu, omega);
90          testMoments(dist, mean, variance, tolerance);
91      }
92  
93      /**
94       * Repeat test of additional moments with alternative source for the expected result.
95       */
96      @ParameterizedTest
97      @CsvSource({
98          // Generated using 128-bit quad precision implementation using Boost C++:
99          // #include <boost/multiprecision/float128.hpp>
100         // #include <boost/math/special_functions/gamma.hpp>
101         // #define quad boost::multiprecision::float128
102         // T v = boost::math::tgamma_delta_ratio(mu, T(0.5));
103         // T mean = sqrt(omega / mu) / v;
104         // T var = omega - (omega / mu) / v / v;
105         "175, 0.75, 0.865407035923572335404337637742305354, 0.00107066217397678136642741884083229635",
106         "175, 1, 0.999285970298141244170512691211913862, 0.0014275495653023751552365584544430618",
107         "175, 1.25, 1.11723567927423980521693795242933784, 0.00178443695662796894404569806805382725",
108         "175, 3.75, 1.93510896053171023839534780723184735, 0.00535331086988390683213709420416109656",
109         "205.25, 0.75, 0.865498143802251959479795150977083271, 0.000912963074856388060643895128688537674",
110         "205.25, 1, 0.999391172614703197622376095323984551, 0.0012172840998085174141918601715848132",
111         "205.25, 1.25, 1.11735329903985129515900415713529348, 0.00152160512476064676773982521448079983",
112         "205.25, 3.75, 1.93531268394172368161190235734322469, 0.00456481537428194030321947564344316985",
113         "305.25, 0.75, 0.865670838787713729127832304174216151, 0.000613998872576147383115881187594898943",
114         "305.25, 1.75, 1.32233404855353739372707758901129787, 0.00143266403601101056060372277105460371",
115         "305.25, 3.75, 1.93569884166858953645398412102636382, 0.00306999436288073691557940593797382064",
116         "305.25, 12.75, 3.56925230533878138370667203279492999, 0.010437980833794505512969980189112608",
117         "305.25, 25.25, 5.02288051864896241877391197174369638, 0.0206712953767302952315679999823609879",
118     })
119     void testAdditionalMoments2(double mu, double omega, double mean, double variance) {
120         // The mean is within 2 ULP.
121         // The variance is closer than the matlab result but the effect of cancellation
122         // prevents high accuracy.
123         final DoubleTolerance tolerance = createRelTolerance(1e-12);
124         final NakagamiDistribution dist = NakagamiDistribution.of(mu, omega);
125         testMoments(dist, mean, variance, tolerance);
126     }
127 
128     /**
129      * Test log density where the density is zero.
130      */
131     @ParameterizedTest
132     @MethodSource
133     void testAdditionalLogDensity(double mu, double omega, double[] x, double[] expected) {
134         final NakagamiDistribution dist = NakagamiDistribution.of(mu, omega);
135         testLogDensity(dist, x, expected, DoubleTolerances.relative(1e-15));
136     }
137 
138     static Stream<Arguments> testAdditionalLogDensity() {
139         final double[] x = {50, 55, 60, 80, 120};
140         return Stream.of(
141             // scipy.stats 1.9.3  (no support for omega):
142             // nakagami.logpdf(x, 0.5)
143             Arguments.of(0.5, 1, x,
144                 new double[]{-1250.2257913526448, -1512.7257913526448, -1800.2257913526448,
145                     -3200.2257913526446, -7200.225791352645}),
146             // nakagami.logpdf(x, 1.5)   (no support for omega)
147             Arguments.of(1.5, 1, x,
148                 new double[]{-3740.7538269087863, -4528.063206549177, -5390.389183795199,
149                     -9589.813819650295, -21589.00288943408}),
150             // R nakagami 1.1.0 package:
151             // print(dnaka(x, 0.5, 2, log=TRUE), digits=17)
152             Arguments.of(0.5, 2, x,
153                 new double[]{-625.57236494292476, -756.82236494292465, -900.57236494292465,
154                     -1600.57236494292442, -3600.57236494292420}),
155             // print(dnaka(x, 0.5, 0.75, log=TRUE), digits=17)
156             Arguments.of(0.5, 0.75, x,
157                 new double[]{-1666.7486169830854, -2016.7486169830854, -2400.0819503164184,
158                     -4266.7486169830854, -9600.0819503164203}),
159             // print(dnaka(x, 1.5, 0.75, log=TRUE), digits=17)
160             Arguments.of(1.5, 0.75, x,
161                 new double[]{-4990.3223038001088, -6040.1316834404988, -7189.9576606865212,
162                     -12789.3822965416184, -28788.5713663254028}),
163             // print(dnaka(x, 1.5, 1.75, log=TRUE), digits=17)
164             Arguments.of(1.5, 1.75, x,
165                 new double[]{-2134.4503934478316, -2584.2597730882230, -3076.9428931913867,
166                     -5476.3675290464835, -12332.6994559731247}),
167             // print(dnaka(x, 1.5, 7.75, log=TRUE), digits=17)
168             Arguments.of(1.5, 7.75, x,
169                 new double[]{-477.69633391576963, -579.11861678196749, -690.23491660863317,
170                     -1231.59503633469740, -2779.17120289267450})
171         );
172     }
173 }