Coverage Report - org.apache.commons.nabla.automatic.analysis.AutomaticDifferentiator
 
Classes in this File Line Coverage Branch Coverage Complexity
AutomaticDifferentiator
74% 
100% 
0
 
 1  
 /*
 2  
  * Licensed to the Apache Software Foundation (ASF) under one or more
 3  
  * contributor license agreements.  See the NOTICE file distributed with
 4  
  * this work for additional information regarding copyright ownership.
 5  
  * The ASF licenses this file to You under the Apache License, Version 2.0
 6  
  * (the "License"); you may not use this file except in compliance with
 7  
  * the License.  You may obtain a copy of the License at
 8  
  *
 9  
  *      http://www.apache.org/licenses/LICENSE-2.0
 10  
  *
 11  
  * Unless required by applicable law or agreed to in writing, software
 12  
  * distributed under the License is distributed on an "AS IS" BASIS,
 13  
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14  
  * See the License for the specific language governing permissions and
 15  
  * limitations under the License.
 16  
  */
 17  
 package org.apache.commons.nabla.automatic.analysis;
 18  
 
 19  
 import java.io.IOException;
 20  
 import java.io.InputStream;
 21  
 import java.io.OutputStream;
 22  
 import java.lang.reflect.Constructor;
 23  
 import java.lang.reflect.InvocationTargetException;
 24  
 import java.util.HashMap;
 25  
 import java.util.HashSet;
 26  
 import java.util.Set;
 27  
 
 28  
 import org.apache.commons.nabla.core.DifferentialPair;
 29  
 import org.apache.commons.nabla.core.DifferentiationException;
 30  
 import org.apache.commons.nabla.core.UnivariateDerivative;
 31  
 import org.apache.commons.nabla.core.UnivariateDifferentiable;
 32  
 import org.apache.commons.nabla.core.UnivariateDifferentiator;
 33  
 import org.objectweb.asm.ClassReader;
 34  
 import org.objectweb.asm.ClassWriter;
 35  
 
 36  
 /** Automatic differentiator class based on bytecode analysis.
 37  
  * <p>This class is an implementation of the {@link UnivariateDifferentiator}
 38  
  * interface that computes <em>exact</em> differentials completely automatically
 39  
  * and generate java classes and instances that compute the differential
 40  
  * of the function as if they were hand-coded and compiled.</p>
 41  
  * <p>The derivative bytecode created the first time an instance of a given class
 42  
  * is differentiated is cached and will be reused if other instances of the same class
 43  
  * are to be created later. The cache can also be dumped in a jar file for
 44  
  * use in an application without bringing the full nabla library and its
 45  
  * dependencies.</p>
 46  
  * <p>This differentiator can handle only pure bytecode methods and known methods
 47  
  * from math implementation classes like {@link java.lang.Math Math} or
 48  
  * {@link java.lang.StrictMath StrictMath}. Pure bytecode methods are analyzed
 49  
  * and converted. Methods from math implementation classes are only recognized
 50  
  * by class and name and replaced by predefined derivative code.</p>
 51  
  * @see org.apache.commons.nabla.Fetchdifferentiator
 52  
  */
 53  
 public class AutomaticDifferentiator implements UnivariateDifferentiator {
 54  
 
 55  
     /** Name for the DifferentialPair class. */
 56  9
     public static final String DP_NAME = DifferentialPair.class.getName().replace('.', '/');
 57  
 
 58  
     /** Descriptor for the DifferentialPair class. */
 59  9
     public static final String DP_DESCRIPTOR = "L" + DP_NAME + ";";
 60  
 
 61  
     /** Descriptor for the derivative class f method. */
 62  9
     public static final String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR;
 63  
 
 64  
     /** UnivariateDifferentiable/UnivariateDerivative map. */
 65  
     private final HashMap<Class<? extends UnivariateDifferentiable>,
 66  
     Class<? extends UnivariateDerivative>> map;
 67  
 
 68  
     /** Math implementation classes. */
 69  
     private final Set<String> mathClasses;
 70  
 
 71  
     /** Simple constructor.
 72  
      * <p>Build a AutomaticDifferentiator instance with an empty cache.</p>
 73  
      */
 74  504
     public AutomaticDifferentiator() {
 75  504
         map = new HashMap<Class<? extends UnivariateDifferentiable>,
 76  
         Class<? extends UnivariateDerivative>>();
 77  504
         mathClasses = new HashSet<String>();
 78  504
         addMathImplementation(Math.class);
 79  504
         addMathImplementation(StrictMath.class);
 80  504
     }
 81  
 
 82  
     /** Add an implementation class for mathematical functions.
 83  
      * <p>At construction, the differentiator considers only the {@link
 84  
      * java.lang.Math Math} and {@link java.lang.StrictMath StrictMath}
 85  
      * classes are math implementation classes. It may be useful to add
 86  
      * other class for example to add some missing functions like
 87  
      * inverse hyperbolic cosine that are not provided by the standard
 88  
      * java classes as of Java 1.6.</p>
 89  
      * @param mathClass implementation class for mathematical functions
 90  
      */
 91  
     public void addMathImplementation(final Class<?> mathClass) {
 92  1512
         mathClasses.add(mathClass.getName().replace('.', '/'));
 93  1512
     }
 94  
 
 95  
     /** Dump the cache into a stream.
 96  
      * @param out output stream where to dump the cache
 97  
      */
 98  
     public void dumpCache(final OutputStream out) {
 99  
         // TODO
 100  0
         throw new RuntimeException("not implemented yet");
 101  
     }
 102  
 
 103  
     /** {@inheritDoc} */
 104  
     public UnivariateDerivative differentiate(final UnivariateDifferentiable d)
 105  
         throws DifferentiationException {
 106  
 
 107  
         // get the derivative class
 108  504
         final Class<? extends UnivariateDerivative> derivativeClass =
 109  
             getDerivativeClass(d.getClass());
 110  
 
 111  
         try {
 112  
 
 113  
             // create the instance
 114  504
             final Constructor<? extends UnivariateDerivative> constructor =
 115  
                 derivativeClass.getConstructor(d.getClass());
 116  504
             return constructor.newInstance(d);
 117  
 
 118  0
         } catch (InstantiationException ie) {
 119  0
             throw new DifferentiationException("abstract class {0} cannot be instantiated ({1})",
 120  
                                                derivativeClass.getName(), ie.getMessage());
 121  0
         } catch (IllegalAccessException iae) {
 122  0
             throw new DifferentiationException("illegal access to class {0} constructor ({1})",
 123  
                                                derivativeClass.getName(), iae.getMessage());
 124  0
         } catch (NoSuchMethodException nsme) {
 125  0
             throw new DifferentiationException("class {0} cannot be built from an instance of class {1} ({2})",
 126  
                                                derivativeClass.getName(), d.getClass().getName(), nsme.getMessage());
 127  0
         } catch (InvocationTargetException ite) {
 128  0
             throw new DifferentiationException("class {0} instantiation from an instance of class {1} failed ({2})",
 129  
                                                derivativeClass.getName(), d.getClass().getName(), ite.getMessage());
 130  
         }
 131  
 
 132  
     }
 133  
 
 134  
     /** Get the derivative class of a differentiable class.
 135  
      * <p>The derivative class is either built on the fly
 136  
      * or retrieved from the cache if it has been built previously.</p>
 137  
      * @param differentiableClass class to differentiate
 138  
      * @return derivative class
 139  
      * @throws DifferentiationException if the class cannot be differentiated
 140  
      */
 141  
     private Class<? extends UnivariateDerivative>
 142  
     getDerivativeClass(final Class<? extends UnivariateDifferentiable> differentiableClass)
 143  
         throws DifferentiationException {
 144  
 
 145  
         // lookup in the map if the class has already been differentiated
 146  504
         Class<? extends UnivariateDerivative> derivativeClass =
 147  
             map.get(differentiableClass);
 148  
 
 149  
         // build the derivative class if it does not exist yet
 150  504
         if (derivativeClass == null) {
 151  
             // perform analytical differentiation
 152  504
             derivativeClass = createDerivativeClass(differentiableClass);
 153  
 
 154  
             // put the newly created class in the map
 155  504
             map.put(differentiableClass, derivativeClass);
 156  
 
 157  
         }
 158  
 
 159  
         // return the derivative class
 160  504
         return derivativeClass;
 161  
 
 162  
     }
 163  
 
 164  
     /** Build a derivative class of a differentiable class.
 165  
      * @param differentiableClass class to differentiate
 166  
      * @return derivative class
 167  
      * @throws DifferentiationException if the class cannot be differentiated
 168  
      */
 169  
     private Class<? extends UnivariateDerivative>
 170  
     createDerivativeClass(final Class<? extends UnivariateDifferentiable> differentiableClass)
 171  
         throws DifferentiationException {
 172  
         try {
 173  
 
 174  
             // set up both ends of the class transform chain
 175  504
             final String classResourceName = "/" + differentiableClass.getName().replace('.', '/') + ".class";
 176  504
             final InputStream stream = differentiableClass.getResourceAsStream(classResourceName);
 177  504
             final ClassReader reader = new ClassReader(stream);
 178  504
             final ClassWriter writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES);
 179  
 
 180  
             // differentiate the function embedded in the differentiable class
 181  504
             final ClassDifferentiator differentiator = new ClassDifferentiator(mathClasses, writer);
 182  504
             reader.accept(differentiator, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
 183  504
             differentiator.reportErrors();
 184  
 
 185  
             // create the derivative class
 186  504
             return new DerivativeLoader(differentiableClass).defineClass(differentiator, writer);
 187  
 
 188  0
         } catch (IOException ioe) {
 189  0
             throw new DifferentiationException("class {0} cannot be read ({1})",
 190  
                                           differentiableClass.getName(), ioe.getMessage());
 191  
         }
 192  
     }
 193  
 
 194  
     /** Class loader generating derivative classes. */
 195  
     private static class DerivativeLoader extends ClassLoader {
 196  
 
 197  
         /** Simple constructor.
 198  
          * @param differentiableClass differentiable class
 199  
          */
 200  
         public DerivativeLoader(final Class<? extends UnivariateDifferentiable> differentiableClass) {
 201  504
             super(differentiableClass.getClassLoader());
 202  504
         }
 203  
 
 204  
         /** Define a derivative class.
 205  
          * @param differentiator class differentiator
 206  
          * @param writer class writer
 207  
          * @return a generated derivative class
 208  
          */
 209  
         @SuppressWarnings("unchecked")
 210  
         public Class<? extends UnivariateDerivative>
 211  
         defineClass(final ClassDifferentiator differentiator, final ClassWriter writer) {
 212  504
             final String name = differentiator.getDerivativeClassName().replace('/', '.');
 213  504
             final byte[] bytecode = writer.toByteArray();
 214  504
             return (Class<? extends UnivariateDerivative>) defineClass(name, bytecode, 0, bytecode.length);
 215  
         }
 216  
     }
 217  
 
 218  
 }