Coverage Report - org.apache.commons.nabla.forward.ForwardModeDifferentiator
 
Classes in this File Line Coverage Branch Coverage Complexity
ForwardModeDifferentiator
69%
29/42
50%
1/2
3
ForwardModeDifferentiator$DerivativeLoader
100%
3/3
N/A
3
 
 1  
 /*
 2  
  * Licensed to the Apache Software Foundation (ASF) under one or more
 3  
  * contributor license agreements.  See the NOTICE file distributed with
 4  
  * this work for additional information regarding copyright ownership.
 5  
  * The ASF licenses this file to You under the Apache License, Version 2.0
 6  
  * (the "License"); you may not use this file except in compliance with
 7  
  * the License.  You may obtain a copy of the License at
 8  
  *
 9  
  *      http://www.apache.org/licenses/LICENSE-2.0
 10  
  *
 11  
  * Unless required by applicable law or agreed to in writing, software
 12  
  * distributed under the License is distributed on an "AS IS" BASIS,
 13  
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14  
  * See the License for the specific language governing permissions and
 15  
  * limitations under the License.
 16  
  */
 17  
 package org.apache.commons.nabla.forward;
 18  
 
 19  
 import java.io.IOException;
 20  
 import java.io.OutputStream;
 21  
 import java.lang.reflect.Constructor;
 22  
 import java.lang.reflect.InvocationTargetException;
 23  
 import java.util.HashMap;
 24  
 import java.util.HashSet;
 25  
 import java.util.Set;
 26  
 
 27  
 import org.apache.commons.math3.analysis.UnivariateFunction;
 28  
 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
 29  
 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
 30  
 import org.apache.commons.math3.analysis.differentiation.UnivariateFunctionDifferentiator;
 31  
 import org.apache.commons.math3.util.FastMath;
 32  
 import org.apache.commons.nabla.DifferentiationException;
 33  
 import org.apache.commons.nabla.NablaMessages;
 34  
 import org.apache.commons.nabla.forward.analysis.ClassDifferentiator;
 35  
 import org.objectweb.asm.ClassWriter;
 36  
 import org.objectweb.asm.Type;
 37  
 import org.objectweb.asm.tree.ClassNode;
 38  
 
 39  
 /** Algorithmic differentiator class in forward mode based on bytecode analysis.
 40  
  * <p>This class is an implementation of the {@link UnivariateFunctionDifferentiator}
 41  
  * interface that computes <em>exact</em> differentials completely automatically
 42  
  * and generate java classes and instances that compute the differential
 43  
  * of the function as if they were hand-coded and compiled.</p>
 44  
  * <p>The derivative bytecode created the first time an instance of a given class
 45  
  * is differentiated is cached and will be reused if other instances of the same class
 46  
  * are to be created later. The cache can also be dumped in a jar file for
 47  
  * use in an application without bringing the full nabla library and its
 48  
  * dependencies.</p>
 49  
  * <p>This differentiator can handle only pure bytecode methods and known methods
 50  
  * from math implementation classes like {@link java.lang.Math Math}, {@link
 51  
  * java.lang.StrictMath StrictMath} or {@link FastMath}. Pure bytecode methods are
 52  
  * analyzed and converted. Methods from math implementation classes are only
 53  
  * recognized by class and name and replaced by predefined derivative code.</p>
 54  
  * @see org.apache.commons.nabla.caching.FetchDifferentiator
 55  
  * @version $Id$
 56  
  */
 57  
 public class ForwardModeDifferentiator implements UnivariateFunctionDifferentiator {
 58  
 
 59  
     /** UnivariateFunction/UnivariateDifferentiableFunction map. */
 60  
     private final HashMap<Class<? extends UnivariateFunction>,
 61  
                           Class<? extends UnivariateDifferentiableFunction>> map;
 62  
 
 63  
     /** Class name/ bytecode map. */
 64  
     private final HashMap<String, byte[]> byteCodeMap;
 65  
 
 66  
     /** Math implementation classes. */
 67  
     private final Set<String> mathClasses;
 68  
 
 69  
     /** Simple constructor.
 70  
      * <p>Build a ForwardAlgorithmicDifferentiator instance with an empty cache.</p>
 71  
      */
 72  67
     public ForwardModeDifferentiator() {
 73  67
         map         = new HashMap<Class<? extends UnivariateFunction>,
 74  
                                   Class<? extends UnivariateDifferentiableFunction>>();
 75  67
         byteCodeMap = new HashMap<String, byte[]>();
 76  67
         mathClasses = new HashSet<String>();
 77  67
         addMathImplementation(Math.class);
 78  67
         addMathImplementation(StrictMath.class);
 79  67
         addMathImplementation(FastMath.class);
 80  67
     }
 81  
 
 82  
     /** Add an implementation class for mathematical functions.
 83  
      * <p>At construction, the differentiator considers only the {@link
 84  
      * java.lang.Math Math} and {@link java.lang.StrictMath StrictMath}
 85  
      * classes are math implementation classes. It may be useful to add
 86  
      * other classes for example to add some missing functions like
 87  
      * inverse hyperbolic cosine that are not provided by the standard
 88  
      * java classes as of Java 1.6.</p>
 89  
      * @param mathClass implementation class for mathematical functions
 90  
      */
 91  
     public void addMathImplementation(final Class<?> mathClass) {
 92  257
         mathClasses.add(mathClass.getName().replace('.', '/'));
 93  257
     }
 94  
 
 95  
     /** Dump the cache into a stream.
 96  
      * @param out output stream where to dump the cache
 97  
      */
 98  
     public void dumpCache(final OutputStream out) {
 99  
         // TODO: implement cache persistence
 100  0
         throw new RuntimeException("not implemented yet");
 101  
     }
 102  
 
 103  
     /** {@inheritDoc} */
 104  
     public UnivariateDifferentiableFunction differentiate(final UnivariateFunction d) {
 105  
 
 106  
         // get the derivative class
 107  67
         final Class<? extends UnivariateDifferentiableFunction> derivativeClass =
 108  
             getDerivativeClass(d.getClass());
 109  
 
 110  
         try {
 111  
 
 112  
             // create the instance
 113  67
             final Constructor<? extends UnivariateDifferentiableFunction> constructor =
 114  
                 derivativeClass.getConstructor(d.getClass());
 115  67
             return constructor.newInstance(d);
 116  
 
 117  0
         } catch (InstantiationException ie) {
 118  0
             throw new DifferentiationException(NablaMessages.CANNOT_INSTANTIATE_ABSTRACT_CLASS,
 119  
                                                derivativeClass.getName(), ie.getMessage());
 120  0
         } catch (IllegalAccessException iae) {
 121  0
             throw new DifferentiationException(NablaMessages.ILLEGAL_ACCESS_TO_CONSTRUCTOR,
 122  
                                                derivativeClass.getName(), iae.getMessage());
 123  0
         } catch (NoSuchMethodException nsme) {
 124  0
             throw new DifferentiationException(NablaMessages.CANNOT_BUILD_CLASS_FROM_OTHER_CLASS,
 125  
                                                derivativeClass.getName(), d.getClass().getName(), nsme.getMessage());
 126  0
         } catch (InvocationTargetException ite) {
 127  0
             throw new DifferentiationException(NablaMessages.CANNOT_INSTANTIATE_CLASS_FROM_OTHER_INSTANCE,
 128  
                                                derivativeClass.getName(), d.getClass().getName(), ite.getMessage());
 129  0
         } catch (VerifyError ve) {
 130  0
             throw new DifferentiationException(NablaMessages.INCORRECT_GENERATED_CODE,
 131  
                                                derivativeClass.getName(), d.getClass().getName(), ve.getMessage());
 132  
         }
 133  
 
 134  
     }
 135  
 
 136  
     /** Get the derivative class of a differentiable class.
 137  
      * <p>The derivative class is either built on the fly
 138  
      * or retrieved from the cache if it has been built previously.</p>
 139  
      * @param differentiableClass class to differentiate
 140  
      * @return derivative class
 141  
      * @throws DifferentiationException if the class cannot be differentiated
 142  
      */
 143  
     private Class<? extends UnivariateDifferentiableFunction>
 144  
     getDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
 145  
         throws DifferentiationException {
 146  
 
 147  
         // lookup in the map if the class has already been differentiated
 148  67
         Class<? extends UnivariateDifferentiableFunction> derivativeClass =
 149  
             map.get(differentiableClass);
 150  
 
 151  
         // build the derivative class if it does not exist yet
 152  67
         if (derivativeClass == null) {
 153  
 
 154  
             // perform algorithmic differentiation
 155  67
             derivativeClass = createDerivativeClass(differentiableClass);
 156  
 
 157  
             // put the newly created class in the map
 158  67
             map.put(differentiableClass, derivativeClass);
 159  
 
 160  
         }
 161  
 
 162  
         // return the derivative class
 163  67
         return derivativeClass;
 164  
 
 165  
     }
 166  
 
 167  
     /** Build a derivative class of a differentiable class.
 168  
      * @param differentiableClass class to differentiate
 169  
      * @return derivative class
 170  
      * @throws DifferentiationException if the class cannot be differentiated
 171  
      */
 172  
     private Class<? extends UnivariateDifferentiableFunction>
 173  
     createDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
 174  
         throws DifferentiationException {
 175  
         try {
 176  
 
 177  
             // differentiate the function embedded in the differentiable class
 178  67
             final ClassDifferentiator differentiator =
 179  
                 new ClassDifferentiator(differentiableClass, mathClasses);
 180  67
             final Type dsType = Type.getType(DerivativeStructure.class);
 181  67
             differentiator.differentiateMethod("value",
 182  
                                                Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE),
 183  
                                                Type.getMethodDescriptor(dsType, dsType));
 184  
 
 185  
             // create the derivative class
 186  67
             final ClassNode   derived = differentiator.getDerivedClass();
 187  67
             final ClassWriter writer  = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
 188  67
             final String name = derived.name.replace('/', '.');
 189  67
             derived.accept(writer);
 190  67
             final byte[] bytecode = writer.toByteArray();
 191  
 
 192  67
             final Class<? extends UnivariateDifferentiableFunction> dClass =
 193  
                     new DerivativeLoader(differentiableClass).defineClass(name, bytecode);
 194  67
             byteCodeMap.put(name, bytecode);
 195  67
             return dClass;
 196  
 
 197  0
         } catch (IOException ioe) {
 198  0
             throw new DifferentiationException(NablaMessages.CANNOT_READ_CLASS,
 199  
                                                differentiableClass.getName(), ioe.getMessage());
 200  
         }
 201  
     }
 202  
 
 203  
     /** Class loader generating derivative classes. */
 204  
     private static class DerivativeLoader extends ClassLoader {
 205  
 
 206  
         /** Simple constructor.
 207  
          * @param differentiableClass differentiable class
 208  
          */
 209  
         public DerivativeLoader(final Class<? extends UnivariateFunction> differentiableClass) {
 210  67
             super(differentiableClass.getClassLoader());
 211  67
         }
 212  
 
 213  
         /** Define a derivative class.
 214  
          * @param name name of the differentiated class
 215  
          * @param bytecode bytecode of the differentiated class
 216  
          * @return a generated derivative class
 217  
          */
 218  
         @SuppressWarnings("unchecked")
 219  
         public Class<? extends UnivariateDifferentiableFunction>
 220  
         defineClass(final String name, final byte[] bytecode) {
 221  67
             return (Class<? extends UnivariateDifferentiableFunction>) defineClass(name, bytecode, 0, bytecode.length);
 222  
         }
 223  
     }
 224  
 
 225  
 }