View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.examples.sofm.chineserings;
19  
20  import java.util.Iterator;
21  import java.util.NoSuchElementException;
22  
23  import org.apache.commons.rng.UniformRandomProvider;
24  import org.apache.commons.rng.simple.RandomSource;
25  import org.apache.commons.geometry.euclidean.threed.Vector3D;
26  
27  import org.apache.commons.math4.neuralnet.SquareNeighbourhood;
28  import org.apache.commons.math4.neuralnet.FeatureInitializer;
29  import org.apache.commons.math4.neuralnet.FeatureInitializerFactory;
30  import org.apache.commons.math4.neuralnet.DistanceMeasure;
31  import org.apache.commons.math4.neuralnet.EuclideanDistance;
32  import org.apache.commons.math4.neuralnet.twod.NeuronSquareMesh2D;
33  import org.apache.commons.math4.neuralnet.sofm.LearningFactorFunction;
34  import org.apache.commons.math4.neuralnet.sofm.LearningFactorFunctionFactory;
35  import org.apache.commons.math4.neuralnet.sofm.NeighbourhoodSizeFunction;
36  import org.apache.commons.math4.neuralnet.sofm.NeighbourhoodSizeFunctionFactory;
37  import org.apache.commons.math4.neuralnet.sofm.KohonenUpdateAction;
38  import org.apache.commons.math4.neuralnet.sofm.KohonenTrainingTask;
39  import org.apache.commons.math4.legacy.stat.descriptive.SummaryStatistics;
40  
41  /**
42   * SOFM for categorizing points that belong to each of two intertwined rings.
43   */
44  class ChineseRingsClassifier {
45      /** SOFM. */
46      private final NeuronSquareMesh2D sofm;
47      /** Rings. */
48      private final ChineseRings rings;
49      /** Distance function. */
50      private final DistanceMeasure distance = new EuclideanDistance();
51  
52      /**
53       * @param rings Training data.
54       * @param dim1 Number of rows of the SOFM.
55       * @param dim2 Number of columns of the SOFM.
56       */
57      ChineseRingsClassifier(ChineseRings rings,
58                             int dim1,
59                             int dim2) {
60          this.rings = rings;
61          sofm = new NeuronSquareMesh2D(dim1, false,
62                                        dim2, false,
63                                        SquareNeighbourhood.MOORE,
64                                        makeInitializers());
65      }
66  
67      /**
68       * Creates training tasks.
69       *
70       * @param numTasks Number of tasks to create.
71       * @param numSamplesPerTask Number of training samples per task.
72       * @return the created tasks.
73       */
74      public Runnable[] createParallelTasks(int numTasks,
75                                            long numSamplesPerTask) {
76          final Runnable[] tasks = new Runnable[numTasks];
77          final LearningFactorFunction learning
78              = LearningFactorFunctionFactory.exponentialDecay(1e-1,
79                                                               5e-2,
80                                                               numSamplesPerTask / 2);
81          final double numNeurons = Math.sqrt((double) sofm.getNumberOfRows() * sofm.getNumberOfColumns());
82          final NeighbourhoodSizeFunction neighbourhood
83              = NeighbourhoodSizeFunctionFactory.exponentialDecay(0.5 * numNeurons,
84                                                                  0.2 * numNeurons,
85                                                                  numSamplesPerTask / 2);
86  
87          for (int i = 0; i < numTasks; i++) {
88              final KohonenUpdateAction action = new KohonenUpdateAction(distance,
89                                                                         learning,
90                                                                         neighbourhood);
91              tasks[i] = new KohonenTrainingTask(sofm.getNetwork(),
92                                                 createRandomIterator(numSamplesPerTask),
93                                                 action);
94          }
95  
96          return tasks;
97      }
98  
99      /**
100      * Creates a training task.
101      *
102      * @param numSamples Number of training samples.
103      * @return the created task.
104      */
105     public Runnable createSequentialTask(long numSamples) {
106         return createParallelTasks(1, numSamples)[0];
107     }
108 
109     /**
110      * Computes various quality measures.
111      *
112      * @return the indicators.
113      */
114     public NeuronSquareMesh2D.DataVisualization computeQualityIndicators() {
115         return sofm.computeQualityIndicators(rings.createIterable());
116     }
117 
118     /**
119      * Creates the features' initializers.
120      * They are sampled from a uniform distribution around the barycentre of
121      * the rings.
122      *
123      * @return an array containing the initializers for the x, y and
124      * z coordinates of the features array of the neurons.
125      */
126     private FeatureInitializer[] makeInitializers() {
127         final SummaryStatistics[] centre = {
128             new SummaryStatistics(),
129             new SummaryStatistics(),
130             new SummaryStatistics()
131         };
132         for (final Vector3D p : rings.getPoints()) {
133             centre[0].addValue(p.getX());
134             centre[1].addValue(p.getY());
135             centre[2].addValue(p.getZ());
136         }
137 
138         final double[] mean = {
139             centre[0].getMean(),
140             centre[1].getMean(),
141             centre[2].getMean()
142         };
143         final double[] dev = {
144             0.1 * centre[0].getStandardDeviation(),
145             0.1 * centre[1].getStandardDeviation(),
146             0.1 * centre[2].getStandardDeviation()
147         };
148 
149         final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create();
150         return new FeatureInitializer[] {
151             FeatureInitializerFactory.uniform(rng, mean[0] - dev[0], mean[0] + dev[0]),
152             FeatureInitializerFactory.uniform(rng, mean[1] - dev[1], mean[1] + dev[1]),
153             FeatureInitializerFactory.uniform(rng, mean[2] - dev[2], mean[2] + dev[2])
154         };
155     }
156 
157     /**
158      * Creates an iterator that will present a series of points coordinates in
159      * a random order.
160      *
161      * @param numSamples Number of samples.
162      * @return the iterator.
163      */
164     private Iterator<double[]> createRandomIterator(final long numSamples) {
165         return new Iterator<double[]>() {
166             /** Data. */
167             private final Vector3D[] points = rings.getPoints();
168             /** RNG. */
169             private final UniformRandomProvider rng = RandomSource.KISS.create();
170             /** Number of samples. */
171             private long n;
172 
173             /** {@inheritDoc} */
174             @Override
175             public boolean hasNext() {
176                 return n < numSamples;
177             }
178 
179             /** {@inheritDoc} */
180             @Override
181             public double[] next() {
182                 if (!hasNext()) {
183                     throw new NoSuchElementException();
184                 }
185                 ++n;
186                 return points[rng.nextInt(points.length)].toArray();
187             }
188 
189             /** {@inheritDoc} */
190             @Override
191             public void remove() {
192                 throw new UnsupportedOperationException();
193             }
194         };
195     }
196 }