1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.examples.sofm.chineserings;
19
20 import java.util.Iterator;
21 import java.util.NoSuchElementException;
22
23 import org.apache.commons.rng.UniformRandomProvider;
24 import org.apache.commons.rng.simple.RandomSource;
25 import org.apache.commons.geometry.euclidean.threed.Vector3D;
26
27 import org.apache.commons.math4.neuralnet.SquareNeighbourhood;
28 import org.apache.commons.math4.neuralnet.FeatureInitializer;
29 import org.apache.commons.math4.neuralnet.FeatureInitializerFactory;
30 import org.apache.commons.math4.neuralnet.DistanceMeasure;
31 import org.apache.commons.math4.neuralnet.EuclideanDistance;
32 import org.apache.commons.math4.neuralnet.twod.NeuronSquareMesh2D;
33 import org.apache.commons.math4.neuralnet.sofm.LearningFactorFunction;
34 import org.apache.commons.math4.neuralnet.sofm.LearningFactorFunctionFactory;
35 import org.apache.commons.math4.neuralnet.sofm.NeighbourhoodSizeFunction;
36 import org.apache.commons.math4.neuralnet.sofm.NeighbourhoodSizeFunctionFactory;
37 import org.apache.commons.math4.neuralnet.sofm.KohonenUpdateAction;
38 import org.apache.commons.math4.neuralnet.sofm.KohonenTrainingTask;
39 import org.apache.commons.math4.legacy.stat.descriptive.SummaryStatistics;
40
41
42
43
44 class ChineseRingsClassifier {
45
46 private final NeuronSquareMesh2D sofm;
47
48 private final ChineseRings rings;
49
50 private final DistanceMeasure distance = new EuclideanDistance();
51
52
53
54
55
56
57 ChineseRingsClassifier(ChineseRings rings,
58 int dim1,
59 int dim2) {
60 this.rings = rings;
61 sofm = new NeuronSquareMesh2D(dim1, false,
62 dim2, false,
63 SquareNeighbourhood.MOORE,
64 makeInitializers());
65 }
66
67
68
69
70
71
72
73
74 public Runnable[] createParallelTasks(int numTasks,
75 long numSamplesPerTask) {
76 final Runnable[] tasks = new Runnable[numTasks];
77 final LearningFactorFunction learning
78 = LearningFactorFunctionFactory.exponentialDecay(1e-1,
79 5e-2,
80 numSamplesPerTask / 2);
81 final double numNeurons = Math.sqrt((double) sofm.getNumberOfRows() * sofm.getNumberOfColumns());
82 final NeighbourhoodSizeFunction neighbourhood
83 = NeighbourhoodSizeFunctionFactory.exponentialDecay(0.5 * numNeurons,
84 0.2 * numNeurons,
85 numSamplesPerTask / 2);
86
87 for (int i = 0; i < numTasks; i++) {
88 final KohonenUpdateAction action = new KohonenUpdateAction(distance,
89 learning,
90 neighbourhood);
91 tasks[i] = new KohonenTrainingTask(sofm.getNetwork(),
92 createRandomIterator(numSamplesPerTask),
93 action);
94 }
95
96 return tasks;
97 }
98
99
100
101
102
103
104
105 public Runnable createSequentialTask(long numSamples) {
106 return createParallelTasks(1, numSamples)[0];
107 }
108
109
110
111
112
113
114 public NeuronSquareMesh2D.DataVisualization computeQualityIndicators() {
115 return sofm.computeQualityIndicators(rings.createIterable());
116 }
117
118
119
120
121
122
123
124
125
126 private FeatureInitializer[] makeInitializers() {
127 final SummaryStatistics[] centre = {
128 new SummaryStatistics(),
129 new SummaryStatistics(),
130 new SummaryStatistics()
131 };
132 for (final Vector3D p : rings.getPoints()) {
133 centre[0].addValue(p.getX());
134 centre[1].addValue(p.getY());
135 centre[2].addValue(p.getZ());
136 }
137
138 final double[] mean = {
139 centre[0].getMean(),
140 centre[1].getMean(),
141 centre[2].getMean()
142 };
143 final double[] dev = {
144 0.1 * centre[0].getStandardDeviation(),
145 0.1 * centre[1].getStandardDeviation(),
146 0.1 * centre[2].getStandardDeviation()
147 };
148
149 final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create();
150 return new FeatureInitializer[] {
151 FeatureInitializerFactory.uniform(rng, mean[0] - dev[0], mean[0] + dev[0]),
152 FeatureInitializerFactory.uniform(rng, mean[1] - dev[1], mean[1] + dev[1]),
153 FeatureInitializerFactory.uniform(rng, mean[2] - dev[2], mean[2] + dev[2])
154 };
155 }
156
157
158
159
160
161
162
163
164 private Iterator<double[]> createRandomIterator(final long numSamples) {
165 return new Iterator<double[]>() {
166
167 private final Vector3D[] points = rings.getPoints();
168
169 private final UniformRandomProvider rng = RandomSource.KISS.create();
170
171 private long n;
172
173
174 @Override
175 public boolean hasNext() {
176 return n < numSamples;
177 }
178
179
180 @Override
181 public double[] next() {
182 if (!hasNext()) {
183 throw new NoSuchElementException();
184 }
185 ++n;
186 return points[rng.nextInt(points.length)].toArray();
187 }
188
189
190 @Override
191 public void remove() {
192 throw new UnsupportedOperationException();
193 }
194 };
195 }
196 }