001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv; 018 019import java.util.Arrays; 020import java.util.List; 021import java.util.ArrayList; 022import java.util.Comparator; 023import java.util.Collections; 024import java.util.Objects; 025import java.util.function.UnaryOperator; 026import java.util.function.IntSupplier; 027import java.util.concurrent.CopyOnWriteArrayList; 028 029import org.apache.commons.math4.legacy.core.MathArrays; 030import org.apache.commons.math4.legacy.analysis.MultivariateFunction; 031import org.apache.commons.math4.legacy.exception.MathUnsupportedOperationException; 032import org.apache.commons.math4.legacy.exception.MathInternalError; 033import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; 034import org.apache.commons.math4.legacy.optim.ConvergenceChecker; 035import org.apache.commons.math4.legacy.optim.OptimizationData; 036import org.apache.commons.math4.legacy.optim.PointValuePair; 037import org.apache.commons.math4.legacy.optim.SimpleValueChecker; 038import org.apache.commons.math4.legacy.optim.InitialGuess; 039import org.apache.commons.math4.legacy.optim.MaxEval; 040import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType; 041import org.apache.commons.math4.legacy.optim.nonlinear.scalar.MultivariateOptimizer; 042import org.apache.commons.math4.legacy.optim.nonlinear.scalar.SimulatedAnnealing; 043import org.apache.commons.math4.legacy.optim.nonlinear.scalar.PopulationSize; 044import org.apache.commons.math4.legacy.optim.nonlinear.scalar.ObjectiveFunction; 045 046/** 047 * This class implements simplex-based direct search optimization. 048 * 049 * <p> 050 * Direct search methods only use objective function values, they do 051 * not need derivatives and don't either try to compute approximation 052 * of the derivatives. According to a 1996 paper by Margaret H. Wright 053 * (<a href="http://cm.bell-labs.com/cm/cs/doc/96/4-02.ps.gz">Direct 054 * Search Methods: Once Scorned, Now Respectable</a>), they are used 055 * when either the computation of the derivative is impossible (noisy 056 * functions, unpredictable discontinuities) or difficult (complexity, 057 * computation cost). In the first cases, rather than an optimum, a 058 * <em>not too bad</em> point is desired. In the latter cases, an 059 * optimum is desired but cannot be reasonably found. In all cases 060 * direct search methods can be useful. 061 * 062 * <p> 063 * Simplex-based direct search methods are based on comparison of 064 * the objective function values at the vertices of a simplex (which is a 065 * set of n+1 points in dimension n) that is updated by the algorithms 066 * steps. 067 * 068 * <p> 069 * In addition to those documented in 070 * {@link MultivariateOptimizer#optimize(OptimizationData[]) MultivariateOptimizer}, 071 * an instance of this class will register the following data: 072 * <ul> 073 * <li>{@link Simplex}</li> 074 * <li>{@link Simplex.TransformFactory}</li> 075 * <li>{@link SimulatedAnnealing}</li> 076 * <li>{@link PopulationSize}</li> 077 * </ul> 078 * 079 * <p> 080 * Each call to {@code optimize} will re-use the start configuration of 081 * the current simplex and move it such that its first vertex is at the 082 * provided start point of the optimization. 083 * If the {@code optimize} method is called to solve a different problem 084 * and the number of parameters change, the simplex must be re-initialized 085 * to one with the appropriate dimensions. 086 * 087 * <p> 088 * Convergence is considered achieved when <em>all</em> the simplex points 089 * have converged. 090 * <p> 091 * Whenever {@link SimulatedAnnealing simulated annealing (SA)} is activated, 092 * and the SA phase has completed, convergence has probably not been reached 093 * yet; whenever it's the case, an additional (non-SA) search will be performed 094 * (using the current best simplex point as a start point). 095 * <p> 096 * Additional "best list" searches can be requested through setting the 097 * {@link PopulationSize} argument of the {@link #optimize(OptimizationData[]) 098 * optimize} method. 099 * 100 * <p> 101 * This implementation does not directly support constrained optimization 102 * with simple bounds. 103 * The call to {@link #optimize(OptimizationData[]) optimize} will throw 104 * {@link MathUnsupportedOperationException} if bounds are passed to it. 105 * 106 * @see NelderMeadTransform 107 * @see MultiDirectionalTransform 108 * @see HedarFukushimaTransform 109 */ 110public class SimplexOptimizer extends MultivariateOptimizer { 111 /** Default simplex side length ratio. */ 112 private static final double SIMPLEX_SIDE_RATIO = 1e-1; 113 /** Simplex update function factory. */ 114 private Simplex.TransformFactory updateRule; 115 /** Initial simplex. */ 116 private Simplex initialSimplex; 117 /** Simulated annealing setup (optional). */ 118 private SimulatedAnnealing simulatedAnnealing = null; 119 /** User-defined number of additional optimizations (optional). */ 120 private int populationSize = 0; 121 /** Actual number of additional optimizations. */ 122 private int additionalSearch = 0; 123 /** Callbacks. */ 124 private final List<Observer> callbacks = new CopyOnWriteArrayList<>(); 125 126 /** 127 * @param checker Convergence checker. 128 */ 129 public SimplexOptimizer(ConvergenceChecker<PointValuePair> checker) { 130 super(checker); 131 } 132 133 /** 134 * @param rel Relative threshold. 135 * @param abs Absolute threshold. 136 */ 137 public SimplexOptimizer(double rel, 138 double abs) { 139 this(new SimpleValueChecker(rel, abs)); 140 } 141 142 /** 143 * Callback interface for updating caller's code with the current 144 * state of the optimization. 145 */ 146 @FunctionalInterface 147 public interface Observer { 148 /** 149 * Method called after each modification of the {@code simplex}. 150 * 151 * @param simplex Current simplex. 152 * @param isInit {@code true} at the start of a new search (either 153 * "main" or "best list"), after the initial simplex's vertices 154 * have been evaluated. 155 * @param numEval Number of evaluations of the objective function. 156 */ 157 void update(Simplex simplex, 158 boolean isInit, 159 int numEval); 160 } 161 162 /** 163 * Register a callback. 164 * 165 * @param cb Callback. 166 */ 167 public void addObserver(Observer cb) { 168 Objects.requireNonNull(cb, "Callback"); 169 callbacks.add(cb); 170 } 171 172 /** {@inheritDoc} */ 173 @Override 174 protected PointValuePair doOptimize() { 175 checkParameters(); 176 177 final MultivariateFunction evalFunc = getObjectiveFunction(); 178 179 final boolean isMinim = getGoalType() == GoalType.MINIMIZE; 180 final Comparator<PointValuePair> comparator = (o1, o2) -> { 181 final double v1 = o1.getValue(); 182 final double v2 = o2.getValue(); 183 return isMinim ? Double.compare(v1, v2) : Double.compare(v2, v1); 184 }; 185 186 // Start points for additional search. 187 final List<PointValuePair> bestList = new ArrayList<>(); 188 189 Simplex currentSimplex = initialSimplex.translate(getStartPoint()).evaluate(evalFunc, comparator); 190 notifyObservers(currentSimplex, true); 191 double temperature = Double.NaN; // Only used with simulated annealing. 192 Simplex previousSimplex = null; 193 194 if (simulatedAnnealing != null) { 195 temperature = 196 temperature(currentSimplex.get(0), 197 currentSimplex.get(currentSimplex.getDimension()), 198 simulatedAnnealing.getStartProbability()); 199 } 200 201 while (true) { 202 if (previousSimplex != null) { // Skip check at first iteration. 203 if (hasConverged(previousSimplex, currentSimplex)) { 204 return currentSimplex.get(0); 205 } 206 } 207 208 // We still need to search. 209 previousSimplex = currentSimplex; 210 211 if (simulatedAnnealing != null) { 212 // Update current temperature. 213 temperature = 214 simulatedAnnealing.getCoolingSchedule().apply(temperature, 215 currentSimplex); 216 217 final double endTemperature = 218 temperature(currentSimplex.get(0), 219 currentSimplex.get(currentSimplex.getDimension()), 220 simulatedAnnealing.getEndProbability()); 221 222 if (temperature < endTemperature) { 223 break; 224 } 225 226 final UnaryOperator<Simplex> update = 227 updateRule.create(evalFunc, 228 comparator, 229 simulatedAnnealing.metropolis(temperature)); 230 231 for (int i = 0; i < simulatedAnnealing.getEpochDuration(); i++) { 232 // Simplex is transformed (and observers are notified). 233 currentSimplex = applyUpdate(update, 234 currentSimplex, 235 evalFunc, 236 comparator); 237 } 238 } else { 239 // No simulated annealing. 240 final UnaryOperator<Simplex> update = 241 updateRule.create(evalFunc, comparator, null); 242 243 // Simplex is transformed (and observers are notified). 244 currentSimplex = applyUpdate(update, 245 currentSimplex, 246 evalFunc, 247 comparator); 248 } 249 250 if (additionalSearch != 0) { 251 // In "bestList", we must keep track of at least two points 252 // in order to be able to compute the new initial simplex for 253 // the additional search. 254 final int max = Math.max(additionalSearch, 2); 255 256 // Store best points. 257 for (int i = 0; i < currentSimplex.getSize(); i++) { 258 keepIfBetter(currentSimplex.get(i), 259 comparator, 260 bestList, 261 max); 262 } 263 } 264 265 incrementIterationCount(); 266 } 267 268 // No convergence. 269 270 if (additionalSearch > 0) { 271 // Additional optimizations. 272 // Reference to counter in the "main" search in order to retrieve 273 // the total number of evaluations in the "best list" search. 274 final IntSupplier evalCount = () -> getEvaluations(); 275 276 return bestListSearch(evalFunc, 277 comparator, 278 bestList, 279 evalCount); 280 } 281 282 throw new MathInternalError(); // Should never happen. 283 } 284 285 /** 286 * Scans the list of (required and optional) optimization data that 287 * characterize the problem. 288 * 289 * @param optData Optimization data. 290 * The following data will be looked for: 291 * <ul> 292 * <li>{@link Simplex}</li> 293 * <li>{@link Simplex.TransformFactory}</li> 294 * <li>{@link SimulatedAnnealing}</li> 295 * <li>{@link PopulationSize}</li> 296 * </ul> 297 */ 298 @Override 299 protected void parseOptimizationData(OptimizationData... optData) { 300 // Allow base class to register its own data. 301 super.parseOptimizationData(optData); 302 303 // The existing values (as set by the previous call) are reused 304 // if not provided in the argument list. 305 for (OptimizationData data : optData) { 306 if (data instanceof Simplex) { 307 initialSimplex = (Simplex) data; 308 } else if (data instanceof Simplex.TransformFactory) { 309 updateRule = (Simplex.TransformFactory) data; 310 } else if (data instanceof SimulatedAnnealing) { 311 simulatedAnnealing = (SimulatedAnnealing) data; 312 } else if (data instanceof PopulationSize) { 313 populationSize = ((PopulationSize) data).getPopulationSize(); 314 } 315 } 316 } 317 318 /** 319 * Detects whether the simplex has shrunk below the user-defined 320 * tolerance. 321 * 322 * @param previous Simplex at previous iteration. 323 * @param current Simplex at current iteration. 324 * @return {@code true} if convergence is considered achieved. 325 */ 326 private boolean hasConverged(Simplex previous, 327 Simplex current) { 328 final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker(); 329 330 for (int i = 0; i < current.getSize(); i++) { 331 final PointValuePair prev = previous.get(i); 332 final PointValuePair curr = current.get(i); 333 334 if (!checker.converged(getIterations(), prev, curr)) { 335 return false; 336 } 337 } 338 339 return true; 340 } 341 342 /** 343 * @throws MathUnsupportedOperationException if bounds were passed to the 344 * {@link #optimize(OptimizationData[]) optimize} method. 345 * @throws NullPointerException if no initial simplex or no transform rule 346 * was passed to the {@link #optimize(OptimizationData[]) optimize} method. 347 * @throws IllegalArgumentException if {@link #populationSize} is negative. 348 */ 349 private void checkParameters() { 350 Objects.requireNonNull(updateRule, "Update rule"); 351 Objects.requireNonNull(initialSimplex, "Initial simplex"); 352 353 if (getLowerBound() != null || 354 getUpperBound() != null) { 355 throw new MathUnsupportedOperationException(LocalizedFormats.CONSTRAINT); 356 } 357 358 if (populationSize < 0) { 359 throw new IllegalArgumentException("Population size"); 360 } 361 362 additionalSearch = simulatedAnnealing == null ? 363 Math.max(0, populationSize) : 364 Math.max(1, populationSize); 365 } 366 367 /** 368 * Computes the temperature as a function of the acceptance probability 369 * and the fitness difference between two of the simplex vertices (usually 370 * the best and worst points). 371 * 372 * @param p1 Simplex point. 373 * @param p2 Simplex point. 374 * @param prob Acceptance probability. 375 * @return the temperature. 376 */ 377 private double temperature(PointValuePair p1, 378 PointValuePair p2, 379 double prob) { 380 return -Math.abs(p1.getValue() - p2.getValue()) / Math.log(prob); 381 } 382 383 /** 384 * Stores the given {@code candidate} if its fitness is better than 385 * that of the last (assumed to be the worst) point in {@code list}. 386 * 387 * <p>If the list is below the maximum size then the {@code candidate} 388 * is added if it is not already in the list. The list is sorted 389 * when it reaches the maximum size. 390 * 391 * @param candidate Point to be stored. 392 * @param comp Fitness comparator. 393 * @param list Starting points (modified in-place). 394 * @param max Maximum size of the {@code list}. 395 */ 396 private static void keepIfBetter(PointValuePair candidate, 397 Comparator<PointValuePair> comp, 398 List<PointValuePair> list, 399 int max) { 400 final int listSize = list.size(); 401 final double[] candidatePoint = candidate.getPoint(); 402 if (listSize == 0) { 403 list.add(candidate); 404 } else if (listSize < max) { 405 // List is not fully populated yet. 406 for (PointValuePair p : list) { 407 final double[] pPoint = p.getPoint(); 408 if (Arrays.equals(pPoint, candidatePoint)) { 409 // Point was already stored. 410 return; 411 } 412 } 413 // Store candidate. 414 list.add(candidate); 415 // Sort the list when required 416 if (list.size() == max) { 417 Collections.sort(list, comp); 418 } 419 } else { 420 final int last = max - 1; 421 if (comp.compare(candidate, list.get(last)) < 0) { 422 for (PointValuePair p : list) { 423 final double[] pPoint = p.getPoint(); 424 if (Arrays.equals(pPoint, candidatePoint)) { 425 // Point was already stored. 426 return; 427 } 428 } 429 430 // Store better candidate and reorder the list. 431 list.set(last, candidate); 432 Collections.sort(list, comp); 433 } 434 } 435 } 436 437 /** 438 * Computes the smallest distance between the given {@code point} 439 * and any of the other points in the {@code list}. 440 * 441 * @param point Point. 442 * @param list List. 443 * @return the smallest distance. 444 */ 445 private static double shortestDistance(PointValuePair point, 446 List<PointValuePair> list) { 447 double minDist = Double.POSITIVE_INFINITY; 448 449 final double[] p = point.getPoint(); 450 for (PointValuePair other : list) { 451 final double[] pOther = other.getPoint(); 452 if (!Arrays.equals(p, pOther)) { 453 final double dist = MathArrays.distance(p, pOther); 454 if (dist < minDist) { 455 minDist = dist; 456 } 457 } 458 } 459 460 return minDist; 461 } 462 463 /** 464 * Perform additional optimizations. 465 * 466 * @param evalFunc Objective function. 467 * @param comp Fitness comparator. 468 * @param bestList Best points encountered during the "main" search. 469 * List is assumed to be ordered from best to worst. 470 * @param evalCount Evaluation counter. 471 * @return the optimum. 472 */ 473 private PointValuePair bestListSearch(MultivariateFunction evalFunc, 474 Comparator<PointValuePair> comp, 475 List<PointValuePair> bestList, 476 IntSupplier evalCount) { 477 PointValuePair best = bestList.get(0); // Overall best result. 478 479 // Additional local optimizations using each of the best 480 // points visited during the main search. 481 for (int i = 0; i < additionalSearch; i++) { 482 final PointValuePair start = bestList.get(i); 483 // Find shortest distance to the other points. 484 final double dist = shortestDistance(start, bestList); 485 final double[] init = start.getPoint(); 486 // Create smaller initial simplex. 487 final Simplex simplex = Simplex.equalSidesAlongAxes(init.length, 488 SIMPLEX_SIDE_RATIO * dist); 489 490 final PointValuePair r = directSearch(init, 491 simplex, 492 evalFunc, 493 getConvergenceChecker(), 494 getGoalType(), 495 callbacks, 496 evalCount); 497 if (comp.compare(r, best) < 0) { 498 best = r; // New overall best. 499 } 500 } 501 502 return best; 503 } 504 505 /** 506 * @param init Start point. 507 * @param simplex Initial simplex. 508 * @param eval Objective function. 509 * Note: It is assumed that evaluations of this function are 510 * incrementing the main counter. 511 * @param checker Convergence checker. 512 * @param goalType Whether to minimize or maximize the objective function. 513 * @param cbList Callbacks. 514 * @param evalCount Evaluation counter. 515 * @return the optimum. 516 */ 517 private static PointValuePair directSearch(double[] init, 518 Simplex simplex, 519 MultivariateFunction eval, 520 ConvergenceChecker<PointValuePair> checker, 521 GoalType goalType, 522 List<Observer> cbList, 523 final IntSupplier evalCount) { 524 final SimplexOptimizer optim = new SimplexOptimizer(checker); 525 526 for (Observer cOrig : cbList) { 527 final SimplexOptimizer.Observer cNew = (spx, isInit, numEval) -> 528 cOrig.update(spx, isInit, evalCount.getAsInt()); 529 530 optim.addObserver(cNew); 531 } 532 533 return optim.optimize(MaxEval.unlimited(), 534 new ObjectiveFunction(eval), 535 goalType, 536 new InitialGuess(init), 537 simplex, 538 new MultiDirectionalTransform()); 539 } 540 541 /** 542 * @param simplex Current simplex. 543 * @param isInit Set to {@code true} at the start of a new search 544 * (either "main" or "best list"), after the evaluation of the initial 545 * simplex's vertices. 546 */ 547 private void notifyObservers(Simplex simplex, 548 boolean isInit) { 549 for (Observer cb : callbacks) { 550 cb.update(simplex, 551 isInit, 552 getEvaluations()); 553 } 554 } 555 556 /** 557 * Applies the {@code update} to the given {@code simplex} (and notifies 558 * observers). 559 * 560 * @param update Simplex transformation. 561 * @param simplex Current simplex. 562 * @param eval Objective function. 563 * @param comp Fitness comparator. 564 * @return the transformed simplex. 565 */ 566 private Simplex applyUpdate(UnaryOperator<Simplex> update, 567 Simplex simplex, 568 MultivariateFunction eval, 569 Comparator<PointValuePair> comp) { 570 final Simplex transformed = update.apply(simplex).evaluate(eval, comp); 571 572 notifyObservers(transformed, false); 573 574 return transformed; 575 } 576}