Author: luc Date: Mon Aug 17 19:50:51 2009 New Revision: 805114 URL: http://svn.apache.org/viewvc?rev=805114&view=rev Log: use the tree API for classes too (in addition to methods) this will allow modifying fields and methods called by the differentiated method
Added: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java (with props) Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java Added: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java?rev=805114&view=auto ============================================================================== --- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java (added) +++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java Mon Aug 17 19:50:51 2009 @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.nabla.algorithmic; + +import org.apache.commons.nabla.core.DifferentialPair; + +/** + * Interface defining methods descriptors for differentials. + */ +public interface Descriptors { + + /** Name for the DifferentialPair class. */ + String DP_NAME = DifferentialPair.class.getName().replace('.', '/'); + + /** Descriptor for the DifferentialPair class. */ + String DP_DESCRIPTOR = "L" + DP_NAME + ";"; + + /** Descriptor for the primitive class f method. */ + String D_RETURN_D_DESCRIPTOR = "(D)D"; + + /** Descriptor for the derivative class f method. */ + String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR; + + /** Descriptor for <code>DifferentialPair f(double)</code> methods. */ + String D_RETURN_DP_DESCRIPTOR = "(D)" + DP_DESCRIPTOR; + + /** Descriptor for <code>double f()</code> methods. */ + String VOID_RETURN_D_DESCRIPTOR = "()D"; + +} Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java ------------------------------------------------------------------------------ svn:eol-style = native Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java ------------------------------------------------------------------------------ svn:keywords = Author Date Id Revision Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java?rev=805114&r1=805113&r2=805114&view=diff ============================================================================== --- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java (original) +++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java Mon Aug 17 19:50:51 2009 @@ -25,6 +25,7 @@ import java.util.HashSet; import java.util.Set; +import org.apache.commons.nabla.algorithmic.Descriptors; import org.apache.commons.nabla.algorithmic.forward.analysis.ForwardModeClassDifferentiator; import org.apache.commons.nabla.core.DifferentiationException; import org.apache.commons.nabla.core.UnivariateDerivative; @@ -32,6 +33,7 @@ import org.apache.commons.nabla.core.UnivariateDifferentiator; import org.objectweb.asm.ClassReader; import org.objectweb.asm.ClassWriter; +import org.objectweb.asm.tree.ClassNode; /** Algorithmic differentiator class in forward mode based on bytecode analysis. * <p>This class is an implementation of the {...@link UnivariateDifferentiator} @@ -163,19 +165,21 @@ throws DifferentiationException { try { - // set up both ends of the class transform chain + // get the original class final String classResourceName = "/" + differentiableClass.getName().replace('.', '/') + ".class"; final InputStream stream = differentiableClass.getResourceAsStream(classResourceName); final ClassReader reader = new ClassReader(stream); - final ClassWriter writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES); // differentiate the function embedded in the differentiable class - final ForwardModeClassDifferentiator differentiator = new ForwardModeClassDifferentiator(mathClasses, writer); - reader.accept(differentiator, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES); - differentiator.reportErrors(); + final ForwardModeClassDifferentiator differentiator = + new ForwardModeClassDifferentiator(reader, mathClasses); + differentiator.differentiateMethod("f", Descriptors.D_RETURN_D_DESCRIPTOR, + Descriptors.DP_RETURN_DP_DESCRIPTOR); // create the derivative class - return new DerivativeLoader(differentiableClass).defineClass(differentiator, writer); + final ClassNode derived = differentiator.getDerivedClass(); + final ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_FRAMES); + return new DerivativeLoader(differentiableClass).defineClass(derived, writer); } catch (IOException ioe) { throw new DifferentiationException("class {0} cannot be read ({1})", @@ -194,14 +198,15 @@ } /** Define a derivative class. - * @param differentiator class differentiator + * @param classNode differentiated class * @param writer class writer * @return a generated derivative class */ @SuppressWarnings("unchecked") public Class<? extends UnivariateDerivative> - defineClass(final ForwardModeClassDifferentiator differentiator, final ClassWriter writer) { - final String name = differentiator.getDerivativeClassName().replace('/', '.'); + defineClass(final ClassNode classNode, final ClassWriter writer) { + final String name = classNode.name.replace('/', '.'); + classNode.accept(writer); final byte[] bytecode = writer.toByteArray(); return (Class<? extends UnivariateDerivative>) defineClass(name, bytecode, 0, bytecode.length); } Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java?rev=805114&r1=805113&r2=805114&view=diff ============================================================================== --- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java (original) +++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java Mon Aug 17 19:50:51 2009 @@ -16,37 +16,38 @@ */ package org.apache.commons.nabla.algorithmic.forward.analysis; +import java.util.ArrayList; +import java.util.List; import java.util.Set; import org.apache.commons.nabla.core.DifferentiationException; import org.apache.commons.nabla.core.UnivariateDerivative; import org.apache.commons.nabla.core.UnivariateDifferentiable; -import org.objectweb.asm.AnnotationVisitor; -import org.objectweb.asm.Attribute; -import org.objectweb.asm.ClassVisitor; -import org.objectweb.asm.FieldVisitor; -import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.ClassReader; import org.objectweb.asm.Opcodes; +import org.objectweb.asm.tree.ClassNode; +import org.objectweb.asm.tree.FieldNode; +import org.objectweb.asm.tree.MethodNode; /** - * Visitor (in asm sense) for differentiating classes using forward mode. + * Differentiator for classes using forward mode. * <p> - * This visitor visits classes implementing the + * This differentiator transforms classes implementing the * {...@link UnivariateDifferentiable UnivariateDifferentiable} interface and convert * them to classes implementing the {...@link UnivariateDerivative * UnivariateDerivative} interface. * </p> * <p> - * The visitor creates a new class as an inner class of the visited class. + * The differentiator creates a new class as an inner class of the visited class. * Instances of the generated class are therefore automatically bound to their * primitive instance which is their directly enclosing instance. As such they * have access to the current value of all fields. * </p> * <p> - * The visited class bytecode is not changed at all. + * The original class bytecode is not changed at all. * </p> */ -public class ForwardModeClassDifferentiator implements ClassVisitor { +public class ForwardModeClassDifferentiator { /** Name for the primitive instance field. */ private static final String PRIMITIVE_FIELD = "primitive"; @@ -54,57 +55,42 @@ /** Math implementation classes. */ private final Set<String> mathClasses; - /** Class generating visitor. */ - private final ClassVisitor generator; - - /** Error reporter. */ - private final ErrorReporter errorReporter; + /** Class to differentiate. */ + private final ClassNode classNode; /** Primitive class name. */ - private String primitiveName; - - /** Descriptor for the primitive class. */ - private String primitiveDesc; + private final String primitiveName; - /** Derivative class name. */ - private String derivativeName; + /** Primitive class methods. */ + private final List<MethodNode> primitiveMethods; - /** Indicator for specific fields and method addition. */ - private boolean specificMembersAdded; + /** Descriptor for the primitive class. */ + private final String primitiveDesc; /** * Simple constructor. + * @param reader reader for the primitive class * @param mathClasses math implementation classes - * @param generator visitor to which class generation calls will be delegated + * @exception DifferentiationException if class cannot be differentiated */ - public ForwardModeClassDifferentiator(final Set<String> mathClasses, - final ClassVisitor generator) { - this.mathClasses = mathClasses; - this.generator = generator; - errorReporter = new ErrorReporter(); - } + @SuppressWarnings("unchecked") + public ForwardModeClassDifferentiator(final ClassReader reader, + final Set<String> mathClasses) + throws DifferentiationException { - /** - * Get the name of the derivative class. - * @return name of the (generated) derivative class - */ - public String getDerivativeClassName() { - return derivativeName; - } + classNode = new ClassNode(); + reader.accept(classNode, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES); + this.mathClasses = mathClasses; - /** {...@inheritdoc} */ - public void visit(final int version, final int access, - final String name, final String signature, - final String superName, final String[] interfaces) { - // set up the various names - primitiveName = name; - derivativeName = primitiveName + "$NablaForwardModeUnivariateDerivative"; - primitiveDesc = "L" + primitiveName + ";"; + // store the primitive class properties + primitiveName = classNode.name; + primitiveDesc = "L" + primitiveName + ";"; + primitiveMethods = classNode.methods; // check the UnivariateDifferentiable interface is implemented final Class<UnivariateDifferentiable> uDerClass = UnivariateDifferentiable.class; boolean isDifferentiable = false; - for (String interf : interfaces) { + for (String interf : (List<String>) classNode.interfaces) { final String interfName = interf.replace('/', '.'); Class<?> interfClass = null; try { @@ -112,162 +98,106 @@ } catch (ClassNotFoundException cnfe) { // this should never occur since class has already been loaded // and an instance already exists ... - errorReporter.register(new DifferentiationException("interface {0} not found " + - "while differentiating class {1}", - interfName, name)); + throw new DifferentiationException("interface {0} not found " + + "while differentiating class {1}", + interfName, primitiveName); } if (interfClass != null) { isDifferentiable = isDifferentiable || uDerClass.isAssignableFrom(interfClass); } } - if (isDifferentiable) { - // generate the new class implementing the UnivariateDerivative interface - generator.visit(version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, - derivativeName, signature, superName, - new String[] { - UnivariateDerivative.class.getName().replace('.', '/') - }); - } else { - errorReporter.register(new DifferentiationException("the {0} class does not implement " + - "the {1} interface", - name, uDerClass.getName())); - } - - specificMembersAdded = false; + if (!isDifferentiable) { + throw new DifferentiationException("the {0} class does not implement the {1} interface", + primitiveName, uDerClass.getName()); + } + + // change the class properties for the derived class + classNode.access = Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC; + classNode.name = primitiveName + "$NablaForwardModeUnivariateDerivative"; + classNode.fields = new ArrayList<FieldNode>(); + classNode.methods = new ArrayList<MethodNode>(); + classNode.interfaces.clear(); + classNode.interfaces.add(UnivariateDerivative.class.getName().replace('.', '/')); + + // primitive instance field and methods setting/getting it + addPrimitiveField(); + addConstructor(); + addGetPrimitive(); } - /** {...@inheritdoc} */ - public MethodVisitor visitMethod(final int access, final String name, - final String desc, final String signature, - final String[] exceptions) { - - // don't do anything if an error has already been encountered - if (errorReporter.hasError()) { - return null; - } - - if (!specificMembersAdded) { - // add the specific members we need - addPrimitiveField(); - addConstructor(); - addGetPrimitive(); - specificMembersAdded = true; - } - - // is it the "public double f(double)" method we want to differentiate ? - if (((access & Opcodes.ACC_PUBLIC) == Opcodes.ACC_PUBLIC) && - "f".equals(name) && "(D)D".equals(desc) && - ((exceptions == null) || (exceptions.length == 0))) { - - // get a generator for the method we are going to create - final MethodVisitor visitor = - generator.visitMethod(access | Opcodes.ACC_SYNTHETIC, name, - MethodDifferentiator.DP_RETURN_DP_DESCRIPTOR, null, null); - - // make sure our own differentiator will be used to transform the code - return new MethodDifferentiator(access, name, desc, signature, exceptions, - visitor, primitiveName, mathClasses, errorReporter); + /** + * Differentiate a method. + * @param name of the method + * @param primitiveDesc descriptor of the method in the primitive class + * @param derivativeDesc descriptor of the method in the derivative class + * @exception DifferentiationException if method cannot be differentiated + */ + @SuppressWarnings("unchecked") + public void differentiateMethod(final String name, final String primitiveDesc, + final String derivativeDesc) + throws DifferentiationException { + + for (final MethodNode method : primitiveMethods) { + if (method.name.equals(name) && method.desc.equals(primitiveDesc)) { + + final MethodDifferentiator differentiator = new MethodDifferentiator(mathClasses); + differentiator.differentiate(primitiveName, method); + classNode.methods.add(method); + } } - - // we are not interested in this method - return null; - - } - - /** {...@inheritdoc} */ - public FieldVisitor visitField(final int access, final String name, - final String desc, final String signature, - final Object value) { - // we are not interested in any fields - return null; - } - - /** {...@inheritdoc} */ - public void visitSource(final String source, final String debug) { - } - - /** {...@inheritdoc} */ - public void visitOuterClass(final String owner, final String name, - final String desc) { - } - - /** {...@inheritdoc} */ - public AnnotationVisitor visitAnnotation(final String desc, - final boolean visible) { - return null; - } - - /** {...@inheritdoc} */ - public void visitAttribute(final Attribute attr) { - } - - /** {...@inheritdoc} */ - public void visitInnerClass(final String name, final String outerName, - final String innerName, final int access) { } - /** {...@inheritdoc} */ - public void visitEnd() { - - // don't do anything if an error has already been encountered - if (errorReporter.hasError()) { - return; - } - - generator.visitEnd(); - + /** + * Get the derived class. + * @return derived class + */ + public ClassNode getDerivedClass() { + return classNode; } /** Add the primitive field. */ + @SuppressWarnings("unchecked") private void addPrimitiveField() { - final FieldVisitor visitor = - generator.visitField(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC, - PRIMITIVE_FIELD, primitiveDesc, null, null); - visitor.visitEnd(); + FieldNode primitiveField = + new FieldNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC, + PRIMITIVE_FIELD, primitiveDesc, null, null); + classNode.fields.add(primitiveField); } /** Add the class constructor. */ + @SuppressWarnings("unchecked") private void addConstructor() { final String init = "<init>"; - final MethodVisitor visitor = - generator.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, init, - "(" + primitiveDesc + ")V", null, null); - visitor.visitCode(); - visitor.visitVarInsn(Opcodes.ALOAD, 0); - visitor.visitMethodInsn(Opcodes.INVOKESPECIAL, "java/lang/Object", init, "()V"); - visitor.visitVarInsn(Opcodes.ALOAD, 0); - visitor.visitVarInsn(Opcodes.ALOAD, 1); - visitor.visitFieldInsn(Opcodes.PUTFIELD, derivativeName, PRIMITIVE_FIELD, primitiveDesc); - visitor.visitInsn(Opcodes.RETURN); - visitor.visitMaxs(0, 0); - visitor.visitEnd(); + final MethodNode constructor = + new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, init, + "(" + primitiveDesc + ")V", null, null); + constructor.visitVarInsn(Opcodes.ALOAD, 0); + constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, "java/lang/Object", init, "()V"); + constructor.visitVarInsn(Opcodes.ALOAD, 0); + constructor.visitVarInsn(Opcodes.ALOAD, 1); + constructor.visitFieldInsn(Opcodes.PUTFIELD, classNode.name, PRIMITIVE_FIELD, primitiveDesc); + constructor.visitInsn(Opcodes.RETURN); + constructor.visitMaxs(0, 0); + classNode.methods.add(constructor); } /** Add the {...@link UnivariateDerivative#getPrimitive() getPrimitive()} method. */ + @SuppressWarnings("unchecked") private void addGetPrimitive() { - final MethodVisitor visitor = - generator.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, "getPrimitive", - "()" + primitiveDesc, null, null); - visitor.visitCode(); - visitor.visitVarInsn(Opcodes.ALOAD, 0); - visitor.visitFieldInsn(Opcodes.GETFIELD, derivativeName, PRIMITIVE_FIELD, primitiveDesc); - visitor.visitInsn(Opcodes.ARETURN); - visitor.visitMaxs(0, 0); - visitor.visitEnd(); - } - - /** Report the errors that may have occurred during analysis. - * @exception DifferentiationException if the derivative class - * could not be generated - */ - public void reportErrors() throws DifferentiationException { - errorReporter.reportErrors(); - } + final MethodNode method = + new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, "getPrimitive", + "()" + primitiveDesc, null, null); + method.visitVarInsn(Opcodes.ALOAD, 0); + method.visitFieldInsn(Opcodes.GETFIELD, classNode.name, PRIMITIVE_FIELD, primitiveDesc); + method.visitInsn(Opcodes.ARETURN); + method.visitMaxs(0, 0); + classNode.methods.add(method); + } } Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java?rev=805114&r1=805113&r2=805114&view=diff ============================================================================== --- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java (original) +++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java Mon Aug 17 19:50:51 2009 @@ -25,6 +25,7 @@ import java.util.Map; import java.util.Set; +import org.apache.commons.nabla.algorithmic.Descriptors; import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer1; import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer12; import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer2; @@ -86,9 +87,7 @@ import org.apache.commons.nabla.algorithmic.forward.trimming.DLoadPop2Trimmer; import org.apache.commons.nabla.algorithmic.forward.trimming.SwappedDloadTrimmer; import org.apache.commons.nabla.algorithmic.forward.trimming.SwappedDstoreTrimmer; -import org.apache.commons.nabla.core.DifferentialPair; import org.apache.commons.nabla.core.DifferentiationException; -import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.Opcodes; import org.objectweb.asm.tree.AbstractInsnNode; import org.objectweb.asm.tree.IincInsnNode; @@ -107,22 +106,7 @@ /** Class transforming a method computing a value to a method * computing both a value and its differential. */ -public class MethodDifferentiator extends MethodNode { - - /** Name for the DifferentialPair class. */ - public static final String DP_NAME = DifferentialPair.class.getName().replace('.', '/'); - - /** Descriptor for the DifferentialPair class. */ - public static final String DP_DESCRIPTOR = "L" + DP_NAME + ";"; - - /** Descriptor for the derivative class f method. */ - public static final String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR; - - /** Descriptor for <code>DifferentialPair f(double)</code> methods. */ - public static final String D_RETURN_DP_DESCRIPTOR = "(D)" + DP_DESCRIPTOR; - - /** Descriptor for <code>double f()</code> methods. */ - private static final String VOID_RETURN_D_DESCRIPTOR = "()D"; +public class MethodDifferentiator { /** Math functions transformer. */ private static final Map<String, MathInvocationTransformer> MATH_TRANSFORMERS = @@ -168,18 +152,9 @@ /** Math implementation classes. */ private final Set<String> mathClasses; - /** Generator to use. */ - private final MethodVisitor generator; - /** Used locals variables array. */ private boolean[] usedLocals; - /** Primitive class name. */ - private final String primitiveName; - - /** Error reporter to use. */ - private final ErrorReporter errorReporter; - /** Set of converted values. */ private final Set<TrackingValue> converted; @@ -193,110 +168,95 @@ private final Map<LabelNode, LabelNode> clonedLabels; /** Build a differentiator for a method. - * @param access access flags of the method - * @param name name of the method - * @param desc descriptor of the method - * @param signature signature of the method - * @param exceptions exceptions thrown by the method - * @param generator bytecode generator to use for the transformed method - * @param primitiveName primitive class name * @param mathClasses math implementation classes - * @param errorReporter reporter used for delaying exceptions */ - public MethodDifferentiator(final int access, final String name, final String desc, - final String signature, final String[] exceptions, - final MethodVisitor generator,final String primitiveName, - final Set<String> mathClasses, - final ErrorReporter errorReporter) { - - super(access, name, desc, signature, exceptions); - this.generator = generator; + public MethodDifferentiator(final Set<String> mathClasses) { this.usedLocals = null; - this.primitiveName = primitiveName; this.mathClasses = mathClasses; - this.errorReporter = errorReporter; this.converted = new HashSet<TrackingValue>(); this.frames = new IdentityHashMap<AbstractInsnNode, Frame>(); this.successors = new IdentityHashMap<AbstractInsnNode, Set<AbstractInsnNode>>(); this.clonedLabels = new HashMap<LabelNode, LabelNode>(); - } - /** {...@inheritdoc} */ - @Override - public void visitEnd() { + /** + * Differentiate a method. + * @param primitiveName primitive class name + * @param method method to differentiate (<em>will</em> be modified) + * @exception DifferentiationException if method cannot be differentiated + */ + public void differentiate(final String primitiveName, final MethodNode method) + throws DifferentiationException { try { // at start, "this" and one differential pair are already used - maxLocals = 2 * (maxLocals + MAX_TEMP) - 1; - usedLocals = new boolean[maxLocals]; + method.maxLocals = 2 * (method.maxLocals + MAX_TEMP) - 1; + usedLocals = new boolean[method.maxLocals]; useLocal(0, 1); useLocal(1, 4); // add spare cells to hold new variables if needed - addSpareLocalVariables(); + addSpareLocalVariables(method.instructions); // analyze the original code, tracing values production/consumption - final Frame[] array = - new FlowAnalyzer(new TrackingInterpreter()).analyze(primitiveName, this); + final FlowAnalyzer analyzer = + new FlowAnalyzer(new TrackingInterpreter(), method.instructions); + final Frame[] array = analyzer.analyze(primitiveName, method); // convert the array into a map, since code changes will shift all indices for (int i = 0; i < array.length; ++i) { - frames.put(instructions.get(i), array[i]); + frames.put(method.instructions.get(i), array[i]); } // identify the needed changes - final Set<AbstractInsnNode> changes = identifyChanges(); + final Set<AbstractInsnNode> changes = identifyChanges(method.instructions); if (changes.isEmpty()) { // the method does not depend on the parameter at all! // we replace all "return d;" by "return DifferentialPair.newConstant(d);" - for (final Iterator<?> i = instructions.iterator(); i.hasNext();) { + for (final Iterator<?> i = method.instructions.iterator(); i.hasNext();) { final AbstractInsnNode insn = (AbstractInsnNode) i.next(); if (insn.getOpcode() == Opcodes.DRETURN) { final InsnList list = new InsnList(); list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, - MethodDifferentiator.DP_NAME, - "newConstant", D_RETURN_DP_DESCRIPTOR)); + Descriptors.DP_NAME, + "newConstant", + Descriptors.D_RETURN_DP_DESCRIPTOR)); list.add(new InsnNode(Opcodes.ARETURN)); - instructions.insert(insn, list); - instructions.remove(insn); + method.instructions.insert(insn, list); + method.instructions.remove(insn); } } } else { // perform the code changes - changeCode(changes); - - // remove the local variables added at the beginning and not used - removeUnusedSpareLocalVariables(); + changeCode(method.instructions, changes); // trim generated instructions list - SwappedDloadTrimmer.getInstance().trim(instructions); - SwappedDstoreTrimmer.getInstance().trim(instructions); - DLoadPop2Trimmer.getInstance().trim(instructions); + SwappedDloadTrimmer.getInstance().trim(method.instructions); + SwappedDstoreTrimmer.getInstance().trim(method.instructions); + DLoadPop2Trimmer.getInstance().trim(method.instructions); } - // change the descriptor to its true final value - desc = DP_RETURN_DP_DESCRIPTOR; + // remove the local variables added at the beginning and not used + removeUnusedSpareLocalVariables(method.instructions); - // generate the method - accept(generator); + // change the method properties to the derivative ones + method.desc = Descriptors.DP_RETURN_DP_DESCRIPTOR; + method.access |= Opcodes.ACC_SYNTHETIC; + method.maxLocals = maxVariables(); } catch (AnalyzerException ae) { + ae.printStackTrace(System.err); if ((ae.getCause() != null) && ae.getCause() instanceof DifferentiationException) { - errorReporter.register((DifferentiationException) ae.getCause()); + throw (DifferentiationException) ae.getCause(); } else { - final DifferentiationException de = - new DifferentiationException("unable to analyze the {0}.{1} method ({2})", - primitiveName, name, ae.getMessage()); - errorReporter.register(de); + throw new DifferentiationException("unable to analyze the {0}.{1} method ({2})", + primitiveName, method.name, ae.getMessage()); } - } catch (DifferentiationException de) { - errorReporter.register(de); } } @@ -309,11 +269,13 @@ * be referenced by the converted instructions in the following passes.</p> * <p>The spare cells that will not be used will be reclaimed after * conversion, to avoid wasting memory.</p> + * @param instructions instructions of the method * @exception DifferentiationException if local variables array has not been * expanded appropriately beforehand * @see #removeUnusedSpareLocalVariables() */ - private void addSpareLocalVariables() throws DifferentiationException { + private void addSpareLocalVariables(final InsnList instructions) + throws DifferentiationException { for (final Iterator<?> i = instructions.iterator(); i.hasNext();) { final AbstractInsnNode insn = (AbstractInsnNode) i.next(); if (insn.getType() == AbstractInsnNode.VAR_INSN) { @@ -340,9 +302,10 @@ } /** Remove the unused spare cells introduced at conversion start. + * @param instructions instructions of the method * @see #addSpareLocalVariables() */ - private void removeUnusedSpareLocalVariables() { + private void removeUnusedSpareLocalVariables(final InsnList instructions) { for (final Iterator<?> i = instructions.iterator(); i.hasNext();) { final AbstractInsnNode insn = (AbstractInsnNode) i.next(); if (insn.getType() == AbstractInsnNode.VAR_INSN) { @@ -358,9 +321,10 @@ * instructions path, updating stack cells and local variables as needed. * Instructions that must be changed are the ones that consume changed * variables or stack cells.</p> + * @param instructions instructions of the method * @return set containing all the instructions that must be changed */ - private Set<AbstractInsnNode> identifyChanges() { + private Set<AbstractInsnNode> identifyChanges(final InsnList instructions) { // the pending set contains the values (local variables or stack cells) // that have been changed, they will trigger changes on the instructions @@ -461,21 +425,22 @@ } /** Perform the code changes. + * @param instructions instructions of the method * @param changes instructions that must be changed * @exception DifferentiationException if some instruction cannot be handled */ - private void changeCode(final Set<AbstractInsnNode> changes) + private void changeCode(final InsnList instructions, final Set<AbstractInsnNode> changes) throws DifferentiationException { // insert the parameter conversion code at method start final InsnList list = new InsnList(); list.add(new VarInsnNode(Opcodes.ALOAD, 1)); list.add(new InsnNode(Opcodes.DUP)); - list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME, - "getValue", VOID_RETURN_D_DESCRIPTOR)); + list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, Descriptors.DP_NAME, + "getValue", Descriptors.VOID_RETURN_D_DESCRIPTOR)); list.add(new VarInsnNode(Opcodes.DSTORE, 1)); - list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME, - "getFirstDerivative", VOID_RETURN_D_DESCRIPTOR)); + list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, Descriptors.DP_NAME, + "getFirstDerivative", Descriptors.VOID_RETURN_D_DESCRIPTOR)); list.add(new VarInsnNode(Opcodes.DSTORE, 3)); instructions.insertBefore(instructions.get(0), list); @@ -688,7 +653,7 @@ * @param name name of the class to test * @return true if the named class is a math implementation class */ - public boolean isMathImplementationClass(final String name) { + private boolean isMathImplementationClass(final String name) { return mathClasses.contains(name); } @@ -744,7 +709,7 @@ /** Shifted the index of a variable instruction. * @param insn variable instruction */ - public void shiftVariable(final VarInsnNode insn) { + private void shiftVariable(final VarInsnNode insn) { int shifted = 0; for (int i = 0; i < insn.var; ++i) { if (usedLocals[i]) { @@ -754,6 +719,19 @@ insn.var = shifted; } + /** Compute the maximal number of used local variables. + * @return maximal number of used local variables + */ + private int maxVariables() { + int max = 0; + for (final boolean isUsed : usedLocals) { + if (isUsed) { + ++max; + } + } + return max; + } + /** Clone an instruction. * @param insn instruction to clone * @return cloned instruction @@ -765,11 +743,17 @@ /** Analyzer preserving instructions successors information. */ private class FlowAnalyzer extends Analyzer { + /** Instructions of the method. */ + private final InsnList instructions; + /** Simple constructor. * @param interpreter associated interpreter + * @param instructions instructions of the method */ - public FlowAnalyzer(final Interpreter interpreter) { + public FlowAnalyzer(final Interpreter interpreter, + final InsnList instructions) { super(interpreter); + this.instructions = instructions; } /** Store a new edge. Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java?rev=805114&r1=805113&r2=805114&view=diff ============================================================================== --- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java (original) +++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java Mon Aug 17 19:50:51 2009 @@ -16,6 +16,7 @@ */ package org.apache.commons.nabla.algorithmic.forward.instructions; +import org.apache.commons.nabla.algorithmic.Descriptors; import org.apache.commons.nabla.algorithmic.forward.analysis.InstructionsTransformer; import org.apache.commons.nabla.algorithmic.forward.analysis.MethodDifferentiator; import org.apache.commons.nabla.core.DifferentiationException; @@ -68,17 +69,16 @@ final InsnList list = new InsnList(); // operand stack initial state: a0, a1 - list.add(new VarInsnNode(Opcodes.DSTORE, 3)); // => a0 - list.add(new VarInsnNode(Opcodes.DSTORE, 1)); // => - list.add(new TypeInsnNode(Opcodes.NEW, - MethodDifferentiator.DP_NAME)); // => o, - list.add(new InsnNode(Opcodes.DUP)); // => o, o - list.add(new VarInsnNode(Opcodes.DLOAD, 1)); // => o, o, a0 - list.add(new VarInsnNode(Opcodes.DLOAD, 3)); // => o, o, a0, a1 + list.add(new VarInsnNode(Opcodes.DSTORE, 3)); // => a0 + list.add(new VarInsnNode(Opcodes.DSTORE, 1)); // => + list.add(new TypeInsnNode(Opcodes.NEW, Descriptors.DP_NAME)); // => o, + list.add(new InsnNode(Opcodes.DUP)); // => o, o + list.add(new VarInsnNode(Opcodes.DLOAD, 1)); // => o, o, a0 + list.add(new VarInsnNode(Opcodes.DLOAD, 3)); // => o, o, a0, a1 list.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, - MethodDifferentiator.DP_NAME, - "<init>", "(DD)V")); // => dp - list.add(new InsnNode(Opcodes.ARETURN)); // => + Descriptors.DP_NAME, + "<init>", "(DD)V")); // => dp + list.add(new InsnNode(Opcodes.ARETURN)); // => return list; }