View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.statistics.distribution;
18  
19  import java.util.concurrent.ThreadLocalRandom;
20  import org.apache.commons.rng.UniformRandomProvider;
21  import org.junit.jupiter.api.Assertions;
22  import org.junit.jupiter.api.Test;
23  import org.junit.jupiter.params.ParameterizedTest;
24  import org.junit.jupiter.params.provider.ValueSource;
25  
26  /**
27   * Test default implementations in the {@link ContinuousDistribution} interface.
28   */
29  class ContinuousDistributionTest {
30      /**
31       * Test the default interface methods.
32       */
33      @Test
34      void testDefaultMethods() {
35          final double high = 0.54;
36          final double low = 0.313;
37  
38          final ContinuousDistribution dist = new ContinuousDistribution() {
39              @Override
40              public double inverseCumulativeProbability(double p) {
41                  // For the default inverseSurvivalProbability(double) method
42                  return 10 * p;
43              }
44              @Override
45              public double getVariance() {
46                  return 0;
47              }
48              @Override
49              public double getSupportUpperBound() {
50                  return 0;
51              }
52              @Override
53              public double getSupportLowerBound() {
54                  return 0;
55              }
56              @Override
57              public double getMean() {
58                  return 0;
59              }
60              @Override
61              public double density(double x) {
62                  // Return input value for testing
63                  return x;
64              }
65              @Override
66              public double cumulativeProbability(double x) {
67                  // For the default probability(double, double) method
68                  return x > 1 ? high : low;
69              }
70              @Override
71              public Sampler createSampler(UniformRandomProvider rng) {
72                  return null;
73              }
74          };
75  
76          for (final double x : new double[] {Double.NaN, Double.POSITIVE_INFINITY,
77              Double.NEGATIVE_INFINITY, 0, 1, 0.123}) {
78              // Return the log of the density
79              Assertions.assertEquals(Math.log(x), dist.logDensity(x));
80          }
81  
82          // Should throw for bad range
83          Assertions.assertThrows(DistributionException.class, () -> dist.probability(0.5, 0.4));
84          Assertions.assertEquals(high - low, dist.probability(0.5, 1.5));
85          Assertions.assertEquals(high - low, dist.probability(0.5, 1.5));
86          for (final double p : new double[] {0.2, 0.5, 0.7}) {
87              Assertions.assertEquals(dist.inverseCumulativeProbability(1 - p),
88                                      dist.inverseSurvivalProbability(p));
89          }
90      }
91  
92      /**
93       * Test the {@link ContinuousDistribution.Sampler} default stream methods.
94       *
95       * @param streamSize Number of values to generate.
96       */
97      @ParameterizedTest
98      @ValueSource(longs = {0, 1, 13})
99      void testSamplerStreamMethods(long streamSize) {
100         final double seed = ThreadLocalRandom.current().nextDouble();
101         final ContinuousDistribution.Sampler s1 = createIncrementSampler(seed);
102         final ContinuousDistribution.Sampler s2 = createIncrementSampler(seed);
103         final ContinuousDistribution.Sampler s3 = createIncrementSampler(seed);
104         // Get the reference output from the sample() method
105         final double[] x = new double[(int) streamSize];
106         for (int i = 0; i < x.length; i++) {
107             x[i] = s1.sample();
108         }
109         // Test default stream methods
110         Assertions.assertArrayEquals(x, s2.samples().limit(streamSize).toArray(), "samples()");
111         Assertions.assertArrayEquals(x, s3.samples(streamSize).toArray(), "samples(long)");
112     }
113 
114     /**
115      * Test the {@link ContinuousDistribution.Sampler} default stream method with a bad stream size.
116      *
117      * @param streamSize Number of values to generate.
118      */
119     @ParameterizedTest
120     @ValueSource(longs = {-1, -6576237846822L})
121     void testSamplerStreamMethodsThrow(long streamSize) {
122         final ContinuousDistribution.Sampler s = createIncrementSampler(42);
123         Assertions.assertThrows(IllegalArgumentException.class, () -> s.samples(streamSize));
124     }
125 
126     /**
127      * Test the {@link ContinuousDistribution.Sampler} default stream methods are not parallel.
128      */
129     @Test
130     void testSamplerStreamMethodsNotParallel() {
131         final ContinuousDistribution.Sampler s = createIncrementSampler(42);
132         Assertions.assertFalse(s.samples().isParallel(), "samples() should not be parallel");
133         Assertions.assertFalse(s.samples(11).isParallel(), "samples(long) should not be parallel");
134     }
135 
136     /**
137      * Creates the sampler with a given seed value.
138      * Each successive output sample will increment this value by 1.
139      *
140      * @param seed Seed value.
141      * @return the sampler
142      */
143     private static ContinuousDistribution.Sampler createIncrementSampler(double seed) {
144         return new ContinuousDistribution.Sampler() {
145             private double x = seed;
146 
147             @Override
148             public double sample() {
149                 return x += 1;
150             }
151         };
152     }
153 }