1 | |
|
2 | |
|
3 | |
|
4 | |
|
5 | |
|
6 | |
|
7 | |
|
8 | |
|
9 | |
|
10 | |
|
11 | |
|
12 | |
|
13 | |
|
14 | |
|
15 | |
|
16 | |
|
17 | |
package org.apache.commons.nabla.automatic.analysis; |
18 | |
|
19 | |
import java.util.Set; |
20 | |
|
21 | |
import org.apache.commons.nabla.core.DifferentiationException; |
22 | |
import org.apache.commons.nabla.core.UnivariateDerivative; |
23 | |
import org.apache.commons.nabla.core.UnivariateDifferentiable; |
24 | |
import org.objectweb.asm.AnnotationVisitor; |
25 | |
import org.objectweb.asm.Attribute; |
26 | |
import org.objectweb.asm.ClassVisitor; |
27 | |
import org.objectweb.asm.FieldVisitor; |
28 | |
import org.objectweb.asm.MethodVisitor; |
29 | |
import org.objectweb.asm.Opcodes; |
30 | |
|
31 | |
|
32 | |
|
33 | |
|
34 | |
|
35 | |
|
36 | |
|
37 | |
|
38 | |
|
39 | |
|
40 | |
|
41 | |
|
42 | |
|
43 | |
|
44 | |
|
45 | |
|
46 | |
|
47 | |
|
48 | |
|
49 | |
public class ClassDifferentiator implements ClassVisitor { |
50 | |
|
51 | |
|
52 | |
private static final String PRIMITIVE_FIELD = "primitive"; |
53 | |
|
54 | |
|
55 | |
private final Set<String> mathClasses; |
56 | |
|
57 | |
|
58 | |
private final ClassVisitor generator; |
59 | |
|
60 | |
|
61 | |
private final ErrorReporter errorReporter; |
62 | |
|
63 | |
|
64 | |
private String primitiveName; |
65 | |
|
66 | |
|
67 | |
private String primitiveDesc; |
68 | |
|
69 | |
|
70 | |
private String derivativeName; |
71 | |
|
72 | |
|
73 | |
private boolean specificMembersAdded; |
74 | |
|
75 | |
|
76 | |
|
77 | |
|
78 | |
|
79 | |
|
80 | |
public ClassDifferentiator(final Set<String> mathClasses, |
81 | 66 | final ClassVisitor generator) { |
82 | 66 | this.mathClasses = mathClasses; |
83 | 66 | this.generator = generator; |
84 | 66 | errorReporter = new ErrorReporter(); |
85 | 66 | } |
86 | |
|
87 | |
|
88 | |
|
89 | |
|
90 | |
|
91 | |
public String getDerivativeClassName() { |
92 | 66 | return derivativeName; |
93 | |
} |
94 | |
|
95 | |
|
96 | |
public void visit(final int version, final int access, |
97 | |
final String name, final String signature, |
98 | |
final String superName, final String[] interfaces) { |
99 | |
|
100 | 66 | primitiveName = name; |
101 | 66 | derivativeName = primitiveName + "$NablaUnivariateDerivative"; |
102 | 66 | primitiveDesc = "L" + primitiveName + ";"; |
103 | |
|
104 | |
|
105 | 66 | final Class<UnivariateDifferentiable> uDerClass = UnivariateDifferentiable.class; |
106 | 66 | boolean isDifferentiable = false; |
107 | 132 | for (String interf : interfaces) { |
108 | 66 | final String interfName = interf.replace('/', '.'); |
109 | 66 | Class<?> interfClass = null; |
110 | |
try { |
111 | 66 | interfClass = Class.forName(interfName); |
112 | 0 | } catch (ClassNotFoundException cnfe) { |
113 | |
|
114 | |
|
115 | 0 | errorReporter.register(new DifferentiationException("interface {0} not found " + |
116 | |
"while differentiating class {1}", |
117 | |
interfName, name)); |
118 | 66 | } |
119 | 66 | if (interfClass != null) { |
120 | 66 | isDifferentiable = isDifferentiable || uDerClass.isAssignableFrom(interfClass); |
121 | |
} |
122 | |
} |
123 | |
|
124 | 66 | if (isDifferentiable) { |
125 | |
|
126 | 66 | generator.visit(version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, |
127 | |
derivativeName, signature, superName, |
128 | |
new String[] { |
129 | |
UnivariateDerivative.class.getName().replace('.', '/') |
130 | |
}); |
131 | |
} else { |
132 | 0 | errorReporter.register(new DifferentiationException("the {0} class does not implement " + |
133 | |
"the {1} interface", |
134 | |
name, uDerClass.getName())); |
135 | |
} |
136 | |
|
137 | 66 | specificMembersAdded = false; |
138 | |
|
139 | 66 | } |
140 | |
|
141 | |
|
142 | |
public MethodVisitor visitMethod(final int access, final String name, |
143 | |
final String desc, final String signature, |
144 | |
final String[] exceptions) { |
145 | |
|
146 | |
|
147 | 188 | if (errorReporter.hasError()) { |
148 | 0 | return null; |
149 | |
} |
150 | |
|
151 | 188 | if (!specificMembersAdded) { |
152 | |
|
153 | 66 | addPrimitiveField(); |
154 | 66 | addConstructor(); |
155 | 66 | addGetPrimitive(); |
156 | 66 | specificMembersAdded = true; |
157 | |
} |
158 | |
|
159 | |
|
160 | 188 | if (((access & Opcodes.ACC_PUBLIC) == Opcodes.ACC_PUBLIC) && |
161 | |
"f".equals(name) && "(D)D".equals(desc) && |
162 | |
((exceptions == null) || (exceptions.length == 0))) { |
163 | |
|
164 | |
|
165 | 66 | final MethodVisitor visitor = |
166 | |
generator.visitMethod(access | Opcodes.ACC_SYNTHETIC, name, |
167 | |
MethodDifferentiator.DP_RETURN_DP_DESCRIPTOR, null, null); |
168 | |
|
169 | |
|
170 | 66 | return new MethodDifferentiator(access, name, desc, signature, exceptions, |
171 | |
visitor, primitiveName, mathClasses, errorReporter); |
172 | |
|
173 | |
} |
174 | |
|
175 | |
|
176 | 122 | return null; |
177 | |
|
178 | |
} |
179 | |
|
180 | |
|
181 | |
public FieldVisitor visitField(final int access, final String name, |
182 | |
final String desc, final String signature, |
183 | |
final Object value) { |
184 | |
|
185 | 66 | return null; |
186 | |
} |
187 | |
|
188 | |
|
189 | |
public void visitSource(final String source, final String debug) { |
190 | 0 | } |
191 | |
|
192 | |
|
193 | |
public void visitOuterClass(final String owner, final String name, |
194 | |
final String desc) { |
195 | 66 | } |
196 | |
|
197 | |
|
198 | |
public AnnotationVisitor visitAnnotation(final String desc, |
199 | |
final boolean visible) { |
200 | 0 | return null; |
201 | |
} |
202 | |
|
203 | |
|
204 | |
public void visitAttribute(final Attribute attr) { |
205 | 0 | } |
206 | |
|
207 | |
|
208 | |
public void visitInnerClass(final String name, final String outerName, |
209 | |
final String innerName, final int access) { |
210 | 69 | } |
211 | |
|
212 | |
|
213 | |
public void visitEnd() { |
214 | |
|
215 | |
|
216 | 66 | if (errorReporter.hasError()) { |
217 | 0 | return; |
218 | |
} |
219 | |
|
220 | 66 | generator.visitEnd(); |
221 | |
|
222 | 66 | } |
223 | |
|
224 | |
|
225 | |
|
226 | |
private void addPrimitiveField() { |
227 | 66 | final FieldVisitor visitor = |
228 | |
generator.visitField(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC, |
229 | |
PRIMITIVE_FIELD, primitiveDesc, null, null); |
230 | 66 | visitor.visitEnd(); |
231 | 66 | } |
232 | |
|
233 | |
|
234 | |
|
235 | |
private void addConstructor() { |
236 | 66 | final String init = "<init>"; |
237 | 66 | final MethodVisitor visitor = |
238 | |
generator.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, init, |
239 | |
"(" + primitiveDesc + ")V", null, null); |
240 | 66 | visitor.visitCode(); |
241 | 66 | visitor.visitVarInsn(Opcodes.ALOAD, 0); |
242 | 66 | visitor.visitMethodInsn(Opcodes.INVOKESPECIAL, "java/lang/Object", init, "()V"); |
243 | 66 | visitor.visitVarInsn(Opcodes.ALOAD, 0); |
244 | 66 | visitor.visitVarInsn(Opcodes.ALOAD, 1); |
245 | 66 | visitor.visitFieldInsn(Opcodes.PUTFIELD, derivativeName, PRIMITIVE_FIELD, primitiveDesc); |
246 | 66 | visitor.visitInsn(Opcodes.RETURN); |
247 | 66 | visitor.visitMaxs(0, 0); |
248 | 66 | visitor.visitEnd(); |
249 | 66 | } |
250 | |
|
251 | |
|
252 | |
|
253 | |
private void addGetPrimitive() { |
254 | 66 | final MethodVisitor visitor = |
255 | |
generator.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, "getPrimitive", |
256 | |
"()" + primitiveDesc, null, null); |
257 | 66 | visitor.visitCode(); |
258 | 66 | visitor.visitVarInsn(Opcodes.ALOAD, 0); |
259 | 66 | visitor.visitFieldInsn(Opcodes.GETFIELD, derivativeName, PRIMITIVE_FIELD, primitiveDesc); |
260 | 66 | visitor.visitInsn(Opcodes.ARETURN); |
261 | 66 | visitor.visitMaxs(0, 0); |
262 | 66 | visitor.visitEnd(); |
263 | 66 | } |
264 | |
|
265 | |
|
266 | |
|
267 | |
|
268 | |
|
269 | |
public void reportErrors() throws DifferentiationException { |
270 | 66 | errorReporter.reportErrors(); |
271 | 66 | } |
272 | |
|
273 | |
} |