View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.sampling;
19  
20  import java.util.List;
21  import java.util.Map;
22  import java.util.ArrayList;
23  import org.apache.commons.rng.UniformRandomProvider;
24  import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
25  import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler;
26  
27  /**
28   * Sampling from a collection of items with user-defined
29   * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
30   * probabilities</a>.
31   * Note that if all unique items are assigned the same probability,
32   * it is much more efficient to use {@link CollectionSampler}.
33   *
34   * <p>Sampling uses {@link UniformRandomProvider#nextDouble()}.</p>
35   *
36   * @param <T> Type of items in the collection.
37   *
38   * @since 1.1
39   */
40  public class DiscreteProbabilityCollectionSampler<T> implements SharedStateObjectSampler<T> {
41      /** The error message for an empty collection. */
42      private static final String EMPTY_COLLECTION = "Empty collection";
43      /** Collection to be sampled from. */
44      private final List<T> items;
45      /** Sampler for the probabilities. */
46      private final SharedStateDiscreteSampler sampler;
47  
48      /**
49       * Creates a sampler.
50       *
51       * @param rng Generator of uniformly distributed random numbers.
52       * @param collection Collection to be sampled, with the probabilities
53       * associated to each of its items.
54       * A (shallow) copy of the items will be stored in the created instance.
55       * The probabilities must be non-negative, but zero values are allowed
56       * and their sum does not have to equal one (input will be normalized
57       * to make the probabilities sum to one).
58       * @throws IllegalArgumentException if {@code collection} is empty, a
59       * probability is negative, infinite or {@code NaN}, or the sum of all
60       * probabilities is not strictly positive.
61       */
62      public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
63                                                  Map<T, Double> collection) {
64          this(toList(collection),
65               createSampler(rng, toProbabilities(collection)));
66      }
67  
68      /**
69       * Creates a sampler.
70       *
71       * @param rng Generator of uniformly distributed random numbers.
72       * @param collection Collection to be sampled.
73       * A (shallow) copy of the items will be stored in the created instance.
74       * @param probabilities Probability associated to each item of the
75       * {@code collection}.
76       * The probabilities must be non-negative, but zero values are allowed
77       * and their sum does not have to equal one (input will be normalized
78       * to make the probabilities sum to one).
79       * @throws IllegalArgumentException if {@code collection} is empty or
80       * a probability is negative, infinite or {@code NaN}, or if the number
81       * of items in the {@code collection} is not equal to the number of
82       * provided {@code probabilities}.
83       */
84      public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
85                                                  List<T> collection,
86                                                  double[] probabilities) {
87          this(copyList(collection),
88               createSampler(rng, collection, probabilities));
89      }
90  
91      /**
92       * @param items Collection to be sampled.
93       * @param sampler Sampler for the probabilities.
94       */
95      private DiscreteProbabilityCollectionSampler(List<T> items,
96                                                   SharedStateDiscreteSampler sampler) {
97          this.items = items;
98          this.sampler = sampler;
99      }
100 
101     /**
102      * Picks one of the items from the collection passed to the constructor.
103      *
104      * @return a random sample.
105      */
106     @Override
107     public T sample() {
108         return items.get(sampler.sample());
109     }
110 
111     /**
112      * {@inheritDoc}
113      *
114      * @since 1.3
115      */
116     @Override
117     public DiscreteProbabilityCollectionSampler<T> withUniformRandomProvider(UniformRandomProvider rng) {
118         return new DiscreteProbabilityCollectionSampler<>(items, sampler.withUniformRandomProvider(rng));
119     }
120 
121     /**
122      * Creates the sampler of the enumerated probability distribution.
123      *
124      * @param rng Generator of uniformly distributed random numbers.
125      * @param probabilities Probability associated to each item.
126      * @return the sampler
127      */
128     private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
129                                                             double[] probabilities) {
130         return GuideTableDiscreteSampler.of(rng, probabilities);
131     }
132 
133     /**
134      * Creates the sampler of the enumerated probability distribution.
135      *
136      * @param <T> Type of items in the collection.
137      * @param rng Generator of uniformly distributed random numbers.
138      * @param collection Collection to be sampled.
139      * @param probabilities Probability associated to each item.
140      * @return the sampler
141      * @throws IllegalArgumentException if the number
142      * of items in the {@code collection} is not equal to the number of
143      * provided {@code probabilities}.
144      */
145     private static <T> SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
146                                                                 List<T> collection,
147                                                                 double[] probabilities) {
148         if (probabilities.length != collection.size()) {
149             throw new IllegalArgumentException("Size mismatch: " +
150                                                probabilities.length + " != " +
151                                                collection.size());
152         }
153         return GuideTableDiscreteSampler.of(rng, probabilities);
154     }
155 
156     // Validation methods exist to raise an exception before invocation of the
157     // private constructor; this mitigates Finalizer attacks
158     // (see SpotBugs CT_CONSTRUCTOR_THROW).
159 
160     /**
161      * Extract the items.
162      *
163      * @param <T> Type of items in the collection.
164      * @param collection Collection.
165      * @return the items
166      * @throws IllegalArgumentException if {@code collection} is empty.
167      */
168     private static <T> List<T> toList(Map<T, Double> collection) {
169         if (collection.isEmpty()) {
170             throw new IllegalArgumentException(EMPTY_COLLECTION);
171         }
172         return new ArrayList<>(collection.keySet());
173     }
174 
175     /**
176      * Extract the probabilities.
177      *
178      * @param <T> Type of items in the collection.
179      * @param collection Collection.
180      * @return the probabilities
181      */
182     private static <T> double[] toProbabilities(Map<T, Double> collection) {
183         final int size = collection.size();
184         final double[] probabilities = new double[size];
185         int count = 0;
186         for (final Double e : collection.values()) {
187             final double probability = e;
188             if (probability < 0 ||
189                 Double.isInfinite(probability) ||
190                 Double.isNaN(probability)) {
191                 throw new IllegalArgumentException("Invalid probability: " +
192                                                    probability);
193             }
194             probabilities[count++] = probability;
195         }
196         return probabilities;
197     }
198 
199     /**
200      * Create a (shallow) copy of the collection.
201      *
202      * @param <T> Type of items in the collection.
203      * @param collection Collection.
204      * @return the copy
205      * @throws IllegalArgumentException if {@code collection} is empty.
206      */
207     private static <T> List<T> copyList(List<T> collection) {
208         if (collection.isEmpty()) {
209             throw new IllegalArgumentException(EMPTY_COLLECTION);
210         }
211         return new ArrayList<>(collection);
212     }
213 }