View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.analysis.function;
19  
20  import java.util.Arrays;
21  
22  import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
23  import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
24  import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
25  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
26  import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
27  import org.apache.commons.math4.legacy.exception.NullArgumentException;
28  import org.apache.commons.math4.core.jdkmath.JdkMath;
29  import org.apache.commons.numbers.core.Precision;
30  
31  /**
32   * <a href="http://en.wikipedia.org/wiki/Gaussian_function">
33   *  Gaussian</a> function.
34   *
35   * @since 3.0
36   */
37  public class Gaussian implements UnivariateDifferentiableFunction {
38      /** Mean. */
39      private final double mean;
40      /** Inverse of the standard deviation. */
41      private final double is;
42      /** Inverse of twice the square of the standard deviation. */
43      private final double i2s2;
44      /** Normalization factor. */
45      private final double norm;
46  
47      /**
48       * Gaussian with given normalization factor, mean and standard deviation.
49       *
50       * @param norm Normalization factor.
51       * @param mean Mean.
52       * @param sigma Standard deviation.
53       * @throws NotStrictlyPositiveException if {@code sigma <= 0}.
54       */
55      public Gaussian(double norm,
56                      double mean,
57                      double sigma)
58          throws NotStrictlyPositiveException {
59          if (sigma <= 0) {
60              throw new NotStrictlyPositiveException(sigma);
61          }
62  
63          this.norm = norm;
64          this.mean = mean;
65          this.is   = 1 / sigma;
66          this.i2s2 = 0.5 * is * is;
67      }
68  
69      /**
70       * Normalized gaussian with given mean and standard deviation.
71       *
72       * @param mean Mean.
73       * @param sigma Standard deviation.
74       * @throws NotStrictlyPositiveException if {@code sigma <= 0}.
75       */
76      public Gaussian(double mean,
77                      double sigma)
78          throws NotStrictlyPositiveException {
79          this(1 / (sigma * JdkMath.sqrt(2 * Math.PI)), mean, sigma);
80      }
81  
82      /**
83       * Normalized gaussian with zero mean and unit standard deviation.
84       */
85      public Gaussian() {
86          this(0, 1);
87      }
88  
89      /** {@inheritDoc} */
90      @Override
91      public double value(double x) {
92          return value(x - mean, norm, i2s2);
93      }
94  
95      /**
96       * Parametric function where the input array contains the parameters of
97       * the Gaussian. Ordered as follows:
98       * <ul>
99       *  <li>Norm</li>
100      *  <li>Mean</li>
101      *  <li>Standard deviation</li>
102      * </ul>
103      */
104     public static class Parametric implements ParametricUnivariateFunction {
105         /**
106          * Computes the value of the Gaussian at {@code x}.
107          *
108          * @param x Value for which the function must be computed.
109          * @param param Values of norm, mean and standard deviation.
110          * @return the value of the function.
111          * @throws NullArgumentException if {@code param} is {@code null}.
112          * @throws DimensionMismatchException if the size of {@code param} is
113          * not 3.
114          * @throws NotStrictlyPositiveException if {@code param[2]} is negative.
115          */
116         @Override
117         public double value(double x, double ... param)
118             throws NullArgumentException,
119                    DimensionMismatchException,
120                    NotStrictlyPositiveException {
121             validateParameters(param);
122 
123             final double diff = x - param[1];
124             final double i2s2 = 1 / (2 * param[2] * param[2]);
125             return Gaussian.value(diff, param[0], i2s2);
126         }
127 
128         /**
129          * Computes the value of the gradient at {@code x}.
130          * The components of the gradient vector are the partial
131          * derivatives of the function with respect to each of the
132          * <em>parameters</em> (norm, mean and standard deviation).
133          *
134          * @param x Value at which the gradient must be computed.
135          * @param param Values of norm, mean and standard deviation.
136          * @return the gradient vector at {@code x}.
137          * @throws NullArgumentException if {@code param} is {@code null}.
138          * @throws DimensionMismatchException if the size of {@code param} is
139          * not 3.
140          * @throws NotStrictlyPositiveException if {@code param[2]} is negative.
141          */
142         @Override
143         public double[] gradient(double x, double ... param)
144             throws NullArgumentException,
145                    DimensionMismatchException,
146                    NotStrictlyPositiveException {
147             validateParameters(param);
148 
149             final double norm = param[0];
150             final double diff = x - param[1];
151             final double sigma = param[2];
152             final double i2s2 = 1 / (2 * sigma * sigma);
153 
154             final double n = Gaussian.value(diff, 1, i2s2);
155             final double m = norm * n * 2 * i2s2 * diff;
156             final double s = m * diff / sigma;
157 
158             return new double[] { n, m, s };
159         }
160 
161         /**
162          * Validates parameters to ensure they are appropriate for the evaluation of
163          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
164          * methods.
165          *
166          * @param param Values of norm, mean and standard deviation.
167          * @throws NullArgumentException if {@code param} is {@code null}.
168          * @throws DimensionMismatchException if the size of {@code param} is
169          * not 3.
170          * @throws NotStrictlyPositiveException if {@code param[2]} is negative.
171          */
172         private void validateParameters(double[] param)
173             throws NullArgumentException,
174                    DimensionMismatchException,
175                    NotStrictlyPositiveException {
176             if (param == null) {
177                 throw new NullArgumentException();
178             }
179             if (param.length != 3) {
180                 throw new DimensionMismatchException(param.length, 3);
181             }
182             if (param[2] <= 0) {
183                 throw new NotStrictlyPositiveException(param[2]);
184             }
185         }
186     }
187 
188     /**
189      * @param xMinusMean {@code x - mean}.
190      * @param norm Normalization factor.
191      * @param i2s2 Inverse of twice the square of the standard deviation.
192      * @return the value of the Gaussian at {@code x}.
193      */
194     private static double value(double xMinusMean,
195                                 double norm,
196                                 double i2s2) {
197         return norm * JdkMath.exp(-xMinusMean * xMinusMean * i2s2);
198     }
199 
200     /** {@inheritDoc}
201      * @since 3.1
202      */
203     @Override
204     public DerivativeStructure value(final DerivativeStructure t)
205         throws DimensionMismatchException {
206 
207         final double u = is * (t.getValue() - mean);
208         double[] f = new double[t.getOrder() + 1];
209 
210         // the nth order derivative of the Gaussian has the form:
211         // dn(g(x)/dxn = (norm / s^n) P_n(u) exp(-u^2/2) with u=(x-m)/s
212         // where P_n(u) is a degree n polynomial with same parity as n
213         // P_0(u) = 1, P_1(u) = -u, P_2(u) = u^2 - 1, P_3(u) = -u^3 + 3 u...
214         // the general recurrence relation for P_n is:
215         // P_n(u) = P_(n-1)'(u) - u P_(n-1)(u)
216         // as per polynomial parity, we can store coefficients of both P_(n-1) and P_n in the same array
217         final double[] p = new double[f.length];
218         p[0] = 1;
219         final double u2 = u * u;
220         double coeff = norm * JdkMath.exp(-0.5 * u2);
221         if (coeff <= Precision.SAFE_MIN) {
222             Arrays.fill(f, 0.0);
223         } else {
224             f[0] = coeff;
225             for (int n = 1; n < f.length; ++n) {
226 
227                 // update and evaluate polynomial P_n(x)
228                 double v = 0;
229                 p[n] = -p[n - 1];
230                 for (int k = n; k >= 0; k -= 2) {
231                     v = v * u2 + p[k];
232                     if (k > 2) {
233                         p[k - 2] = (k - 1) * p[k - 1] - p[k - 3];
234                     } else if (k == 2) {
235                         p[0] = p[1];
236                     }
237                 }
238                 if ((n & 0x1) == 1) {
239                     v *= u;
240                 }
241 
242                 coeff *= is;
243                 f[n] = coeff * v;
244             }
245         }
246 
247         return t.compose(f);
248     }
249 }