Coverage Report - org.apache.commons.nabla.forward.analysis.ClassDifferentiator
 
Classes in this File Line Coverage Branch Coverage Complexity
ClassDifferentiator
96%
73/76
68%
11/16
2.667
 
 1  
 /*
 2  
  * Licensed to the Apache Software Foundation (ASF) under one or more
 3  
  * contributor license agreements.  See the NOTICE file distributed with
 4  
  * this work for additional information regarding copyright ownership.
 5  
  * The ASF licenses this file to You under the Apache License, Version 2.0
 6  
  * (the "License"); you may not use this file except in compliance with
 7  
  * the License.  You may obtain a copy of the License at
 8  
  *
 9  
  *      http://www.apache.org/licenses/LICENSE-2.0
 10  
  *
 11  
  * Unless required by applicable law or agreed to in writing, software
 12  
  * distributed under the License is distributed on an "AS IS" BASIS,
 13  
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14  
  * See the License for the specific language governing permissions and
 15  
  * limitations under the License.
 16  
  */
 17  
 package org.apache.commons.nabla.forward.analysis;
 18  
 
 19  
 import java.io.IOException;
 20  
 import java.io.InputStream;
 21  
 import java.lang.reflect.Field;
 22  
 import java.util.Set;
 23  
 
 24  
 import org.apache.commons.math3.analysis.UnivariateFunction;
 25  
 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
 26  
 import org.apache.commons.nabla.DifferentiationException;
 27  
 import org.apache.commons.nabla.NablaMessages;
 28  
 import org.objectweb.asm.ClassReader;
 29  
 import org.objectweb.asm.Label;
 30  
 import org.objectweb.asm.Opcodes;
 31  
 import org.objectweb.asm.Type;
 32  
 import org.objectweb.asm.tree.ClassNode;
 33  
 import org.objectweb.asm.tree.FieldNode;
 34  
 import org.objectweb.asm.tree.MethodNode;
 35  
 
 36  
 /**
 37  
  * Differentiator for classes using forward mode.
 38  
  * <p>
 39  
  * This differentiator transforms classes implementing the
 40  
  * {@link UnivariateFunction UnivariateFunction} interface and convert
 41  
  * them to classes implementing the {@link UnivariateDifferentiableFunction
 42  
  * UnivariateDifferentiableFunction} interface.
 43  
  * </p>
 44  
  * <p>
 45  
  * The differentiator creates a new class in the same package as the primitive class and
 46  
  * which only preserve a private reference to the primitive instance. They access the
 47  
  * current value of all necessary primitive instance fields thanks to reflection and
 48  
  * bypassing access restrictions.
 49  
  * </p>
 50  
  * <p>
 51  
  * The original class bytecode is not changed at all.
 52  
  * </p>
 53  
  * @version $Id$
 54  
  */
 55  
 public class ClassDifferentiator {
 56  
 
 57  
     /** Name for the primitive instance field. */
 58  
     private static final String PRIMITIVE_FIELD = "primitive";
 59  
 
 60  
     /** Name fo the constructor methods. */
 61  
     private static final String INIT = "<init>";
 62  
 
 63  
     /** Math implementation classes. */
 64  
     private final Set<String> mathClasses;
 65  
 
 66  
     /** Class to differentiate. */
 67  
     private final Class<? extends UnivariateFunction> primitiveClass;
 68  
 
 69  
     /** Node of the class to differentiate. */
 70  
     private final ClassNode primitiveNode;
 71  
 
 72  
     /** Class to differentiate. */
 73  
     private final ClassNode classNode;
 74  
 
 75  
     /**
 76  
      * Simple constructor.
 77  
      * @param primitiveClass primitive class
 78  
      * @param mathClasses math implementation classes
 79  
      * @exception DifferentiationException if class cannot be differentiated
 80  
      * @throws IOException if class cannot be read
 81  
      */
 82  
     public ClassDifferentiator(final Class<? extends UnivariateFunction> primitiveClass,
 83  
                                final Set<String> mathClasses)
 84  67
         throws DifferentiationException, IOException {
 85  
 
 86  
         // get the original class
 87  67
         this.primitiveClass = primitiveClass;
 88  67
         final String classResourceName = "/" + primitiveClass.getName().replace('.', '/') + ".class";
 89  67
         final InputStream stream = primitiveClass.getResourceAsStream(classResourceName);
 90  67
         final ClassReader reader = new ClassReader(stream);
 91  67
         primitiveNode = new ClassNode(Opcodes.ASM4);
 92  67
         reader.accept(primitiveNode, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
 93  67
         this.mathClasses = mathClasses;
 94  67
         classNode = new ClassNode(Opcodes.ASM4);
 95  
 
 96  
         // check the UnivariateFunction interface is implemented
 97  67
         final Class<UnivariateFunction> uFuncClass = UnivariateFunction.class;
 98  67
         boolean isDifferentiable = false;
 99  67
         for (String interf : primitiveNode.interfaces) {
 100  67
             final String interfName = interf.replace('/', '.');
 101  67
             Class<?> interfClass = null;
 102  
             try {
 103  67
                 interfClass = Class.forName(interfName);
 104  0
             } catch (ClassNotFoundException cnfe) {
 105  
                 // this should never occur since class has already been loaded
 106  
                 // and an instance already exists ...
 107  0
                 throw new DifferentiationException(NablaMessages.INTERFACE_NOT_FOUND_WHILE_DIFFERENTIATING,
 108  
                                                    interfName, primitiveNode.name);
 109  67
             }
 110  67
             if (interfClass != null) {
 111  67
                 isDifferentiable = isDifferentiable || uFuncClass.isAssignableFrom(interfClass);
 112  
             }
 113  67
         }
 114  
 
 115  67
         if (!isDifferentiable) {
 116  0
             throw new DifferentiationException(NablaMessages.CLASS_DOES_NOT_IMPLEMENT_INTERFACE,
 117  
                                                primitiveNode.name, uFuncClass.getName());
 118  
         }
 119  
 
 120  
         // change the class properties for the derived class
 121  67
         classNode.visit(primitiveNode.version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC,
 122  
                         primitiveNode.name + "_NablaForwardModeUnivariateDerivative",
 123  
                         null, Type.getType(Object.class).getInternalName(),
 124  
                         new String[] {
 125  
                             UnivariateDifferentiableFunction.class.getName().replace('.', '/')
 126  
                         });
 127  
 
 128  
         // add boilerplate code
 129  67
         addPrimitiveField();
 130  67
         addConstructor();
 131  67
         addGetPrimitiveFieldMethod();
 132  
 
 133  67
     }
 134  
 
 135  
     /**
 136  
      * Differentiate a method.
 137  
      * @param name of the method
 138  
      * @param primitiveDesc descriptor of the method in the primitive class
 139  
      * @param derivativeDesc descriptor of the method in the derivative class
 140  
      * @exception DifferentiationException if method cannot be differentiated
 141  
      */
 142  
     public void differentiateMethod(final String name, final String primitiveDesc,
 143  
                                     final String derivativeDesc)
 144  
         throws DifferentiationException {
 145  
 
 146  67
         for (final MethodNode method : primitiveNode.methods) {
 147  192
             if (method.name.equals(name) && method.desc.equals(primitiveDesc)) {
 148  
 
 149  67
                 final MethodDifferentiator differentiator =
 150  
                         new MethodDifferentiator(mathClasses, classNode.name);
 151  67
                 differentiator.differentiate(primitiveNode.name, method);
 152  67
                 classNode.methods.add(method);
 153  
 
 154  192
             }
 155  
         }
 156  67
     }
 157  
 
 158  
     /**
 159  
      * Get the derived class.
 160  
      * @return derived class
 161  
      */
 162  
     public ClassNode getDerivedClass() {
 163  67
         return classNode;
 164  
     }
 165  
 
 166  
     /** Add the primitive field.
 167  
      */
 168  
     private void addPrimitiveField() {
 169  67
         final FieldNode primitiveField =
 170  
             new FieldNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC,
 171  
                           PRIMITIVE_FIELD, Type.getDescriptor(primitiveClass), null, null);
 172  67
         classNode.fields.add(primitiveField);
 173  67
     }
 174  
 
 175  
     /** Add the class constructor.
 176  
      */
 177  
     private void addConstructor() {
 178  67
         final MethodNode constructor =
 179  
             new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, INIT,
 180  
                            Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(primitiveClass)),
 181  
                            null, null);
 182  67
         constructor.visitVarInsn(Opcodes.ALOAD, 0);
 183  67
         constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getType(Object.class).getInternalName(),
 184  
                                     INIT, "()V");
 185  67
         constructor.visitVarInsn(Opcodes.ALOAD, 0);
 186  67
         constructor.visitVarInsn(Opcodes.ALOAD, 1);
 187  67
         constructor.visitFieldInsn(Opcodes.PUTFIELD, classNode.name, PRIMITIVE_FIELD,
 188  
                                    Type.getDescriptor(primitiveClass));
 189  67
         constructor.visitInsn(Opcodes.RETURN);
 190  67
         constructor.visitMaxs(0, 0);
 191  67
         classNode.methods.add(constructor);
 192  67
     }
 193  
 
 194  
     /** Add the getPrimitiveField method.
 195  
      */
 196  
     private void addGetPrimitiveFieldMethod() {
 197  67
         final MethodNode method =
 198  
             new MethodNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_SYNTHETIC, "getPrimitiveField",
 199  
                            Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(String.class)),
 200  
                            null, null);
 201  67
         final Label start     = new Label();
 202  67
         final Label end       = new Label();
 203  67
         method.visitTryCatchBlock(start, end, end, Type.getInternalName(IllegalAccessException.class));
 204  67
         method.visitTryCatchBlock(start, end, end, Type.getInternalName(NoSuchFieldException.class));
 205  67
         method.visitLabel(start);
 206  67
         method.visitLdcInsn(Type.getType(primitiveClass));
 207  67
         method.visitVarInsn(Opcodes.ALOAD, 1);
 208  67
         method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class),
 209  
                                "getDeclaredField",
 210  
                                Type.getMethodDescriptor(Type.getType(Field.class), Type.getType(String.class)));
 211  67
         method.visitVarInsn(Opcodes.ASTORE, 2);
 212  67
         method.visitVarInsn(Opcodes.ALOAD, 2);
 213  67
         method.visitInsn(Opcodes.ICONST_1);
 214  67
         method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class),
 215  
                                "setAccessible",
 216  
                                Type.getMethodDescriptor(Type.VOID_TYPE, Type.BOOLEAN_TYPE));
 217  67
         method.visitVarInsn(Opcodes.ALOAD, 2);
 218  67
         method.visitVarInsn(Opcodes.ALOAD, 0);
 219  67
         method.visitFieldInsn(Opcodes.GETFIELD, classNode.name, PRIMITIVE_FIELD,
 220  
                               Type.getDescriptor(primitiveClass));
 221  67
         method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class),
 222  
                                "get",
 223  
                                Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(Object.class)));
 224  67
         method.visitInsn(Opcodes.ARETURN);
 225  67
         method.visitLabel(end);
 226  67
         method.visitVarInsn(Opcodes.ASTORE, 2);
 227  67
         method.visitTypeInsn(Opcodes.NEW, Type.getInternalName(RuntimeException.class));
 228  67
         method.visitInsn(Opcodes.DUP);
 229  67
         method.visitVarInsn(Opcodes.ALOAD, 2);
 230  67
         method.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(RuntimeException.class),
 231  
                                INIT, Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(Throwable.class)));
 232  67
         method.visitInsn(Opcodes.ATHROW);
 233  67
         classNode.methods.add(method);
 234  67
     }
 235  
 
 236  
 }