commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From l..@apache.org
Subject svn commit: r1411753 - in /commons/sandbox/nabla/trunk/src: main/java/org/apache/commons/nabla/forward/analysis/ main/java/org/apache/commons/nabla/forward/instructions/ test/java/org/apache/commons/nabla/forward/
Date Tue, 20 Nov 2012 17:22:05 GMT
Author: luc
Date: Tue Nov 20 17:22:04 2012
New Revision: 1411753

URL: http://svn.apache.org/viewvc?rev=1411753&view=rev
Log:
Ensure all differentiated method have a reference DerivativeStructure.

Some methods that must be differentiated rely only on fields to compute
their return value and do not have double arguments. When
differentiated, they must build DerivativeStructure instances on their
own. However, building DerivativeStructure instances requires knowing
the number of free parameters and the derivation order, which are
usually passed as arguments.

Previously, methods that had differentiated arguments retrieved the
number of free parameters and the derivation order from their arguments,
and other methods retrieved these values from the fields they used. This
did not work if the fields have not been initialized yet and only the
primitive double field was available. So now we make sure that all
differentiated methods get at least one differentiated argument, even if
the primitive method did not have any double arguments. We set up a
custom additional arguments if needed. This indeed simplifies handling,
as we are sure we always use the arguments.

Modified:
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java
    commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java?rev=1411753&r1=1411752&r2=1411753&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
(original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
Tue Nov 20 17:22:04 2012
@@ -137,14 +137,9 @@ public class MethodDifferentiator {
         throws DifferentiationException {
         try {
 
-            // copy the primitive method as a new independent node
-            final MethodNode method =
-                    new MethodNode(primitiveMethod.access | Opcodes.ACC_SYNTHETIC,
-                                   primitiveMethod.name, differentiatedMethod.getDifferentiatedType().getDescriptor(),
-                                   null, primitiveMethod.exceptions.toArray(new String[primitiveMethod.exceptions.size()]));
-            primitiveMethod.accept(method);
+            // create a new independent node with proper arguments
+            final MethodNode method = fixArguments(primitiveMethod, differentiatedMethod);
 
-            final boolean isStatic = (method.access & Opcodes.ACC_STATIC) != 0;
             final boolean[] usedLocals = new boolean[method.maxLocals + 1];
 
             // analyze the original code, tracing values production/consumption
@@ -162,7 +157,8 @@ public class MethodDifferentiator {
                     identifyChanges(method.name, usedLocals,
                                     differentiatedMethod.getPrimitiveSpec().getPrimitiveType(),
                                     differentiatedMethod.getDifferentiatedType(),
-                                    method.instructions, isStatic);
+                                    method.instructions,
+                                    differentiatedMethod.getPrimitiveSpec().isStatic());
 
             // perform the code changes
             for (final AbstractInsnNode insn : changes) {
@@ -181,7 +177,7 @@ public class MethodDifferentiator {
                 // insert the preservation of the reference derivative structure
                 // (we know we have reserved the last local variable for this)
                 method.instructions.insert(preserveReferenceDerivativeStructure(differentiatedMethod,
-                                                                                isStatic,
usedLocals.length - 1));
+                                                                                usedLocals.length
- 1));
             }
 
             // remove the local variables added at the beginning and not used
@@ -207,6 +203,63 @@ public class MethodDifferentiator {
         }
     }
 
+    /** Fix method node as per by arguments types and number changes.
+     * <p>
+     * Transformed method may have some arguments types changed or additional
+     * arguments appended at the end. We fix the node there to make sure flow
+     * analysis will behave properly.
+     * </p>
+     * @param primitiveMethod method to differentiate
+     * @param differentiatedMethod specification for the differentiated method
+     * @return method with fixed descriptor and local variables (but no instructions changed
yet)
+     */
+    private MethodNode fixArguments(final MethodNode primitiveMethod,
+                                    final DifferentiatedElementSpecification differentiatedMethod)
{
+
+        final Type primitiveType      = differentiatedMethod.getPrimitiveSpec().getPrimitiveType();
+        final Type differentiatedType = differentiatedMethod.getDifferentiatedType();
+
+        // copy the primitive method as a new independent node
+        final MethodNode method =
+                 new MethodNode(primitiveMethod.access | Opcodes.ACC_SYNTHETIC,
+                               primitiveMethod.name, differentiatedType.getDescriptor(),
+                               null, primitiveMethod.exceptions.toArray(new String[primitiveMethod.exceptions.size()]));
+        primitiveMethod.accept(method);
+
+        // check for possible additional arguments
+        int primitiveArgCount      = primitiveType.getArgumentTypes().length;
+        int differentiatedArgCount = differentiatedType.getArgumentTypes().length;
+        if (differentiatedMethod.getPrimitiveSpec().isStatic()) {
+            // this is a static method, there is no "this" local variable
+            --primitiveArgCount;
+            --differentiatedArgCount;
+        }
+        final int additionalArguments = differentiatedArgCount - primitiveArgCount;
+
+        if (additionalArguments > 0) {
+
+            final int lastArgIndex = primitiveType.getArgumentsAndReturnSizes() >>
2;
+
+            // shift variable instructions indices to make sure they don't mess with the
new arguments
+            for (final ListIterator<AbstractInsnNode> iterator = method.instructions.iterator();
iterator.hasNext();) {
+                final AbstractInsnNode insn = iterator.next();
+                if (insn.getType() == AbstractInsnNode.VAR_INSN) {
+                    final VarInsnNode varInsn = (VarInsnNode) insn;
+                    if (varInsn.var > lastArgIndex) {
+                        varInsn.var += additionalArguments;
+                    }
+                }
+            }
+
+            // update the number of local variables used
+            method.maxLocals += additionalArguments;
+
+        }
+
+        return method;
+
+    }
+
     /** Mark local variables usage.
      * @param instructions methods instructions
      * @param usedLocals array of variables use indicators to fill in
@@ -800,21 +853,20 @@ public class MethodDifferentiator {
 
     /** Create instructions to preserve a reference transformed variable.
      * @param differentiatedMethod specification for the differentiated method
-     * @param isStatic if true, the method is a static method
      * @param dsIndex index of the reference transformed variable
      * @return list of conversion instructions
      */
     public InsnList preserveReferenceDerivativeStructure(final DifferentiatedElementSpecification
differentiatedMethod,
-                                                         final boolean isStatic, final int
dsIndex) {
+                                                         final int dsIndex) {
 
         final Type transformedType = transformType(Type.DOUBLE_TYPE);
         final Type[] parameterTypes = differentiatedMethod.getDifferentiatedType().getArgumentTypes();
-        int var = isStatic ? 0 : 1;
+        int var = differentiatedMethod.getPrimitiveSpec().isStatic() ? 0 : 1;
         for (int i = 0; i < parameterTypes.length; ++i) {
             if (parameterTypes[i].equals(transformedType)) {
-                // we have found the first derivative structure parameter
+                // we have found the first transformed variable argument
 
-                // preserve the parameter as a new variable
+                // preserve the argument as a new variable
                 final InsnList list = new InsnList();
                 list.add(new VarInsnNode(Opcodes.ALOAD, var));
                 list.add(new VarInsnNode(Opcodes.ASTORE, dsIndex));
@@ -826,30 +878,9 @@ public class MethodDifferentiator {
 
         }
 
-        // if we reach this point, this means there are no DerivativeStructure at all among
the arguments
-        // so the method must use a differentiated field, we use this field as the reference
-        final DifferentiatedElementSpecification field =
-                classDifferentiator.getUsedDifferentiatedField(differentiatedMethod.getPrimitiveSpec());
-        if (field == null) {
-            // this should never happen as we build class that do use DerivativeStructure
-            // either from their arguments or from class/instance fields
-            throw DifferentiationException.createInternalError(null);
-        }
-
-        // preserve the field as a new variable
-        final InsnList list = new InsnList();
-        if (field.getPrimitiveSpec().isStatic()) {
-            list.add(new FieldInsnNode(Opcodes.GETSTATIC,
-                                       getDifferentiatedInternalName(field.getPrimitiveSpec().getPrimitiveClass()),
-                                       field.getPrimitiveSpec().getName(), field.getDifferentiatedType().getDescriptor()));
-        } else {
-            list.add(new VarInsnNode(Opcodes.ALOAD, 0));
-            list.add(new FieldInsnNode(Opcodes.GETFIELD,
-                                       getDifferentiatedInternalName(field.getPrimitiveSpec().getPrimitiveClass()),
-                                       field.getPrimitiveSpec().getName(), field.getDifferentiatedType().getDescriptor()));
-        }
-        list.add(new VarInsnNode(Opcodes.ASTORE, dsIndex));
-        return list;
+        // this should never happen as we build methods that ensure a transformed variable
+        // is always present as an argument
+        throw DifferentiationException.createInternalError(null);
 
     }
 

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java?rev=1411753&r1=1411752&r2=1411753&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java
(original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java
Tue Nov 20 17:22:04 2012
@@ -16,6 +16,9 @@
  */
 package org.apache.commons.nabla.forward.instructions;
 
+import java.util.ArrayList;
+import java.util.List;
+
 import org.apache.commons.nabla.DifferentiationException;
 import org.apache.commons.nabla.forward.DifferentiatedElementSpecification;
 import org.apache.commons.nabla.forward.PrimitiveElementSpecification;
@@ -26,6 +29,7 @@ import org.objectweb.asm.Type;
 import org.objectweb.asm.tree.AbstractInsnNode;
 import org.objectweb.asm.tree.InsnList;
 import org.objectweb.asm.tree.MethodInsnNode;
+import org.objectweb.asm.tree.VarInsnNode;
 
 /** Differentiation transformer for INVOKESPECIAL/INVOKEVIRTUAL/INVOKESTATIC/INVOKEINTERFACE
  * instructions on non-math related classes.
@@ -42,17 +46,29 @@ public class InvokeNonMathTransformer im
         final MethodInsnNode methodInsn = (MethodInsnNode) insn;
 
         // build transformed method signature based on stack elements
-        final Type transformedType     = methodDifferentiator.transformType(Type.DOUBLE_TYPE);
-        final Type primitiveMethodType = Type.getMethodType(methodInsn.desc);
-        final Type[] argumentTypes     = new Type[primitiveMethodType.getArgumentTypes().length];
-        for (int i = 0; i < argumentTypes.length; ++i) {
-            final int index = argumentTypes.length - 1 - i;
-            argumentTypes[i] = methodDifferentiator.stackElementIsConverted(insn, index)
?
-                               transformedType : methodDifferentiator.stackElementType(insn,
index);
+        final Type transformedType      = methodDifferentiator.transformType(Type.DOUBLE_TYPE);
+        final Type primitiveMethodType  = Type.getMethodType(methodInsn.desc);
+        final int initialArgumentsCount = primitiveMethodType.getArgumentTypes().length;
+        final List<Type> argumentTypes  = new ArrayList<Type>(initialArgumentsCount);
+        boolean hasTransformedArgument  = false;
+        for (int i = 0; i < initialArgumentsCount; ++i) {
+            final int index = initialArgumentsCount - 1 - i;
+            if (methodDifferentiator.stackElementIsConverted(insn, index)) {
+                hasTransformedArgument = true;
+                argumentTypes.add(transformedType);
+            } else {
+                argumentTypes.add(methodDifferentiator.stackElementType(insn, index));
+            }
+        }
+        if (!hasTransformedArgument) {
+            // none of the original method arguments are transformed,
+            // we need to add a custom argument to pass the reference transformed variable
+            argumentTypes.add(transformedType);
         }
         final Type returnType =
                 methodDifferentiator.transformType(primitiveMethodType.getReturnType());
-        final Type differentiatedMethodType = Type.getMethodType(returnType, argumentTypes);
+        final Type differentiatedMethodType =
+                Type.getMethodType(returnType, argumentTypes.toArray(new Type[argumentTypes.size()]));
 
         // request the global differentiator to differentiate the invoked method
         final PrimitiveElementSpecification primitive =
@@ -63,8 +79,12 @@ public class InvokeNonMathTransformer im
                 new DifferentiatedElementSpecification(primitive, differentiatedMethodType);
         methodDifferentiator.requestMethodDifferentiation(differentiated);
 
-        // create the transformed instruction
+        // create the transformed instructions
         final InsnList list = new InsnList();
+        if (!hasTransformedArgument) {
+            // add the custom argument
+            list.add(new VarInsnNode(Opcodes.ALOAD, dsIndex));
+        }
         list.add(new MethodInsnNode(methodInsn.getOpcode(),
                                     methodDifferentiator.getDifferentiatedInternalName(methodInsn.owner),
                                     methodInsn.name,

Modified: commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java?rev=1411753&r1=1411752&r2=1411753&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java
(original)
+++ commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java
Tue Nov 20 17:22:04 2012
@@ -283,7 +283,7 @@ public class ForwardModeDifferentiatorTe
     public static class IntermediateDataFunction implements ReferenceFunction {
 
         // these fields are changed as side effects of each call to the value method
-        // as they are used to transfer data to the f method
+        // as they are used to transfer data from the value method to the f method
         private static double staticData;
         private double instanceData;
 



Mime
View raw message