View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.analysis.function;
19  
20  import java.util.Arrays;
21  
22  import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
23  import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
24  import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
25  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
26  import org.apache.commons.math4.legacy.exception.NullArgumentException;
27  import org.apache.commons.math4.core.jdkmath.JdkMath;
28  
29  /**
30   * <a href="http://en.wikipedia.org/wiki/Sigmoid_function">
31   *  Sigmoid</a> function.
32   * It is the inverse of the {@link Logit logit} function.
33   * A more flexible version, the generalised logistic, is implemented
34   * by the {@link Logistic} class.
35   *
36   * @since 3.0
37   */
38  public class Sigmoid implements UnivariateDifferentiableFunction {
39      /** Lower asymptote. */
40      private final double lo;
41      /** Higher asymptote. */
42      private final double hi;
43  
44      /**
45       * Usual sigmoid function, where the lower asymptote is 0 and the higher
46       * asymptote is 1.
47       */
48      public Sigmoid() {
49          this(0, 1);
50      }
51  
52      /**
53       * Sigmoid function.
54       *
55       * @param lo Lower asymptote.
56       * @param hi Higher asymptote.
57       */
58      public Sigmoid(double lo,
59                     double hi) {
60          this.lo = lo;
61          this.hi = hi;
62      }
63  
64      /** {@inheritDoc} */
65      @Override
66      public double value(double x) {
67          return value(x, lo, hi);
68      }
69  
70      /**
71       * Parametric function where the input array contains the parameters of
72       * the {@link Sigmoid#Sigmoid(double,double) sigmoid function}. Ordered
73       * as follows:
74       * <ul>
75       *  <li>Lower asymptote</li>
76       *  <li>Higher asymptote</li>
77       * </ul>
78       */
79      public static class Parametric implements ParametricUnivariateFunction {
80          /**
81           * Computes the value of the sigmoid at {@code x}.
82           *
83           * @param x Value for which the function must be computed.
84           * @param param Values of lower asymptote and higher asymptote.
85           * @return the value of the function.
86           * @throws NullArgumentException if {@code param} is {@code null}.
87           * @throws DimensionMismatchException if the size of {@code param} is
88           * not 2.
89           */
90          @Override
91          public double value(double x, double ... param)
92              throws NullArgumentException,
93                     DimensionMismatchException {
94              validateParameters(param);
95              return Sigmoid.value(x, param[0], param[1]);
96          }
97  
98          /**
99           * Computes the value of the gradient at {@code x}.
100          * The components of the gradient vector are the partial
101          * derivatives of the function with respect to each of the
102          * <em>parameters</em> (lower asymptote and higher asymptote).
103          *
104          * @param x Value at which the gradient must be computed.
105          * @param param Values for lower asymptote and higher asymptote.
106          * @return the gradient vector at {@code x}.
107          * @throws NullArgumentException if {@code param} is {@code null}.
108          * @throws DimensionMismatchException if the size of {@code param} is
109          * not 2.
110          */
111         @Override
112         public double[] gradient(double x, double ... param)
113             throws NullArgumentException,
114                    DimensionMismatchException {
115             validateParameters(param);
116 
117             final double invExp1 = 1 / (1 + JdkMath.exp(-x));
118 
119             return new double[] { 1 - invExp1, invExp1 };
120         }
121 
122         /**
123          * Validates parameters to ensure they are appropriate for the evaluation of
124          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
125          * methods.
126          *
127          * @param param Values for lower and higher asymptotes.
128          * @throws NullArgumentException if {@code param} is {@code null}.
129          * @throws DimensionMismatchException if the size of {@code param} is
130          * not 2.
131          */
132         private void validateParameters(double[] param)
133             throws NullArgumentException,
134                    DimensionMismatchException {
135             if (param == null) {
136                 throw new NullArgumentException();
137             }
138             if (param.length != 2) {
139                 throw new DimensionMismatchException(param.length, 2);
140             }
141         }
142     }
143 
144     /**
145      * @param x Value at which to compute the sigmoid.
146      * @param lo Lower asymptote.
147      * @param hi Higher asymptote.
148      * @return the value of the sigmoid function at {@code x}.
149      */
150     private static double value(double x,
151                                 double lo,
152                                 double hi) {
153         return lo + (hi - lo) / (1 + JdkMath.exp(-x));
154     }
155 
156     /** {@inheritDoc}
157      * @since 3.1
158      */
159     @Override
160     public DerivativeStructure value(final DerivativeStructure t)
161         throws DimensionMismatchException {
162 
163         double[] f = new double[t.getOrder() + 1];
164         final double exp = JdkMath.exp(-t.getValue());
165         if (Double.isInfinite(exp)) {
166 
167             // special handling near lower boundary, to avoid NaN
168             f[0] = lo;
169             Arrays.fill(f, 1, f.length, 0.0);
170         } else {
171 
172             // the nth order derivative of sigmoid has the form:
173             // dn(sigmoid(x)/dxn = P_n(exp(-x)) / (1+exp(-x))^(n+1)
174             // where P_n(t) is a degree n polynomial with normalized higher term
175             // P_0(t) = 1, P_1(t) = t, P_2(t) = t^2 - t, P_3(t) = t^3 - 4 t^2 + t...
176             // the general recurrence relation for P_n is:
177             // P_n(x) = n t P_(n-1)(t) - t (1 + t) P_(n-1)'(t)
178             final double[] p = new double[f.length];
179 
180             final double inv   = 1 / (1 + exp);
181             double coeff = hi - lo;
182             for (int n = 0; n < f.length; ++n) {
183 
184                 // update and evaluate polynomial P_n(t)
185                 double v = 0;
186                 p[n] = 1;
187                 for (int k = n; k >= 0; --k) {
188                     v = v * exp + p[k];
189                     if (k > 1) {
190                         p[k - 1] = (n - k + 2) * p[k - 2] - (k - 1) * p[k - 1];
191                     } else {
192                         p[0] = 0;
193                     }
194                 }
195 
196                 coeff *= inv;
197                 f[n]   = coeff * v;
198             }
199 
200             // fix function value
201             f[0] += lo;
202         }
203 
204         return t.compose(f);
205     }
206 }