1 | |
|
2 | |
|
3 | |
|
4 | |
|
5 | |
|
6 | |
|
7 | |
|
8 | |
|
9 | |
|
10 | |
|
11 | |
|
12 | |
|
13 | |
|
14 | |
|
15 | |
|
16 | |
|
17 | |
package org.apache.commons.nabla.forward; |
18 | |
|
19 | |
import java.io.IOException; |
20 | |
import java.io.OutputStream; |
21 | |
import java.lang.reflect.Constructor; |
22 | |
import java.lang.reflect.InvocationTargetException; |
23 | |
import java.util.HashMap; |
24 | |
import java.util.HashSet; |
25 | |
import java.util.Set; |
26 | |
|
27 | |
import org.apache.commons.math3.analysis.UnivariateFunction; |
28 | |
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; |
29 | |
import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction; |
30 | |
import org.apache.commons.math3.analysis.differentiation.UnivariateFunctionDifferentiator; |
31 | |
import org.apache.commons.math3.util.FastMath; |
32 | |
import org.apache.commons.nabla.DifferentiationException; |
33 | |
import org.apache.commons.nabla.NablaMessages; |
34 | |
import org.apache.commons.nabla.forward.analysis.ClassDifferentiator; |
35 | |
import org.objectweb.asm.ClassWriter; |
36 | |
import org.objectweb.asm.Type; |
37 | |
import org.objectweb.asm.tree.ClassNode; |
38 | |
|
39 | |
|
40 | |
|
41 | |
|
42 | |
|
43 | |
|
44 | |
|
45 | |
|
46 | |
|
47 | |
|
48 | |
|
49 | |
|
50 | |
|
51 | |
|
52 | |
|
53 | |
|
54 | |
|
55 | |
|
56 | |
|
57 | |
public class ForwardModeDifferentiator implements UnivariateFunctionDifferentiator { |
58 | |
|
59 | |
|
60 | |
private final HashMap<Class<? extends UnivariateFunction>, |
61 | |
Class<? extends UnivariateDifferentiableFunction>> map; |
62 | |
|
63 | |
|
64 | |
private final HashMap<String, byte[]> byteCodeMap; |
65 | |
|
66 | |
|
67 | |
private final Set<String> mathClasses; |
68 | |
|
69 | |
|
70 | |
|
71 | |
|
72 | 67 | public ForwardModeDifferentiator() { |
73 | 67 | map = new HashMap<Class<? extends UnivariateFunction>, |
74 | |
Class<? extends UnivariateDifferentiableFunction>>(); |
75 | 67 | byteCodeMap = new HashMap<String, byte[]>(); |
76 | 67 | mathClasses = new HashSet<String>(); |
77 | 67 | addMathImplementation(Math.class); |
78 | 67 | addMathImplementation(StrictMath.class); |
79 | 67 | addMathImplementation(FastMath.class); |
80 | 67 | } |
81 | |
|
82 | |
|
83 | |
|
84 | |
|
85 | |
|
86 | |
|
87 | |
|
88 | |
|
89 | |
|
90 | |
|
91 | |
public void addMathImplementation(final Class<?> mathClass) { |
92 | 257 | mathClasses.add(mathClass.getName().replace('.', '/')); |
93 | 257 | } |
94 | |
|
95 | |
|
96 | |
|
97 | |
|
98 | |
public void dumpCache(final OutputStream out) { |
99 | |
|
100 | 0 | throw new RuntimeException("not implemented yet"); |
101 | |
} |
102 | |
|
103 | |
|
104 | |
public UnivariateDifferentiableFunction differentiate(final UnivariateFunction d) { |
105 | |
|
106 | |
|
107 | 67 | final Class<? extends UnivariateDifferentiableFunction> derivativeClass = |
108 | |
getDerivativeClass(d.getClass()); |
109 | |
|
110 | |
try { |
111 | |
|
112 | |
|
113 | 67 | final Constructor<? extends UnivariateDifferentiableFunction> constructor = |
114 | |
derivativeClass.getConstructor(d.getClass()); |
115 | 67 | return constructor.newInstance(d); |
116 | |
|
117 | 0 | } catch (InstantiationException ie) { |
118 | 0 | throw new DifferentiationException(NablaMessages.CANNOT_INSTANTIATE_ABSTRACT_CLASS, |
119 | |
derivativeClass.getName(), ie.getMessage()); |
120 | 0 | } catch (IllegalAccessException iae) { |
121 | 0 | throw new DifferentiationException(NablaMessages.ILLEGAL_ACCESS_TO_CONSTRUCTOR, |
122 | |
derivativeClass.getName(), iae.getMessage()); |
123 | 0 | } catch (NoSuchMethodException nsme) { |
124 | 0 | throw new DifferentiationException(NablaMessages.CANNOT_BUILD_CLASS_FROM_OTHER_CLASS, |
125 | |
derivativeClass.getName(), d.getClass().getName(), nsme.getMessage()); |
126 | 0 | } catch (InvocationTargetException ite) { |
127 | 0 | throw new DifferentiationException(NablaMessages.CANNOT_INSTANTIATE_CLASS_FROM_OTHER_INSTANCE, |
128 | |
derivativeClass.getName(), d.getClass().getName(), ite.getMessage()); |
129 | 0 | } catch (VerifyError ve) { |
130 | 0 | throw new DifferentiationException(NablaMessages.INCORRECT_GENERATED_CODE, |
131 | |
derivativeClass.getName(), d.getClass().getName(), ve.getMessage()); |
132 | |
} |
133 | |
|
134 | |
} |
135 | |
|
136 | |
|
137 | |
|
138 | |
|
139 | |
|
140 | |
|
141 | |
|
142 | |
|
143 | |
private Class<? extends UnivariateDifferentiableFunction> |
144 | |
getDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass) |
145 | |
throws DifferentiationException { |
146 | |
|
147 | |
|
148 | 67 | Class<? extends UnivariateDifferentiableFunction> derivativeClass = |
149 | |
map.get(differentiableClass); |
150 | |
|
151 | |
|
152 | 67 | if (derivativeClass == null) { |
153 | |
|
154 | |
|
155 | 67 | derivativeClass = createDerivativeClass(differentiableClass); |
156 | |
|
157 | |
|
158 | 67 | map.put(differentiableClass, derivativeClass); |
159 | |
|
160 | |
} |
161 | |
|
162 | |
|
163 | 67 | return derivativeClass; |
164 | |
|
165 | |
} |
166 | |
|
167 | |
|
168 | |
|
169 | |
|
170 | |
|
171 | |
|
172 | |
private Class<? extends UnivariateDifferentiableFunction> |
173 | |
createDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass) |
174 | |
throws DifferentiationException { |
175 | |
try { |
176 | |
|
177 | |
|
178 | 67 | final ClassDifferentiator differentiator = |
179 | |
new ClassDifferentiator(differentiableClass, mathClasses); |
180 | 67 | final Type dsType = Type.getType(DerivativeStructure.class); |
181 | 67 | differentiator.differentiateMethod("value", |
182 | |
Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE), |
183 | |
Type.getMethodDescriptor(dsType, dsType)); |
184 | |
|
185 | |
|
186 | 67 | final ClassNode derived = differentiator.getDerivedClass(); |
187 | 67 | final ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_FRAMES); |
188 | 67 | final String name = derived.name.replace('/', '.'); |
189 | 67 | derived.accept(writer); |
190 | 67 | final byte[] bytecode = writer.toByteArray(); |
191 | |
|
192 | 67 | final Class<? extends UnivariateDifferentiableFunction> dClass = |
193 | |
new DerivativeLoader(differentiableClass).defineClass(name, bytecode); |
194 | 67 | byteCodeMap.put(name, bytecode); |
195 | 67 | return dClass; |
196 | |
|
197 | 0 | } catch (IOException ioe) { |
198 | 0 | throw new DifferentiationException(NablaMessages.CANNOT_READ_CLASS, |
199 | |
differentiableClass.getName(), ioe.getMessage()); |
200 | |
} |
201 | |
} |
202 | |
|
203 | |
|
204 | |
private static class DerivativeLoader extends ClassLoader { |
205 | |
|
206 | |
|
207 | |
|
208 | |
|
209 | |
public DerivativeLoader(final Class<? extends UnivariateFunction> differentiableClass) { |
210 | 67 | super(differentiableClass.getClassLoader()); |
211 | 67 | } |
212 | |
|
213 | |
|
214 | |
|
215 | |
|
216 | |
|
217 | |
|
218 | |
@SuppressWarnings("unchecked") |
219 | |
public Class<? extends UnivariateDifferentiableFunction> |
220 | |
defineClass(final String name, final byte[] bytecode) { |
221 | 67 | return (Class<? extends UnivariateDifferentiableFunction>) defineClass(name, bytecode, 0, bytecode.length); |
222 | |
} |
223 | |
} |
224 | |
|
225 | |
} |