commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From l..@apache.org
Subject svn commit: r1404267 - in /commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward: ./ analysis/ instructions/
Date Wed, 31 Oct 2012 17:47:42 GMT
Author: luc
Date: Wed Oct 31 17:47:42 2012
New Revision: 1404267

URL: http://svn.apache.org/viewvc?rev=1404267&view=rev
Log:
Work In Progress for PUTFIELD/PUTSTATIC support.

When DerivativeStructure are stored in fields, existing methods may need
to be differentiated again as they did not knew beforehand the field
will be changed. Now Nabla does invalidate these methods and
differentiate them again before generating the class.

Added:
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/DifferentiableField.java
  (with props)
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/DifferentiableMethod.java
  (with props)
Modified:
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java

Added: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/DifferentiableField.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/DifferentiableField.java?rev=1404267&view=auto
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/DifferentiableField.java
(added)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/DifferentiableField.java
Wed Oct 31 17:47:42 2012
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.nabla.forward;
+
+import org.objectweb.asm.Type;
+
+/** Identifier for a field to differentiate.
+ * @version $Id$
+ */
+public class DifferentiableField {
+
+    /** Primitive class to which the field belongs. */
+    private final Class<?> primitiveClass;
+
+    /** Indicator for static fields. */
+    private final boolean isStatic;
+
+    /** Name of the field. */
+    private final String field;
+
+    /** Type of the field in the primitive class. */
+    private final Type primitiveFieldType;
+
+    /** Type of the field in the differentiated class. */
+    private final Type differentiatedFieldType;
+
+    /** Simple constructor.
+     * @param primitiveClass class in which the field is defined
+     * @param isStatic if true, the field is static
+     * @param field field name
+     * @param primitiveFieldType field type in the primitive
+     * @param differentiatedFieldType field type in the differentiated class
+     */
+    public DifferentiableField(final Class<?> primitiveClass, final boolean isStatic,
+                               final String field, final Type primitiveFieldType,
+                               final Type differentiatedFieldType) {
+        this.primitiveClass          = primitiveClass;
+        this.isStatic                = isStatic;
+        this.field                   = field;
+        this.primitiveFieldType      = primitiveFieldType;
+        this.differentiatedFieldType = differentiatedFieldType;
+    }
+
+    /** Get the primitive class to which the field belongs.
+     * @return primitive class to which the field belongs
+     */
+    public Class<?> getPrimitiveClass() {
+        return primitiveClass;
+    }
+
+    /** Get the name of the field.
+     * @return name of the field
+     */
+    public String getField() {
+        return field;
+    }
+
+    /** Get the static indicator of the field.
+     * @return static indcator of the field
+     */
+    public boolean isStatic() {
+        return isStatic;
+    }
+
+    /** Get the type of the field in the primitive class.
+     * @return type of the field in the primitive class
+     */
+    public Type getPrimitiveFieldType() {
+        return primitiveFieldType;
+    }
+
+    /** Get the type of the field in the differentiated class.
+     * @return type of the field in the differentiated class
+     */
+    public Type getDifferentiatedFieldType() {
+        return differentiatedFieldType;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public boolean equals(final Object other) {
+
+        if (this == other) {
+            return true;
+        }
+
+        if (other instanceof DifferentiableField) {
+            // since fractions are always in lowest terms, numerators and
+            // denominators can be compared directly for equality.
+            DifferentiableField df = (DifferentiableField) other;
+            return primitiveClass.equals(df.primitiveClass)         &&
+                   (isStatic == df.isStatic)                        &&
+                   field.equals(df.field)                           &&
+                   primitiveFieldType.equals(df.primitiveFieldType) &&
+                   differentiatedFieldType.equals(df.differentiatedFieldType);
+        }
+
+        return false;
+
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public int hashCode() {
+        // the following coefficients are arbitrarily chosen prime numbers
+        return 311 * primitiveClass.hashCode() +
+               457 * Boolean.valueOf(isStatic).hashCode() +
+               547 * field.hashCode() +
+               643 * primitiveFieldType.hashCode() +
+               733 * differentiatedFieldType.hashCode();
+    }
+
+}

Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/DifferentiableField.java
------------------------------------------------------------------------------
    svn:eol-style = native

Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/DifferentiableField.java
------------------------------------------------------------------------------
    svn:keywords = "Author Date Id Revision"

Added: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/DifferentiableMethod.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/DifferentiableMethod.java?rev=1404267&view=auto
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/DifferentiableMethod.java
(added)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/DifferentiableMethod.java
Wed Oct 31 17:47:42 2012
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.nabla.forward;
+
+import org.objectweb.asm.Type;
+
+/** Identifier for a method to differentiate.
+ * @version $Id$
+ */
+public class DifferentiableMethod {
+
+    /** Primitive class to which the method belongs. */
+    private final Class<?> primitiveClass;
+
+    /** Indicator for static methods. */
+    private final boolean isStatic;
+
+    /** Name of the method. */
+    private final String method;
+
+    /** Type of the method in the primitive class. */
+    private final Type primitiveMethodType;
+
+    /** Type of the method in the differentiated class. */
+    private final Type differentiatedMethodType;
+
+    /** Simple constructor.
+     * @param primitiveClass class in which the method is defined
+     * @param isStatic if true, the method is static
+     * @param method method name
+     * @param primitiveMethodType method type in the primitive (includes return and arguments
types)
+     * @param differentiatedMethodType method type in the differentiated class (includes
return and arguments types)
+     */
+    public DifferentiableMethod(final Class<?> primitiveClass, final boolean isStatic,
+                                final String method, final Type primitiveMethodType,
+                                final Type differentiatedMethodType) {
+        this.primitiveClass           = primitiveClass;
+        this.isStatic                 = isStatic;
+        this.method                   = method;
+        this.primitiveMethodType      = primitiveMethodType;
+        this.differentiatedMethodType = differentiatedMethodType;
+    }
+
+    /** Get the primitive class to which the method belongs.
+     * @return primitive class to which the method belongs
+     */
+    public Class<?> getPrimitiveClass() {
+        return primitiveClass;
+    }
+
+    /** Get the name of the method.
+     * @return name of the method
+     */
+    public String getMethod() {
+        return method;
+    }
+
+    /** Get the type of the method in the primitive class.
+     * @return type of the method in the primitive class
+     */
+    public Type getPrimitiveMethodType() {
+        return primitiveMethodType;
+    }
+
+    /** Get the type of the method in the differentiated class.
+     * @return type of the method in the differentiated class
+     */
+    public Type getDifferentiatedMethodType() {
+        return differentiatedMethodType;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public boolean equals(final Object other) {
+
+        if (this == other) {
+            return true;
+        }
+
+        if (other instanceof DifferentiableMethod) {
+            // since fractions are always in lowest terms, numerators and
+            // denominators can be compared directly for equality.
+            final DifferentiableMethod dm = (DifferentiableMethod) other;
+            return primitiveClass.equals(dm.primitiveClass)           &&
+                   (isStatic == dm.isStatic)                          &&
+                   method.equals(dm.method)                           &&
+                   primitiveMethodType.equals(dm.primitiveMethodType) &&
+                   differentiatedMethodType.equals(dm.differentiatedMethodType);
+        }
+
+        return false;
+
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public int hashCode() {
+        // the following coefficients are arbitrarily chosen prime numbers
+        return 109 * primitiveClass.hashCode() +
+               479 * Boolean.valueOf(isStatic).hashCode() +
+               601 * method.hashCode() +
+               571 * primitiveMethodType.hashCode() +
+               587 * differentiatedMethodType.hashCode();
+    }
+
+}

Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/DifferentiableMethod.java
------------------------------------------------------------------------------
    svn:eol-style = native

Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/DifferentiableMethod.java
------------------------------------------------------------------------------
    svn:keywords = "Author Date Id Revision"

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java?rev=1404267&r1=1404266&r2=1404267&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java
(original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java
Wed Oct 31 17:47:42 2012
@@ -69,22 +69,34 @@ public class ForwardModeDifferentiator i
     /** Math implementation classes. */
     private final Set<String> mathClasses;
 
-    /** Pending differentiations. */
-    private final Set<DifferentiableMethod> pendingDifferentiations;
+    /** Differentiators for various classes. */
+    private final Set<ClassDifferentiator> differentiators;
 
-    /** Processed differentiations. */
-    private final Set<DifferentiableMethod> processedDifferentiations;
+    /** Pending methods differentiations. */
+    private final Set<DifferentiableMethod> pendingMethodsDifferentiations;
+
+    /** Processed methods differentiations. */
+    private final Set<DifferentiableMethod> processedMethodsDifferentiations;
+
+    /** Pending fields differentiations. */
+    private final Set<DifferentiableField> pendingFieldsDifferentiations;
+
+    /** Processed fields differentiations. */
+    private final Set<DifferentiableField> processedFieldsDifferentiations;
 
     /** Simple constructor.
      * <p>Build a ForwardAlgorithmicDifferentiator instance with an empty cache.</p>
      */
     public ForwardModeDifferentiator() {
-        map                       = new HashMap<Class<? extends UnivariateFunction>,
-                                                Class<? extends NablaUnivariateDifferentiableFunction>>();
-        byteCodeMap               = new HashMap<String, byte[]>();
-        mathClasses               = new HashSet<String>();
-        pendingDifferentiations   = new HashSet<ForwardModeDifferentiator.DifferentiableMethod>();
-        processedDifferentiations = new HashSet<ForwardModeDifferentiator.DifferentiableMethod>();
+        map                              = new HashMap<Class<? extends UnivariateFunction>,
+                                                       Class<? extends NablaUnivariateDifferentiableFunction>>();
+        byteCodeMap                      = new HashMap<String, byte[]>();
+        mathClasses                      = new HashSet<String>();
+        differentiators                  = new HashSet<ClassDifferentiator>();
+        pendingMethodsDifferentiations   = new HashSet<DifferentiableMethod>();
+        processedMethodsDifferentiations = new HashSet<DifferentiableMethod>();
+        pendingFieldsDifferentiations    = new HashSet<DifferentiableField>();
+        processedFieldsDifferentiations  = new HashSet<DifferentiableField>();
         addMathImplementation(Math.class);
         addMathImplementation(StrictMath.class);
         addMathImplementation(FastMath.class);
@@ -130,9 +142,39 @@ public class ForwardModeDifferentiator i
                     new DifferentiableMethod(Class.forName(owner), isStatic,
                                              method, primitiveMethodType, differentiatedMethodType);
 
-            if (!processedDifferentiations.contains(dm)) {
+            if (!processedMethodsDifferentiations.contains(dm)) {
                 // schedule the request if method has not been processed yet
-                pendingDifferentiations.add(dm);
+                pendingMethodsDifferentiations.add(dm);
+            }
+ 
+        } catch (ClassNotFoundException cnfe) {
+            throw new DifferentiationException(NablaMessages.CANNOT_READ_CLASS,
+                                               owner, cnfe.getMessage());
+        }
+
+    }
+
+    /** Request differentiation of a field.
+     * @param owner class in which the field is defined
+     * @param isStatic if true, the field is static
+     * @param field field name
+     * @param primitiveFieldType field type in the primitive
+     * @param differentiatedFieldType field type in the differentiated class
+     * @exception DifferentiationException if class cannot be found
+     */
+    public void requestFieldDifferentiation(final String owner, final boolean isStatic,
+                                            final String field, final Type primitiveFieldType,
+                                            final Type differentiatedFieldType) {
+
+        try {
+
+            final DifferentiableField df =
+                    new DifferentiableField(Class.forName(owner), isStatic,
+                                            field, primitiveFieldType, differentiatedFieldType);
+
+            if (!processedFieldsDifferentiations.contains(df)) {
+                // schedule the request if field has not been processed yet
+                pendingFieldsDifferentiations.add(df);
             }
  
         } catch (ClassNotFoundException cnfe) {
@@ -194,74 +236,107 @@ public class ForwardModeDifferentiator i
     private Class<? extends NablaUnivariateDifferentiableFunction>
     createDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
         throws DifferentiationException {
-        try {
-
-            final Set<ClassDifferentiator> differentiators = new HashSet<ClassDifferentiator>();
 
-            // bootstrap differentiation using the top level value function from the UnivariateFunction
interface
-            final Type dsType = Type.getType(DerivativeStructure.class);
-            requestMethodDifferentiation(differentiableClass.getName(), false, "value",
-                                         Type.getMethodType(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE),
-                                         Type.getMethodType(dsType, dsType));
-
-            while (!pendingDifferentiations.isEmpty()) {
-
-                // move the method from pending to processed
-                final Iterator<DifferentiableMethod> iterator = pendingDifferentiations.iterator();
-                final DifferentiableMethod dm = iterator.next();
-                iterator.remove();
-                processedDifferentiations.add(dm);
-
-                // find a differentiator for the class owning the method
-                ClassDifferentiator differentiator = null;
-                for (Iterator<ClassDifferentiator> dIter = differentiators.iterator();
-                     differentiator == null && dIter.hasNext();) {
-                    ClassDifferentiator current = dIter.next();
-                    if (Type.getInternalName(dm.getPrimitiveClass()).equals(current.getPrimitive().name))
{
-                        // we have already build a differentiator for the same class, reuse
it
-                        differentiator = current;
-                    }
+        // bootstrap differentiation using the top level value function from the UnivariateFunction
interface
+        final Type dsType = Type.getType(DerivativeStructure.class);
+        requestMethodDifferentiation(differentiableClass.getName(), false, "value",
+                                     Type.getMethodType(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE),
+                                     Type.getMethodType(dsType, dsType));
+
+        while (!(pendingMethodsDifferentiations.isEmpty() && pendingFieldsDifferentiations.isEmpty()))
{
+
+            for (final DifferentiableField df : pendingFieldsDifferentiations) {
+
+                // differentiate the field
+                getDifferentiator(df.getPrimitiveClass()).differentiateField(df);
+
+                // this field may be used by some methods already differentiated
+                // in order to make sure it is taken into account properly,
+                // we invalidate all existing differentiated methods and redo them all
+                pendingMethodsDifferentiations.addAll(processedMethodsDifferentiations);
+                processedMethodsDifferentiations.clear();
+                for (final ClassDifferentiator differentiator : differentiators) {
+                    differentiator.clearDifferentiatedMethods();
                 }
-                if (differentiator == null) {
-                    // it is the first time we process this class, create a differentiator
for it
-                    differentiator = new ClassDifferentiator(dm.getPrimitiveClass(), mathClasses,
this);
-                    differentiators.add(differentiator);
-                }
-
-                differentiator.differentiateMethod(dm.getMethod(), dm.getPrimitiveMethodType(),
-                                                   dm.getDifferentiatedMethodType());
 
             }
 
-            // create the differential classes
-            Class<? extends NablaUnivariateDifferentiableFunction> nudf = null;
-            for (ClassDifferentiator differentiator : differentiators) {
-                final ClassNode   derived = differentiator.getDerived();
-                final ClassWriter writer  = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
-                final String name = derived.name.replace('/', '.');
-                derived.accept(writer);
-                final byte[] bytecode = writer.toByteArray();
-
-                final Class<? extends NablaDifferentiated> dClass =
-                        new DerivativeLoader(differentiableClass).defineClass(name, bytecode);
-                byteCodeMap.put(name, bytecode);
- 
-                // TODO: remove development trace
+            // move all fields from pending to processed
+            processedFieldsDifferentiations.addAll(pendingFieldsDifferentiations);
+            pendingFieldsDifferentiations.clear();
+
+            // move one method from pending to processed
+            final Iterator<DifferentiableMethod> iterator = pendingMethodsDifferentiations.iterator();
+            final DifferentiableMethod dm = iterator.next();
+            iterator.remove();
+            processedMethodsDifferentiations.add(dm);
+
+            // differentiate the method
+            getDifferentiator(dm.getPrimitiveClass()).differentiateMethod(dm);
+
+        }
+
+        // create the differential classes
+        Class<? extends NablaUnivariateDifferentiableFunction> nudf = null;
+        for (ClassDifferentiator differentiator : differentiators) {
+            final ClassNode   derived = differentiator.getDerived();
+            final ClassWriter writer  = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
+            final String name = derived.name.replace('/', '.');
+            derived.accept(writer);
+            final byte[] bytecode = writer.toByteArray();
+
+            // TODO: remove development trace
+            try {
                 new ClassReader(differentiableClass.getResourceAsStream("/" + Type.getInternalName(differentiableClass)
+ ".class")).accept(new TraceClassVisitor(new PrintWriter(System.out)), 0);
                 new ClassReader(bytecode).accept(new TraceClassVisitor(new PrintWriter(System.err)),
0);
+            } catch (IOException ioe) {
+                throw new RuntimeException(ioe);
+            }
 
-                if (differentiator.getPrimitive().name.equals(Type.getType(differentiableClass).getInternalName()))
{
-                    nudf = (Class<? extends NablaUnivariateDifferentiableFunction>)
dClass;
-                }
+            final Class<? extends NablaDifferentiated> dClass =
+                    new DerivativeLoader(differentiableClass).defineClass(name, bytecode);
+            byteCodeMap.put(name, bytecode);
+
+            if (differentiator.getPrimitiveName().equals(Type.getType(differentiableClass).getInternalName()))
{
+                nudf = (Class<? extends NablaUnivariateDifferentiableFunction>) dClass;
+            }
+
+        }
+
+        // return the top level one
+        return nudf;
+
+    }
+
+    /** Find a differentiator for a class.
+     * <p>
+     * If the differentiator has not been created yet, it will be created here.
+     * </p>
+     * @param primitiveClass primitive class to differentiate
+     * @return differentiator for {@code primitiveClass}
+     * @throws DifferentiationException if the class cannot be read
+     */
+    private ClassDifferentiator getDifferentiator(Class<?> primitiveClass)
+        throws DifferentiationException {
+        try {
 
+            // find a differentiator for the class owning the method
+            for (ClassDifferentiator differentiator : differentiators) {
+                if (Type.getInternalName(primitiveClass).equals(differentiator.getPrimitiveName()))
{
+                    // we have already build a differentiator for the same class, reuse it
+                    return differentiator;
+                }
             }
 
-            // return the top level one
-            return nudf;
+            // it is the first time we process this class, create a differentiator for it
+            final ClassDifferentiator differentiator = new ClassDifferentiator(primitiveClass,
mathClasses, this);
+            differentiators.add(differentiator);
+            return differentiator;
 
+            
         } catch (IOException ioe) {
             throw new DifferentiationException(NablaMessages.CANNOT_READ_CLASS,
-                                               differentiableClass.getName(), ioe.getMessage());
+                                               primitiveClass.getName(), ioe.getMessage());
         }
     }
 
@@ -287,101 +362,4 @@ public class ForwardModeDifferentiator i
         }
     }
 
-    /** Identifier for a method to differentiate. */
-    private static class DifferentiableMethod {
-
-        /** Primitive class to which the method belongs. */
-        private final Class<?> primitiveClass;
-
-        /** Indicator for static methods. */
-        private final boolean isStatic;
-
-        /** Name of the method. */
-        private final String method;
-
-        /** Type of the method in the primitive class. */
-        private final Type primitiveMethodType;
-
-        /** Type of the method in the differentiated class. */
-        private final Type differentiatedMethodType;
-
-        /** Simple constructor.
-         * @param primitiveClass class in which the method is defined
-         * @param isStatic if true, the method is static
-         * @param method method name
-         * @param primitiveMethodType method type in the primitive (includes return and arguments
types)
-         * @param differentiatedMethodType method type in the differentiated class (includes
return and arguments types)
-         */
-        public DifferentiableMethod(final Class<?> primitiveClass, final boolean isStatic,
-                                    final String method, final Type primitiveMethodType,
-                                    final Type differentiatedMethodType) {
-            this.primitiveClass           = primitiveClass;
-            this.isStatic                 = isStatic;
-            this.method                   = method;
-            this.primitiveMethodType      = primitiveMethodType;
-            this.differentiatedMethodType = differentiatedMethodType;
-        }
-
-        /** Get the primitive class to which the method belongs.
-         * @return primitive class to which the method belongs
-         */
-        public Class<?> getPrimitiveClass() {
-            return primitiveClass;
-        }
-
-        /** Get the name of the method.
-         * @return name of the method
-         */
-        public String getMethod() {
-            return method;
-        }
-
-        /** Get the type of the method in the primitive class.
-         * @return type of the method in the primitive class
-         */
-        public Type getPrimitiveMethodType() {
-            return primitiveMethodType;
-        }
-
-        /** Get the type of the method in the differentiated class.
-         * @return type of the method in the differentiated class
-         */
-        public Type getDifferentiatedMethodType() {
-            return differentiatedMethodType;
-        }
-
-        /** {@inheritDoc} */
-        @Override
-        public boolean equals(final Object other) {
-
-            if (this == other) {
-                return true;
-            }
-
-            if (other instanceof DifferentiableMethod) {
-                // since fractions are always in lowest terms, numerators and
-                // denominators can be compared directly for equality.
-                DifferentiableMethod dm = (DifferentiableMethod)other;
-                return (primitiveClass      == dm.primitiveClass) &&
-                       (isStatic            == dm.isStatic)       &&
-                       (method              == dm.method)         &&
-                       (primitiveMethodType == dm.primitiveMethodType);
-            }
-
-            return false;
-
-        }
-
-        /** {@inheritDoc} */
-        @Override
-        public int hashCode() {
-            // the following coefficients are arbitrarily chosen prime numbers
-            return 109 * primitiveClass.hashCode() +
-                   479 * Boolean.valueOf(isStatic).hashCode() +
-                   601 * method.hashCode() +
-                   571 * primitiveMethodType.hashCode();
-        }
-
-    }
-
 }

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java?rev=1404267&r1=1404266&r2=1404267&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java
(original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java
Wed Oct 31 17:47:42 2012
@@ -20,15 +20,20 @@ import java.io.IOException;
 import java.util.Set;
 
 import org.apache.commons.math3.analysis.UnivariateFunction;
+import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
 import org.apache.commons.nabla.DifferentiationException;
 import org.apache.commons.nabla.NablaMessages;
+import org.apache.commons.nabla.forward.DifferentiableField;
+import org.apache.commons.nabla.forward.DifferentiableMethod;
 import org.apache.commons.nabla.forward.ForwardModeDifferentiator;
 import org.apache.commons.nabla.forward.NablaDifferentiated;
 import org.apache.commons.nabla.forward.NablaUnivariateDifferentiableFunction;
 import org.objectweb.asm.ClassReader;
+import org.objectweb.asm.Label;
 import org.objectweb.asm.Opcodes;
 import org.objectweb.asm.Type;
 import org.objectweb.asm.tree.ClassNode;
+import org.objectweb.asm.tree.FieldNode;
 import org.objectweb.asm.tree.MethodNode;
 
 /**
@@ -110,56 +115,79 @@ public class ClassDifferentiator {
                                Type.getType(NablaUnivariateDifferentiableFunction.class)
:
                                Type.getType(NablaDifferentiated.class);
         classNode.visit(primitiveNode.version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC,
-                        primitiveNode.name + CLASS_SUFFIX,
-                        null, superType.getInternalName(), new String[0]);
+                        getDerivedName(), null, superType.getInternalName(), new String[0]);
+
+    }
+
+    /** Get the name of the primitive class.
+     * @return name of the primitive class
+     */
+    public String getPrimitiveName() {
+        return primitiveNode.name;
+    }
+
+    /** Get the name of the derived class.
+     * @return name of the derived class
+     */
+    public String getDerivedName() {
+        return primitiveNode.name  + CLASS_SUFFIX;
+    }
+
+    /** Get the derived class node.
+     * <p>
+     * The constructor is created here, so this method must be called only once
+     * to avoid multiple definitions.
+     * </p>
+     * @return derived class node
+     */
+    public ClassNode getDerived() {
 
         // add constructor calling superclass constructor
         final MethodNode constructor =
                 new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, INIT,
-                               Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(primitiveClass)),
+                               Type.getMethodDescriptor(Type.VOID_TYPE,
+                                                        Type.getType("L" + primitiveNode.name
+ ";")),
                                null, null);
         constructor.visitVarInsn(Opcodes.ALOAD, 0);
         constructor.visitVarInsn(Opcodes.ALOAD, 1);
-        constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, superType.getInternalName(),
+        constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, classNode.superName,
                                     INIT, Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(UnivariateFunction.class)));
+
+        for (final FieldNode field : classNode.fields) {
+            constructor.visitVarInsn(Opcodes.ALOAD, 0);
+            constructor.visitInsn(Opcodes.ACONST_NULL);
+            if ((field.access & Opcodes.ACC_STATIC ) != 0) {
+                constructor.visitFieldInsn(Opcodes.PUTSTATIC, classNode.name, field.name,
field.desc);
+            } else {
+                constructor.visitFieldInsn(Opcodes.PUTFIELD, classNode.name, field.name,
field.desc);
+            }
+        }
+
         constructor.visitInsn(Opcodes.RETURN);
         constructor.visitMaxs(2, 2);
         classNode.methods.add(constructor);
 
-    }
-
-    /** Get the primitive class node.
-     * @return primitive class node
-     */
-    public ClassNode getPrimitive() {
-        return primitiveNode;
-    }
-
-    /** Get the derived class node.
-     * @return derived class node
-     */
-    public ClassNode getDerived() {
         return classNode;
+
     }
 
     /**
      * Differentiate a method.
-     * @param name of the method
-     * @param primitiveMethodType type of the method in the primitive class
-     * @param derivativedMethodType type of the method in the derivative class
+     * @param differentiableMethod method to differentiate
      * @exception DifferentiationException if method cannot be differentiated
      */
-    public void differentiateMethod(final String name, final Type primitiveMethodType,
-                                    final Type derivativedMethodType)
+    public void differentiateMethod(final DifferentiableMethod differentiableMethod)
         throws DifferentiationException {
 
         for (final MethodNode primitiveMethod : primitiveNode.methods) {
-            if (primitiveMethod.name.equals(name) && Type.getType(primitiveMethod.desc).equals(primitiveMethodType))
{
+            if (primitiveMethod.name.equals(differentiableMethod.getMethod()) &&
+                Type.getType(primitiveMethod.desc).equals(differentiableMethod.getPrimitiveMethodType()))
{
 
                 final MethodDifferentiator differentiator = new MethodDifferentiator(mathClasses,
this);
-                final MethodNode differentiatedMethod     = differentiator.differentiate(primitiveMethod,
-                                                                                        
primitiveMethodType,
-                                                                                        
derivativedMethodType);
+                final MethodNode differentiatedMethod =
+                        differentiator.differentiate(primitiveMethod,
+                                                     differentiableMethod.getPrimitiveMethodType(),
+                                                     differentiableMethod.getDifferentiatedMethodType());
                 classNode.methods.add(differentiatedMethod);
 
             }
@@ -167,6 +195,100 @@ public class ClassDifferentiator {
 
     }
 
+    /** Clear all differentiated methods.
+     */
+    public void clearDifferentiatedMethods() {
+        classNode.methods.clear();
+    }
+
+    /**
+     * Differentiate a field.
+     * @param differentiableField field to differentiate
+     * @exception DifferentiationException if field cannot be differentiated
+     */
+    public void differentiateField(final DifferentiableField differentiableField)
+        throws DifferentiationException {
+
+        // set up the field itself
+        classNode.fields.add(new FieldNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL,
+                                           differentiableField.getField(),
+                                           differentiableField.getDifferentiatedFieldType().getDescriptor(),
+                                           null, null));
+
+        // get the access mode of the original field
+        int access = 0;
+        for (final FieldNode fieldNode : primitiveNode.fields) {
+            if (fieldNode.name.equals(differentiableField.getField())) {
+                access = fieldNode.access;
+            }
+        }
+
+        // set up a private getter for the field
+        // the following bytecode is equivalent to the following java code
+        //    public/protected/private DerivativeStructure get_x(int parameters, int order)
{
+        //        if (x == null) {
+        //            x = new DerivativeStructure(parameters, order,
+        //                                        ((Double) getPrimitiveField("x")).doubleValue());
+        //        }
+        //        return x;
+        //    }
+        final Type dsType = Type.getType(DerivativeStructure.class);
+        final MethodNode getter =
+                new MethodNode(access | Opcodes.ACC_SYNTHETIC, "get_" + differentiableField.getField(),
+                               Type.getMethodDescriptor(dsType, Type.INT_TYPE, Type.INT_TYPE),
+                               null, null);
+        getter.visitVarInsn(Opcodes.ALOAD, 0);
+        if (differentiableField.isStatic()) {
+            getter.visitFieldInsn(Opcodes.GETSTATIC, classNode.name, differentiableField.getField(),
+                                  differentiableField.getDifferentiatedFieldType().getDescriptor());
+        } else {
+            getter.visitFieldInsn(Opcodes.GETFIELD, classNode.name, differentiableField.getField(),
+                                  differentiableField.getDifferentiatedFieldType().getDescriptor());
+        }
+        Label ifNonNull = new Label();
+        getter.visitJumpInsn(Opcodes.IFNONNULL, ifNonNull);
+        getter.visitVarInsn(Opcodes.ALOAD, 0);
+        getter.visitTypeInsn(Opcodes.NEW, dsType.getInternalName());
+        getter.visitInsn(Opcodes.DUP);
+        getter.visitVarInsn(Opcodes.ILOAD, 1);
+        getter.visitVarInsn(Opcodes.ILOAD, 2);
+        getter.visitVarInsn(Opcodes.ALOAD, 0);
+        getter.visitLdcInsn(differentiableField.getField());
+        getter.visitMethodInsn(Opcodes.INVOKEVIRTUAL, classNode.name,
+                               "getPrimitiveField", Type.getMethodDescriptor(Type.getType(Object.class),
+                                                                             Type.getType(String.class)));
+        final Type doubleType = Type.getType(Double.class);
+        getter.visitTypeInsn(Opcodes.CHECKCAST, doubleType.getInternalName());
+        getter.visitMethodInsn(Opcodes.INVOKEVIRTUAL, doubleType.getInternalName(), "doubleValue",
+                               Type.getMethodDescriptor(Type.DOUBLE_TYPE));
+        getter.visitMethodInsn(Opcodes.INVOKESPECIAL, dsType.getInternalName(), INIT,
+                               Type.getMethodDescriptor(Type.VOID_TYPE, Type.INT_TYPE, Type.INT_TYPE,
Type.DOUBLE_TYPE));
+
+        if (differentiableField.isStatic()) {
+            getter.visitFieldInsn(Opcodes.PUTSTATIC, classNode.name, differentiableField.getField(),
+                                  dsType.getDescriptor());
+        } else {
+            getter.visitFieldInsn(Opcodes.PUTFIELD, classNode.name, differentiableField.getField(),
+                                  dsType.getDescriptor());
+        }
+
+        getter.visitLabel(ifNonNull);
+
+        getter.visitVarInsn(Opcodes.ALOAD, 0);
+        if (differentiableField.isStatic()) {
+            getter.visitFieldInsn(Opcodes.GETSTATIC, classNode.name, differentiableField.getField(),
+                                  dsType.getDescriptor());
+        } else {
+            getter.visitFieldInsn(Opcodes.GETFIELD, classNode.name, differentiableField.getField(),
+                                  dsType.getDescriptor());
+        }
+
+        getter.visitInsn(Opcodes.ARETURN);
+        getter.visitMaxs(7, 3);
+        classNode.methods.add(getter);
+
+    }
+
     /** Request differentiation of a method.
      * @param owner class in which the method is defined
      * @param isStatic if true, the method is static
@@ -182,4 +304,19 @@ public class ClassDifferentiator {
                                                            method, primitiveMethodType, differentiatedMethodType);
     }
 
+    /** Request differentiation of a field.
+     * @param owner class in which the field is defined
+     * @param isStatic if true, the field is static
+     * @param field field name
+     * @param primitiveFieldType field type in the primitive
+     * @param differentiatedFieldType field type in the differentiated class
+     * @exception DifferentiationException if class cannot be found
+     */
+    public void requestFieldDifferentiation(final String owner, final boolean isStatic,
+                                            final String field, final Type primitiveFieldType,
+                                            final Type differentiatedFieldType) {
+        forwardDifferentiator.requestFieldDifferentiation(owner, isStatic, field,
+                                                          primitiveFieldType, differentiatedFieldType);
+    }
+
 }

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java?rev=1404267&r1=1404266&r2=1404267&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
(original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
Wed Oct 31 17:47:42 2012
@@ -103,14 +103,14 @@ public class MethodDifferentiator {
      * @return name of the primitive class
      */
     public String getPrimitiveName() {
-        return classDifferentiator.getPrimitive().name;
+        return classDifferentiator.getPrimitiveName();
     }
 
     /** Get the name of the derived class.
      * @return name of the derived class
      */
     public String getDerivedName() {
-        return classDifferentiator.getDerived().name;
+        return classDifferentiator.getDerivedName();
     }
 
     /**
@@ -255,7 +255,6 @@ public class MethodDifferentiator {
      * @param method method name
      * @param primitiveMethodType method type in the primitive (includes return and arguments
types)
      * @param differentiatedMethodType method type in the differentiated class (includes
return and arguments types)
-     * @return type of the differentiated method
      * @exception DifferentiationException if class cannot be found
      */
     public void requestMethodDifferentiation(final String owner, final boolean isStatic,
@@ -265,6 +264,21 @@ public class MethodDifferentiator {
                                                          primitiveMethodType, differentiatedMethodType);
     }
 
+    /** Request differentiation of a field.
+     * @param owner class in which the field is defined
+     * @param isStatic if true, the field is static
+     * @param field field name
+     * @param primitiveFieldType field type in the primitive
+     * @param differentiatedFieldType field type in the differentiated class
+     * @exception DifferentiationException if class cannot be found
+     */
+    public void requestFieldDifferentiation(final String owner, final boolean isStatic,
+                                            final String field, final Type primitiveFieldType,
+                                            final Type differentiatedFieldType) {
+        classDifferentiator.requestFieldDifferentiation(owner, isStatic, field,
+                                                        primitiveFieldType, differentiatedFieldType);
+    }
+
     /** Identify the instructions that must be changed.
      * <p>Identification is based on data flow analysis. We start by changing
      * the local variables in the initial frame to match the parameters of

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java?rev=1404267&r1=1404266&r2=1404267&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java
(original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java
Wed Oct 31 17:47:42 2012
@@ -80,8 +80,25 @@ public class PutTransformer implements I
                                                     final MethodDifferentiator methodDifferentiator,
                                                     final int dsIndex)
         throws DifferentiationException {
-        // TODO ad support for PUTFIELD/PUTSTATIC in the case of transformed fields
-        throw new RuntimeException("PUTFIELD/PUTSTATIC not handled yet for transformed fields");
+
+        final Type primitiveFieldType = Type.getType(insn.desc);
+        if (!primitiveFieldType.equals(Type.DOUBLE_TYPE)) {
+            // TODO handle double array types
+            throw new RuntimeException("PUTFIELD/PUTSTATIC not handled yet for non-double
fields");            
+        }
+        final Type differentiatedFieldType = Type.getType(DerivativeStructure.class);
+
+        // we need to add a new field in the transformed class
+        final boolean isStatic = insn.getOpcode() == Opcodes.PUTSTATIC;
+        methodDifferentiator.requestFieldDifferentiation(Type.getType("L" + insn.owner +
";").getClassName(),
+                                                         isStatic, insn.name,
+                                                         primitiveFieldType, differentiatedFieldType);
+
+        final InsnList list = new InsnList();
+        list.add(new FieldInsnNode(insn.getOpcode(), methodDifferentiator.getDerivedName(),
+                                   insn.name, differentiatedFieldType.getDescriptor()));
+        return list;
+
     }
 
     /** Get the replacement instructions when the field type is not transformed.



Mime
View raw message