1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.commons.math4.neuralnet.sofm; 19 20 import java.util.Collection; 21 import java.util.HashSet; 22 import java.util.concurrent.atomic.AtomicLong; 23 import java.util.function.DoubleUnaryOperator; 24 25 import org.apache.commons.math4.neuralnet.DistanceMeasure; 26 import org.apache.commons.math4.neuralnet.MapRanking; 27 import org.apache.commons.math4.neuralnet.Network; 28 import org.apache.commons.math4.neuralnet.Neuron; 29 import org.apache.commons.math4.neuralnet.UpdateAction; 30 31 /** 32 * Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen"> 33 * Kohonen's Self-Organizing Map</a>. 34 * <br> 35 * The {@link #update(Network,double[]) update} method modifies the 36 * features {@code w} of the "winning" neuron and its neighbours 37 * according to the following rule: 38 * <code> 39 * w<sub>new</sub> = w<sub>old</sub> + α e<sup>(-d / σ)</sup> * (sample - w<sub>old</sub>) 40 * </code> 41 * where 42 * <ul> 43 * <li>α is the current <em>learning rate</em>, </li> 44 * <li>σ is the current <em>neighbourhood size</em>, and</li> 45 * <li>{@code d} is the number of links to traverse in order to reach 46 * the neuron from the winning neuron.</li> 47 * </ul> 48 * <br> 49 * This class is thread-safe as long as the arguments passed to the 50 * {@link #KohonenUpdateAction(DistanceMeasure,LearningFactorFunction, 51 * NeighbourhoodSizeFunction) constructor} are instances of thread-safe 52 * classes. 53 * <br> 54 * Each call to the {@link #update(Network,double[]) update} method 55 * will increment the internal counter used to compute the current 56 * values for 57 * <ul> 58 * <li>the <em>learning rate</em>, and</li> 59 * <li>the <em>neighbourhood size</em>.</li> 60 * </ul> 61 * Consequently, the function instances that compute those values (passed 62 * to the constructor of this class) must take into account whether this 63 * class's instance will be shared by multiple threads, as this will impact 64 * the training process. 65 * 66 * @since 3.3 67 */ 68 public class KohonenUpdateAction implements UpdateAction { 69 /** Distance function. */ 70 private final DistanceMeasure distance; 71 /** Learning factor update function. */ 72 private final LearningFactorFunction learningFactor; 73 /** Neighbourhood size update function. */ 74 private final NeighbourhoodSizeFunction neighbourhoodSize; 75 /** Number of calls to {@link #update(Network,double[])}. */ 76 private final AtomicLong numberOfCalls = new AtomicLong(0); 77 78 /** 79 * @param distance Distance function. 80 * @param learningFactor Learning factor update function. 81 * @param neighbourhoodSize Neighbourhood size update function. 82 */ 83 public KohonenUpdateAction(DistanceMeasure distance, 84 LearningFactorFunction learningFactor, 85 NeighbourhoodSizeFunction neighbourhoodSize) { 86 this.distance = distance; 87 this.learningFactor = learningFactor; 88 this.neighbourhoodSize = neighbourhoodSize; 89 } 90 91 /** 92 * {@inheritDoc} 93 */ 94 @Override 95 public void update(Network net, 96 double[] features) { 97 final long numCalls = numberOfCalls.incrementAndGet() - 1; 98 final double currentLearning = learningFactor.value(numCalls); 99 final Neuron best = findAndUpdateBestNeuron(net, 100 features, 101 currentLearning); 102 103 final int currentNeighbourhood = neighbourhoodSize.value(numCalls); 104 // The farther away the neighbour is from the winning neuron, the 105 // smaller the learning rate will become. 106 final Gaussian neighbourhoodDecay 107 = new Gaussian(currentLearning, currentNeighbourhood); 108 109 if (currentNeighbourhood > 0) { 110 // Initial set of neurons only contains the winning neuron. 111 Collection<Neuron> neighbours = new HashSet<>(); 112 neighbours.add(best); 113 // Winning neuron must be excluded from the neighbours. 114 final HashSet<Neuron> exclude = new HashSet<>(); 115 exclude.add(best); 116 117 int radius = 1; 118 do { 119 // Retrieve immediate neighbours of the current set of neurons. 120 neighbours = net.getNeighbours(neighbours, exclude); 121 122 // Update all the neighbours. 123 for (final Neuron n : neighbours) { 124 updateNeighbouringNeuron(n, features, neighbourhoodDecay.applyAsDouble(radius)); 125 } 126 127 // Add the neighbours to the exclude list so that they will 128 // not be updated more than once per training step. 129 exclude.addAll(neighbours); 130 ++radius; 131 } while (radius <= currentNeighbourhood); 132 } 133 } 134 135 /** 136 * Retrieves the number of calls to the {@link #update(Network,double[]) update} 137 * method. 138 * 139 * @return the current number of calls. 140 */ 141 public long getNumberOfCalls() { 142 return numberOfCalls.get(); 143 } 144 145 /** 146 * Tries to update a neuron. 147 * 148 * @param n Neuron to be updated. 149 * @param features Training data. 150 * @param learningRate Learning factor. 151 * @return {@code true} if the update succeeded, {@code true} if a 152 * concurrent update has been detected. 153 */ 154 private boolean attemptNeuronUpdate(Neuron n, 155 double[] features, 156 double learningRate) { 157 final double[] expect = n.getFeatures(); 158 final double[] update = computeFeatures(expect, 159 features, 160 learningRate); 161 162 return n.compareAndSetFeatures(expect, update); 163 } 164 165 /** 166 * Atomically updates the given neuron. 167 * 168 * @param n Neuron to be updated. 169 * @param features Training data. 170 * @param learningRate Learning factor. 171 */ 172 private void updateNeighbouringNeuron(Neuron n, 173 double[] features, 174 double learningRate) { 175 while (true) { 176 if (attemptNeuronUpdate(n, features, learningRate)) { 177 break; 178 } 179 } 180 } 181 182 /** 183 * Searches for the neuron whose features are closest to the given 184 * sample, and atomically updates its features. 185 * 186 * @param net Network. 187 * @param features Sample data. 188 * @param learningRate Current learning factor. 189 * @return the winning neuron. 190 */ 191 private Neuron findAndUpdateBestNeuron(Network net, 192 double[] features, 193 double learningRate) { 194 final MapRanking rank = new MapRanking(net, distance); 195 196 while (true) { 197 final Neuron best = rank.rank(features, 1).get(0); 198 199 if (attemptNeuronUpdate(best, features, learningRate)) { 200 return best; 201 } 202 203 // If another thread modified the state of the winning neuron, 204 // it may not be the best match anymore for the given training 205 // sample: Hence, the winner search is performed again. 206 } 207 } 208 209 /** 210 * Computes the new value of the features set. 211 * 212 * @param current Current values of the features. 213 * @param sample Training data. 214 * @param learningRate Learning factor. 215 * @return the new values for the features. 216 */ 217 private double[] computeFeatures(double[] current, 218 double[] sample, 219 double learningRate) { 220 final int len = current.length; 221 final double[] r = new double[len]; 222 for (int i = 0; i < len; i++) { 223 final double c = current[i]; 224 final double s = sample[i]; 225 r[i] = c + learningRate * (s - c); 226 } 227 return r; 228 } 229 230 /** 231 * Gaussian function with zero mean. 232 */ 233 private static class Gaussian implements DoubleUnaryOperator { 234 /** Inverse of twice the square of the standard deviation. */ 235 private final double i2s2; 236 /** Normalization factor. */ 237 private final double norm; 238 239 /** 240 * @param norm Normalization factor. 241 * @param sigma Standard deviation. 242 */ 243 Gaussian(double norm, 244 double sigma) { 245 this.norm = norm; 246 i2s2 = 1d / (2 * sigma * sigma); 247 } 248 249 @Override 250 public double applyAsDouble(double x) { 251 return norm * Math.exp(-x * x * i2s2); 252 } 253 } 254 }