1 | |
|
2 | |
|
3 | |
|
4 | |
|
5 | |
|
6 | |
|
7 | |
|
8 | |
|
9 | |
|
10 | |
|
11 | |
|
12 | |
|
13 | |
|
14 | |
|
15 | |
|
16 | |
|
17 | |
package org.apache.commons.nabla.forward.instructions; |
18 | |
|
19 | |
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; |
20 | |
import org.apache.commons.nabla.DifferentiationException; |
21 | |
import org.apache.commons.nabla.NablaMessages; |
22 | |
import org.apache.commons.nabla.forward.analysis.InstructionsTransformer; |
23 | |
import org.apache.commons.nabla.forward.analysis.MethodDifferentiator; |
24 | |
import org.objectweb.asm.Opcodes; |
25 | |
import org.objectweb.asm.Type; |
26 | |
import org.objectweb.asm.tree.AbstractInsnNode; |
27 | |
import org.objectweb.asm.tree.InsnList; |
28 | |
import org.objectweb.asm.tree.InsnNode; |
29 | |
import org.objectweb.asm.tree.MethodInsnNode; |
30 | |
|
31 | |
|
32 | |
|
33 | |
|
34 | |
public class InvokeStaticTransformer implements InstructionsTransformer { |
35 | |
|
36 | |
|
37 | |
private final boolean stack0Converted; |
38 | |
|
39 | |
|
40 | |
private final boolean stack1Converted; |
41 | |
|
42 | |
|
43 | |
|
44 | |
|
45 | |
|
46 | 43 | public InvokeStaticTransformer(final boolean stack0Converted, final boolean stack1Converted) { |
47 | 43 | this.stack0Converted = stack0Converted; |
48 | 43 | this.stack1Converted = stack1Converted; |
49 | 43 | } |
50 | |
|
51 | |
|
52 | |
public InsnList getReplacement(final AbstractInsnNode insn, |
53 | |
final MethodDifferentiator methodDifferentiator) |
54 | |
throws DifferentiationException { |
55 | |
|
56 | 43 | final MethodInsnNode methodInsn = (MethodInsnNode) insn; |
57 | 43 | if (!methodDifferentiator.isMathImplementationClass(methodInsn.owner)) { |
58 | |
|
59 | 0 | throw new RuntimeException("INVOKESTATIC on non math related classes not handled yet" + |
60 | |
methodInsn.owner + methodInsn.owner); |
61 | |
} |
62 | |
|
63 | 43 | final InsnList list = new InsnList(); |
64 | |
|
65 | 43 | if (methodInsn.desc.equals(Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE))) { |
66 | |
|
67 | |
|
68 | |
try { |
69 | |
|
70 | 32 | DerivativeStructure.class.getDeclaredMethod(methodInsn.name); |
71 | 0 | } catch (NoSuchMethodException nsme) { |
72 | 0 | throw new DifferentiationException(NablaMessages.UNKNOWN_METHOD, |
73 | |
methodInsn.owner, methodInsn.name); |
74 | 32 | } |
75 | |
|
76 | 32 | list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, |
77 | |
DS_TYPE.getInternalName(), methodInsn.name, |
78 | |
Type.getMethodDescriptor(DS_TYPE))); |
79 | |
|
80 | 11 | } else if (methodInsn.desc.equals(Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE, Type.DOUBLE_TYPE))) { |
81 | |
|
82 | |
|
83 | 11 | if (methodInsn.name.equals("pow")) { |
84 | |
|
85 | |
|
86 | |
|
87 | 3 | if (stack1Converted) { |
88 | 2 | if (!stack0Converted) { |
89 | 1 | list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, |
90 | |
DS_TYPE.getInternalName(), methodInsn.name, |
91 | |
Type.getMethodDescriptor(DS_TYPE, Type.DOUBLE_TYPE))); |
92 | |
} else { |
93 | 1 | list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, |
94 | |
DS_TYPE.getInternalName(), methodInsn.name, |
95 | |
Type.getMethodDescriptor(DS_TYPE, DS_TYPE))); |
96 | |
} |
97 | |
} else { |
98 | |
|
99 | |
|
100 | 1 | list.add(new InsnNode(Opcodes.DUP_X2)); |
101 | 1 | list.add(new InsnNode(Opcodes.POP)); |
102 | 1 | list.add(methodDifferentiator.doubleToDerivativeStructureConversion()); |
103 | 1 | list.add(new InsnNode(Opcodes.SWAP)); |
104 | |
|
105 | |
|
106 | 1 | list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, |
107 | |
DS_TYPE.getInternalName(), methodInsn.name, |
108 | |
Type.getMethodDescriptor(DS_TYPE, DS_TYPE))); |
109 | |
} |
110 | |
|
111 | |
} else { |
112 | |
|
113 | 8 | if (stack1Converted) { |
114 | 6 | if (!stack0Converted) { |
115 | |
|
116 | 2 | list.add(methodDifferentiator.doubleToDerivativeStructureConversion()); |
117 | |
} |
118 | |
} else { |
119 | |
|
120 | 2 | list.add(new InsnNode(Opcodes.DUP_X2)); |
121 | 2 | list.add(new InsnNode(Opcodes.POP)); |
122 | 2 | list.add(methodDifferentiator.doubleToDerivativeStructureConversion()); |
123 | 2 | list.add(new InsnNode(Opcodes.SWAP)); |
124 | |
} |
125 | |
|
126 | |
|
127 | 8 | list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, |
128 | |
DS_TYPE.getInternalName(), methodInsn.name, |
129 | |
Type.getMethodDescriptor(DS_TYPE, DS_TYPE, DS_TYPE))); |
130 | |
|
131 | |
} |
132 | |
|
133 | |
} else { |
134 | 0 | throw new DifferentiationException(NablaMessages.UNKNOWN_METHOD, |
135 | |
methodInsn.owner, methodInsn.name); |
136 | |
} |
137 | |
|
138 | 43 | return list; |
139 | |
|
140 | |
} |
141 | |
|
142 | |
} |