View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.examples.sofm.chineserings;
19  
20  import java.io.FileNotFoundException;
21  import java.io.PrintWriter;
22  import java.io.UnsupportedEncodingException;
23  import java.nio.charset.StandardCharsets;
24  import java.util.concurrent.Callable;
25  
26  import picocli.CommandLine;
27  import picocli.CommandLine.Option;
28  import picocli.CommandLine.Command;
29  
30  import org.apache.commons.geometry.euclidean.threed.Vector3D;
31  import org.apache.commons.math4.neuralnet.twod.NeuronSquareMesh2D;
32  
33  /**
34   * Application class.
35   */
36  @Command(description = "Run the application",
37           mixinStandardHelpOptions = true)
38  public final class StandAlone implements Callable<Void> {
39      /** The number of rows. */
40      @Option(names = { "-r" }, paramLabel = "numRows",
41              description = "Number of rows of the 2D SOFM (default: ${DEFAULT-VALUE}).")
42      private int numRows = 15;
43      /** The number of columns. */
44      @Option(names = { "-c" }, paramLabel = "numCols",
45              description = "Number of columns of the 2D SOFM (default: ${DEFAULT-VALUE}).")
46      private int numCols = 15;
47      /** The number of samples. */
48      @Option(names = { "-s" }, paramLabel = "numSamples",
49              description = "Number of samples for the training (default: ${DEFAULT-VALUE}).")
50      private long numSamples = 100000;
51      /** The output file. */
52      @Option(names = { "-o" }, paramLabel = "outputFile", required = true,
53              description = "Output file name.")
54      private String outputFile = null;
55  
56      /**
57       * Program entry point.
58       *
59       * @param args Command line arguments and options.
60       */
61      public static void main(String[] args) {
62          CommandLine.call(new StandAlone(), args);
63      }
64  
65      @Override
66      public Void call() throws Exception {
67          final ChineseRings rings = new ChineseRings(Vector3D.of(1, 2, 3),
68                                                      25, 2,
69                                                      20, 1,
70                                                      2000, 1500);
71  
72          final ChineseRingsClassifier classifier = new ChineseRingsClassifier(rings, numRows, numCols);
73          classifier.createSequentialTask(numSamples).run();
74          printResult(outputFile, classifier);
75  
76          return null;
77      }
78  
79      /**
80       * Prints various quality measures of the map to files.
81       *
82       * @param fileName File name.
83       * @param sofm Classifier.
84       * @throws UnsupportedEncodingException If UTF-8 encoding does not exist.
85       * @throws FileNotFoundException If the file cannot be created.
86       */
87      private static void printResult(String fileName,
88                                      ChineseRingsClassifier sofm)
89                                      throws FileNotFoundException, UnsupportedEncodingException {
90          final NeuronSquareMesh2D.DataVisualization result = sofm.computeQualityIndicators();
91  
92          try (PrintWriter out = new PrintWriter(fileName, StandardCharsets.UTF_8.name())) {
93              out.println("# Number of samples: " + result.getNumberOfSamples());
94              out.println("# Quantization error: " + result.getMeanQuantizationError());
95              out.println("# Topographic error: " + result.getMeanTopographicError());
96              out.println();
97  
98              printImage("Quantization error", result.getQuantizationError(), out);
99              printImage("Topographic error", result.getTopographicError(), out);
100             printImage("Normalized hits", result.getNormalizedHits(), out);
101             printImage("U-matrix", result.getUMatrix(), out);
102         }
103     }
104 
105     /**
106      * @param desc Data description.
107      * @param image Data.
108      * @param out Output stream.
109      */
110     private static void printImage(String desc,
111                                    double[][] image,
112                                    PrintWriter out) {
113         out.println("# " + desc);
114         final int nR = image.length;
115         final int nC = image[0].length;
116         for (int i = 0; i < nR; i++) {
117             for (int j = 0; j < nC; j++) {
118                 out.print(image[i][j] + " ");
119             }
120             out.println();
121         }
122         out.println();
123     }
124 }