Classes in this File | Line Coverage | Branch Coverage | Complexity | ||||||||
AutomaticDifferentiator |
|
| 0.0;0 |
1 | /* |
|
2 | * Licensed to the Apache Software Foundation (ASF) under one or more |
|
3 | * contributor license agreements. See the NOTICE file distributed with |
|
4 | * this work for additional information regarding copyright ownership. |
|
5 | * The ASF licenses this file to You under the Apache License, Version 2.0 |
|
6 | * (the "License"); you may not use this file except in compliance with |
|
7 | * the License. You may obtain a copy of the License at |
|
8 | * |
|
9 | * http://www.apache.org/licenses/LICENSE-2.0 |
|
10 | * |
|
11 | * Unless required by applicable law or agreed to in writing, software |
|
12 | * distributed under the License is distributed on an "AS IS" BASIS, |
|
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
14 | * See the License for the specific language governing permissions and |
|
15 | * limitations under the License. |
|
16 | */ |
|
17 | package org.apache.commons.nabla.automatic.analysis; |
|
18 | ||
19 | import java.io.IOException; |
|
20 | import java.io.InputStream; |
|
21 | import java.io.OutputStream; |
|
22 | import java.lang.reflect.Constructor; |
|
23 | import java.lang.reflect.InvocationTargetException; |
|
24 | import java.util.HashMap; |
|
25 | import java.util.HashSet; |
|
26 | import java.util.Set; |
|
27 | ||
28 | import org.apache.commons.nabla.core.DifferentialPair; |
|
29 | import org.apache.commons.nabla.core.DifferentiationException; |
|
30 | import org.apache.commons.nabla.core.UnivariateDerivative; |
|
31 | import org.apache.commons.nabla.core.UnivariateDifferentiable; |
|
32 | import org.apache.commons.nabla.core.UnivariateDifferentiator; |
|
33 | import org.objectweb.asm.ClassReader; |
|
34 | import org.objectweb.asm.ClassWriter; |
|
35 | ||
36 | /** Automatic differentiator class based on bytecode analysis. |
|
37 | * <p>This class is an implementation of the {@link UnivariateDifferentiator} |
|
38 | * interface that computes <em>exact</em> differentials completely automatically |
|
39 | * and generate java classes and instances that compute the differential |
|
40 | * of the function as if they were hand-coded and compiled.</p> |
|
41 | * <p>The derivative bytecode created the first time an instance of a given class |
|
42 | * is differentiated is cached and will be reused if other instances of the same class |
|
43 | * are to be created later. The cache can also be dumped in a jar file for |
|
44 | * use in an application without bringing the full nabla library and its |
|
45 | * dependencies.</p> |
|
46 | * <p>This differentiator can handle only pure bytecode methods and known methods |
|
47 | * from math implementation classes like {@link java.lang.Math Math} or |
|
48 | * {@link java.lang.StrictMath StrictMath}. Pure bytecode methods are analyzed |
|
49 | * and converted. Methods from math implementation classes are only recognized |
|
50 | * by class and name and replaced by predefined derivative code.</p> |
|
51 | * @see org.apache.commons.nabla.Fetchdifferentiator |
|
52 | */ |
|
53 | public class AutomaticDifferentiator implements UnivariateDifferentiator { |
|
54 | ||
55 | /** Name for the DifferentialPair class. */ |
|
56 | 9 | public static final String DP_NAME = DifferentialPair.class.getName().replace('.', '/'); |
57 | ||
58 | /** Descriptor for the DifferentialPair class. */ |
|
59 | 9 | public static final String DP_DESCRIPTOR = "L" + DP_NAME + ";"; |
60 | ||
61 | /** Descriptor for the derivative class f method. */ |
|
62 | 9 | public static final String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR; |
63 | ||
64 | /** UnivariateDifferentiable/UnivariateDerivative map. */ |
|
65 | private final HashMap<Class<? extends UnivariateDifferentiable>, |
|
66 | Class<? extends UnivariateDerivative>> map; |
|
67 | ||
68 | /** Math implementation classes. */ |
|
69 | private final Set<String> mathClasses; |
|
70 | ||
71 | /** Simple constructor. |
|
72 | * <p>Build a AutomaticDifferentiator instance with an empty cache.</p> |
|
73 | */ |
|
74 | 504 | public AutomaticDifferentiator() { |
75 | 504 | map = new HashMap<Class<? extends UnivariateDifferentiable>, |
76 | Class<? extends UnivariateDerivative>>(); |
|
77 | 504 | mathClasses = new HashSet<String>(); |
78 | 504 | addMathImplementation(Math.class); |
79 | 504 | addMathImplementation(StrictMath.class); |
80 | 504 | } |
81 | ||
82 | /** Add an implementation class for mathematical functions. |
|
83 | * <p>At construction, the differentiator considers only the {@link |
|
84 | * java.lang.Math Math} and {@link java.lang.StrictMath StrictMath} |
|
85 | * classes are math implementation classes. It may be useful to add |
|
86 | * other class for example to add some missing functions like |
|
87 | * inverse hyperbolic cosine that are not provided by the standard |
|
88 | * java classes as of Java 1.6.</p> |
|
89 | * @param mathClass implementation class for mathematical functions |
|
90 | */ |
|
91 | public void addMathImplementation(final Class<?> mathClass) { |
|
92 | 1512 | mathClasses.add(mathClass.getName().replace('.', '/')); |
93 | 1512 | } |
94 | ||
95 | /** Dump the cache into a stream. |
|
96 | * @param out output stream where to dump the cache |
|
97 | */ |
|
98 | public void dumpCache(final OutputStream out) { |
|
99 | // TODO |
|
100 | 0 | throw new RuntimeException("not implemented yet"); |
101 | } |
|
102 | ||
103 | /** {@inheritDoc} */ |
|
104 | public UnivariateDerivative differentiate(final UnivariateDifferentiable d) |
|
105 | throws DifferentiationException { |
|
106 | ||
107 | // get the derivative class |
|
108 | 504 | final Class<? extends UnivariateDerivative> derivativeClass = |
109 | getDerivativeClass(d.getClass()); |
|
110 | ||
111 | try { |
|
112 | ||
113 | // create the instance |
|
114 | 504 | final Constructor<? extends UnivariateDerivative> constructor = |
115 | derivativeClass.getConstructor(d.getClass()); |
|
116 | 504 | return constructor.newInstance(d); |
117 | ||
118 | 0 | } catch (InstantiationException ie) { |
119 | 0 | throw new DifferentiationException("abstract class {0} cannot be instantiated ({1})", |
120 | derivativeClass.getName(), ie.getMessage()); |
|
121 | 0 | } catch (IllegalAccessException iae) { |
122 | 0 | throw new DifferentiationException("illegal access to class {0} constructor ({1})", |
123 | derivativeClass.getName(), iae.getMessage()); |
|
124 | 0 | } catch (NoSuchMethodException nsme) { |
125 | 0 | throw new DifferentiationException("class {0} cannot be built from an instance of class {1} ({2})", |
126 | derivativeClass.getName(), d.getClass().getName(), nsme.getMessage()); |
|
127 | 0 | } catch (InvocationTargetException ite) { |
128 | 0 | throw new DifferentiationException("class {0} instantiation from an instance of class {1} failed ({2})", |
129 | derivativeClass.getName(), d.getClass().getName(), ite.getMessage()); |
|
130 | } |
|
131 | ||
132 | } |
|
133 | ||
134 | /** Get the derivative class of a differentiable class. |
|
135 | * <p>The derivative class is either built on the fly |
|
136 | * or retrieved from the cache if it has been built previously.</p> |
|
137 | * @param differentiableClass class to differentiate |
|
138 | * @return derivative class |
|
139 | * @throws DifferentiationException if the class cannot be differentiated |
|
140 | */ |
|
141 | private Class<? extends UnivariateDerivative> |
|
142 | getDerivativeClass(final Class<? extends UnivariateDifferentiable> differentiableClass) |
|
143 | throws DifferentiationException { |
|
144 | ||
145 | // lookup in the map if the class has already been differentiated |
|
146 | 504 | Class<? extends UnivariateDerivative> derivativeClass = |
147 | map.get(differentiableClass); |
|
148 | ||
149 | // build the derivative class if it does not exist yet |
|
150 | 504 | if (derivativeClass == null) { |
151 | // perform analytical differentiation |
|
152 | 504 | derivativeClass = createDerivativeClass(differentiableClass); |
153 | ||
154 | // put the newly created class in the map |
|
155 | 504 | map.put(differentiableClass, derivativeClass); |
156 | ||
157 | } |
|
158 | ||
159 | // return the derivative class |
|
160 | 504 | return derivativeClass; |
161 | ||
162 | } |
|
163 | ||
164 | /** Build a derivative class of a differentiable class. |
|
165 | * @param differentiableClass class to differentiate |
|
166 | * @return derivative class |
|
167 | * @throws DifferentiationException if the class cannot be differentiated |
|
168 | */ |
|
169 | private Class<? extends UnivariateDerivative> |
|
170 | createDerivativeClass(final Class<? extends UnivariateDifferentiable> differentiableClass) |
|
171 | throws DifferentiationException { |
|
172 | try { |
|
173 | ||
174 | // set up both ends of the class transform chain |
|
175 | 504 | final String classResourceName = "/" + differentiableClass.getName().replace('.', '/') + ".class"; |
176 | 504 | final InputStream stream = differentiableClass.getResourceAsStream(classResourceName); |
177 | 504 | final ClassReader reader = new ClassReader(stream); |
178 | 504 | final ClassWriter writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES); |
179 | ||
180 | // differentiate the function embedded in the differentiable class |
|
181 | 504 | final ClassDifferentiator differentiator = new ClassDifferentiator(mathClasses, writer); |
182 | 504 | reader.accept(differentiator, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES); |
183 | 504 | differentiator.reportErrors(); |
184 | ||
185 | // create the derivative class |
|
186 | 504 | return new DerivativeLoader(differentiableClass).defineClass(differentiator, writer); |
187 | ||
188 | 0 | } catch (IOException ioe) { |
189 | 0 | throw new DifferentiationException("class {0} cannot be read ({1})", |
190 | differentiableClass.getName(), ioe.getMessage()); |
|
191 | } |
|
192 | } |
|
193 | ||
194 | /** Class loader generating derivative classes. */ |
|
195 | private static class DerivativeLoader extends ClassLoader { |
|
196 | ||
197 | /** Simple constructor. |
|
198 | * @param differentiableClass differentiable class |
|
199 | */ |
|
200 | public DerivativeLoader(final Class<? extends UnivariateDifferentiable> differentiableClass) { |
|
201 | 504 | super(differentiableClass.getClassLoader()); |
202 | 504 | } |
203 | ||
204 | /** Define a derivative class. |
|
205 | * @param differentiator class differentiator |
|
206 | * @param writer class writer |
|
207 | * @return a generated derivative class |
|
208 | */ |
|
209 | @SuppressWarnings("unchecked") |
|
210 | public Class<? extends UnivariateDerivative> |
|
211 | defineClass(final ClassDifferentiator differentiator, final ClassWriter writer) { |
|
212 | 504 | final String name = differentiator.getDerivativeClassName().replace('/', '.'); |
213 | 504 | final byte[] bytecode = writer.toByteArray(); |
214 | 504 | return (Class<? extends UnivariateDerivative>) defineClass(name, bytecode, 0, bytecode.length); |
215 | } |
|
216 | } |
|
217 | ||
218 | } |