Coverage Report - org.apache.commons.nabla.algorithmic.forward.analysis.ClassDifferentiator
 
Classes in this File Line Coverage Branch Coverage Complexity
ClassDifferentiator
88%
63/71
61%
16/26
2.067
 
 1  
 /*
 2  
  * Licensed to the Apache Software Foundation (ASF) under one or more
 3  
  * contributor license agreements.  See the NOTICE file distributed with
 4  
  * this work for additional information regarding copyright ownership.
 5  
  * The ASF licenses this file to You under the Apache License, Version 2.0
 6  
  * (the "License"); you may not use this file except in compliance with
 7  
  * the License.  You may obtain a copy of the License at
 8  
  *
 9  
  *      http://www.apache.org/licenses/LICENSE-2.0
 10  
  *
 11  
  * Unless required by applicable law or agreed to in writing, software
 12  
  * distributed under the License is distributed on an "AS IS" BASIS,
 13  
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14  
  * See the License for the specific language governing permissions and
 15  
  * limitations under the License.
 16  
  */
 17  
 package org.apache.commons.nabla.algorithmic.forward.analysis;
 18  
 
 19  
 import java.util.Set;
 20  
 
 21  
 import org.apache.commons.nabla.core.DifferentiationException;
 22  
 import org.apache.commons.nabla.core.UnivariateDerivative;
 23  
 import org.apache.commons.nabla.core.UnivariateDifferentiable;
 24  
 import org.objectweb.asm.AnnotationVisitor;
 25  
 import org.objectweb.asm.Attribute;
 26  
 import org.objectweb.asm.ClassVisitor;
 27  
 import org.objectweb.asm.FieldVisitor;
 28  
 import org.objectweb.asm.MethodVisitor;
 29  
 import org.objectweb.asm.Opcodes;
 30  
 
 31  
 /**
 32  
  * Visitor (in asm sense) for differentiating classes.
 33  
  * <p>
 34  
  * This visitor visits classes implementing the
 35  
  * {@link UnivariateDifferentiable UnivariateDifferentiable} interface and convert
 36  
  * them to classes implementing the {@link UnivariateDerivative
 37  
  * UnivariateDerivative} interface.
 38  
  * </p>
 39  
  * <p>
 40  
  * The visitor creates a new class as an inner class of the visited class.
 41  
  * Instances of the generated class are therefore automatically bound to their
 42  
  * primitive instance which is their directly enclosing instance. As such they
 43  
  * have access to the current value of all fields.
 44  
  * </p>
 45  
  * <p>
 46  
  * The visited class bytecode is not changed at all.
 47  
  * </p>
 48  
  */
 49  
 public class ClassDifferentiator implements ClassVisitor {
 50  
 
 51  
     /** Name for the primitive instance field. */
 52  
     private static final String PRIMITIVE_FIELD = "primitive";
 53  
 
 54  
     /** Math implementation classes. */
 55  
     private final Set<String> mathClasses;
 56  
 
 57  
     /** Class generating visitor. */
 58  
     private final ClassVisitor generator;
 59  
 
 60  
     /** Error reporter. */
 61  
     private final ErrorReporter errorReporter;
 62  
 
 63  
     /** Primitive class name. */
 64  
     private String primitiveName;
 65  
 
 66  
     /** Descriptor for the primitive class. */
 67  
     private String primitiveDesc;
 68  
 
 69  
     /** Derivative class name. */
 70  
     private String derivativeName;
 71  
 
 72  
     /** Indicator for specific fields and method addition. */
 73  
     private boolean specificMembersAdded;
 74  
 
 75  
     /**
 76  
      * Simple constructor.
 77  
      * @param mathClasses math implementation classes
 78  
      * @param generator visitor to which class generation calls will be delegated
 79  
      */
 80  
     public ClassDifferentiator(final Set<String> mathClasses,
 81  66
                           final ClassVisitor generator) {
 82  66
         this.mathClasses = mathClasses;
 83  66
         this.generator   = generator;
 84  66
         errorReporter    = new ErrorReporter();
 85  66
     }
 86  
 
 87  
     /**
 88  
      * Get the name of the derivative class.
 89  
      * @return name of the (generated) derivative class
 90  
      */
 91  
     public String getDerivativeClassName() {
 92  66
         return derivativeName;
 93  
     }
 94  
 
 95  
     /** {@inheritDoc} */
 96  
     public void visit(final int version, final int access,
 97  
                       final String name, final String signature,
 98  
                       final String superName, final String[] interfaces) {
 99  
         // set up the various names
 100  66
         primitiveName = name;
 101  66
         derivativeName   = primitiveName + "$NablaUnivariateDerivative";
 102  66
         primitiveDesc = "L" + primitiveName + ";";
 103  
 
 104  
         // check the UnivariateDifferentiable interface is implemented
 105  66
         final Class<UnivariateDifferentiable> uDerClass = UnivariateDifferentiable.class;
 106  66
         boolean isDifferentiable = false;
 107  132
         for (String interf : interfaces) {
 108  66
             final String interfName = interf.replace('/', '.');
 109  66
             Class<?> interfClass = null;
 110  
             try {
 111  66
                 interfClass = Class.forName(interfName);
 112  0
             } catch (ClassNotFoundException cnfe) {
 113  
                 // this should never occur since class has already been loaded
 114  
                 // and an instance already exists ...
 115  0
                 errorReporter.register(new DifferentiationException("interface {0} not found " +
 116  
                                                                     "while differentiating class {1}",
 117  
                                                                     interfName, name));
 118  66
             }
 119  66
             if (interfClass != null) {
 120  66
                 isDifferentiable = isDifferentiable || uDerClass.isAssignableFrom(interfClass);
 121  
             }
 122  
         }
 123  
 
 124  66
         if (isDifferentiable) {
 125  
             // generate the new class implementing the UnivariateDerivative interface
 126  66
             generator.visit(version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC,
 127  
                             derivativeName, signature, superName,
 128  
                             new String[] {
 129  
                                 UnivariateDerivative.class.getName().replace('.', '/')
 130  
                             });
 131  
         } else {
 132  0
             errorReporter.register(new DifferentiationException("the {0} class does not implement " +
 133  
                                                                 "the {1} interface",
 134  
                                                                 name, uDerClass.getName()));
 135  
         }
 136  
 
 137  66
         specificMembersAdded = false;
 138  
 
 139  66
     }
 140  
 
 141  
     /** {@inheritDoc} */
 142  
     public MethodVisitor visitMethod(final int access, final String name,
 143  
                                      final String desc, final String signature,
 144  
                                      final String[] exceptions) {
 145  
 
 146  
         // don't do anything if an error has already been encountered
 147  188
         if (errorReporter.hasError()) {
 148  0
             return null;
 149  
         }
 150  
 
 151  188
         if (!specificMembersAdded) {
 152  
             // add the specific members we need
 153  66
             addPrimitiveField();
 154  66
             addConstructor();
 155  66
             addGetPrimitive();
 156  66
             specificMembersAdded = true;
 157  
         }
 158  
 
 159  
         // is it the "public double f(double)" method we want to differentiate ?
 160  188
         if (((access & Opcodes.ACC_PUBLIC) == Opcodes.ACC_PUBLIC) &&
 161  
                 "f".equals(name) && "(D)D".equals(desc) &&
 162  
                 ((exceptions == null) || (exceptions.length == 0))) {
 163  
 
 164  
             // get a generator for the method we are going to create
 165  66
             final MethodVisitor visitor =
 166  
                 generator.visitMethod(access | Opcodes.ACC_SYNTHETIC, name,
 167  
                                       MethodDifferentiator.DP_RETURN_DP_DESCRIPTOR, null, null);
 168  
 
 169  
             // make sure our own differentiator will be used to transform the code
 170  66
             return new MethodDifferentiator(access, name, desc, signature, exceptions,
 171  
                                        visitor, primitiveName, mathClasses, errorReporter);
 172  
 
 173  
         }
 174  
 
 175  
         // we are not interested in this method
 176  122
         return null;
 177  
 
 178  
     }
 179  
 
 180  
     /** {@inheritDoc} */
 181  
     public FieldVisitor visitField(final int access, final String name,
 182  
                                    final String desc, final  String signature,
 183  
                                    final Object value) {
 184  
         // we are not interested in any fields
 185  66
         return null;
 186  
     }
 187  
 
 188  
     /** {@inheritDoc} */
 189  
     public void visitSource(final String source, final String debug) {
 190  0
     }
 191  
 
 192  
     /** {@inheritDoc} */
 193  
     public void visitOuterClass(final String owner, final String name,
 194  
                                 final String desc) {
 195  66
     }
 196  
 
 197  
     /** {@inheritDoc} */
 198  
     public AnnotationVisitor visitAnnotation(final String desc,
 199  
                                              final boolean visible) {
 200  0
         return null;
 201  
     }
 202  
 
 203  
     /** {@inheritDoc} */
 204  
     public void visitAttribute(final Attribute attr) {
 205  0
     }
 206  
 
 207  
     /** {@inheritDoc} */
 208  
     public void visitInnerClass(final String name, final String outerName,
 209  
                                 final String innerName, final int access) {
 210  69
     }
 211  
 
 212  
     /** {@inheritDoc} */
 213  
     public void visitEnd() {
 214  
 
 215  
         // don't do anything if an error has already been encountered
 216  66
         if (errorReporter.hasError()) {
 217  0
             return;
 218  
         }
 219  
 
 220  66
         generator.visitEnd();
 221  
 
 222  66
     }
 223  
 
 224  
     /** Add the primitive field.
 225  
      */
 226  
     private void addPrimitiveField() {
 227  66
         final FieldVisitor visitor =
 228  
             generator.visitField(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC,
 229  
                                  PRIMITIVE_FIELD, primitiveDesc, null, null);
 230  66
         visitor.visitEnd();
 231  66
     }
 232  
 
 233  
     /** Add the class constructor.
 234  
      */
 235  
     private void addConstructor() {
 236  66
         final String init = "<init>";
 237  66
         final MethodVisitor visitor =
 238  
             generator.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, init,
 239  
                                   "(" + primitiveDesc + ")V", null, null);
 240  66
         visitor.visitCode();
 241  66
         visitor.visitVarInsn(Opcodes.ALOAD, 0);
 242  66
         visitor.visitMethodInsn(Opcodes.INVOKESPECIAL, "java/lang/Object", init, "()V");
 243  66
         visitor.visitVarInsn(Opcodes.ALOAD, 0);
 244  66
         visitor.visitVarInsn(Opcodes.ALOAD, 1);
 245  66
         visitor.visitFieldInsn(Opcodes.PUTFIELD, derivativeName, PRIMITIVE_FIELD, primitiveDesc);
 246  66
         visitor.visitInsn(Opcodes.RETURN);
 247  66
         visitor.visitMaxs(0, 0);
 248  66
         visitor.visitEnd();
 249  66
     }
 250  
 
 251  
     /** Add the {@link UnivariateDerivative#getPrimitive() getPrimitive()} method.
 252  
      */
 253  
     private void addGetPrimitive() {
 254  66
         final MethodVisitor visitor =
 255  
             generator.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, "getPrimitive",
 256  
                                   "()" + primitiveDesc, null, null);
 257  66
         visitor.visitCode();
 258  66
         visitor.visitVarInsn(Opcodes.ALOAD, 0);
 259  66
         visitor.visitFieldInsn(Opcodes.GETFIELD, derivativeName, PRIMITIVE_FIELD, primitiveDesc);
 260  66
         visitor.visitInsn(Opcodes.ARETURN);
 261  66
         visitor.visitMaxs(0, 0);
 262  66
         visitor.visitEnd();
 263  66
     }
 264  
 
 265  
     /** Report the errors that may have occurred during analysis.
 266  
      * @exception DifferentiationException if the derivative class
 267  
      * could not be generated
 268  
      */
 269  
     public void reportErrors() throws DifferentiationException {
 270  66
         errorReporter.reportErrors();
 271  66
     }
 272  
 
 273  
 }