1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.statistics.distribution;
19
20 import org.apache.commons.rng.UniformRandomProvider;
21 import org.apache.commons.rng.sampling.distribution.DiscreteUniformSampler;
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 public final class UniformDiscreteDistribution extends AbstractDiscreteDistribution {
39
40 private final int lower;
41
42 private final int upper;
43
44 private final double upperMinusLowerPlus1;
45
46 private final double pmf;
47
48 private final double logPmf;
49
50 private final double sf0;
51
52
53
54
55
56 private UniformDiscreteDistribution(int lower,
57 int upper) {
58 this.lower = lower;
59 this.upper = upper;
60 upperMinusLowerPlus1 = (double) upper - (double) lower + 1;
61 pmf = 1.0 / upperMinusLowerPlus1;
62 logPmf = -Math.log(upperMinusLowerPlus1);
63 sf0 = (upperMinusLowerPlus1 - 1) / upperMinusLowerPlus1;
64 }
65
66
67
68
69
70
71
72
73
74 public static UniformDiscreteDistribution of(int lower,
75 int upper) {
76 if (lower > upper) {
77 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
78 lower, upper);
79 }
80 return new UniformDiscreteDistribution(lower, upper);
81 }
82
83
84 @Override
85 public double probability(int x) {
86 if (x < lower || x > upper) {
87 return 0;
88 }
89 return pmf;
90 }
91
92
93 @Override
94 public double probability(int x0,
95 int x1) {
96 if (x0 > x1) {
97 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
98 }
99 if (x0 >= upper || x1 < lower) {
100
101 return 0;
102 }
103
104
105
106
107
108
109
110 final long l = Math.max(lower - 1L, x0);
111 final long u = Math.min(upper, x1);
112
113 return (u - l) / upperMinusLowerPlus1;
114 }
115
116
117 @Override
118 public double logProbability(int x) {
119 if (x < lower || x > upper) {
120 return Double.NEGATIVE_INFINITY;
121 }
122 return logPmf;
123 }
124
125
126 @Override
127 public double cumulativeProbability(int x) {
128 if (x <= lower) {
129
130 return x == lower ? pmf : 0;
131 }
132 if (x >= upper) {
133 return 1;
134 }
135 return ((double) x - lower + 1) / upperMinusLowerPlus1;
136 }
137
138
139 @Override
140 public double survivalProbability(int x) {
141 if (x <= lower) {
142
143
144 return x == lower ? sf0 : 1;
145 }
146 if (x >= upper) {
147 return 0;
148 }
149 return ((double) upper - x) / upperMinusLowerPlus1;
150 }
151
152
153 @Override
154 public int inverseCumulativeProbability(double p) {
155 ArgumentUtils.checkProbability(p);
156 if (p > sf0) {
157 return upper;
158 }
159 if (p <= pmf) {
160 return lower;
161 }
162
163
164
165 int x = (int) (lower + Math.ceil(p * upperMinusLowerPlus1) - 1);
166
167
168
169
170
171 if (((double) x - lower) / upperMinusLowerPlus1 >= p) {
172
173
174 x--;
175 } else if (((double) x - lower + 1) / upperMinusLowerPlus1 < p) {
176
177
178 x++;
179 }
180
181 return x;
182 }
183
184
185 @Override
186 public int inverseSurvivalProbability(final double p) {
187 ArgumentUtils.checkProbability(p);
188 if (p < pmf) {
189 return upper;
190 }
191 if (p >= sf0) {
192 return lower;
193 }
194
195
196
197 int x = (int) (upper - Math.floor(p * upperMinusLowerPlus1));
198
199
200
201
202
203 if (((double) upper - x + 1) / upperMinusLowerPlus1 <= p) {
204
205
206 x--;
207 } else if (((double) upper - x) / upperMinusLowerPlus1 > p) {
208
209
210 x++;
211 }
212
213 return x;
214 }
215
216
217
218
219
220
221 @Override
222 public double getMean() {
223
224 return 0.5 * ((double) upper + (double) lower);
225 }
226
227
228
229
230
231
232
233
234
235
236 @Override
237 public double getVariance() {
238 return (upperMinusLowerPlus1 * upperMinusLowerPlus1 - 1) / 12;
239 }
240
241
242
243
244
245
246
247 @Override
248 public int getSupportLowerBound() {
249 return lower;
250 }
251
252
253
254
255
256
257
258 @Override
259 public int getSupportUpperBound() {
260 return upper;
261 }
262
263
264 @Override
265 public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) {
266
267 return DiscreteUniformSampler.of(rng, lower, upper)::sample;
268 }
269 }