NelderMeadTransform.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;

import java.util.Comparator;
import java.util.function.UnaryOperator;
import java.util.function.DoublePredicate;

import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
import org.apache.commons.math4.legacy.optim.PointValuePair;

/**
 * <a href="https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method">Nelder-Mead method</a>.
 */
public class NelderMeadTransform
    implements Simplex.TransformFactory {
    /** Default value for {@link #alpha}: {@value}. */
    private static final double DEFAULT_ALPHA = 1;
    /** Default value for {@link #gamma}: {@value}. */
    private static final double DEFAULT_GAMMA = 2;
    /** Default value for {@link #rho}: {@value}. */
    private static final double DEFAULT_RHO = 0.5;
    /** Default value for {@link #sigma}: {@value}. */
    private static final double DEFAULT_SIGMA = 0.5;
    /** Reflection coefficient. */
    private final double alpha;
    /** Expansion coefficient. */
    private final double gamma;
    /** Contraction coefficient. */
    private final double rho;
    /** Shrinkage coefficient. */
    private final double sigma;

    /**
     * @param alpha Reflection coefficient.
     * @param gamma Expansion coefficient.
     * @param rho Contraction coefficient.
     * @param sigma Shrinkage coefficient.
     */
    public NelderMeadTransform(double alpha,
                               double gamma,
                               double rho,
                               double sigma) {
        this.alpha = alpha;
        this.gamma = gamma;
        this.rho = rho;
        this.sigma = sigma;
    }

    /**
     * Transform with default values.
     */
    public NelderMeadTransform() {
        this(DEFAULT_ALPHA,
             DEFAULT_GAMMA,
             DEFAULT_RHO,
             DEFAULT_SIGMA);
    }

    /** {@inheritDoc} */
    @Override
    public UnaryOperator<Simplex> create(final MultivariateFunction evaluationFunction,
                                         final Comparator<PointValuePair> comparator,
                                         final DoublePredicate sa) {
        return original -> {
            // The simplex has n + 1 points if dimension is n.
            final int n = original.getDimension();

            // Interesting values.
            final PointValuePair best = original.get(0);
            final PointValuePair secondWorst = original.get(n - 1);
            final PointValuePair worst = original.get(n);
            final double[] xWorst = worst.getPoint();

            // Centroid of the best vertices, dismissing the worst point (at index n).
            final double[] centroid = Simplex.centroid(original.asList().subList(0, n));

            // Reflection.
            final PointValuePair reflected = Simplex.newPoint(centroid,
                                                              -alpha,
                                                              xWorst,
                                                              evaluationFunction);
            if (comparator.compare(reflected, secondWorst) < 0 &&
                comparator.compare(best, reflected) <= 0) {
                return original.replaceLast(reflected);
            }

            if (comparator.compare(reflected, best) < 0) {
                // Expansion.
                final PointValuePair expanded = Simplex.newPoint(centroid,
                                                                 -gamma,
                                                                 xWorst,
                                                                 evaluationFunction);
                if (comparator.compare(expanded, reflected) < 0 ||
                    (sa != null &&
                     sa.test(expanded.getValue() - reflected.getValue()))) {
                    return original.replaceLast(expanded);
                } else {
                    return original.replaceLast(reflected);
                }
            }

            if (comparator.compare(reflected, worst) < 0) {
                // Outside contraction.
                final PointValuePair contracted = Simplex.newPoint(centroid,
                                                                   rho,
                                                                   reflected.getPoint(),
                                                                   evaluationFunction);
                if (comparator.compare(contracted, reflected) < 0) {
                    return original.replaceLast(contracted); // Accept contracted point.
                }
            } else {
                // Inside contraction.
                final PointValuePair contracted = Simplex.newPoint(centroid,
                                                                   rho,
                                                                   xWorst,
                                                                   evaluationFunction);
                if (comparator.compare(contracted, worst) < 0) {
                    return original.replaceLast(contracted); // Accept contracted point.
                }
            }

            // Shrink.
            return original.shrink(sigma, evaluationFunction);
        };
    }

    /** {@inheritDoc} */
    @Override
    public String toString() {
        return "Nelder-Mead [a=" + alpha +
            " g=" + gamma +
            " r=" + rho +
            " s=" + sigma + "]";
    }
}