View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.statistics.distribution;
19  
20  import java.util.function.DoubleSupplier;
21  import org.apache.commons.numbers.gamma.Erf;
22  import org.apache.commons.numbers.gamma.ErfDifference;
23  import org.apache.commons.numbers.gamma.Erfcx;
24  import org.apache.commons.rng.UniformRandomProvider;
25  import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
26  
27  /**
28   * Implementation of the truncated normal distribution.
29   *
30   * <p>The probability density function of \( X \) is:
31   *
32   * <p>\[ f(x;\mu,\sigma,a,b) = \frac{1}{\sigma}\,\frac{\phi(\frac{x - \mu}{\sigma})}{\Phi(\frac{b - \mu}{\sigma}) - \Phi(\frac{a - \mu}{\sigma}) } \]
33   *
34   * <p>for \( \mu \) mean of the parent normal distribution,
35   * \( \sigma \) standard deviation of the parent normal distribution,
36   * \( -\infty \le a \lt b \le \infty \) the truncation interval, and
37   * \( x \in [a, b] \), where \( \phi \) is the probability
38   * density function of the standard normal distribution and \( \Phi \)
39   * is its cumulative distribution function.
40   *
41   * @see <a href="https://en.wikipedia.org/wiki/Truncated_normal_distribution">
42   * Truncated normal distribution (Wikipedia)</a>
43   */
44  public final class TruncatedNormalDistribution extends AbstractContinuousDistribution {
45  
46      /** The max allowed value for x where (x*x) will not overflow.
47       * This is a limit on computation of the moments of the truncated normal
48       * as some calculations assume x*x is finite. Value is sqrt(MAX_VALUE). */
49      private static final double MAX_X = 0x1.fffffffffffffp511;
50  
51      /** The min allowed probability range of the parent normal distribution.
52       * Set to 0.0. This may be too low for accurate usage. It is a signal that
53       * the truncation is invalid. */
54      private static final double MIN_P = 0.0;
55  
56      /** sqrt(2). */
57      private static final double ROOT2 = 1.414213562373095048801688724209698078;
58      /** Normalisation constant 2 / sqrt(2 pi) = sqrt(2 / pi). */
59      private static final double ROOT_2_PI = 0.797884560802865405726436165423365309;
60      /** Normalisation constant sqrt(2 pi) / 2 = sqrt(pi / 2). */
61      private static final double ROOT_PI_2 = 1.253314137315500251207882642405522626;
62  
63      /**
64       * The threshold to switch to a rejection sampler. When the truncated
65       * distribution covers more than this fraction of the CDF then rejection
66       * sampling will be more efficient than inverse CDF sampling. Performance
67       * benchmarks indicate that a normalized Gaussian sampler is up to 10 times
68       * faster than inverse transform sampling using a fast random generator. See
69       * STATISTICS-55.
70       */
71      private static final double REJECTION_THRESHOLD = 0.2;
72  
73      /** Parent normal distribution. */
74      private final NormalDistribution parentNormal;
75      /** Lower bound of this distribution. */
76      private final double lower;
77      /** Upper bound of this distribution. */
78      private final double upper;
79  
80      /** Stored value of {@code parentNormal.probability(lower, upper)}. This is used to
81       * normalise the probability computations. */
82      private final double cdfDelta;
83      /** log(cdfDelta). */
84      private final double logCdfDelta;
85      /** Stored value of {@code parentNormal.cumulativeProbability(lower)}. Used to map
86       * a probability into the range of the parent normal distribution. */
87      private final double cdfAlpha;
88      /** Stored value of {@code parentNormal.survivalProbability(upper)}. Used to map
89       * a probability into the range of the parent normal distribution. */
90      private final double sfBeta;
91  
92      /**
93       * @param parent Parent distribution.
94       * @param z Probability of the parent distribution for {@code [lower, upper]}.
95       * @param lower Lower bound (inclusive) of the distribution, can be {@link Double#NEGATIVE_INFINITY}.
96       * @param upper Upper bound (inclusive) of the distribution, can be {@link Double#POSITIVE_INFINITY}.
97       */
98      private TruncatedNormalDistribution(NormalDistribution parent, double z, double lower, double upper) {
99          this.parentNormal = parent;
100         this.lower = lower;
101         this.upper = upper;
102 
103         cdfDelta = z;
104         logCdfDelta = Math.log(cdfDelta);
105         // Used to map the inverse probability.
106         cdfAlpha = parentNormal.cumulativeProbability(lower);
107         sfBeta = parentNormal.survivalProbability(upper);
108     }
109 
110     /**
111      * Creates a truncated normal distribution.
112      *
113      * <p>Note that the {@code mean} and {@code sd} is of the parent normal distribution,
114      * and not the true mean and standard deviation of the truncated normal distribution.
115      * The {@code lower} and {@code upper} bounds define the truncation of the parent
116      * normal distribution.
117      *
118      * @param mean Mean for the parent distribution.
119      * @param sd Standard deviation for the parent distribution.
120      * @param lower Lower bound (inclusive) of the distribution, can be {@link Double#NEGATIVE_INFINITY}.
121      * @param upper Upper bound (inclusive) of the distribution, can be {@link Double#POSITIVE_INFINITY}.
122      * @return the distribution
123      * @throws IllegalArgumentException if {@code sd <= 0}; if {@code lower >= upper}; or if
124      * the truncation covers no probability range in the parent distribution.
125      */
126     public static TruncatedNormalDistribution of(double mean, double sd, double lower, double upper) {
127         if (sd <= 0) {
128             throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sd);
129         }
130         if (lower >= upper) {
131             throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GTE_HIGH, lower, upper);
132         }
133 
134         // Use an instance for the parent normal distribution to maximise accuracy
135         // in range computations using the error function
136         final NormalDistribution parent = NormalDistribution.of(mean, sd);
137 
138         // If there is no computable range then raise an exception.
139         final double z = parent.probability(lower, upper);
140         if (z <= MIN_P) {
141             // Map the bounds to a standard normal distribution for the message
142             final double a = (lower - mean) / sd;
143             final double b = (upper - mean) / sd;
144             throw new DistributionException(
145                "Excess truncation of standard normal : CDF(%s, %s) = %s", a, b, z);
146         }
147 
148         // Here we have a meaningful truncation. Note that excess truncation may not be optimal.
149         // For example truncation close to zero where the PDF is constant can be approximated
150         // using a uniform distribution.
151 
152         return new TruncatedNormalDistribution(parent, z, lower, upper);
153     }
154 
155     /** {@inheritDoc} */
156     @Override
157     public double density(double x) {
158         if (x < lower || x > upper) {
159             return 0;
160         }
161         return parentNormal.density(x) / cdfDelta;
162     }
163 
164     /** {@inheritDoc} */
165     @Override
166     public double probability(double x0, double x1) {
167         if (x0 > x1) {
168             throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
169                                             x0, x1);
170         }
171         return parentNormal.probability(clipToRange(x0), clipToRange(x1)) / cdfDelta;
172     }
173 
174     /** {@inheritDoc} */
175     @Override
176     public double logDensity(double x) {
177         if (x < lower || x > upper) {
178             return Double.NEGATIVE_INFINITY;
179         }
180         return parentNormal.logDensity(x) - logCdfDelta;
181     }
182 
183     /** {@inheritDoc} */
184     @Override
185     public double cumulativeProbability(double x) {
186         if (x <= lower) {
187             return 0;
188         } else if (x >= upper) {
189             return 1;
190         }
191         return parentNormal.probability(lower, x) / cdfDelta;
192     }
193 
194     /** {@inheritDoc} */
195     @Override
196     public double survivalProbability(double x) {
197         if (x <= lower) {
198             return 1;
199         } else if (x >= upper) {
200             return 0;
201         }
202         return parentNormal.probability(x, upper) / cdfDelta;
203     }
204 
205     /** {@inheritDoc} */
206     @Override
207     public double inverseCumulativeProbability(double p) {
208         ArgumentUtils.checkProbability(p);
209         // Exact bound
210         if (p == 0) {
211             return lower;
212         } else if (p == 1) {
213             return upper;
214         }
215         // Linearly map p to the range [lower, upper]
216         final double x = parentNormal.inverseCumulativeProbability(cdfAlpha + p * cdfDelta);
217         return clipToRange(x);
218     }
219 
220     /** {@inheritDoc} */
221     @Override
222     public double inverseSurvivalProbability(double p) {
223         ArgumentUtils.checkProbability(p);
224         // Exact bound
225         if (p == 1) {
226             return lower;
227         } else if (p == 0) {
228             return upper;
229         }
230         // Linearly map p to the range [lower, upper]
231         final double x = parentNormal.inverseSurvivalProbability(sfBeta + p * cdfDelta);
232         return clipToRange(x);
233     }
234 
235     /** {@inheritDoc} */
236     @Override
237     public Sampler createSampler(UniformRandomProvider rng) {
238         // If the truncation covers a reasonable amount of the normal distribution
239         // then a rejection sampler can be used.
240         double threshold = REJECTION_THRESHOLD;
241         // If the truncation is entirely in the upper or lower half then adjust the
242         // threshold as twice the samples can be used
243         if (lower >= 0 || upper <= 0) {
244             threshold *= 0.5;
245         }
246 
247         if (cdfDelta > threshold) {
248             // Create the rejection sampler
249             final ZigguratSampler.NormalizedGaussian sampler = ZigguratSampler.NormalizedGaussian.of(rng);
250             DoubleSupplier gen;
251             // Use mirroring if possible
252             if (lower >= 0) {
253                 // Return the upper-half of the Gaussian
254                 gen = () -> Math.abs(sampler.sample());
255             } else if (upper <= 0) {
256                 // Return the lower-half of the Gaussian
257                 gen = () -> -Math.abs(sampler.sample());
258             } else {
259                 // Return the full range of the Gaussian
260                 gen = sampler::sample;
261             }
262             // Map the bounds to a standard normal distribution
263             final double u = parentNormal.getMean();
264             final double s = parentNormal.getStandardDeviation();
265             final double a = (lower - u) / s;
266             final double b = (upper - u) / s;
267             // Sample in [a, b] using rejection
268             return () -> {
269                 double x = gen.getAsDouble();
270                 while (x < a || x > b) {
271                     x = gen.getAsDouble();
272                 }
273                 // Avoid floating-point error when mapping back
274                 return clipToRange(u + x * s);
275             };
276         }
277 
278         // Default to an inverse CDF sampler
279         return super.createSampler(rng);
280     }
281 
282     /**
283      * {@inheritDoc}
284      *
285      * <p>Represents the true mean of the truncated normal distribution rather
286      * than the parent normal distribution mean.
287      *
288      * <p>For \( \mu \) mean of the parent normal distribution,
289      * \( \sigma \) standard deviation of the parent normal distribution, and
290      * \( a \lt b \) the truncation interval of the parent normal distribution, the mean is:
291      *
292      * <p>\[ \mu + \frac{\phi(a)-\phi(b)}{\Phi(b) - \Phi(a)}\sigma \]
293      *
294      * <p>where \( \phi \) is the probability density function of the standard normal distribution
295      * and \( \Phi \) is its cumulative distribution function.
296      */
297     @Override
298     public double getMean() {
299         final double u = parentNormal.getMean();
300         final double s = parentNormal.getStandardDeviation();
301         final double a = (lower - u) / s;
302         final double b = (upper - u) / s;
303         return u + moment1(a, b) * s;
304     }
305 
306     /**
307      * {@inheritDoc}
308      *
309      * <p>Represents the true variance of the truncated normal distribution rather
310      * than the parent normal distribution variance.
311      *
312      * <p>For \( \mu \) mean of the parent normal distribution,
313      * \( \sigma \) standard deviation of the parent normal distribution, and
314      * \( a \lt b \) the truncation interval of the parent normal distribution, the variance is:
315      *
316      * <p>\[ \sigma^2 \left[1 + \frac{a\phi(a)-b\phi(b)}{\Phi(b) - \Phi(a)} -
317      *       \left( \frac{\phi(a)-\phi(b)}{\Phi(b) - \Phi(a)} \right)^2 \right] \]
318      *
319      * <p>where \( \phi \) is the probability density function of the standard normal distribution
320      * and \( \Phi \) is its cumulative distribution function.
321      */
322     @Override
323     public double getVariance() {
324         final double u = parentNormal.getMean();
325         final double s = parentNormal.getStandardDeviation();
326         final double a = (lower - u) / s;
327         final double b = (upper - u) / s;
328         return variance(a, b) * s * s;
329     }
330 
331     /**
332      * {@inheritDoc}
333      *
334      * <p>The lower bound of the support is equal to the lower bound parameter
335      * of the distribution.
336      */
337     @Override
338     public double getSupportLowerBound() {
339         return lower;
340     }
341 
342     /**
343      * {@inheritDoc}
344      *
345      * <p>The upper bound of the support is equal to the upper bound parameter
346      * of the distribution.
347      */
348     @Override
349     public double getSupportUpperBound() {
350         return upper;
351     }
352 
353     /**
354      * Clip the value to the range [lower, upper].
355      * This is used to handle floating-point error at the support bound.
356      *
357      * @param x Value x
358      * @return x clipped to the range
359      */
360     private double clipToRange(double x) {
361         return clip(x, lower, upper);
362     }
363 
364     /**
365      * Clip the value to the range [lower, upper].
366      *
367      * @param x Value x
368      * @param lower Lower bound (inclusive)
369      * @param upper Upper bound (inclusive)
370      * @return x clipped to the range
371      */
372     private static double clip(double x, double lower, double upper) {
373         if (x <= lower) {
374             return lower;
375         }
376         return x < upper ? x : upper;
377     }
378 
379     // Calculation of variance and mean can suffer from cancellation.
380     //
381     // Use formulas from Jorge Fernandez-de-Cossio-Diaz adapted under the
382     // terms of the MIT "Expat" License (see NOTICE and LICENSE).
383     //
384     // These formulas use the complementary error function
385     //   erfcx(z) = erfc(z) * exp(z^2)
386     // This avoids computation of exp terms for the Gaussian PDF and then
387     // dividing by the error functions erf or erfc:
388     //   exp(-0.5*x*x) / erfc(x / sqrt(2)) == 1 / erfcx(x / sqrt(2))
389     // At large z the erfcx function is computable but exp(-0.5*z*z) and
390     // erfc(z) are zero. Use of these formulas allows computation of the
391     // mean and variance for the usable range of the truncated distribution
392     // (cdf(a, b) != 0). The variance is not accurate when it approaches
393     // machine epsilon (2^-52) at extremely narrow truncations and the
394     // computation -> 0.
395     //
396     // See: https://github.com/cossio/TruncatedNormal.jl
397 
398     /**
399      * Compute the first moment (mean) of the truncated standard normal distribution.
400      *
401      * <p>Assumes {@code a <= b}.
402      *
403      * @param a Lower bound
404      * @param b Upper bound
405      * @return the first moment
406      */
407     static double moment1(double a, double b) {
408         // Assume a <= b
409         if (a == b) {
410             return a;
411         }
412         if (Math.abs(a) > Math.abs(b)) {
413             // Subtract from zero to avoid generating -0.0
414             return 0 - moment1(-b, -a);
415         }
416 
417         // Here:
418         // |a| <= |b|
419         // a < b
420         // 0 < b
421 
422         if (a <= -MAX_X) {
423             // No truncation
424             return 0;
425         }
426         if (b >= MAX_X) {
427             // One-sided truncation
428             return ROOT_2_PI / Erfcx.value(a / ROOT2);
429         }
430 
431         // pdf = exp(-0.5*x*x) / sqrt(2*pi)
432         // cdf = erfc(-x/sqrt(2)) / 2
433         // Compute:
434         // -(pdf(b) - pdf(a)) / cdf(b, a)
435         // Note:
436         // exp(-0.5*b*b) - exp(-0.5*a*a)
437         // Use cancellation of powers:
438         // exp(-0.5*(b*b-a*a)) * exp(-0.5*a*a) - exp(-0.5*a*a)
439         // expm1(-0.5*(b*b-a*a)) * exp(-0.5*a*a)
440 
441         // dx = -0.5*(b*b-a*a)
442         final double dx = 0.5 * (b + a) * (b - a);
443         double m;
444         if (a <= 0) {
445             // Opposite signs
446             m = ROOT_2_PI * -Math.expm1(-dx) * Math.exp(-0.5 * a * a) / ErfDifference.value(a / ROOT2, b / ROOT2);
447         } else {
448             final double z = Math.exp(-dx) * Erfcx.value(b / ROOT2) - Erfcx.value(a / ROOT2);
449             if (z == 0) {
450                 // Occurs when a and b have large magnitudes and are very close
451                 return (a + b) * 0.5;
452             }
453             m = ROOT_2_PI * Math.expm1(-dx) / z;
454         }
455 
456         // Clip to the range
457         return clip(m, a, b);
458     }
459 
460     /**
461      * Compute the second moment of the truncated standard normal distribution.
462      *
463      * <p>Assumes {@code a <= b}.
464      *
465      * @param a Lower bound
466      * @param b Upper bound
467      * @return the first moment
468      */
469     private static double moment2(double a, double b) {
470         // Assume a < b.
471         // a == b is handled in the variance method
472         if (Math.abs(a) > Math.abs(b)) {
473             return moment2(-b, -a);
474         }
475 
476         // Here:
477         // |a| <= |b|
478         // a < b
479         // 0 < b
480 
481         if (a <= -MAX_X) {
482             // No truncation
483             return 1;
484         }
485         if (b >= MAX_X) {
486             // One-sided truncation.
487             // For a -> inf : moment2 -> a*a
488             // This occurs when erfcx(z) is approximated by (1/sqrt(pi)) / z and terms
489             // cancel. z > 6.71e7, a > 9.49e7
490             return 1 + ROOT_2_PI * a / Erfcx.value(a / ROOT2);
491         }
492 
493         // pdf = exp(-0.5*x*x) / sqrt(2*pi)
494         // cdf = erfc(-x/sqrt(2)) / 2
495         // Compute:
496         // 1 - (b*pdf(b) - a*pdf(a)) / cdf(b, a)
497         // = (cdf(b, a) - b*pdf(b) -a*pdf(a)) / cdf(b, a)
498 
499         // Note:
500         // For z -> 0:
501         //   sqrt(pi / 2) * erf(z / sqrt(2)) -> z
502         //   z * Math.exp(-0.5 * z * z) -> z
503         // Both computations below have cancellation as b -> 0 and the
504         // second moment is not computable as the fraction P/Q
505         // since P < ulp(Q). This always occurs when b < MIN_X
506         // if MIN_X is set at the point where
507         //   exp(-0.5 * z * z) / sqrt(2 pi) == 1 / sqrt(2 pi).
508         // This is JDK dependent due to variations in Math.exp.
509         // For b < MIN_X the second moment can be approximated using
510         // a uniform distribution: (b^3 - a^3) / (3b - 3a).
511         // In practice it also occurs when b > MIN_X since any a < MIN_X
512         // is effectively zero for part of the computation. A
513         // threshold to transition to a uniform distribution
514         // approximation is a compromise. Also note it will not
515         // correct computation when (b-a) is small and is far from 0.
516         // Thus the second moment is left to be inaccurate for
517         // small ranges (b-a) and the variance -> 0 when the true
518         // variance is close to or below machine epsilon.
519 
520         double m;
521 
522         if (a <= 0) {
523             // Opposite signs
524             final double ea = ROOT_PI_2 * Erf.value(a / ROOT2);
525             final double eb = ROOT_PI_2 * Erf.value(b / ROOT2);
526             final double fa = ea - a * Math.exp(-0.5 * a * a);
527             final double fb = eb - b * Math.exp(-0.5 * b * b);
528             // Assume fb >= fa && eb >= ea
529             // If fb <= fa this is a tiny range around 0
530             m = (fb - fa) / (eb - ea);
531             // Clip to the range
532             m = clip(m, 0, 1);
533         } else {
534             final double dx = 0.5 * (b + a) * (b - a);
535             final double ex = Math.exp(-dx);
536             final double ea = ROOT_PI_2 * Erfcx.value(a / ROOT2);
537             final double eb = ROOT_PI_2 * Erfcx.value(b / ROOT2);
538             final double fa = ea + a;
539             final double fb = eb + b;
540             m = (fa - fb * ex) / (ea - eb * ex);
541             // Clip to the range
542             m = clip(m, a * a, b * b);
543         }
544         return m;
545     }
546 
547     /**
548      * Compute the variance of the truncated standard normal distribution.
549      *
550      * <p>Assumes {@code a <= b}.
551      *
552      * @param a Lower bound
553      * @param b Upper bound
554      * @return the first moment
555      */
556     static double variance(double a, double b) {
557         if (a == b) {
558             return 0;
559         }
560 
561         final double m1 = moment1(a, b);
562         double m2 = moment2(a, b);
563         // variance = m2 - m1*m1
564         // rearrange x^2 - y^2 as (x-y)(x+y)
565         m2 = Math.sqrt(m2);
566         final double variance = (m2 - m1) * (m2 + m1);
567 
568         // Detect floating-point error.
569         if (variance >= 1) {
570             // Note:
571             // Extreme truncations in the tails can compute a variance above 1,
572             // for example if m2 is infinite: m2 - m1*m1 > 1
573             // Detect no truncation as the terms a and b lie far either side of zero;
574             // otherwise return 0 to indicate very small unknown variance.
575             return a < -1 && b > 1 ? 1 : 0;
576         } else if (variance <= 0) {
577             // Floating-point error can create negative variance so return 0.
578             return 0;
579         }
580 
581         return variance;
582     }
583 }