Coverage Report - org.apache.commons.nabla.algorithmic.forward.ForwardAlgorithmicDifferentiator
 
Classes in this File Line Coverage Branch Coverage Complexity
ForwardAlgorithmicDifferentiator
68%
24/35
50%
1/2
2.75
ForwardAlgorithmicDifferentiator$DerivativeLoader
100%
5/5
N/A
2.75
 
 1  
 /*
 2  
  * Licensed to the Apache Software Foundation (ASF) under one or more
 3  
  * contributor license agreements.  See the NOTICE file distributed with
 4  
  * this work for additional information regarding copyright ownership.
 5  
  * The ASF licenses this file to You under the Apache License, Version 2.0
 6  
  * (the "License"); you may not use this file except in compliance with
 7  
  * the License.  You may obtain a copy of the License at
 8  
  *
 9  
  *      http://www.apache.org/licenses/LICENSE-2.0
 10  
  *
 11  
  * Unless required by applicable law or agreed to in writing, software
 12  
  * distributed under the License is distributed on an "AS IS" BASIS,
 13  
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14  
  * See the License for the specific language governing permissions and
 15  
  * limitations under the License.
 16  
  */
 17  
 package org.apache.commons.nabla.algorithmic.forward;
 18  
 
 19  
 import java.io.IOException;
 20  
 import java.io.InputStream;
 21  
 import java.io.OutputStream;
 22  
 import java.lang.reflect.Constructor;
 23  
 import java.lang.reflect.InvocationTargetException;
 24  
 import java.util.HashMap;
 25  
 import java.util.HashSet;
 26  
 import java.util.Set;
 27  
 
 28  
 import org.apache.commons.nabla.algorithmic.forward.analysis.ClassDifferentiator;
 29  
 import org.apache.commons.nabla.core.DifferentiationException;
 30  
 import org.apache.commons.nabla.core.UnivariateDerivative;
 31  
 import org.apache.commons.nabla.core.UnivariateDifferentiable;
 32  
 import org.apache.commons.nabla.core.UnivariateDifferentiator;
 33  
 import org.objectweb.asm.ClassReader;
 34  
 import org.objectweb.asm.ClassWriter;
 35  
 
 36  
 /** Automatic differentiator class based on bytecode analysis.
 37  
  * <p>This class is an implementation of the {@link UnivariateDifferentiator}
 38  
  * interface that computes <em>exact</em> differentials completely automatically
 39  
  * and generate java classes and instances that compute the differential
 40  
  * of the function as if they were hand-coded and compiled.</p>
 41  
  * <p>The derivative bytecode created the first time an instance of a given class
 42  
  * is differentiated is cached and will be reused if other instances of the same class
 43  
  * are to be created later. The cache can also be dumped in a jar file for
 44  
  * use in an application without bringing the full nabla library and its
 45  
  * dependencies.</p>
 46  
  * <p>This differentiator can handle only pure bytecode methods and known methods
 47  
  * from math implementation classes like {@link java.lang.Math Math} or
 48  
  * {@link java.lang.StrictMath StrictMath}. Pure bytecode methods are analyzed
 49  
  * and converted. Methods from math implementation classes are only recognized
 50  
  * by class and name and replaced by predefined derivative code.</p>
 51  
  * @see org.apache.commons.nabla.caching.FetchDifferentiator
 52  
  */
 53  
 public class ForwardAlgorithmicDifferentiator implements UnivariateDifferentiator {
 54  
 
 55  
     /** UnivariateDifferentiable/UnivariateDerivative map. */
 56  
     private final HashMap<Class<? extends UnivariateDifferentiable>,
 57  
     Class<? extends UnivariateDerivative>> map;
 58  
 
 59  
     /** Math implementation classes. */
 60  
     private final Set<String> mathClasses;
 61  
 
 62  
     /** Simple constructor.
 63  
      * <p>Build a ForwardAlgorithmicDifferentiator instance with an empty cache.</p>
 64  
      */
 65  66
     public ForwardAlgorithmicDifferentiator() {
 66  66
         map = new HashMap<Class<? extends UnivariateDifferentiable>,
 67  
         Class<? extends UnivariateDerivative>>();
 68  66
         mathClasses = new HashSet<String>();
 69  66
         addMathImplementation(Math.class);
 70  66
         addMathImplementation(StrictMath.class);
 71  66
     }
 72  
 
 73  
     /** Add an implementation class for mathematical functions.
 74  
      * <p>At construction, the differentiator considers only the {@link
 75  
      * java.lang.Math Math} and {@link java.lang.StrictMath StrictMath}
 76  
      * classes are math implementation classes. It may be useful to add
 77  
      * other class for example to add some missing functions like
 78  
      * inverse hyperbolic cosine that are not provided by the standard
 79  
      * java classes as of Java 1.6.</p>
 80  
      * @param mathClass implementation class for mathematical functions
 81  
      */
 82  
     public void addMathImplementation(final Class<?> mathClass) {
 83  188
         mathClasses.add(mathClass.getName().replace('.', '/'));
 84  188
     }
 85  
 
 86  
     /** Dump the cache into a stream.
 87  
      * @param out output stream where to dump the cache
 88  
      */
 89  
     public void dumpCache(final OutputStream out) {
 90  
         // TODO implement cache persistence
 91  0
         throw new RuntimeException("not implemented yet");
 92  
     }
 93  
 
 94  
     /** {@inheritDoc} */
 95  
     public UnivariateDerivative differentiate(final UnivariateDifferentiable d)
 96  
         throws DifferentiationException {
 97  
 
 98  
         // get the derivative class
 99  66
         final Class<? extends UnivariateDerivative> derivativeClass =
 100  
             getDerivativeClass(d.getClass());
 101  
 
 102  
         try {
 103  
 
 104  
             // create the instance
 105  66
             final Constructor<? extends UnivariateDerivative> constructor =
 106  
                 derivativeClass.getConstructor(d.getClass());
 107  66
             return constructor.newInstance(d);
 108  
 
 109  0
         } catch (InstantiationException ie) {
 110  0
             throw new DifferentiationException("abstract class {0} cannot be instantiated ({1})",
 111  
                                                derivativeClass.getName(), ie.getMessage());
 112  0
         } catch (IllegalAccessException iae) {
 113  0
             throw new DifferentiationException("illegal access to class {0} constructor ({1})",
 114  
                                                derivativeClass.getName(), iae.getMessage());
 115  0
         } catch (NoSuchMethodException nsme) {
 116  0
             throw new DifferentiationException("class {0} cannot be built from an instance of class {1} ({2})",
 117  
                                                derivativeClass.getName(), d.getClass().getName(), nsme.getMessage());
 118  0
         } catch (InvocationTargetException ite) {
 119  0
             throw new DifferentiationException("class {0} instantiation from an instance of class {1} failed ({2})",
 120  
                                                derivativeClass.getName(), d.getClass().getName(), ite.getMessage());
 121  
         }
 122  
 
 123  
     }
 124  
 
 125  
     /** Get the derivative class of a differentiable class.
 126  
      * <p>The derivative class is either built on the fly
 127  
      * or retrieved from the cache if it has been built previously.</p>
 128  
      * @param differentiableClass class to differentiate
 129  
      * @return derivative class
 130  
      * @throws DifferentiationException if the class cannot be differentiated
 131  
      */
 132  
     private Class<? extends UnivariateDerivative>
 133  
     getDerivativeClass(final Class<? extends UnivariateDifferentiable> differentiableClass)
 134  
         throws DifferentiationException {
 135  
 
 136  
         // lookup in the map if the class has already been differentiated
 137  66
         Class<? extends UnivariateDerivative> derivativeClass =
 138  
             map.get(differentiableClass);
 139  
 
 140  
         // build the derivative class if it does not exist yet
 141  66
         if (derivativeClass == null) {
 142  
             // perform analytical differentiation
 143  66
             derivativeClass = createDerivativeClass(differentiableClass);
 144  
 
 145  
             // put the newly created class in the map
 146  66
             map.put(differentiableClass, derivativeClass);
 147  
 
 148  
         }
 149  
 
 150  
         // return the derivative class
 151  66
         return derivativeClass;
 152  
 
 153  
     }
 154  
 
 155  
     /** Build a derivative class of a differentiable class.
 156  
      * @param differentiableClass class to differentiate
 157  
      * @return derivative class
 158  
      * @throws DifferentiationException if the class cannot be differentiated
 159  
      */
 160  
     private Class<? extends UnivariateDerivative>
 161  
     createDerivativeClass(final Class<? extends UnivariateDifferentiable> differentiableClass)
 162  
         throws DifferentiationException {
 163  
         try {
 164  
 
 165  
             // set up both ends of the class transform chain
 166  66
             final String classResourceName = "/" + differentiableClass.getName().replace('.', '/') + ".class";
 167  66
             final InputStream stream = differentiableClass.getResourceAsStream(classResourceName);
 168  66
             final ClassReader reader = new ClassReader(stream);
 169  66
             final ClassWriter writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES);
 170  
 
 171  
             // differentiate the function embedded in the differentiable class
 172  66
             final ClassDifferentiator differentiator = new ClassDifferentiator(mathClasses, writer);
 173  66
             reader.accept(differentiator, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
 174  66
             differentiator.reportErrors();
 175  
 
 176  
             // create the derivative class
 177  66
             return new DerivativeLoader(differentiableClass).defineClass(differentiator, writer);
 178  
 
 179  0
         } catch (IOException ioe) {
 180  0
             throw new DifferentiationException("class {0} cannot be read ({1})",
 181  
                                           differentiableClass.getName(), ioe.getMessage());
 182  
         }
 183  
     }
 184  
 
 185  
     /** Class loader generating derivative classes. */
 186  
     private static class DerivativeLoader extends ClassLoader {
 187  
 
 188  
         /** Simple constructor.
 189  
          * @param differentiableClass differentiable class
 190  
          */
 191  
         public DerivativeLoader(final Class<? extends UnivariateDifferentiable> differentiableClass) {
 192  66
             super(differentiableClass.getClassLoader());
 193  66
         }
 194  
 
 195  
         /** Define a derivative class.
 196  
          * @param differentiator class differentiator
 197  
          * @param writer class writer
 198  
          * @return a generated derivative class
 199  
          */
 200  
         @SuppressWarnings("unchecked")
 201  
         public Class<? extends UnivariateDerivative>
 202  
         defineClass(final ClassDifferentiator differentiator, final ClassWriter writer) {
 203  66
             final String name = differentiator.getDerivativeClassName().replace('/', '.');
 204  66
             final byte[] bytecode = writer.toByteArray();
 205  66
             return (Class<? extends UnivariateDerivative>) defineClass(name, bytecode, 0, bytecode.length);
 206  
         }
 207  
     }
 208  
 
 209  
 }