Coverage Report - org.apache.commons.nabla.automatic.analysis.MethodDifferentiator
 
Classes in this File Line Coverage Branch Coverage Complexity
MethodDifferentiator
85%
210/247
68%
116/170
0
MethodDifferentiator$FlowAnalyzer
100%
10/10
100%
2/2
0
 
 1  
 /*
 2  
  * Licensed to the Apache Software Foundation (ASF) under one or more
 3  
  * contributor license agreements.  See the NOTICE file distributed with
 4  
  * this work for additional information regarding copyright ownership.
 5  
  * The ASF licenses this file to You under the Apache License, Version 2.0
 6  
  * (the "License"); you may not use this file except in compliance with
 7  
  * the License.  You may obtain a copy of the License at
 8  
  *
 9  
  *      http://www.apache.org/licenses/LICENSE-2.0
 10  
  *
 11  
  * Unless required by applicable law or agreed to in writing, software
 12  
  * distributed under the License is distributed on an "AS IS" BASIS,
 13  
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14  
  * See the License for the specific language governing permissions and
 15  
  * limitations under the License.
 16  
  */
 17  
 package org.apache.commons.nabla.automatic.analysis;
 18  
 
 19  
 import java.util.ArrayList;
 20  
 import java.util.HashMap;
 21  
 import java.util.HashSet;
 22  
 import java.util.IdentityHashMap;
 23  
 import java.util.Iterator;
 24  
 import java.util.List;
 25  
 import java.util.Map;
 26  
 import java.util.Set;
 27  
 
 28  
 import org.apache.commons.nabla.automatic.arithmetic.DAddTransformer1;
 29  
 import org.apache.commons.nabla.automatic.arithmetic.DAddTransformer12;
 30  
 import org.apache.commons.nabla.automatic.arithmetic.DAddTransformer2;
 31  
 import org.apache.commons.nabla.automatic.arithmetic.DDivTransformer1;
 32  
 import org.apache.commons.nabla.automatic.arithmetic.DDivTransformer12;
 33  
 import org.apache.commons.nabla.automatic.arithmetic.DDivTransformer2;
 34  
 import org.apache.commons.nabla.automatic.arithmetic.DMulTransformer1;
 35  
 import org.apache.commons.nabla.automatic.arithmetic.DMulTransformer12;
 36  
 import org.apache.commons.nabla.automatic.arithmetic.DMulTransformer2;
 37  
 import org.apache.commons.nabla.automatic.arithmetic.DNegTransformer;
 38  
 import org.apache.commons.nabla.automatic.arithmetic.DRemTransformer1;
 39  
 import org.apache.commons.nabla.automatic.arithmetic.DRemTransformer12;
 40  
 import org.apache.commons.nabla.automatic.arithmetic.DRemTransformer2;
 41  
 import org.apache.commons.nabla.automatic.arithmetic.DSubTransformer1;
 42  
 import org.apache.commons.nabla.automatic.arithmetic.DSubTransformer12;
 43  
 import org.apache.commons.nabla.automatic.arithmetic.DSubTransformer2;
 44  
 import org.apache.commons.nabla.automatic.functions.AcosTransformer;
 45  
 import org.apache.commons.nabla.automatic.functions.AcoshTransformer;
 46  
 import org.apache.commons.nabla.automatic.functions.AsinTransformer;
 47  
 import org.apache.commons.nabla.automatic.functions.AsinhTransformer;
 48  
 import org.apache.commons.nabla.automatic.functions.Atan2Transformer1;
 49  
 import org.apache.commons.nabla.automatic.functions.Atan2Transformer12;
 50  
 import org.apache.commons.nabla.automatic.functions.Atan2Transformer2;
 51  
 import org.apache.commons.nabla.automatic.functions.AtanTransformer;
 52  
 import org.apache.commons.nabla.automatic.functions.AtanhTransformer;
 53  
 import org.apache.commons.nabla.automatic.functions.CbrtTransformer;
 54  
 import org.apache.commons.nabla.automatic.functions.CosTransformer;
 55  
 import org.apache.commons.nabla.automatic.functions.CoshTransformer;
 56  
 import org.apache.commons.nabla.automatic.functions.ExpTransformer;
 57  
 import org.apache.commons.nabla.automatic.functions.Expm1Transformer;
 58  
 import org.apache.commons.nabla.automatic.functions.HypotTransformer1;
 59  
 import org.apache.commons.nabla.automatic.functions.HypotTransformer12;
 60  
 import org.apache.commons.nabla.automatic.functions.HypotTransformer2;
 61  
 import org.apache.commons.nabla.automatic.functions.Log10Transformer;
 62  
 import org.apache.commons.nabla.automatic.functions.Log1pTransformer;
 63  
 import org.apache.commons.nabla.automatic.functions.LogTransformer;
 64  
 import org.apache.commons.nabla.automatic.functions.MathInvocationTransformer;
 65  
 import org.apache.commons.nabla.automatic.functions.PowTransformer1;
 66  
 import org.apache.commons.nabla.automatic.functions.PowTransformer12;
 67  
 import org.apache.commons.nabla.automatic.functions.PowTransformer2;
 68  
 import org.apache.commons.nabla.automatic.functions.SinTransformer;
 69  
 import org.apache.commons.nabla.automatic.functions.SinhTransformer;
 70  
 import org.apache.commons.nabla.automatic.functions.SqrtTransformer;
 71  
 import org.apache.commons.nabla.automatic.functions.TanTransformer;
 72  
 import org.apache.commons.nabla.automatic.functions.TanhTransformer;
 73  
 import org.apache.commons.nabla.automatic.instructions.DLoadTransformer;
 74  
 import org.apache.commons.nabla.automatic.instructions.DReturnTransformer;
 75  
 import org.apache.commons.nabla.automatic.instructions.DStoreTransformer;
 76  
 import org.apache.commons.nabla.automatic.instructions.DcmpTransformer1;
 77  
 import org.apache.commons.nabla.automatic.instructions.DcmpTransformer12;
 78  
 import org.apache.commons.nabla.automatic.instructions.DcmpTransformer2;
 79  
 import org.apache.commons.nabla.automatic.instructions.Dup2Transformer;
 80  
 import org.apache.commons.nabla.automatic.instructions.Dup2X1Transformer;
 81  
 import org.apache.commons.nabla.automatic.instructions.Dup2X2Transformer1;
 82  
 import org.apache.commons.nabla.automatic.instructions.Dup2X2Transformer12;
 83  
 import org.apache.commons.nabla.automatic.instructions.Dup2X2Transformer2;
 84  
 import org.apache.commons.nabla.automatic.instructions.NarrowingTransformer;
 85  
 import org.apache.commons.nabla.automatic.instructions.WideningTransformer;
 86  
 import org.apache.commons.nabla.automatic.trimming.DLoadPop2Trimmer;
 87  
 import org.apache.commons.nabla.automatic.trimming.SwappedDloadTrimmer;
 88  
 import org.apache.commons.nabla.automatic.trimming.SwappedDstoreTrimmer;
 89  
 import org.apache.commons.nabla.core.DifferentialPair;
 90  
 import org.apache.commons.nabla.core.DifferentiationException;
 91  
 import org.objectweb.asm.MethodVisitor;
 92  
 import org.objectweb.asm.Opcodes;
 93  
 import org.objectweb.asm.tree.AbstractInsnNode;
 94  
 import org.objectweb.asm.tree.FieldInsnNode;
 95  
 import org.objectweb.asm.tree.IincInsnNode;
 96  
 import org.objectweb.asm.tree.InsnList;
 97  
 import org.objectweb.asm.tree.InsnNode;
 98  
 import org.objectweb.asm.tree.LabelNode;
 99  
 import org.objectweb.asm.tree.MethodInsnNode;
 100  
 import org.objectweb.asm.tree.MethodNode;
 101  
 import org.objectweb.asm.tree.VarInsnNode;
 102  
 import org.objectweb.asm.tree.analysis.Analyzer;
 103  
 import org.objectweb.asm.tree.analysis.AnalyzerException;
 104  
 import org.objectweb.asm.tree.analysis.BasicValue;
 105  
 import org.objectweb.asm.tree.analysis.Frame;
 106  
 import org.objectweb.asm.tree.analysis.Interpreter;
 107  
 
 108  
 /** Class transforming a method computing a value to a method
 109  
  * computing both a value and its differential.
 110  
  */
 111  432
 public class MethodDifferentiator extends MethodNode {
 112  
 
 113  
     /** Name for the DifferentialPair class. */
 114  1
     public static final String DP_NAME = DifferentialPair.class.getName().replace('.', '/');
 115  
 
 116  
     /** Descriptor for the DifferentialPair class. */
 117  1
     public static final String DP_DESCRIPTOR = "L" + DP_NAME + ";";
 118  
 
 119  
     /** Descriptor for the derivative class f method. */
 120  1
     public static final String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR;
 121  
 
 122  
     /** Descriptor for <code>double f()</code> methods. */
 123  
     private static final String VOID_RETURN_D_DESCRIPTOR = "()D";
 124  
 
 125  
     /** Math functions transformer. */
 126  1
     private static final Map<String, MathInvocationTransformer> MATH_TRANSFORMERS =
 127  
         new HashMap<String, MathInvocationTransformer>();
 128  
 
 129  
     static {
 130  1
         MATH_TRANSFORMERS.put("acos",     new AcosTransformer());
 131  1
         MATH_TRANSFORMERS.put("acosh",    new AcoshTransformer());
 132  1
         MATH_TRANSFORMERS.put("asin",     new AsinTransformer());
 133  1
         MATH_TRANSFORMERS.put("asinh",    new AsinhTransformer());
 134  1
         MATH_TRANSFORMERS.put("atan2_12", new Atan2Transformer12());
 135  1
         MATH_TRANSFORMERS.put("atan2_1",  new Atan2Transformer1());
 136  1
         MATH_TRANSFORMERS.put("atan2_2",  new Atan2Transformer2());
 137  1
         MATH_TRANSFORMERS.put("atan",     new AtanTransformer());
 138  1
         MATH_TRANSFORMERS.put("atanh",    new AtanhTransformer());
 139  1
         MATH_TRANSFORMERS.put("cbrt",     new CbrtTransformer());
 140  1
         MATH_TRANSFORMERS.put("cos",      new CosTransformer());
 141  1
         MATH_TRANSFORMERS.put("cosh",     new CoshTransformer());
 142  1
         MATH_TRANSFORMERS.put("exp",      new ExpTransformer());
 143  1
         MATH_TRANSFORMERS.put("expm1",    new Expm1Transformer());
 144  1
         MATH_TRANSFORMERS.put("hypot_12", new HypotTransformer12());
 145  1
         MATH_TRANSFORMERS.put("hypot_1",  new HypotTransformer1());
 146  1
         MATH_TRANSFORMERS.put("hypot_2",  new HypotTransformer2());
 147  1
         MATH_TRANSFORMERS.put("log10",    new Log10Transformer());
 148  1
         MATH_TRANSFORMERS.put("log1p",    new Log1pTransformer());
 149  1
         MATH_TRANSFORMERS.put("log",      new LogTransformer());
 150  1
         MATH_TRANSFORMERS.put("pow_12",   new PowTransformer12());
 151  1
         MATH_TRANSFORMERS.put("pow_1",    new PowTransformer1());
 152  1
         MATH_TRANSFORMERS.put("pow_2",    new PowTransformer2());
 153  1
         MATH_TRANSFORMERS.put("sin",      new SinTransformer());
 154  1
         MATH_TRANSFORMERS.put("sinh",     new SinhTransformer());
 155  1
         MATH_TRANSFORMERS.put("sqrt",     new SqrtTransformer());
 156  1
         MATH_TRANSFORMERS.put("tan",      new TanTransformer());
 157  1
         MATH_TRANSFORMERS.put("tanh",     new TanhTransformer());
 158  1
     }
 159  
 
 160  
     /** Message format for unknown method. */
 161  
     private static final String UNKNOWN_METHOD_FMT = "unknown method {0}.{1}";
 162  
 
 163  
     /** Maximal number of temporary size 2 variables. */
 164  
     private static final int MAX_TEMP = 5;
 165  
 
 166  
     /** Math implementation classes. */
 167  
     private final Set<String> mathClasses;
 168  
 
 169  
     /** Generator to use. */
 170  
     private final MethodVisitor generator;
 171  
 
 172  
     /** Used locals variables array. */
 173  
     private boolean[] usedLocals;
 174  
 
 175  
     /** Primitive class name. */
 176  
     private final String primitiveName;
 177  
 
 178  
     /** Error reporter to use. */
 179  
     private final ErrorReporter errorReporter;
 180  
 
 181  
     /** Set of converted values. */
 182  
     private final Set<TrackingValue> converted;
 183  
 
 184  
     /** Frames for the original method. */
 185  
     private final Map<AbstractInsnNode, Frame> frames;
 186  
 
 187  
     /** Instructions successors array. */
 188  
     private final Map<AbstractInsnNode, Set<AbstractInsnNode>> successors;
 189  
 
 190  
     /** Cloned labels map. */
 191  
     private final Map<LabelNode, LabelNode> clonedLabels;
 192  
 
 193  
     /** Build a differentiator for a method.
 194  
      * @param access access flags of the method
 195  
      * @param name name of the method
 196  
      * @param desc descriptor of the method
 197  
      * @param signature signature of the method
 198  
      * @param exceptions exceptions thrown by the method
 199  
      * @param generator bytecode generator to use for the transformed method
 200  
      * @param primitiveName primitive class name
 201  
      * @param mathClasses math implementation classes
 202  
      * @param errorReporter reporter used for delaying exceptions
 203  
      */
 204  
     public MethodDifferentiator(final int access, final String name, final String desc,
 205  
                                 final String signature, final String[] exceptions,
 206  
                                 final MethodVisitor generator,final  String primitiveName,
 207  
                                 final Set<String> mathClasses,
 208  
                                 final ErrorReporter errorReporter) {
 209  
 
 210  66
         super(access, name, desc, signature, exceptions);
 211  66
         this.generator     = generator;
 212  66
         this.usedLocals    = null;
 213  66
         this.primitiveName = primitiveName;
 214  66
         this.mathClasses   = mathClasses;
 215  66
         this.errorReporter = errorReporter;
 216  66
         this.converted     = new HashSet<TrackingValue>();
 217  66
         this.frames        = new IdentityHashMap<AbstractInsnNode, Frame>();
 218  66
         this.successors    = new IdentityHashMap<AbstractInsnNode, Set<AbstractInsnNode>>();
 219  66
         this.clonedLabels  = new HashMap<LabelNode, LabelNode>();
 220  
 
 221  66
     }
 222  
 
 223  
     /** {@inheritDoc} */
 224  
     @Override
 225  
     public void visitEnd() {
 226  
         try {
 227  
 
 228  
             // at start, "this" and one differential pair are used
 229  66
             maxLocals  = 2 * (maxLocals + MAX_TEMP) - 1;
 230  66
             usedLocals = new boolean[maxLocals];
 231  66
             useLocal(0, 1);
 232  66
             useLocal(1, 4);
 233  
 
 234  
             // add spare cells to hold new variables if needed
 235  66
             addSpareLocalVariables();
 236  
 
 237  
             // analyze the original code, tracing values production/consumption
 238  66
             final Frame[] array =
 239  
                 new FlowAnalyzer(new TrackingInterpreter()).analyze(primitiveName, this);
 240  
 
 241  
             // convert the array into a map, since code changes will shift all indices
 242  347
             for (int i = 0; i < array.length; ++i) {
 243  281
                 frames.put(instructions.get(i), array[i]);
 244  
             }
 245  
 
 246  
             // identify the needed changes
 247  66
             final Set<AbstractInsnNode> changes = identifyChanges();
 248  
 
 249  66
             if (changes.isEmpty()) {
 250  
 
 251  
                 // the method does not depend on the parameter at all!
 252  
                 // we replace all code by a simple "return DifferentialPair.ZERO;"
 253  1
                 instructions.clear();
 254  1
                 instructions.add(new FieldInsnNode(Opcodes.GETSTATIC, DP_NAME, "ZERO", DP_DESCRIPTOR));
 255  1
                 instructions.add(new InsnNode(Opcodes.ARETURN));
 256  
 
 257  
             } else {
 258  
 
 259  
                 // perform the code changes
 260  65
                 changeCode(changes);
 261  
 
 262  
                 // remove the local variables added at the beginning and not used
 263  65
                 removeUnusedSpareLocalVariables();
 264  
 
 265  
                 // trim generated instructions list
 266  65
                 SwappedDloadTrimmer.getInstance().trim(instructions);
 267  65
                 SwappedDstoreTrimmer.getInstance().trim(instructions);
 268  65
                 DLoadPop2Trimmer.getInstance().trim(instructions);
 269  
 
 270  
             }
 271  
 
 272  
             // change the descriptor to its true final value
 273  66
             desc = DP_RETURN_DP_DESCRIPTOR;
 274  
 
 275  
             // generate the method
 276  66
             accept(generator);
 277  
 
 278  0
         } catch (AnalyzerException ae) {
 279  0
             if ((ae.getCause() != null) && ae.getCause() instanceof DifferentiationException) {
 280  0
                 errorReporter.register((DifferentiationException) ae.getCause());
 281  
             } else {
 282  0
                 final DifferentiationException de =
 283  
                     new DifferentiationException("unable to analyze the {0}.{1} method ({2})",
 284  
                                             new Object[] {
 285  
                                                 primitiveName, name, ae.getMessage()
 286  
                                             });
 287  0
                 errorReporter.register(de);
 288  
             }
 289  0
         } catch (DifferentiationException de) {
 290  0
             errorReporter.register(de);
 291  66
         }
 292  66
     }
 293  
 
 294  
     /** Add spare cells for new local variables.
 295  
      * <p>In order to ease conversion from double values to differential pairs,
 296  
      * we start by reserving one spare cell between each original local variables.
 297  
      * So we have to modify the indices in all instructions referencing local
 298  
      * variables in the original code, to take into account the renumbering
 299  
      * introduced by these spare cells. The spare cells by themselves will
 300  
      * be referenced by the converted instructions in the following passes.</p>
 301  
      * <p>The spare cells that will not be used will be reclaimed after
 302  
      * conversion, to avoid wasting memory.</p>
 303  
      * @exception DifferentiationException if local variables array has not been
 304  
      * expanded appropriately beforehand
 305  
      * @see #removeUnusedSpareLocalVariables()
 306  
      */
 307  
     private void addSpareLocalVariables() throws DifferentiationException {
 308  66
         for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
 309  281
             final AbstractInsnNode insn = (AbstractInsnNode) i.next();
 310  281
             if (insn.getType() == AbstractInsnNode.VAR_INSN) {
 311  95
                 final VarInsnNode varInsn = (VarInsnNode) insn;
 312  95
                 if (varInsn.var > 2) {
 313  15
                     varInsn.var = 2 * varInsn.var - 1;
 314  15
                     final int opcode = varInsn.getOpcode();
 315  15
                     if ((opcode == Opcodes.ILOAD)  || (opcode == Opcodes.FLOAD)  ||
 316  
                         (opcode == Opcodes.ALOAD)  || (opcode == Opcodes.ISTORE) ||
 317  
                         (opcode == Opcodes.FSTORE) || (opcode == Opcodes.ASTORE)) {
 318  4
                         useLocal(varInsn.var, 1);
 319  
                     } else {
 320  11
                         useLocal(varInsn.var, 2);
 321  
                     }
 322  
                 }
 323  95
             } else if (insn.getOpcode() == Opcodes.IINC) {
 324  2
                 final IincInsnNode iincInsn = (IincInsnNode) insn;
 325  2
                 if (iincInsn.var > 2) {
 326  2
                     iincInsn.var = 2 * iincInsn.var - 1;
 327  2
                     useLocal(iincInsn.var, 1);
 328  
                 }
 329  
             }
 330  281
         }
 331  66
     }
 332  
 
 333  
     /** Remove the unused spare cells introduced at conversion start.
 334  
      * @see #addSpareLocalVariables()
 335  
      */
 336  
     private void removeUnusedSpareLocalVariables() {
 337  65
         for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
 338  1973
             final AbstractInsnNode insn = (AbstractInsnNode) i.next();
 339  1973
             if (insn.getType() == AbstractInsnNode.VAR_INSN) {
 340  873
                 shiftVariable((VarInsnNode) insn);
 341  
             }
 342  1973
         }
 343  65
     }
 344  
 
 345  
     /** Identify the instructions that must be changed.
 346  
      * <p>Identification is based on data flow analysis. We start by changing
 347  
      * the local variables in the initial frame to match the parameters of
 348  
      * the derivative method, and propagate these variables following the
 349  
      * instructions path, updating stack cells and local variables as needed.
 350  
      * Instructions that must be changed are the ones that consume changed
 351  
      * variables or stack cells.</p>
 352  
      * @return set containing all the instructions that must be changed
 353  
      */
 354  
     private Set<AbstractInsnNode> identifyChanges() {
 355  
 
 356  
         // the pending set contains the values (local variables or stack cells)
 357  
         // that have been changed, they will trigger changes on the instructions
 358  
         // that consume them
 359  66
         final Set<TrackingValue> pending = new HashSet<TrackingValue>();
 360  
 
 361  
         // the changes set contains the instructions that must be changed
 362  66
         final Set<AbstractInsnNode> changes = new HashSet<AbstractInsnNode>();
 363  
 
 364  
         // start by converting the parameter of the method,
 365  
         // which is kept in local variable 1 of the initial frame
 366  66
         final TrackingValue dpParameter = (TrackingValue) frames.get(instructions.get(0)).getLocal(1);
 367  66
         pending.add(dpParameter);
 368  
 
 369  
         // propagate the values conversions throughout the method
 370  210
         while (!pending.isEmpty()) {
 371  
 
 372  
             // pop one element from the set of changed values
 373  144
             final Iterator<TrackingValue> iterator = pending.iterator();
 374  144
             final TrackingValue value = iterator.next();
 375  144
             iterator.remove();
 376  
 
 377  
             // this value is converted
 378  144
             converted.add(value);
 379  
 
 380  
             // check the consumers instructions for this value
 381  144
             for (final AbstractInsnNode consumer : value.getConsumers()) {
 382  
 
 383  
                 // an instruction consuming a converted value and producing
 384  
                 // a double must be changed to produce a differential pair,
 385  
                 // get the double values produced and add them to the changed set
 386  252
                 for (TrackingValue produced : getProducedDoubleValues(consumer)) {
 387  
 
 388  
                     // add it to the pending set if it has not already been processed
 389  184
                     if (!converted.contains(produced)) {
 390  83
                         pending.add(produced);
 391  
                     }
 392  
 
 393  
                 }
 394  
 
 395  
                 // as a consumer of a converted value, the instruction must be changed
 396  252
                 changes.add(consumer);
 397  
 
 398  
             }
 399  
 
 400  
             // check the producers instructions for this value
 401  144
             for (final AbstractInsnNode producer : value.getProducers()) {
 402  
 
 403  
                 // an instruction producing a converted value must be changed
 404  181
                 changes.add(producer);
 405  
 
 406  
             }
 407  144
         }
 408  
 
 409  66
         return changes;
 410  
 
 411  
     }
 412  
 
 413  
     /** Get the list of double values produced by an instruction.
 414  
      * @param instruction instruction producing the values
 415  
      * @return list of double values produced
 416  
      */
 417  
     private List<TrackingValue> getProducedDoubleValues(final AbstractInsnNode instruction) {
 418  
 
 419  252
         final List<TrackingValue> values = new ArrayList<TrackingValue>();
 420  
 
 421  
         // get the frame before instruction execution
 422  252
         final Frame before = frames.get(instruction);
 423  252
         final int beforeStackSize = before.getStackSize();
 424  252
         final int locals = before.getLocals();
 425  
 
 426  
         // check the frames produced by this instruction
 427  
         // (they correspond to the input frames of its successors)
 428  252
         final Set<AbstractInsnNode> set = successors.get(instruction);
 429  252
         if (set != null) {
 430  
 
 431  
             // loop over the successors of this instruction
 432  185
             for (final AbstractInsnNode successor : set) {
 433  185
                 final Frame produced = frames.get(successor);
 434  
 
 435  
                 // check the stack cells
 436  393
                 for (int i = 0; i < produced.getStackSize(); ++i) {
 437  208
                     final TrackingValue value = (TrackingValue) produced.getStack(i);
 438  208
                     if (((i >= beforeStackSize) || (value != before.getStack(i))) &&
 439  
                         value.getValue().equals(BasicValue.DOUBLE_VALUE)) {
 440  175
                         values.add(value);
 441  
                     }
 442  
                 }
 443  
 
 444  
                 // check the local variables
 445  3128
                 for (int i = 0; i < locals; ++i) {
 446  2943
                     final TrackingValue value = (TrackingValue) produced.getLocal(i);
 447  2943
                     if ((value != before.getLocal(i)) &&
 448  
                         value.getValue().equals(BasicValue.DOUBLE_VALUE)) {
 449  9
                         values.add(value);
 450  
                     }
 451  
                 }
 452  185
             }
 453  
         }
 454  
 
 455  252
         return values;
 456  
 
 457  
     }
 458  
 
 459  
     /** Perform the code changes.
 460  
      * @param changes instructions that must be changed
 461  
      * @exception DifferentiationException if some instruction cannot be handled
 462  
      */
 463  
     private void changeCode(final Set<AbstractInsnNode> changes)
 464  
         throws DifferentiationException {
 465  
 
 466  
         // insert the parameter conversion code at method start
 467  65
         final InsnList list = new InsnList();
 468  65
         list.add(new VarInsnNode(Opcodes.ALOAD, 1));
 469  65
         list.add(new InsnNode(Opcodes.DUP));
 470  65
         list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME,
 471  
                                     "getValue", VOID_RETURN_D_DESCRIPTOR));
 472  65
         list.add(new VarInsnNode(Opcodes.DSTORE, 1));
 473  65
         list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME,
 474  
                                     "getFirstDerivative", VOID_RETURN_D_DESCRIPTOR));
 475  65
         list.add(new VarInsnNode(Opcodes.DSTORE, 3));
 476  
 
 477  65
         instructions.insertBefore(instructions.get(0), list);
 478  
 
 479  
         // transform the existing instructions
 480  65
         for (final AbstractInsnNode insn : changes) {
 481  235
             instructions.insert(insn, getReplacement(insn));
 482  235
             instructions.remove(insn);
 483  
         }
 484  
 
 485  65
     }
 486  
 
 487  
     /** Get the replacement list for an instruction.
 488  
      * @param insn instruction to replace
 489  
      * @return replacement instructions list
 490  
      * @exception DifferentiationException if some instruction cannot be handled
 491  
      * or if no temporary variable can be reserved
 492  
      */
 493  
     private InsnList getReplacement(final AbstractInsnNode insn)
 494  
         throws DifferentiationException {
 495  
 
 496  
         // get the frame at the start of the instruction
 497  235
         final Frame frame = frames.get(insn);
 498  235
         final int size = frame.getStackSize();
 499  235
         final boolean stack1Converted = (size > 0) && converted.contains(frame.getStack(size - 2));
 500  235
         final boolean stack0Converted = (size > 1) && converted.contains(frame.getStack(size - 1));
 501  
 
 502  235
         switch(insn.getOpcode()) {
 503  
         case Opcodes.DLOAD :
 504  86
             useLocal(((VarInsnNode) insn).var, 4);
 505  86
             return  DLoadTransformer.getInstance().getReplacement(insn, this);
 506  
         case Opcodes.DALOAD :
 507  
             // TODO add support for DALOAD differentiation
 508  0
             throw new RuntimeException("DALOAD not handled yet");
 509  
         case Opcodes.DSTORE :
 510  5
             useLocal(((VarInsnNode) insn).var, 4);
 511  5
             return  DStoreTransformer.getInstance().getReplacement(insn, this);
 512  
         case Opcodes.DASTORE :
 513  
             // TODO add support for DASTORE differentiation
 514  0
             throw new RuntimeException("DASTORE not handled yet");
 515  
         case Opcodes.DUP2 :
 516  0
             return Dup2Transformer.getInstance().getReplacement(insn, this);
 517  
         case Opcodes.DUP2_X1 :
 518  0
             return Dup2X1Transformer.getInstance().getReplacement(insn, this);
 519  
         case Opcodes.DUP2_X2 :
 520  0
             if (stack1Converted) {
 521  0
                 if (stack0Converted) {
 522  0
                     return Dup2X2Transformer12.getInstance().getReplacement(insn, this);
 523  
                 }
 524  0
                 return Dup2X2Transformer1.getInstance().getReplacement(insn, this);
 525  
             }
 526  0
             return Dup2X2Transformer2.getInstance().getReplacement(insn, this);
 527  
         case Opcodes.DADD :
 528  8
             if (stack1Converted) {
 529  7
                 if (stack0Converted) {
 530  1
                     return DAddTransformer12.getInstance().getReplacement(insn, this);
 531  
                 }
 532  6
                 return DAddTransformer1.getInstance().getReplacement(insn, this);
 533  
             }
 534  1
             return  DAddTransformer2.getInstance().getReplacement(insn, this);
 535  
         case Opcodes.DSUB :
 536  5
             if (stack1Converted) {
 537  4
                 if (stack0Converted) {
 538  1
                     return DSubTransformer12.getInstance().getReplacement(insn, this);
 539  
                 }
 540  3
                 return DSubTransformer1.getInstance().getReplacement(insn, this);
 541  
             }
 542  1
             return  DSubTransformer2.getInstance().getReplacement(insn, this);
 543  
         case Opcodes.DMUL :
 544  12
             if (stack1Converted) {
 545  9
                 if (stack0Converted) {
 546  8
                     return  DMulTransformer12.getInstance().getReplacement(insn, this);
 547  
                 }
 548  1
                 return  DMulTransformer1.getInstance().getReplacement(insn, this);
 549  
             }
 550  3
             return  DMulTransformer2.getInstance().getReplacement(insn, this);
 551  
         case Opcodes.DDIV :
 552  4
             if (stack1Converted) {
 553  2
                 if (stack0Converted) {
 554  1
                     return  DDivTransformer12.getInstance().getReplacement(insn, this);
 555  
                 }
 556  1
                 return  DDivTransformer1.getInstance().getReplacement(insn, this);
 557  
             }
 558  2
             return  DDivTransformer2.getInstance().getReplacement(insn, this);
 559  
         case Opcodes.DREM :
 560  3
             if (stack1Converted) {
 561  2
                 if (stack0Converted) {
 562  1
                     return  DRemTransformer12.getInstance().getReplacement(insn, this);
 563  
                 }
 564  1
                 return  DRemTransformer1.getInstance().getReplacement(insn, this);
 565  
             }
 566  1
             return  DRemTransformer2.getInstance().getReplacement(insn, this);
 567  
         case Opcodes.DNEG :
 568  1
             return  DNegTransformer.getInstance().getReplacement(insn, this);
 569  
         case Opcodes.DCONST_0 :
 570  
         case Opcodes.DCONST_1 :
 571  
         case Opcodes.LDC :
 572  
         case Opcodes.I2D :
 573  
         case Opcodes.L2D :
 574  
         case Opcodes.F2D :
 575  2
             return WideningTransformer.getInstance().getReplacement(insn, this);
 576  
         case Opcodes.POP2 :
 577  
         case Opcodes.D2I :
 578  
         case Opcodes.D2L :
 579  
         case Opcodes.D2F :
 580  1
             return NarrowingTransformer.getInstance().getReplacement(insn, this);
 581  
         case Opcodes.DCMPL :
 582  
         case Opcodes.DCMPG :
 583  0
             if (stack1Converted) {
 584  0
                 if (stack0Converted) {
 585  0
                     return  DcmpTransformer12.getInstance().getReplacement(insn, this);
 586  
                 }
 587  0
                 return  DcmpTransformer1.getInstance().getReplacement(insn, this);
 588  
             }
 589  0
             return  DcmpTransformer2.getInstance().getReplacement(insn, this);
 590  
         case Opcodes.DRETURN :
 591  65
             return  DReturnTransformer.getInstance().getReplacement(insn, this);
 592  
         case Opcodes.GETSTATIC :
 593  
             // TODO add support for GETSTATIC differentiation
 594  0
             throw new RuntimeException("GETSTATIC not handled yet");
 595  
         case Opcodes.PUTSTATIC :
 596  
             // TODO add support for PUTSTATIC differentiation
 597  0
             throw new RuntimeException("PUTSTATIC not handled yet");
 598  
         case Opcodes.GETFIELD :
 599  
             // TODO add support for GETFIELD differentiation
 600  0
             throw new RuntimeException("GETFIELD not handled yet");
 601  
         case Opcodes.PUTFIELD :
 602  
             // TODO add support for PUTFIELD differentiation
 603  0
             throw new RuntimeException("PUTFIELD not handled yet");
 604  
         case Opcodes.INVOKEVIRTUAL :
 605  
             // TODO add support for INVOKEVIRTUAL differentiation
 606  0
             throw new RuntimeException("INVOKEVIRTUAL not handled yet");
 607  
         case Opcodes.INVOKESPECIAL :
 608  
             // TODO add support for INVOKESPECIAL differentiation
 609  0
             throw new RuntimeException("INVOKESPECIAL not handled yet");
 610  
         case Opcodes.INVOKESTATIC :
 611  43
             return replaceInvocation((MethodInsnNode) insn,
 612  
                                      stack1Converted, stack0Converted);
 613  
         case Opcodes.INVOKEINTERFACE :
 614  
             // TODO add support for INVOKEINTERFACE differentiation
 615  0
             throw new RuntimeException("INVOKEINTERFACE not handled yet");
 616  
         case Opcodes.NEWARRAY :
 617  
             // TODO add support for NEWARRAY differentiation
 618  0
             throw new RuntimeException("NEWARRAY not handled yet");
 619  
         case Opcodes.ANEWARRAY :
 620  
             // TODO add support for ANEWARRAY differentiation
 621  0
             throw new RuntimeException("ANEWARRAY not handled yet");
 622  
         case Opcodes.MULTIANEWARRAY :
 623  
             // TODO add support for MULTIANEWARRAY differentiation
 624  0
             throw new RuntimeException("MULTIANEWARRAY not handled yet");
 625  
         default:
 626  0
             throw new DifferentiationException("unable to handle instruction with opcode {0}",
 627  
                                           new Object[] {
 628  
                                               Integer.valueOf(insn.getOpcode())
 629  
                                           });
 630  
         }
 631  
 
 632  
     }
 633  
 
 634  
     /** Replace an INVOKESTATIC instruction.
 635  
      * @param methodInsn invocation instruction
 636  
      * @param stack1Converted if true, the stack sub-head has been
 637  
      * converted to differential pair
 638  
      * @param stack0Converted if true, the stack head has been
 639  
      * converted to differential pair
 640  
      * @return replacement instructions list
 641  
      * @exception DifferentiationException if the instruction cannot be replaced
 642  
      */
 643  
     private InsnList replaceInvocation(final MethodInsnNode methodInsn,
 644  
                                        final boolean stack1Converted,
 645  
                                        final boolean stack0Converted)
 646  
         throws DifferentiationException {
 647  43
         if (isMathImplementationClass(methodInsn.owner)) {
 648  43
             if ("(D)D".equals(methodInsn.desc)) {
 649  
                 // this is a univariate method like sin, cos, exp ...
 650  32
                 final MathInvocationTransformer transformer = MATH_TRANSFORMERS.get(methodInsn.name);
 651  32
                 if (transformer == null) {
 652  0
                     throw new DifferentiationException(UNKNOWN_METHOD_FMT,
 653  
                                                   methodInsn.owner, methodInsn.name);
 654  
                 }
 655  32
                 return transformer.getReplacementList(methodInsn.owner, this);
 656  11
             } else if ("(DD)D".equals(methodInsn.desc)) {
 657  
                 // this is a bivariate method like atan2, pow ...
 658  
 
 659  
                 // we may want to differentiate against first, second or both parameters
 660  11
                 String name = null;
 661  11
                 if (stack1Converted) {
 662  8
                     if (stack0Converted) {
 663  5
                         name = methodInsn.name + "_12";
 664  
                     } else {
 665  3
                         name = methodInsn.name + "_1";
 666  
                     }
 667  3
                 } else if (stack0Converted) {
 668  3
                     name = methodInsn.name + "_2";
 669  
                 }
 670  
 
 671  11
                 if (name != null) {
 672  11
                     final MathInvocationTransformer transformer = MATH_TRANSFORMERS.get(name);
 673  11
                     if (transformer == null) {
 674  0
                         throw new DifferentiationException(UNKNOWN_METHOD_FMT,
 675  
                                                       methodInsn.owner, methodInsn.name);
 676  
                     }
 677  11
                     return transformer.getReplacementList(methodInsn.owner, this);
 678  
                 }
 679  
             }
 680  
         }
 681  0
         throw new DifferentiationException("unexpected instruction {0}",
 682  
                                            Integer.valueOf(methodInsn.getOpcode()));
 683  
     }
 684  
 
 685  
     /** Test if a class is a math implementation class.
 686  
      * @param name name of the class to test
 687  
      * @return true if the named class is a math implementation class
 688  
      */
 689  
     public boolean isMathImplementationClass(final String name) {
 690  43
         return mathClasses.contains(name);
 691  
     }
 692  
 
 693  
     /** Set a local variable as used by the modified code.
 694  
      * @param index index of the variable
 695  
      * @param size size of the variable (1 or 2 for standard variables,
 696  
      * 4 for special expanded differential pairs)
 697  
      * @exception DifferentiationException if the number of the
 698  
      * temporary variable lies outside of the allowed range
 699  
      */
 700  
     public void useLocal(final int index, final int size)
 701  
         throws DifferentiationException {
 702  394
         if ((index < 0) || ((index + size - 1) >= usedLocals.length)) {
 703  0
             throw new DifferentiationException("index of size {0} local variable ({1}) " +
 704  
                                           "outside of [{2}, {3}] range",
 705  
                                           Integer.valueOf(size), Integer.valueOf(index),
 706  
                                           Integer.valueOf(1), Integer.valueOf(MAX_TEMP));
 707  
         }
 708  1554
         for (int i = index; i < index + size; ++i) {
 709  1160
             usedLocals[i] = true;
 710  
         }
 711  394
     }
 712  
 
 713  
     /** Get index of a size 2 temporary variable.
 714  
      * <p>Temporary variables can be used for very short duration
 715  
      * storage of size 2 values (i.e one double, or one long or two
 716  
      * integers). These variables are reused in many replacement
 717  
      * instructions sequences, so their content may be overridden
 718  
      * at any time: they should be considered to have a scope
 719  
      * limited to one replacement sequence only. This means that
 720  
      * one should <em>not</em> store a value in a variable in one
 721  
      * replacement sequence and retrieve it later in another
 722  
      * replacement sequence as it may have been overridden in
 723  
      * between.</p>
 724  
      * <p>At most 5 temporary variables may be used.</p>
 725  
      * @param number number of the temporary variable (must be
 726  
      * between 1 and the maximal number of available temporary
 727  
      * variables)
 728  
      * @return index of the variable within the local variables
 729  
      * array
 730  
      * @exception DifferentiationException if the number of the
 731  
      * temporary variable lies outside of the allowed range
 732  
      */
 733  
     public int getTmp(final int number) throws DifferentiationException {
 734  89
         if ((number < 0) || (number > MAX_TEMP)) {
 735  0
             throw new DifferentiationException("number of temporary variable ({0}) outside of [{1}, {2}] range",
 736  
                                                Integer.valueOf(number),
 737  
                                                Integer.valueOf(1),
 738  
                                                Integer.valueOf(MAX_TEMP));
 739  
         }
 740  89
         final int index = usedLocals.length - 2 * number;
 741  89
         useLocal(index, 2);
 742  89
         return index;
 743  
     }
 744  
 
 745  
     /** Shifted the index of a variable instruction.
 746  
      * @param insn variable instruction
 747  
      */
 748  
     public void shiftVariable(final VarInsnNode insn) {
 749  873
         int shifted = 0;
 750  5026
         for (int i = 0; i < insn.var; ++i) {
 751  4153
             if (usedLocals[i]) {
 752  3141
                 ++shifted;
 753  
             }
 754  
         }
 755  873
         insn.var = shifted;
 756  873
     }
 757  
 
 758  
     /** Clone an instruction.
 759  
      * @param insn instruction to clone
 760  
      * @return cloned instruction
 761  
      */
 762  
     public AbstractInsnNode clone(final AbstractInsnNode insn) {
 763  3
         return insn.clone(clonedLabels);
 764  
     }
 765  
 
 766  
     /** Analyzer preserving instructions successors information. */
 767  
     private class FlowAnalyzer extends Analyzer {
 768  
 
 769  
         /** Simple constructor.
 770  
          * @param interpreter associated interpreter
 771  
          */
 772  66
         public FlowAnalyzer(final Interpreter interpreter) {
 773  66
             super(interpreter);
 774  66
         }
 775  
 
 776  
         /** Store a new edge.
 777  
          * @param insn current instruction
 778  
          * @param successor successor instruction
 779  
          */
 780  
         protected void newControlFlowEdge(final int insn, final int successor) {
 781  
             // store the successor information
 782  217
             final AbstractInsnNode node = instructions.get(insn);
 783  217
             Set<AbstractInsnNode> set = successors.get(node);
 784  217
             if (set == null) {
 785  215
                 set = new HashSet<AbstractInsnNode>();
 786  215
                 successors.put(node, set);
 787  
             }
 788  217
             set.add(instructions.get(successor));
 789  217
         }
 790  
 
 791  
     }
 792  
 
 793  
 }