Coverage Report - org.apache.commons.nabla.forward.instructions.InvokeStaticTransformer
 
Classes in this File Line Coverage Branch Coverage Complexity
InvokeStaticTransformer
88%
31/35
87%
14/16
7
 
 1  
 /*
 2  
  * Licensed to the Apache Software Foundation (ASF) under one or more
 3  
  * contributor license agreements.  See the NOTICE file distributed with
 4  
  * this work for additional information regarding copyright ownership.
 5  
  * The ASF licenses this file to You under the Apache License, Version 2.0
 6  
  * (the "License"); you may not use this file except in compliance with
 7  
  * the License.  You may obtain a copy of the License at
 8  
  *
 9  
  *      http://www.apache.org/licenses/LICENSE-2.0
 10  
  *
 11  
  * Unless required by applicable law or agreed to in writing, software
 12  
  * distributed under the License is distributed on an "AS IS" BASIS,
 13  
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14  
  * See the License for the specific language governing permissions and
 15  
  * limitations under the License.
 16  
  */
 17  
 package org.apache.commons.nabla.forward.instructions;
 18  
 
 19  
 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
 20  
 import org.apache.commons.nabla.DifferentiationException;
 21  
 import org.apache.commons.nabla.NablaMessages;
 22  
 import org.apache.commons.nabla.forward.analysis.InstructionsTransformer;
 23  
 import org.apache.commons.nabla.forward.analysis.MethodDifferentiator;
 24  
 import org.objectweb.asm.Opcodes;
 25  
 import org.objectweb.asm.Type;
 26  
 import org.objectweb.asm.tree.AbstractInsnNode;
 27  
 import org.objectweb.asm.tree.InsnList;
 28  
 import org.objectweb.asm.tree.InsnNode;
 29  
 import org.objectweb.asm.tree.MethodInsnNode;
 30  
 
 31  
 /** Differentiation transformer for INVOKESTATIC instructions.
 32  
  * @version $Id$
 33  
  */
 34  
 public class InvokeStaticTransformer implements InstructionsTransformer {
 35  
 
 36  
     /** Indicator for top stack element conversion. */
 37  
     private final boolean stack0Converted;
 38  
 
 39  
     /** Indicator for next to top stack element conversion. */
 40  
     private final boolean stack1Converted;
 41  
 
 42  
     /** Simple constructor.
 43  
      * @param stack0Converted if true, the top level stack element has already been converted
 44  
      * @param stack1Converted if true, the next to top level stack element has already been converted
 45  
      */
 46  43
     public InvokeStaticTransformer(final boolean stack0Converted, final boolean stack1Converted) {
 47  43
         this.stack0Converted = stack0Converted;
 48  43
         this.stack1Converted = stack1Converted;
 49  43
     }
 50  
 
 51  
     /** {@inheritDoc} */
 52  
     public InsnList getReplacement(final AbstractInsnNode insn,
 53  
                                    final MethodDifferentiator methodDifferentiator)
 54  
         throws DifferentiationException {
 55  
 
 56  43
         final MethodInsnNode methodInsn = (MethodInsnNode) insn;
 57  43
         if (!methodDifferentiator.isMathImplementationClass(methodInsn.owner)) {
 58  
             // TODO: handle INVOKESTATIC on non math related classes
 59  0
             throw new RuntimeException("INVOKESTATIC on non math related classes not handled yet" +
 60  
                     methodInsn.owner + methodInsn.owner);
 61  
         }
 62  
 
 63  43
         final InsnList list = new InsnList();
 64  
 
 65  43
         if (methodInsn.desc.equals(Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE))) {
 66  
             // this is a univariate method like sin, cos, exp ...
 67  
 
 68  
             try {
 69  
                 // check that a corresponding method exist for DerivativeStructure
 70  32
                 DerivativeStructure.class.getDeclaredMethod(methodInsn.name);
 71  0
             } catch (NoSuchMethodException nsme) {
 72  0
                 throw new DifferentiationException(NablaMessages.UNKNOWN_METHOD,
 73  
                                                    methodInsn.owner, methodInsn.name);
 74  32
             }
 75  
 
 76  32
             list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
 77  
                                         DS_TYPE.getInternalName(), methodInsn.name,
 78  
                                         Type.getMethodDescriptor(DS_TYPE)));
 79  
 
 80  11
         } else if (methodInsn.desc.equals(Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE, Type.DOUBLE_TYPE))) {
 81  
             // this is a bivariate method like atan2, pow ...
 82  
 
 83  11
             if (methodInsn.name.equals("pow")) {
 84  
                 // special case for pow: in DerivativeStructure, it is an instance method,
 85  
                 // not a static method as the other two parameters functions like atan2 or hypot
 86  
 
 87  3
                 if (stack1Converted) {
 88  2
                     if (!stack0Converted) {
 89  1
                         list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
 90  
                                                     DS_TYPE.getInternalName(), methodInsn.name,
 91  
                                                     Type.getMethodDescriptor(DS_TYPE, Type.DOUBLE_TYPE)));
 92  
                     } else {
 93  1
                         list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
 94  
                                                     DS_TYPE.getInternalName(), methodInsn.name,
 95  
                                                     Type.getMethodDescriptor(DS_TYPE, DS_TYPE)));
 96  
                     }
 97  
                 } else {
 98  
 
 99  
                     // initial stack state: x, ds_y
 100  1
                     list.add(new InsnNode(Opcodes.DUP_X2));                                 // => ds_y, x, ds_y
 101  1
                     list.add(new InsnNode(Opcodes.POP));                                    // => ds_y, x
 102  1
                     list.add(methodDifferentiator.doubleToDerivativeStructureConversion()); // => ds_y, ds_x
 103  1
                     list.add(new InsnNode(Opcodes.SWAP));                                   // => ds_x, ds_y
 104  
 
 105  
                     // call the static two parameters method for DerivativeStructure
 106  1
                     list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
 107  
                                                 DS_TYPE.getInternalName(), methodInsn.name,
 108  
                                                 Type.getMethodDescriptor(DS_TYPE, DS_TYPE)));
 109  
                 }
 110  
 
 111  
             } else {
 112  
 
 113  8
                 if (stack1Converted) {
 114  6
                     if (!stack0Converted) {
 115  
                         // the top level element is not a DerivativeStructure, convert it
 116  2
                         list.add(methodDifferentiator.doubleToDerivativeStructureConversion());
 117  
                     }
 118  
                 } else {
 119  
                     // initial stack state: x, ds_y
 120  2
                     list.add(new InsnNode(Opcodes.DUP_X2));                                 // => ds_y, x, ds_y
 121  2
                     list.add(new InsnNode(Opcodes.POP));                                    // => ds_y, x
 122  2
                     list.add(methodDifferentiator.doubleToDerivativeStructureConversion()); // => ds_y, ds_x
 123  2
                     list.add(new InsnNode(Opcodes.SWAP));                                   // => ds_x, ds_y
 124  
                 }
 125  
 
 126  
                 // call the static two parameters method for DerivativeStructure
 127  8
                 list.add(new MethodInsnNode(Opcodes.INVOKESTATIC,
 128  
                                             DS_TYPE.getInternalName(), methodInsn.name,
 129  
                                             Type.getMethodDescriptor(DS_TYPE, DS_TYPE, DS_TYPE)));
 130  
 
 131  
             }
 132  
 
 133  
         } else {
 134  0
             throw new DifferentiationException(NablaMessages.UNKNOWN_METHOD,
 135  
                                                methodInsn.owner, methodInsn.name);
 136  
         }
 137  
 
 138  43
         return list;
 139  
 
 140  
     }
 141  
 
 142  
 }