commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From l..@apache.org
Subject svn commit: r1404270 - in /commons/sandbox/nabla/trunk/src: main/java/org/apache/commons/nabla/forward/ main/java/org/apache/commons/nabla/forward/analysis/ main/java/org/apache/commons/nabla/forward/instructions/ test/java/org/apache/commons/nabla/for...
Date Wed, 31 Oct 2012 17:49:01 GMT
Author: luc
Date: Wed Oct 31 17:49:01 2012
New Revision: 1404270

URL: http://svn.apache.org/viewvc?rev=1404270&view=rev
Log:
Finalized work on fields.

Now fields that are used only as intermediate data between several
methods are identified and the methods that use them are differentiated
as required.

Modified:
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java
    commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java?rev=1404270&r1=1404269&r2=1404270&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java Wed Oct 31 17:49:01 2012
@@ -26,7 +26,6 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
-import java.util.Map;
 import java.util.Set;
 
 import org.apache.commons.math3.analysis.UnivariateFunction;
@@ -90,9 +89,6 @@ public class ForwardModeDifferentiator i
     /** Processed fields differentiations. */
     private final Set<DifferentiatedElementSpecification> processedFields;
 
-    /** Map associating a field with the methods using it. */
-    private final Map<DifferentiatedElementSpecification, List<PrimitiveElementSpecification>> fieldsUsers;
-
     /** Simple constructor.
      * <p>Build a ForwardAlgorithmicDifferentiator instance with an empty cache.</p>
      */
@@ -106,7 +102,6 @@ public class ForwardModeDifferentiator i
         processedMethods = new HashSet<DifferentiatedElementSpecification>();
         pendingFields    = new HashSet<DifferentiatedElementSpecification>();
         processedFields  = new HashSet<DifferentiatedElementSpecification>();
-        fieldsUsers      = new HashMap<DifferentiatedElementSpecification, List<PrimitiveElementSpecification>>();
         addMathImplementation(Math.class);
         addMathImplementation(StrictMath.class);
         addMathImplementation(FastMath.class);
@@ -147,53 +142,54 @@ public class ForwardModeDifferentiator i
      */
     public DifferentiatedElementSpecification getUsedDifferentiatedField(final PrimitiveElementSpecification method) {
 
-        // loop over all fields and check their users
-        for (final Map.Entry<DifferentiatedElementSpecification, List<PrimitiveElementSpecification>> entry : fieldsUsers.entrySet()) {
-            if (entry.getValue().contains(method)) {
-                return entry.getKey();
-            }
-        }
-
-        // none of the fields we are interested in is used by this method
-        return null;
- 
-    }
-
-    /** Get a differentiated element.
-     * @param primitive primitive specification to check
-     * @return a differentiated element corresponding to the primitive, null if no elements correspond
-     */
-    public DifferentiatedElementSpecification getDifferentiated(final PrimitiveElementSpecification primitive) {
+        final ClassDifferentiator differentiator = getDifferentiator(method);
 
-        // loop over all pending fields
+        // look for the field in the pending fields set
         for (final DifferentiatedElementSpecification field : pendingFields) {
-            if (field.getPrimitiveSpec().equals(primitive)) {
+            final List<PrimitiveElementSpecification> users =
+                    differentiator.getMethodsUsingField(field.getPrimitiveSpec());
+            if (users.contains(method)) {
+                // the method uses the pending field
                 return field;
             }
         }
 
-        // loop over all processed fields
+        // look for the field in the processed fields set
         for (final DifferentiatedElementSpecification field : processedFields) {
-            if (field.getPrimitiveSpec().equals(primitive)) {
+            final List<PrimitiveElementSpecification> users =
+                    differentiator.getMethodsUsingField(field.getPrimitiveSpec());
+            if (users.contains(method)) {
+                // the method uses the processes field
                 return field;
             }
         }
 
+        // none of the fields we are interested in is used by this method
+        return null;
+ 
+    }
+
+    /** Get a differentiated method.
+     * @param primitiveMethod primitive method specification to check
+     * @return a differentiated method corresponding to the primitive method, null if no methods correspond
+     */
+    public DifferentiatedElementSpecification getDifferentiatedMethod(final PrimitiveElementSpecification primitiveMethod) {
+
         // loop over all pending methods
         for (final DifferentiatedElementSpecification method : pendingMethods) {
-            if (method.getPrimitiveSpec().equals(primitive)) {
+            if (method.getPrimitiveSpec().equals(primitiveMethod)) {
                 return method;
             }
         }
 
         // loop over all processed methods
         for (final DifferentiatedElementSpecification method : processedMethods) {
-            if (method.getPrimitiveSpec().equals(primitive)) {
+            if (method.getPrimitiveSpec().equals(primitiveMethod)) {
                 return method;
             }
         }
 
-        // none of the element we are interested in correspond to the specification
+        // none of the methods we are interested in correspond to the specification
         return null;
  
     }
@@ -203,11 +199,53 @@ public class ForwardModeDifferentiator i
      */
     public void requestMethodDifferentiation(final DifferentiatedElementSpecification method) {
         if (!processedMethods.contains(method)) {
+
             // schedule the request if method has not been processed yet
             pendingMethods.add(method);
+
+            // get the callers for this method
+            for (final ClassDifferentiator differentiator : differentiators) {
+                for (final PrimitiveElementSpecification caller : differentiator.getMethodsCallingMethod(method.getPrimitiveSpec())) {
+                    for (final Iterator<DifferentiatedElementSpecification> iterator = processedMethods.iterator(); iterator.hasNext();) {
+                        DifferentiatedElementSpecification processed = iterator.next();
+                        if (processed.getPrimitiveSpec().equals(caller)) {
+                            // the newly differentiated method was called by an already processed method
+                            // we need to redo the differentiation as typically the returned type may have been changed
+                            iterator.remove();
+                            pendingMethods.add(processed);
+                        }
+                    }
+                }
+            }
+
         }
     }
 
+    /** Get a differentiated field.
+     * @param primitiveField primitive field specification to check
+     * @return a differentiated field corresponding to the primitive field, null if no fields correspond
+     */
+    public DifferentiatedElementSpecification getDifferentiatedField(final PrimitiveElementSpecification primitiveField) {
+
+        // loop over all pending fields
+        for (final DifferentiatedElementSpecification field : pendingFields) {
+            if (field.getPrimitiveSpec().equals(primitiveField)) {
+                return field;
+            }
+        }
+
+        // loop over all processed fields
+        for (final DifferentiatedElementSpecification field : processedFields) {
+            if (field.getPrimitiveSpec().equals(primitiveField)) {
+                return field;
+            }
+        }
+
+        // none of the fields we are interested in correspond to the specification
+        return null;
+ 
+    }
+
     /** Request differentiation of a field.
      * @param field field to differentiate
      */
@@ -219,16 +257,55 @@ public class ForwardModeDifferentiator i
             pendingFields.add(field);
 
             // get the users for this field
-            final List<PrimitiveElementSpecification> users = new ArrayList<PrimitiveElementSpecification>();
             for (final ClassDifferentiator differentiator : differentiators) {
-                users.addAll(differentiator.getMethodsUsingField(field));
+                for (final PrimitiveElementSpecification user : differentiator.getMethodsUsingField(field.getPrimitiveSpec())) {
+                    for (final Iterator<DifferentiatedElementSpecification> iterator = processedMethods.iterator(); iterator.hasNext();) {
+                        DifferentiatedElementSpecification processed = iterator.next();
+                        if (processed.getPrimitiveSpec().equals(user)) {
+                            // the newly differentiated field was used by an already processed method
+                            // we need to redo the differentiation as typically the type may have been changed
+                            iterator.remove();
+                            pendingMethods.add(processed);
+                        }
+                    }
+                }
             }
-            fieldsUsers.put(field, users);
+
 
         }
 
     }
 
+    /** Check if a type is a double or a double array with any dimension.
+     * @param type type to check
+     * @return true if the type is a double type or a double array type with any dimension
+     */
+    public boolean isDoubleOrDoubleArray(final Type type) {
+        return type.equals(Type.DOUBLE_TYPE) ||
+                (type.getSort() == Type.ARRAY && type.getElementType().equals(Type.DOUBLE_TYPE));
+    }
+
+    /** Convert a double or a double array type to the corresponding {@link DerivativeStructure} type.
+     * @param type type to convert
+     * @return converted type
+     */
+    public Type convertDoubleOrDoubleArray(final Type type) {
+        final Type dsType = Type.getType(DerivativeStructure.class);
+        if (type.equals(Type.DOUBLE_TYPE)) {
+            return dsType;
+        } else if (type.getSort() == Type.ARRAY && type.getElementType().equals(Type.DOUBLE_TYPE)) {
+            final StringBuilder desc = new StringBuilder();
+            for (int i = 0; i < type.getDimensions(); ++i) {
+                desc.append('[');
+            }
+            desc.append(dsType.getDescriptor());
+            return Type.getType(desc.toString());
+        } else {
+            // don't change the type
+            return type;
+        }
+    }
+
     /** {@inheritDoc} */
     public NablaUnivariateDifferentiableFunction differentiate(final UnivariateFunction d) {
 
@@ -283,56 +360,35 @@ public class ForwardModeDifferentiator i
         throws DifferentiationException {
 
         // bootstrap differentiation starting from the top level value function defined by the UnivariateFunction interface
-        final Type dsType = Type.getType(DerivativeStructure.class);
         requestMethodDifferentiation(new DifferentiatedElementSpecification("value", false,
                                      Type.getInternalName(differentiableClass),
                                      Type.getMethodType(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE),
-                                     Type.getMethodType(dsType, dsType)));
+                                     Type.getMethodType(convertDoubleOrDoubleArray(Type.DOUBLE_TYPE),
+                                                        convertDoubleOrDoubleArray(Type.DOUBLE_TYPE))));
 
         // loop while there is still something to differentiate (either fields or methods)
         while (!(pendingFields.isEmpty() && pendingMethods.isEmpty())) {
 
-            for (final DifferentiatedElementSpecification df : pendingFields) {
-
-                // differentiate the field
-                getDifferentiator(df).differentiateField(df);
-
-                // get all the methods that use this field
-                final List<PrimitiveElementSpecification> users = fieldsUsers.get(df);
-
-                if (users != null) {
-
-                    // this field is used by some methods in the classes we monitor,
-                    // in order to make sure it is taken into account properly,
-                    // we invalidate the differentiation of these methods and process them again
-                    for (final PrimitiveElementSpecification user : users) {
-                        for (final Iterator<DifferentiatedElementSpecification> iterator = processedMethods.iterator(); iterator.hasNext();) {
-                            final DifferentiatedElementSpecification method = iterator.next();
-                            if (user.equals(method.getPrimitiveSpec())) {
-                                // the differentiated method uses the field, we need to differentiate it again
-                                iterator.remove();
-                                getDifferentiator(method).clearDifferentiatedElement(method);
-                                pendingMethods.add(method);
-                            }
-                        }
-                    }
-
-                }
-
-            }
-
-            // move all fields from pending to processed
-            processedFields.addAll(pendingFields);
+            // save the current status of elements to process at this iteration
+            // to protect the pending lists from concurrent modifications
+            final List<DifferentiatedElementSpecification> fields =
+                    new ArrayList<DifferentiatedElementSpecification>(pendingFields);
             pendingFields.clear();
+            final List<DifferentiatedElementSpecification> methods =
+                    new ArrayList<DifferentiatedElementSpecification>(pendingMethods);
+            pendingMethods.clear();
+
+            // differentiate the fields
+            for (final DifferentiatedElementSpecification field : fields) {
+                getDifferentiator(field.getPrimitiveSpec()).differentiateField(field);
+                processedFields.add(field);
+            }
 
-            // move one method from pending to processed
-            final Iterator<DifferentiatedElementSpecification> iterator = pendingMethods.iterator();
-            final DifferentiatedElementSpecification dm = iterator.next();
-            iterator.remove();
-            processedMethods.add(dm);
-
-            // differentiate the method
-            getDifferentiator(dm).differentiateMethod(dm);
+            // differentiate the methods
+            for (final DifferentiatedElementSpecification method : methods) {
+                getDifferentiator(method.getPrimitiveSpec()).differentiateMethod(method);
+                processedMethods.add(method);
+            }
 
         }
 
@@ -376,13 +432,13 @@ public class ForwardModeDifferentiator i
      * @return differentiator for {@code element}
      * @throws DifferentiationException if the class cannot be read
      */
-    private ClassDifferentiator getDifferentiator(final DifferentiatedElementSpecification element)
+    private ClassDifferentiator getDifferentiator(final PrimitiveElementSpecification element)
         throws DifferentiationException {
         try {
 
             // find a differentiator for the class owning the method
             for (ClassDifferentiator differentiator : differentiators) {
-                if (element.getPrimitiveSpec().getPrimitiveClass().equals(differentiator.getPrimitiveInternalName())) {
+                if (element.getPrimitiveClass().equals(differentiator.getPrimitiveInternalName())) {
                     // we have already build a differentiator for the same class, reuse it
                     return differentiator;
                 }
@@ -390,20 +446,15 @@ public class ForwardModeDifferentiator i
 
             // it is the first time we process this class, create a differentiator for it
             final ClassDifferentiator differentiator =
-                    new ClassDifferentiator(element.getPrimitiveSpec().getPrimitiveClass(), mathClasses, this);
+                    new ClassDifferentiator(element.getPrimitiveClass(), mathClasses, this);
             differentiators.add(differentiator);
 
-            // the new class may use some already monitored fields, we need to update the map
-            for (final Map.Entry<DifferentiatedElementSpecification, List<PrimitiveElementSpecification>> entry : fieldsUsers.entrySet()) {
-                entry.getValue().addAll(differentiator.getMethodsUsingField(entry.getKey()));
-            }
-
             return differentiator;
 
             
         } catch (IOException ioe) {
             throw new DifferentiationException(NablaMessages.CANNOT_READ_CLASS,
-                                               element.getPrimitiveSpec().getPrimitiveClass(), ioe.getMessage());
+                                               element.getPrimitiveClass(), ioe.getMessage());
         }
     }
 

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java?rev=1404270&r1=1404269&r2=1404270&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java Wed Oct 31 17:49:01 2012
@@ -23,14 +23,13 @@ import java.util.List;
 import java.util.Set;
 
 import org.apache.commons.math3.analysis.UnivariateFunction;
-import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
 import org.apache.commons.nabla.DifferentiationException;
 import org.apache.commons.nabla.NablaMessages;
-import org.apache.commons.nabla.forward.PrimitiveElementSpecification;
 import org.apache.commons.nabla.forward.DifferentiatedElementSpecification;
 import org.apache.commons.nabla.forward.ForwardModeDifferentiator;
 import org.apache.commons.nabla.forward.NablaDifferentiated;
 import org.apache.commons.nabla.forward.NablaUnivariateDifferentiableFunction;
+import org.apache.commons.nabla.forward.PrimitiveElementSpecification;
 import org.objectweb.asm.ClassReader;
 import org.objectweb.asm.Label;
 import org.objectweb.asm.Opcodes;
@@ -39,6 +38,7 @@ import org.objectweb.asm.tree.AbstractIn
 import org.objectweb.asm.tree.ClassNode;
 import org.objectweb.asm.tree.FieldInsnNode;
 import org.objectweb.asm.tree.FieldNode;
+import org.objectweb.asm.tree.MethodInsnNode;
 import org.objectweb.asm.tree.MethodNode;
 
 /**
@@ -222,18 +222,46 @@ public class ClassDifferentiator {
         }
     }
 
+    /** Get all methods that call a method.
+     * @param method method to check methods against
+     * @return a list of all methods that call the specified method
+     */
+    public List<PrimitiveElementSpecification> getMethodsCallingMethod(final PrimitiveElementSpecification method) {
+
+        List<PrimitiveElementSpecification> callers = new ArrayList<PrimitiveElementSpecification>();
+        final String owner = method.getPrimitiveClass();
+        final String name  = method.getName();
+        final String desc  = method.getPrimitiveType().getDescriptor();
+
+        for (final MethodNode primitiveMethod : primitiveNode.methods) {
+            for (Iterator<AbstractInsnNode> iterator = primitiveMethod.instructions.iterator(); iterator.hasNext();) {
+                final AbstractInsnNode insn = iterator.next();
+                if (insn.getType() == AbstractInsnNode.METHOD_INSN) {
+                    final MethodInsnNode mInsn = (MethodInsnNode) insn;
+                    if (mInsn.owner.equals(owner) && mInsn.name.equals(name) && mInsn.desc.equals(desc)) {
+                        final Type primitiveMethodType = Type.getMethodType(primitiveMethod.desc);
+                        callers.add(new PrimitiveElementSpecification(primitiveMethod.name,
+                                                                      (primitiveMethod.access & Opcodes.ACC_STATIC) != 0,
+                                                                      getPrimitiveInternalName(), primitiveMethodType));
+                    }
+                }
+            }
+        }
+
+        return callers;
+
+    }
+
     /** Get all methods that use a field.
-     * @param differentiatedField field to check methods against
-     * @return a list of all methods that uses the specified field, with the {@link
-     * PrimitiveElementSpecification#getDifferentiatedElementType() differentiated method type}
-     * set to null as it is not known yet
+     * @param field field to check methods against
+     * @return a list of all methods that use the specified field
      */
-    public List<PrimitiveElementSpecification> getMethodsUsingField(final DifferentiatedElementSpecification differentiatedField) {
+    public List<PrimitiveElementSpecification> getMethodsUsingField(final PrimitiveElementSpecification field) {
 
         List<PrimitiveElementSpecification> users = new ArrayList<PrimitiveElementSpecification>();
-        final String owner = differentiatedField.getPrimitiveSpec().getPrimitiveClass();
-        final String name  = differentiatedField.getPrimitiveSpec().getName();
-        final String desc  = differentiatedField.getPrimitiveSpec().getPrimitiveType().getDescriptor();
+        final String owner = field.getPrimitiveClass();
+        final String name  = field.getName();
+        final String desc  = field.getPrimitiveType().getDescriptor();
 
         for (final MethodNode primitiveMethod : primitiveNode.methods) {
             for (Iterator<AbstractInsnNode> iterator = primitiveMethod.instructions.iterator(); iterator.hasNext();) {
@@ -285,11 +313,12 @@ public class ClassDifferentiator {
         //        }
         //        return x;
         //    }
-        final Type dsType = Type.getType(DerivativeStructure.class);
+        final Type transformedType = convertDoubleOrDoubleArray(Type.DOUBLE_TYPE);
         final MethodNode getter =
                 new MethodNode(access | Opcodes.ACC_SYNTHETIC,
                                GETTER_PREFIX + differentiableField.getPrimitiveSpec().getName(),
-                               Type.getMethodDescriptor(dsType, Type.INT_TYPE, Type.INT_TYPE),
+                               Type.getMethodDescriptor(transformedType,
+                                                        Type.INT_TYPE, Type.INT_TYPE),
                                null, null);
         getter.visitVarInsn(Opcodes.ALOAD, 0);
         if (differentiableField.getPrimitiveSpec().isStatic()) {
@@ -304,7 +333,7 @@ public class ClassDifferentiator {
         Label ifNonNull = new Label();
         getter.visitJumpInsn(Opcodes.IFNONNULL, ifNonNull);
         getter.visitVarInsn(Opcodes.ALOAD, 0);
-        getter.visitTypeInsn(Opcodes.NEW, dsType.getInternalName());
+        getter.visitTypeInsn(Opcodes.NEW, transformedType.getInternalName());
         getter.visitInsn(Opcodes.DUP);
         getter.visitVarInsn(Opcodes.ILOAD, 1);
         getter.visitVarInsn(Opcodes.ILOAD, 2);
@@ -317,17 +346,17 @@ public class ClassDifferentiator {
         getter.visitTypeInsn(Opcodes.CHECKCAST, doubleType.getInternalName());
         getter.visitMethodInsn(Opcodes.INVOKEVIRTUAL, doubleType.getInternalName(), "doubleValue",
                                Type.getMethodDescriptor(Type.DOUBLE_TYPE));
-        getter.visitMethodInsn(Opcodes.INVOKESPECIAL, dsType.getInternalName(), INIT,
+        getter.visitMethodInsn(Opcodes.INVOKESPECIAL, transformedType.getInternalName(), INIT,
                                Type.getMethodDescriptor(Type.VOID_TYPE, Type.INT_TYPE, Type.INT_TYPE, Type.DOUBLE_TYPE));
 
         if (differentiableField.getPrimitiveSpec().isStatic()) {
             getter.visitFieldInsn(Opcodes.PUTSTATIC, getDifferentiatedInternalName(getPrimitiveInternalName()),
                                   differentiableField.getPrimitiveSpec().getName(),
-                                  dsType.getDescriptor());
+                                  transformedType.getDescriptor());
         } else {
             getter.visitFieldInsn(Opcodes.PUTFIELD, getDifferentiatedInternalName(getPrimitiveInternalName()),
                                   differentiableField.getPrimitiveSpec().getName(),
-                                  dsType.getDescriptor());
+                                  transformedType.getDescriptor());
         }
 
         getter.visitLabel(ifNonNull);
@@ -335,10 +364,10 @@ public class ClassDifferentiator {
         getter.visitVarInsn(Opcodes.ALOAD, 0);
         if (differentiableField.getPrimitiveSpec().isStatic()) {
             getter.visitFieldInsn(Opcodes.GETSTATIC, classNode.name, differentiableField.getPrimitiveSpec().getName(),
-                                  dsType.getDescriptor());
+                                  transformedType.getDescriptor());
         } else {
             getter.visitFieldInsn(Opcodes.GETFIELD, classNode.name, differentiableField.getPrimitiveSpec().getName(),
-                                  dsType.getDescriptor());
+                                  transformedType.getDescriptor());
         }
 
         getter.visitInsn(Opcodes.ARETURN);
@@ -363,6 +392,14 @@ public class ClassDifferentiator {
         return forwardDifferentiator.getUsedDifferentiatedField(method);
     }
 
+    /** Get a differentiated method.
+     * @param primitiveMethod primitive method specification to check
+     * @return a differentiated method corresponding to the primitive method, null if no methods correspond
+     */
+    public DifferentiatedElementSpecification getDifferentiatedMethod(final PrimitiveElementSpecification primitiveMethod) {
+        return forwardDifferentiator.getDifferentiatedMethod(primitiveMethod);
+    }
+
     /** Request differentiation of a method.
      * @param method method to differentiate
      */
@@ -370,12 +407,12 @@ public class ClassDifferentiator {
         forwardDifferentiator.requestMethodDifferentiation(method);
     }
 
-    /** Get a differentiated element.
-     * @param primitive primitive specification to check
-     * @return a differentiated element corresponding to the primitive, null if no elements correspond
+    /** Get a differentiated field.
+     * @param primitiveField primitive field specification to check
+     * @return a differentiated field corresponding to the primitive field, null if no fields correspond
      */
-    public DifferentiatedElementSpecification getDifferentiated(final PrimitiveElementSpecification primitive) {
-        return forwardDifferentiator.getDifferentiated(primitive);
+    public DifferentiatedElementSpecification getDifferentiatedField(final PrimitiveElementSpecification primitiveField) {
+        return forwardDifferentiator.getDifferentiatedField(primitiveField);
     }
 
     /** Request differentiation of a field.
@@ -385,4 +422,20 @@ public class ClassDifferentiator {
         forwardDifferentiator.requestFieldDifferentiation(field);
     }
 
+    /** Check if a type is a double or a double array with any dimension.
+     * @param type type to check
+     * @return true if the type is a double type or a double array type with any dimension
+     */
+    public boolean isDoubleOrDoubleArray(final Type type) {
+        return forwardDifferentiator.isDoubleOrDoubleArray(type);
+    }
+
+    /** Convert a double or a double array type to the corresponding transformed type.
+     * @param type type to convert
+     * @return converted type
+     */
+    public Type convertDoubleOrDoubleArray(final Type type) {
+        return forwardDifferentiator.convertDoubleOrDoubleArray(type);
+    }
+
 }

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java?rev=1404270&r1=1404269&r2=1404270&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java Wed Oct 31 17:49:01 2012
@@ -28,8 +28,8 @@ import java.util.Set;
 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
 import org.apache.commons.nabla.DifferentiationException;
 import org.apache.commons.nabla.NablaMessages;
-import org.apache.commons.nabla.forward.PrimitiveElementSpecification;
 import org.apache.commons.nabla.forward.DifferentiatedElementSpecification;
+import org.apache.commons.nabla.forward.PrimitiveElementSpecification;
 import org.apache.commons.nabla.forward.arithmetic.DAddTransformer;
 import org.apache.commons.nabla.forward.arithmetic.DDivTransformer;
 import org.apache.commons.nabla.forward.arithmetic.DMulTransformer;
@@ -144,8 +144,8 @@ public class MethodDifferentiator {
             // identify the needed changes in code
             Set<AbstractInsnNode> changes =
                     identifyChanges(method.name, usedLocals,
-                                    differentiatedMethod.getPrimitiveSpec().getPrimitiveType().getArgumentTypes(),
-                                    differentiatedMethod.getDifferentiatedType().getArgumentTypes(),
+                                    differentiatedMethod.getPrimitiveSpec().getPrimitiveType(),
+                                    differentiatedMethod.getDifferentiatedType(),
                                     method.instructions, isStatic);
 
             // perform the code changes
@@ -269,6 +269,20 @@ public class MethodDifferentiator {
         return classDifferentiator.getUsedDifferentiatedField(dm);
     }
 
+    /** Get a differentiated method.
+     * @param method method name
+     * @param isStatic if true, the method is static
+     * @param owner class in which the method is defined
+     * @param primitiveType method type in the primitive
+     * @return a differentiated method corresponding to the primitive method, null if no methods correspond
+     */
+    public DifferentiatedElementSpecification getDifferentiatedMethod(final String method, final boolean isStatic,
+                                                                      final String owner, final Type primitiveType) {
+        final PrimitiveElementSpecification primitiveMethod =
+                new PrimitiveElementSpecification(method, isStatic, owner, primitiveType);
+        return classDifferentiator.getDifferentiatedMethod(primitiveMethod);
+    }
+
     /** Request differentiation of a method.
      * @param method method name
      * @param isStatic if true, the method is static
@@ -285,19 +299,18 @@ public class MethodDifferentiator {
             classDifferentiator.requestMethodDifferentiation(dm);
     }
 
-    /** Get a differentiated element.
-     * @param name element name
+    /** Get a differentiated field.
+     * @param field field name
      * @param isStatic if true, the field is static
      * @param owner class in which the field is defined
-     * @param primitiveType element type in the primitive
-     * @param differentiatedType element type in the differentiated class
-     * @return a differentiated element corresponding to the primitive, null if no elements correspond
-     */
-    public DifferentiatedElementSpecification getDifferentiated(final String name, final boolean isStatic,
-                                                                final String owner, final Type primitiveType) {
-        final PrimitiveElementSpecification primitive =
-                new PrimitiveElementSpecification(name, isStatic, owner, primitiveType);
-        return classDifferentiator.getDifferentiated(primitive);
+     * @param primitiveType field type in the primitive
+     * @return a differentiated field corresponding to the primitive field, null if no fields correspond
+     */
+    public DifferentiatedElementSpecification getDifferentiatedField(final String field, final boolean isStatic,
+                                                                     final String owner, final Type primitiveType) {
+        final PrimitiveElementSpecification primitiveField =
+                new PrimitiveElementSpecification(field, isStatic, owner, primitiveType);
+        return classDifferentiator.getDifferentiatedField(primitiveField);
     }
 
     /** Request differentiation of a field.
@@ -317,6 +330,22 @@ public class MethodDifferentiator {
             classDifferentiator.requestFieldDifferentiation(df);
     }
 
+    /** Check if a type is a double or a double array with any dimension.
+     * @param type type to check
+     * @return true if the type is a double type or a double array type with any dimension
+     */
+    public boolean isDoubleOrDoubleArray(final Type type) {
+        return classDifferentiator.isDoubleOrDoubleArray(type);
+    }
+
+    /** Convert a double or a double array type to the corresponding transformed type.
+     * @param type type to convert
+     * @return converted type
+     */
+    public Type convertDoubleOrDoubleArray(final Type type) {
+        return classDifferentiator.convertDoubleOrDoubleArray(type);
+    }
+
     /** Identify the instructions that must be changed.
      * <p>Identification is based on data flow analysis. We start by changing
      * the local variables in the initial frame to match the parameters of
@@ -334,15 +363,15 @@ public class MethodDifferentiator {
      * </ul>
      * @param name method name
      * @param usedLocals array of variables use indicators to fill in
-     * @param primitiveArguments type of the method arguments in the primitive class
-     * @param derivedArguments type of the method arguments in the derived class
+     * @param primitiveType type of the method in the primitive class
+     * @param differentiatedType type of the method in the differentiated class
      * @param instructions instructions of the method
      * @param isStatic if true, the method is a static method
      * @return set containing all the instructions that must be changed
      * @exception DifferentiationException if some unsupported bytecode is found
      */
     private Set<AbstractInsnNode> identifyChanges(final String name, final boolean[] usedLocals,
-                                                  final Type[] primitiveArguments, final Type[] derivedArguments,
+                                                  final Type primitiveType, final Type differentiatedType,
                                                   final InsnList instructions, final boolean isStatic)
         throws DifferentiationException {
 
@@ -352,10 +381,16 @@ public class MethodDifferentiator {
         //   - changed method arguments
         //   - changed return values of transformed methods
         //   - changed fields
+        //   - changed returns
         final Set<TrackingValue> pending =
-                identifyChangedArguments(isStatic, instructions, primitiveArguments, derivedArguments, usedLocals);
+                identifyChangedArguments(isStatic, instructions,
+                                         primitiveType.getArgumentTypes(), differentiatedType.getArgumentTypes(),
+                                         usedLocals);
         pending.addAll(identifyChangedMethods(instructions));
         pending.addAll(identifyChangedFields(instructions));
+        if (!primitiveType.getReturnType().equals(differentiatedType.getReturnType())) {
+            pending.addAll(identifyChangedReturns(instructions));
+        }
 
         // the changes set contains the instructions that must be changed
         final Set<AbstractInsnNode> changes = new HashSet<AbstractInsnNode>();
@@ -408,9 +443,6 @@ public class MethodDifferentiator {
             } else if (insn.getType() == AbstractInsnNode.INVOKE_DYNAMIC_INSN) {
                 // TODO: add support for INVOKE_DYNAMIC
                 throw new DifferentiationException(NablaMessages.INVOKE_DYNAMIC_NOT_HANDLED_YET, getPrimitiveInternalName());
-            } else if (insn.getOpcode() == Opcodes.DRETURN) {
-                // the DRETURN instructions must be changed
-                changes.add(insn);
             }
         }
 
@@ -427,8 +459,8 @@ public class MethodDifferentiator {
      * @return set of argument local variables that are changed
      */
     private Set<TrackingValue> identifyChangedArguments(final boolean isStatic, final InsnList instructions,
-                                                 final Type[] primitiveArguments, final Type[] derivedArguments,
-                                                 final boolean[] usedLocals) {
+                                                        final Type[] primitiveArguments, final Type[] derivedArguments,
+                                                        final boolean[] usedLocals) {
 
         final Set<TrackingValue> changedValues = new HashSet<TrackingValue>();
 
@@ -514,9 +546,8 @@ public class MethodDifferentiator {
             if (insn.getType() == AbstractInsnNode.METHOD_INSN) {
                 final MethodInsnNode methodInsn = (MethodInsnNode) insn;
                 final DifferentiatedElementSpecification de =
-                        getDifferentiated(methodInsn.name,
-                                          methodInsn.getOpcode() == Opcodes.INVOKESTATIC,
-                                          methodInsn.owner, Type.getType(methodInsn.desc));
+                        getDifferentiatedMethod(methodInsn.name, methodInsn.getOpcode() == Opcodes.INVOKESTATIC,
+                                                methodInsn.owner, Type.getType(methodInsn.desc));
                 if (de != null) {
 
                     final Type pReturn = de.getPrimitiveSpec().getPrimitiveType().getReturnType();
@@ -563,10 +594,10 @@ public class MethodDifferentiator {
             if (insn.getType() == AbstractInsnNode.FIELD_INSN) {
                 final FieldInsnNode fieldInsn = (FieldInsnNode) insn;
                 final DifferentiatedElementSpecification de =
-                        getDifferentiated(fieldInsn.name,
-                                          fieldInsn.getOpcode() == Opcodes.GETSTATIC ||
-                                          fieldInsn.getOpcode() == Opcodes.PUTSTATIC,
-                                          fieldInsn.owner, Type.getType(fieldInsn.desc));
+                        getDifferentiatedField(fieldInsn.name,
+                                               fieldInsn.getOpcode() == Opcodes.GETSTATIC ||
+                                               fieldInsn.getOpcode() == Opcodes.PUTSTATIC,
+                                               fieldInsn.owner, Type.getType(fieldInsn.desc));
                 if (de != null) {
                     if (fieldInsn.getOpcode() == Opcodes.GETFIELD || fieldInsn.getOpcode() == Opcodes.GETSTATIC) {
                         // this is a get instruction, it produces a value that is changed
@@ -591,6 +622,29 @@ public class MethodDifferentiator {
 
     }
 
+    /** Identify how return value are used.
+     * @param instructions instructions of the method
+     * @return set of stack values that are changed
+     */
+    private Set<TrackingValue> identifyChangedReturns(final InsnList instructions) {
+
+        final Set<TrackingValue> changedValues = new HashSet<TrackingValue>();
+
+        for (final ListIterator<AbstractInsnNode> iterator = instructions.iterator(); iterator.hasNext();) {
+            final AbstractInsnNode insn = iterator.next();
+            if (insn.getOpcode() == Opcodes.DRETURN) {
+                // this is a return instruction, it consumes a value that is changed
+                // the consumed value is on the top of the instruction own frame
+                final Frame<TrackingValue> frame = frames.get(insn);
+                final TrackingValue stackTop = frame.getStack(frame.getStackSize() - 1);
+                changedValues.add(stackTop);
+            }
+        }
+
+        return changedValues;
+
+    }
+
     /** Get the list of double values produced by an instruction and not yet converted.
      * @param instruction instruction producing the values
      * @return list of double values produced
@@ -765,20 +819,20 @@ public class MethodDifferentiator {
 
     }
 
-    /** Create instructions to preserve a reference {@link DerivativeStructure} variable.
+    /** Create instructions to preserve a reference transformed variable.
      * @param differentiatedMethod specification for the differentiated method
      * @param isStatic if true, the method is a static method
-     * @param dsIndex index of the reference {@link DerivativeStructure derivative structure} variable
+     * @param dsIndex index of the reference transformed variable
      * @return list of conversion instructions
      */
     public InsnList preserveReferenceDerivativeStructure(final DifferentiatedElementSpecification differentiatedMethod,
                                                          final boolean isStatic, final int dsIndex) {
 
-        final Type dsType = Type.getType(DerivativeStructure.class);
+        final Type transformedType = convertDoubleOrDoubleArray(Type.DOUBLE_TYPE);
         final Type[] parameterTypes = differentiatedMethod.getDifferentiatedType().getArgumentTypes();
         int var = isStatic ? 0 : 1;
         for (int i = 0; i < parameterTypes.length; ++i) {
-            if (parameterTypes[i].equals(dsType)) {
+            if (parameterTypes[i].equals(transformedType)) {
                 // we have found the first derivative structure parameter
 
                 // preserve the parameter as a new variable

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java?rev=1404270&r1=1404269&r2=1404270&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java Wed Oct 31 17:49:01 2012
@@ -53,9 +53,9 @@ public class GetTransformer implements I
 
         final FieldInsnNode fieldInsn = (FieldInsnNode) insn;
         final DifferentiatedElementSpecification de =
-                methodDifferentiator.getDifferentiated(fieldInsn.name,
-                                                       fieldInsn.getOpcode() == Opcodes.GETSTATIC,
-                                                       fieldInsn.owner, Type.getType(fieldInsn.desc));
+                methodDifferentiator.getDifferentiatedField(fieldInsn.name,
+                                                            fieldInsn.getOpcode() == Opcodes.GETSTATIC,
+                                                            fieldInsn.owner, Type.getType(fieldInsn.desc));
         if (de != null) {
             // the field is differentiated
             return getReplacementTransformedField(fieldInsn, methodDifferentiator, de);

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java?rev=1404270&r1=1404269&r2=1404270&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java Wed Oct 31 17:49:01 2012
@@ -16,7 +16,6 @@
  */
 package org.apache.commons.nabla.forward.instructions;
 
-import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
 import org.apache.commons.nabla.DifferentiationException;
 import org.apache.commons.nabla.forward.analysis.InstructionsTransformer;
 import org.apache.commons.nabla.forward.analysis.MethodDifferentiator;
@@ -41,16 +40,16 @@ public class InvokeNonMathTransformer im
         final MethodInsnNode methodInsn = (MethodInsnNode) insn;
 
         // build transformed method signature based on stack elements
-        final Type dsType              = Type.getType(DerivativeStructure.class);
+        final Type transformedType     = methodDifferentiator.convertDoubleOrDoubleArray(Type.DOUBLE_TYPE);
         final Type primitiveMethodType = Type.getMethodType(methodInsn.desc);
         final Type[] argumentTypes     = new Type[primitiveMethodType.getArgumentTypes().length];
         for (int i = 0; i < argumentTypes.length; ++i) {
             final int index = argumentTypes.length - 1 - i;
             argumentTypes[i] = methodDifferentiator.stackElementIsConverted(insn, index) ?
-                               dsType : methodDifferentiator.stackElementType(insn, index);
+                               transformedType : methodDifferentiator.stackElementType(insn, index);
         }
-        final Type returnType = (primitiveMethodType.getReturnType() == Type.DOUBLE_TYPE) ?
-                                dsType : primitiveMethodType.getReturnType();
+        final Type returnType =
+                methodDifferentiator.convertDoubleOrDoubleArray(primitiveMethodType.getReturnType());
         final Type differentiatedMethodType = Type.getMethodType(returnType, argumentTypes);
 
         // request the global differentiator to differentiate the invoked method

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java?rev=1404270&r1=1404269&r2=1404270&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java Wed Oct 31 17:49:01 2012
@@ -82,11 +82,11 @@ public class PutTransformer implements I
         throws DifferentiationException {
 
         final Type primitiveFieldType = Type.getType(insn.desc);
-        if (!primitiveFieldType.equals(Type.DOUBLE_TYPE)) {
+        if (!methodDifferentiator.isDoubleOrDoubleArray(primitiveFieldType)) {
             // TODO handle double array types
             throw new RuntimeException("PUTFIELD/PUTSTATIC not handled yet for non-double fields");            
         }
-        final Type differentiatedFieldType = Type.getType(DerivativeStructure.class);
+        final Type differentiatedFieldType = methodDifferentiator.convertDoubleOrDoubleArray(primitiveFieldType);
 
         // we need to add a new field in the transformed class
         final boolean isStatic = insn.getOpcode() == Opcodes.PUTSTATIC;

Modified: commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java?rev=1404270&r1=1404269&r2=1404270&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java (original)
+++ commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java Wed Oct 31 17:49:01 2012
@@ -268,7 +268,11 @@ public class ForwardModeDifferentiatorTe
                 intermediateData = t + 3.0;
                 return f();
             }
-            public double firstDerivative(double t) { return 2.0; }
+
+            public double firstDerivative(double t) {
+                return 2.0;
+            }
+
         }, -5.25, 5, 20, 8.0e-15);
     }
 



Mime
View raw message