View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import org.apache.commons.rng.UniformRandomProvider;
20  import org.apache.commons.rng.sampling.RandomAssert;
21  import org.apache.commons.rng.sampling.SharedStateSampler;
22  import org.junit.jupiter.api.Assertions;
23  import org.junit.jupiter.api.Test;
24  
25  /**
26   * Test for the {@link GaussianSampler}. The tests hit edge cases for the sampler.
27   */
28  class GaussianSamplerTest {
29      /**
30       * Test the constructor with a zero standard deviation.
31       */
32      @Test
33      void testConstructorThrowsWithZeroStandardDeviation() {
34          final UniformRandomProvider rng = RandomAssert.seededRNG();
35          final NormalizedGaussianSampler gauss = ZigguratSampler.NormalizedGaussian.of(rng);
36          final double mean = 1;
37          final double standardDeviation = 0;
38          Assertions.assertThrows(IllegalArgumentException.class,
39              () -> GaussianSampler.of(gauss, mean, standardDeviation));
40      }
41  
42      /**
43       * Test the constructor with an infinite standard deviation.
44       */
45      @Test
46      void testConstructorThrowsWithInfiniteStandardDeviation() {
47          final UniformRandomProvider rng = RandomAssert.seededRNG();
48          final NormalizedGaussianSampler gauss = new ZigguratNormalizedGaussianSampler(rng);
49          final double mean = 1;
50          final double standardDeviation = Double.POSITIVE_INFINITY;
51          Assertions.assertThrows(IllegalArgumentException.class,
52              () -> GaussianSampler.of(gauss, mean, standardDeviation));
53      }
54  
55      /**
56       * Test the constructor with a NaN standard deviation.
57       */
58      @Test
59      void testConstructorThrowsWithNaNStandardDeviation() {
60          final UniformRandomProvider rng = RandomAssert.seededRNG();
61          final NormalizedGaussianSampler gauss = new ZigguratNormalizedGaussianSampler(rng);
62          final double mean = 1;
63          final double standardDeviation = Double.NaN;
64          Assertions.assertThrows(IllegalArgumentException.class,
65              () -> GaussianSampler.of(gauss, mean, standardDeviation));
66      }
67  
68      /**
69       * Test the constructor with an infinite mean.
70       */
71      @Test
72      void testConstructorThrowsWithInfiniteMean() {
73          final UniformRandomProvider rng = RandomAssert.seededRNG();
74          final NormalizedGaussianSampler gauss = new ZigguratNormalizedGaussianSampler(rng);
75          final double mean = Double.POSITIVE_INFINITY;
76          final double standardDeviation = 1;
77          Assertions.assertThrows(IllegalArgumentException.class,
78              () -> GaussianSampler.of(gauss, mean, standardDeviation));
79      }
80  
81      /**
82       * Test the constructor with a NaN mean.
83       */
84      @Test
85      void testConstructorThrowsWithNaNMean() {
86          final UniformRandomProvider rng = RandomAssert.seededRNG();
87          final NormalizedGaussianSampler gauss = new ZigguratNormalizedGaussianSampler(rng);
88          final double mean = Double.NaN;
89          final double standardDeviation = 1;
90          Assertions.assertThrows(IllegalArgumentException.class,
91              () -> GaussianSampler.of(gauss, mean, standardDeviation));
92      }
93  
94      /**
95       * Test the SharedStateSampler implementation.
96       */
97      @Test
98      void testSharedStateSampler() {
99          final UniformRandomProvider rng1 = RandomAssert.seededRNG();
100         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
101         final NormalizedGaussianSampler gauss = ZigguratSampler.NormalizedGaussian.of(rng1);
102         final double mean = 1.23;
103         final double standardDeviation = 4.56;
104         final SharedStateContinuousSampler sampler1 =
105             GaussianSampler.of(gauss, mean, standardDeviation);
106         final SharedStateContinuousSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
107         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
108     }
109 
110     /**
111      * Test the SharedStateSampler implementation throws if the underlying sampler is
112      * not a SharedStateSampler.
113      */
114     @Test
115     void testSharedStateSamplerThrowsIfUnderlyingSamplerDoesNotShareState() {
116         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
117         final NormalizedGaussianSampler gauss = new NormalizedGaussianSampler() {
118             @Override
119             public double sample() {
120                 return 0;
121             }
122         };
123         final double mean = 1.23;
124         final double standardDeviation = 4.56;
125         final SharedStateContinuousSampler sampler1 =
126             GaussianSampler.of(gauss, mean, standardDeviation);
127         Assertions.assertThrows(UnsupportedOperationException.class,
128             () -> sampler1.withUniformRandomProvider(rng2));
129     }
130 
131     /**
132      * Test the SharedStateSampler implementation throws if the underlying sampler is
133      * a SharedStateSampler that returns an incorrect type.
134      */
135     @Test
136     void testSharedStateSamplerThrowsIfUnderlyingSamplerReturnsWrongSharedState() {
137         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
138         final NormalizedGaussianSampler gauss = new BadSharedStateNormalizedGaussianSampler();
139         final double mean = 1.23;
140         final double standardDeviation = 4.56;
141         final SharedStateContinuousSampler sampler1 =
142             GaussianSampler.of(gauss, mean, standardDeviation);
143         Assertions.assertThrows(UnsupportedOperationException.class,
144             () -> sampler1.withUniformRandomProvider(rng2));
145     }
146 
147     /**
148      * Test class to return an incorrect sampler from the SharedStateSampler method.
149      *
150      * <p>Note that due to type erasure the type returned by the SharedStateSampler is not
151      * available at run-time and the GaussianSampler has to assume it is the correct type.</p>
152      */
153     private static class BadSharedStateNormalizedGaussianSampler
154             implements NormalizedGaussianSampler, SharedStateSampler<Integer> {
155         @Override
156         public double sample() {
157             return 0;
158         }
159 
160         @Override
161         public Integer withUniformRandomProvider(UniformRandomProvider rng) {
162             // Something that is not a NormalizedGaussianSampler
163             return Integer.valueOf(44);
164         }
165     }
166 }