View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.analysis.function;
19  
20  import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
21  import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
22  import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
23  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
24  import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
25  import org.apache.commons.math4.legacy.exception.NullArgumentException;
26  import org.apache.commons.math4.core.jdkmath.JdkMath;
27  
28  /**
29   * <a href="http://en.wikipedia.org/wiki/Generalised_logistic_function">
30   *  Generalised logistic</a> function.
31   *
32   * @since 3.0
33   */
34  public class Logistic implements UnivariateDifferentiableFunction {
35      /** Lower asymptote. */
36      private final double a;
37      /** Upper asymptote. */
38      private final double k;
39      /** Growth rate. */
40      private final double b;
41      /** Parameter that affects near which asymptote maximum growth occurs. */
42      private final double oneOverN;
43      /** Parameter that affects the position of the curve along the ordinate axis. */
44      private final double q;
45      /** Abscissa of maximum growth. */
46      private final double m;
47  
48      /**
49       * @param k If {@code b > 0}, value of the function for x going towards +&infin;.
50       * If {@code b < 0}, value of the function for x going towards -&infin;.
51       * @param m Abscissa of maximum growth.
52       * @param b Growth rate.
53       * @param q Parameter that affects the position of the curve along the
54       * ordinate axis.
55       * @param a If {@code b > 0}, value of the function for x going towards -&infin;.
56       * If {@code b < 0}, value of the function for x going towards +&infin;.
57       * @param n Parameter that affects near which asymptote the maximum
58       * growth occurs.
59       * @throws NotStrictlyPositiveException if {@code n <= 0}.
60       */
61      public Logistic(double k,
62                      double m,
63                      double b,
64                      double q,
65                      double a,
66                      double n)
67          throws NotStrictlyPositiveException {
68          if (n <= 0) {
69              throw new NotStrictlyPositiveException(n);
70          }
71  
72          this.k = k;
73          this.m = m;
74          this.b = b;
75          this.q = q;
76          this.a = a;
77          oneOverN = 1 / n;
78      }
79  
80      /** {@inheritDoc} */
81      @Override
82      public double value(double x) {
83          return value(m - x, k, b, q, a, oneOverN);
84      }
85  
86      /**
87       * Parametric function where the input array contains the parameters of
88       * the {@link Logistic#Logistic(double,double,double,double,double,double)
89       * logistic function}. Ordered as follows:
90       * <ul>
91       *  <li>k</li>
92       *  <li>m</li>
93       *  <li>b</li>
94       *  <li>q</li>
95       *  <li>a</li>
96       *  <li>n</li>
97       * </ul>
98       */
99      public static class Parametric implements ParametricUnivariateFunction {
100         /**
101          * Computes the value of the sigmoid at {@code x}.
102          *
103          * @param x Value for which the function must be computed.
104          * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
105          * {@code a} and  {@code n}.
106          * @return the value of the function.
107          * @throws NullArgumentException if {@code param} is {@code null}.
108          * @throws DimensionMismatchException if the size of {@code param} is
109          * not 6.
110          * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
111          */
112         @Override
113         public double value(double x, double ... param)
114             throws NullArgumentException,
115                    DimensionMismatchException,
116                    NotStrictlyPositiveException {
117             validateParameters(param);
118             return Logistic.value(param[1] - x, param[0],
119                                   param[2], param[3],
120                                   param[4], 1 / param[5]);
121         }
122 
123         /**
124          * Computes the value of the gradient at {@code x}.
125          * The components of the gradient vector are the partial
126          * derivatives of the function with respect to each of the
127          * <em>parameters</em>.
128          *
129          * @param x Value at which the gradient must be computed.
130          * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
131          * {@code a} and  {@code n}.
132          * @return the gradient vector at {@code x}.
133          * @throws NullArgumentException if {@code param} is {@code null}.
134          * @throws DimensionMismatchException if the size of {@code param} is
135          * not 6.
136          * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
137          */
138         @Override
139         public double[] gradient(double x, double ... param)
140             throws NullArgumentException,
141                    DimensionMismatchException,
142                    NotStrictlyPositiveException {
143             validateParameters(param);
144 
145             final double b = param[2];
146             final double q = param[3];
147 
148             final double mMinusX = param[1] - x;
149             final double oneOverN = 1 / param[5];
150             final double exp = JdkMath.exp(b * mMinusX);
151             final double qExp = q * exp;
152             final double qExp1 = qExp + 1;
153             final double factor1 = (param[0] - param[4]) * oneOverN / JdkMath.pow(qExp1, oneOverN);
154             final double factor2 = -factor1 / qExp1;
155 
156             // Components of the gradient.
157             final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN);
158             final double gm = factor2 * b * qExp;
159             final double gb = factor2 * mMinusX * qExp;
160             final double gq = factor2 * exp;
161             final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN);
162             final double gn = factor1 * JdkMath.log(qExp1) * oneOverN;
163 
164             return new double[] { gk, gm, gb, gq, ga, gn };
165         }
166 
167         /**
168          * Validates parameters to ensure they are appropriate for the evaluation of
169          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
170          * methods.
171          *
172          * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
173          * {@code a} and {@code n}.
174          * @throws NullArgumentException if {@code param} is {@code null}.
175          * @throws DimensionMismatchException if the size of {@code param} is
176          * not 6.
177          * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
178          */
179         private void validateParameters(double[] param)
180             throws NullArgumentException,
181                    DimensionMismatchException,
182                    NotStrictlyPositiveException {
183             if (param == null) {
184                 throw new NullArgumentException();
185             }
186             if (param.length != 6) {
187                 throw new DimensionMismatchException(param.length, 6);
188             }
189             if (param[5] <= 0) {
190                 throw new NotStrictlyPositiveException(param[5]);
191             }
192         }
193     }
194 
195     /**
196      * @param mMinusX {@code m - x}.
197      * @param k {@code k}.
198      * @param b {@code b}.
199      * @param q {@code q}.
200      * @param a {@code a}.
201      * @param oneOverN {@code 1 / n}.
202      * @return the value of the function.
203      */
204     private static double value(double mMinusX,
205                                 double k,
206                                 double b,
207                                 double q,
208                                 double a,
209                                 double oneOverN) {
210         return a + (k - a) / JdkMath.pow(1 + q * JdkMath.exp(b * mMinusX), oneOverN);
211     }
212 
213     /** {@inheritDoc}
214      * @since 3.1
215      */
216     @Override
217     public DerivativeStructure value(final DerivativeStructure t) {
218         return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a);
219     }
220 }