View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.sampling;
19  
20  import java.util.Arrays;
21  import java.util.Collections;
22  import java.util.List;
23  import java.util.Map;
24  import java.util.TreeMap;
25  import org.apache.commons.rng.UniformRandomProvider;
26  import org.junit.jupiter.api.Assertions;
27  import org.junit.jupiter.api.Test;
28  
29  /**
30   * Test class for {@link DiscreteProbabilityCollectionSampler}.
31   */
32  class DiscreteProbabilityCollectionSamplerTest {
33      /** RNG. */
34      private final UniformRandomProvider rng = RandomAssert.createRNG();
35  
36      @Test
37      void testPrecondition1() {
38          // Size mismatch
39          final List<Double> collection = Arrays.asList(1d, 2d);
40          final double[] probabilities = {0};
41          Assertions.assertThrows(IllegalArgumentException.class,
42              () -> new DiscreteProbabilityCollectionSampler<>(rng,
43                   collection,
44                   probabilities));
45      }
46  
47      @Test
48      void testPrecondition2() {
49          // Negative probability
50          final List<Double> collection = Arrays.asList(1d, 2d);
51          final double[] probabilities = {0, -1};
52          Assertions.assertThrows(IllegalArgumentException.class,
53              () -> new DiscreteProbabilityCollectionSampler<>(rng,
54                  collection,
55                  probabilities));
56      }
57  
58      @Test
59      void testPrecondition3() {
60          // Probabilities do not sum above 0
61          final List<Double> collection = Arrays.asList(1d, 2d);
62          final double[] probabilities = {0, 0};
63          Assertions.assertThrows(IllegalArgumentException.class,
64              () -> new DiscreteProbabilityCollectionSampler<>(rng,
65                  collection,
66                  probabilities));
67      }
68  
69      @Test
70      void testPrecondition4() {
71          // NaN probability
72          final List<Double> collection = Arrays.asList(1d, 2d);
73          final double[] probabilities = {0, Double.NaN};
74          Assertions.assertThrows(IllegalArgumentException.class,
75              () -> new DiscreteProbabilityCollectionSampler<>(rng,
76                  collection,
77                  probabilities));
78      }
79  
80      @Test
81      void testPrecondition5() {
82          // Infinite probability
83          final List<Double> collection = Arrays.asList(1d, 2d);
84          final double[] probabilities = {0, Double.POSITIVE_INFINITY};
85          Assertions.assertThrows(IllegalArgumentException.class,
86              () -> new DiscreteProbabilityCollectionSampler<>(rng,
87                  collection,
88                  probabilities));
89      }
90  
91      @Test
92      void testPrecondition6() {
93          // Empty Map<T, Double> not allowed
94          final Map<String, Double> collection = Collections.emptyMap();
95          Assertions.assertThrows(IllegalArgumentException.class,
96              () -> new DiscreteProbabilityCollectionSampler<>(rng,
97                   collection));
98      }
99  
100     @Test
101     void testPrecondition7() {
102         // Empty List<T> not allowed
103         final List<Double> collection = Collections.emptyList();
104         final double[] probabilities = {};
105         Assertions.assertThrows(IllegalArgumentException.class,
106             () -> new DiscreteProbabilityCollectionSampler<>(rng,
107                 collection,
108                 probabilities));
109     }
110 
111     @Test
112     void testSample() {
113         final DiscreteProbabilityCollectionSampler<Double> sampler =
114             new DiscreteProbabilityCollectionSampler<>(rng,
115                                                        Arrays.asList(3d, -1d, 3d, 7d, -2d, 8d),
116                                                        new double[] {0.2, 0.2, 0.3, 0.3, 0, 0});
117         final double expectedMean = 3.4;
118         final double expectedVariance = 7.84;
119 
120         final int n = 100000000;
121         double sum = 0;
122         double sumOfSquares = 0;
123         for (int i = 0; i < n; i++) {
124             final double rand = sampler.sample();
125             sum += rand;
126             sumOfSquares += rand * rand;
127         }
128 
129         final double mean = sum / n;
130         Assertions.assertEquals(expectedMean, mean, 1e-3);
131         final double variance = sumOfSquares / n - mean * mean;
132         Assertions.assertEquals(expectedVariance, variance, 2e-3);
133     }
134 
135 
136     @Test
137     void testSampleUsingMap() {
138         final UniformRandomProvider rng1 = RandomAssert.seededRNG();
139         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
140         final List<Integer> items = Arrays.asList(1, 3, 4, 6, 9);
141         final double[] probabilities = {0.1, 0.2, 0.3, 0.4, 0.5};
142         final DiscreteProbabilityCollectionSampler<Integer> sampler1 =
143             new DiscreteProbabilityCollectionSampler<>(rng1, items, probabilities);
144 
145         // Create a map version. The map iterator must be ordered so use a TreeMap.
146         final Map<Integer, Double> map = new TreeMap<>();
147         for (int i = 0; i < probabilities.length; i++) {
148             map.put(items.get(i), probabilities[i]);
149         }
150         final DiscreteProbabilityCollectionSampler<Integer> sampler2 =
151             new DiscreteProbabilityCollectionSampler<>(rng2, map);
152 
153         for (int i = 0; i < 50; i++) {
154             Assertions.assertEquals(sampler1.sample(), sampler2.sample());
155         }
156     }
157 
158     /**
159      * Edge-case test:
160      * Create a sampler that will return 1 for nextDouble() forcing the search to
161      * identify the end item of the cumulative probability array.
162      */
163     @Test
164     void testSampleWithProbabilityAtLastItem() {
165         // Ensure the samples pick probability 0 (the first item) and then
166         // a probability (for the second item) that hits an edge case.
167         final UniformRandomProvider dummyRng = new UniformRandomProvider() {
168             private int count;
169 
170             @Override
171             public long nextLong() {
172                 return 0;
173             }
174 
175             @Override
176             public double nextDouble() {
177                 // Return 0 then the 1.0 for the probability
178                 return (count++ == 0) ? 0 : 1.0;
179             }
180         };
181 
182         final List<Double> items = Arrays.asList(1d, 2d);
183         final DiscreteProbabilityCollectionSampler<Double> sampler =
184             new DiscreteProbabilityCollectionSampler<>(dummyRng,
185                                                        items,
186                                                        new double[] {0.5, 0.5});
187         final Double item1 = sampler.sample();
188         final Double item2 = sampler.sample();
189         // Check they are in the list
190         Assertions.assertTrue(items.contains(item1), "Sample item1 is not from the list");
191         Assertions.assertTrue(items.contains(item2), "Sample item2 is not from the list");
192         // Test the two samples are different items
193         Assertions.assertNotSame(item1, item2, "Item1 and 2 should be different");
194     }
195 
196     /**
197      * Test the SharedStateSampler implementation.
198      */
199     @Test
200     void testSharedStateSampler() {
201         final UniformRandomProvider rng1 = RandomAssert.seededRNG();
202         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
203         final List<Double> items = Arrays.asList(1d, 2d, 3d, 4d);
204         final DiscreteProbabilityCollectionSampler<Double> sampler1 =
205             new DiscreteProbabilityCollectionSampler<>(rng1,
206                                                        items,
207                                                        new double[] {0.1, 0.2, 0.3, 0.4});
208         final DiscreteProbabilityCollectionSampler<Double> sampler2 = sampler1.withUniformRandomProvider(rng2);
209         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
210     }
211 }