hama-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yxji...@apache.org
Subject svn commit: r1500189 - in /hama/trunk: ./ ml/src/main/java/org/apache/hama/ml/distance/ ml/src/main/java/org/apache/hama/ml/kmeans/ ml/src/main/java/org/apache/hama/ml/math/ ml/src/main/java/org/apache/hama/ml/regression/ ml/src/main/java/org/apache/ha...
Date Sat, 06 Jul 2013 02:12:01 GMT
Author: yxjiang
Date: Sat Jul  6 02:12:01 2013
New Revision: 1500189

URL: http://svn.apache.org/r1500189
Log:
HAMA-773: Matrix/Vector operation does not validate the input argument.

Modified:
    hama/trunk/CHANGES.txt
    hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java
    hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java
    hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java

Modified: hama/trunk/CHANGES.txt
URL: http://svn.apache.org/viewvc/hama/trunk/CHANGES.txt?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/CHANGES.txt (original)
+++ hama/trunk/CHANGES.txt Sat Jul  6 02:12:01 2013
@@ -12,6 +12,7 @@ Release 0.6.3 (unreleased changes)
   IMPROVEMENTS
 
    HAMA-765: Add apply method to Vector/Matrix (Yexi Jiang)
+	 HAMA-773: Matrix/Vector operation does not validate the input argument (Yexi Jiang)
 
 Release 0.6.2 - June 26, 2013
 

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java Sat Jul  6
02:12:01 2013
@@ -50,7 +50,7 @@ public final class CosineDistance implem
     double lengthSquaredv1 = vec1.pow(2).sum();
     double lengthSquaredv2 = vec2.pow(2).sum();
 
-    double dotProduct = vec2.dot(vec1);
+    double dotProduct = vec2.dotUnsafe(vec1);
     double denominator = Math.sqrt(lengthSquaredv1)
         * Math.sqrt(lengthSquaredv2);
 

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java Sat Jul
 6 02:12:01 2013
@@ -36,7 +36,7 @@ public final class EuclidianDistance imp
 
   @Override
   public double measureDistance(DoubleVector vec1, DoubleVector vec2) {
-    return Math.sqrt(vec2.subtract(vec1).pow(2).sum());
+    return Math.sqrt(vec2.subtractUnsafe(vec1).pow(2).sum());
   }
 
 }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java Sat Jul  6 02:12:01
2013
@@ -162,7 +162,7 @@ public final class KMeansBSP
       if (oldCenter == null) {
         msgCenters[msg.getCenterIndex()] = newCenter;
       } else {
-        msgCenters[msg.getCenterIndex()] = oldCenter.add(newCenter);
+        msgCenters[msg.getCenterIndex()] = oldCenter.addUnsafe(newCenter);
       }
     }
     // divide by how often we globally summed vectors
@@ -177,7 +177,7 @@ public final class KMeansBSP
     for (int i = 0; i < msgCenters.length; i++) {
       final DoubleVector oldCenter = centers[i];
       if (msgCenters[i] != null) {
-        double calculateError = oldCenter.subtract(msgCenters[i]).abs().sum();
+        double calculateError = oldCenter.subtractUnsafe(msgCenters[i]).abs().sum();
         if (calculateError > 0.0d) {
           centers[i] = msgCenters[i];
           convergedCounter++;
@@ -241,7 +241,7 @@ public final class KMeansBSP
     } else {
       // add the vector to the center
       newCenterArray[lowestDistantCenter] = newCenterArray[lowestDistantCenter]
-          .add(key);
+          .addUnsafe(key);
       summationCount[lowestDistantCenter]++;
     }
   }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java Sat Jul  6
02:12:01 2013
@@ -21,6 +21,8 @@ import java.util.Arrays;
 import java.util.HashSet;
 import java.util.Random;
 
+import com.google.common.base.Preconditions;
+
 /**
  * Dense double matrix implementation, internally uses two dimensional double
  * arrays.
@@ -384,7 +386,7 @@ public final class DenseDoubleMatrix imp
    * @see de.jungblut.math.DoubleMatrix#multiply(de.jungblut.math.DoubleMatrix)
    */
   @Override
-  public final DoubleMatrix multiply(DoubleMatrix other) {
+  public final DoubleMatrix multiplyUnsafe(DoubleMatrix other) {
     DenseDoubleMatrix matrix = new DenseDoubleMatrix(this.getRowCount(),
         other.getColumnCount());
 
@@ -412,7 +414,7 @@ public final class DenseDoubleMatrix imp
    * )
    */
   @Override
-  public final DoubleMatrix multiplyElementWise(DoubleMatrix other) {
+  public final DoubleMatrix multiplyElementWiseUnsafe(DoubleMatrix other) {
     DenseDoubleMatrix matrix = new DenseDoubleMatrix(this.numRows,
         this.numColumns);
 
@@ -431,7 +433,7 @@ public final class DenseDoubleMatrix imp
    * de.jungblut.math.DoubleMatrix#multiplyVector(de.jungblut.math.DoubleVector)
    */
   @Override
-  public final DoubleVector multiplyVector(DoubleVector v) {
+  public final DoubleVector multiplyVectorUnsafe(DoubleVector v) {
     DoubleVector vector = new DenseDoubleVector(this.getRowCount());
     for (int row = 0; row < numRows; row++) {
       double sum = 0.0d;
@@ -494,7 +496,7 @@ public final class DenseDoubleMatrix imp
    * @see de.jungblut.math.DoubleMatrix#subtract(de.jungblut.math.DoubleMatrix)
    */
   @Override
-  public DoubleMatrix subtract(DoubleMatrix other) {
+  public DoubleMatrix subtractUnsafe(DoubleMatrix other) {
     DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns);
     for (int i = 0; i < numRows; i++) {
       for (int j = 0; j < numColumns; j++) {
@@ -509,7 +511,7 @@ public final class DenseDoubleMatrix imp
    * @see de.jungblut.math.DoubleMatrix#subtract(de.jungblut.math.DoubleVector)
    */
   @Override
-  public DenseDoubleMatrix subtract(DoubleVector vec) {
+  public DenseDoubleMatrix subtractUnsafe(DoubleVector vec) {
     DenseDoubleMatrix cop = new DenseDoubleMatrix(this.getRowCount(),
         this.getColumnCount());
     for (int i = 0; i < this.getColumnCount(); i++) {
@@ -523,7 +525,7 @@ public final class DenseDoubleMatrix imp
    * @see de.jungblut.math.DoubleMatrix#divide(de.jungblut.math.DoubleVector)
    */
   @Override
-  public DoubleMatrix divide(DoubleVector vec) {
+  public DoubleMatrix divideUnsafe(DoubleVector vec) {
     DoubleMatrix cop = new DenseDoubleMatrix(this.getRowCount(),
         this.getColumnCount());
     for (int i = 0; i < this.getColumnCount(); i++) {
@@ -532,12 +534,22 @@ public final class DenseDoubleMatrix imp
     return cop;
   }
 
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public DoubleMatrix divide(DoubleVector vec) {
+    Preconditions.checkArgument(this.getColumnCount() == vec.getDimension(),
+        "Dimension mismatch.");
+    return this.divideUnsafe(vec);
+  }
+
   /*
    * (non-Javadoc)
    * @see de.jungblut.math.DoubleMatrix#divide(de.jungblut.math.DoubleMatrix)
    */
   @Override
-  public DoubleMatrix divide(DoubleMatrix other) {
+  public DoubleMatrix divideUnsafe(DoubleMatrix other) {
     DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns);
     for (int i = 0; i < numRows; i++) {
       for (int j = 0; j < numColumns; j++) {
@@ -547,6 +559,13 @@ public final class DenseDoubleMatrix imp
     return m;
   }
 
+  @Override
+  public DoubleMatrix divide(DoubleMatrix other) {
+    Preconditions.checkArgument(this.getRowCount() == other.getRowCount()
+        && this.getColumnCount() == other.getColumnCount());
+    return divideUnsafe(other);
+  }
+
   /*
    * (non-Javadoc)
    * @see de.jungblut.math.DoubleMatrix#divide(double)
@@ -775,7 +794,7 @@ public final class DenseDoubleMatrix imp
    * Just a absolute error function.
    */
   public static double error(DenseDoubleMatrix a, DenseDoubleMatrix b) {
-    return a.subtract(b).sum();
+    return a.subtractUnsafe(b).sum();
   }
 
   @Override
@@ -795,20 +814,91 @@ public final class DenseDoubleMatrix imp
   /**
    * {@inheritDoc}
    */
-  public DoubleMatrix applyToElements(DoubleMatrix other, DoubleDoubleFunction fun) {
-    if (this.numRows != other.getRowCount()
-        || this.numColumns != other.getColumnCount()) {
-      throw new IllegalArgumentException(
-          "Cannot apply double double function to matrices with different sizes.");
-    }
-    
+  public DoubleMatrix applyToElements(DoubleMatrix other,
+      DoubleDoubleFunction fun) {
+    Preconditions
+        .checkArgument(this.numRows == other.getRowCount()
+            && this.numColumns == other.getColumnCount(),
+            "Cannot apply double double function to matrices with different sizes.");
+
     for (int r = 0; r < this.numRows; ++r) {
       for (int c = 0; c < this.numColumns; ++c) {
         this.set(r, c, fun.apply(this.get(r, c), other.get(r, c)));
       }
     }
-    
+
     return this;
   }
 
+  /*
+   * (non-Javadoc)
+   * @see
+   * org.apache.hama.ml.math.DoubleMatrix#safeMultiply(org.apache.hama.ml.math
+   * .DoubleMatrix)
+   */
+  @Override
+  public DoubleMatrix multiply(DoubleMatrix other) {
+    Preconditions
+        .checkArgument(
+            this.numColumns == other.getRowCount(),
+            String
+                .format(
+                    "Matrix with size [%d, %d] cannot multiple matrix with size [%d, %d]",
+                    this.numRows, this.numColumns, other.getRowCount(),
+                    other.getColumnCount()));
+
+    return this.multiplyUnsafe(other);
+  }
+
+  /*
+   * (non-Javadoc)
+   * @see
+   * org.apache.hama.ml.math.DoubleMatrix#safeMultiplyElementWise(org.apache
+   * .hama.ml.math.DoubleMatrix)
+   */
+  @Override
+  public DoubleMatrix multiplyElementWise(DoubleMatrix other) {
+    Preconditions.checkArgument(this.numRows == other.getRowCount()
+        && this.numColumns == other.getColumnCount(),
+        "Matrices with different dimensions cannot be multiplied elementwise.");
+    return this.multiplyElementWiseUnsafe(other);
+  }
+
+  /*
+   * (non-Javadoc)
+   * @see
+   * org.apache.hama.ml.math.DoubleMatrix#safeMultiplyVector(org.apache.hama
+   * .ml.math.DoubleVector)
+   */
+  @Override
+  public DoubleVector multiplyVector(DoubleVector v) {
+    Preconditions.checkArgument(this.numColumns == v.getDimension(),
+        "Dimension mismatch.");
+    return this.multiplyVectorUnsafe(v);
+  }
+
+  /*
+   * (non-Javadoc)
+   * @see org.apache.hama.ml.math.DoubleMatrix#subtract(org.apache.hama.ml.math.
+   * DoubleMatrix)
+   */
+  @Override
+  public DoubleMatrix subtract(DoubleMatrix other) {
+    Preconditions.checkArgument(this.numRows == other.getRowCount()
+        && this.numColumns == other.getColumnCount(), "Dimension mismatch.");
+    return subtractUnsafe(other);
+  }
+
+  /*
+   * (non-Javadoc)
+   * @see org.apache.hama.ml.math.DoubleMatrix#subtract(org.apache.hama.ml.math.
+   * DoubleVector)
+   */
+  @Override
+  public DoubleMatrix subtract(DoubleVector vec) {
+    Preconditions.checkArgument(this.numColumns == vec.getDimension(),
+        "Dimension mismatch.");
+    return null;
+  }
+
 }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java Sat Jul  6
02:12:01 2013
@@ -24,6 +24,7 @@ import java.util.Comparator;
 import java.util.Iterator;
 import java.util.List;
 
+import com.google.common.base.Preconditions;
 import com.google.common.collect.AbstractIterator;
 
 /**
@@ -112,16 +113,17 @@ public final class DenseDoubleVector imp
   }
 
   /**
-   * {@inheritDoc}}
+   * {@inheritDoc}
    */
   @Override
-  public DoubleVector applyToElements(DoubleVector other, DoubleDoubleFunction func) {
+  public DoubleVector applyToElements(DoubleVector other,
+      DoubleDoubleFunction func) {
     for (int i = 0; i < vector.length; i++) {
       this.vector[i] = func.apply(vector[i], other.get(i));
     }
     return this;
   }
-  
+
   /*
    * (non-Javadoc)
    * @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.function.
@@ -157,7 +159,7 @@ public final class DenseDoubleVector imp
    * @see de.jungblut.math.DoubleVector#add(de.jungblut.math.DoubleVector)
    */
   @Override
-  public final DoubleVector add(DoubleVector v) {
+  public final DoubleVector addUnsafe(DoubleVector v) {
     DenseDoubleVector newv = new DenseDoubleVector(v.getLength());
     for (int i = 0; i < v.getLength(); i++) {
       newv.set(i, this.get(i) + v.get(i));
@@ -183,7 +185,7 @@ public final class DenseDoubleVector imp
    * @see de.jungblut.math.DoubleVector#subtract(de.jungblut.math.DoubleVector)
    */
   @Override
-  public final DoubleVector subtract(DoubleVector v) {
+  public final DoubleVector subtractUnsafe(DoubleVector v) {
     DoubleVector newv = new DenseDoubleVector(v.getLength());
     for (int i = 0; i < v.getLength(); i++) {
       newv.set(i, this.get(i) - v.get(i));
@@ -235,7 +237,7 @@ public final class DenseDoubleVector imp
    * @see de.jungblut.math.DoubleVector#multiply(de.jungblut.math.DoubleVector)
    */
   @Override
-  public DoubleVector multiply(DoubleVector vector) {
+  public DoubleVector multiplyUnsafe(DoubleVector vector) {
     DoubleVector v = new DenseDoubleVector(this.getLength());
     for (int i = 0; i < v.getLength(); i++) {
       v.set(i, this.get(i) * vector.get(i));
@@ -338,10 +340,10 @@ public final class DenseDoubleVector imp
    * @see de.jungblut.math.DoubleVector#dot(de.jungblut.math.DoubleVector)
    */
   @Override
-  public double dot(DoubleVector s) {
+  public double dotUnsafe(DoubleVector vector) {
     double dotProduct = 0.0d;
     for (int i = 0; i < getLength(); i++) {
-      dotProduct += this.get(i) * s.get(i);
+      dotProduct += this.get(i) * vector.get(i);
     }
     return dotProduct;
   }
@@ -652,4 +654,54 @@ public final class DenseDoubleVector imp
     return null;
   }
 
+  /*
+   * (non-Javadoc)
+   * @see org.apache.hama.ml.math.DoubleVector#safeAdd(org.apache.hama.ml.math.
+   * DoubleVector)
+   */
+  @Override
+  public DoubleVector add(DoubleVector vector) {
+    Preconditions.checkArgument(this.vector.length == vector.getDimension(),
+        "Dimensions of two vectors do not equal.");
+    return this.addUnsafe(vector);
+  }
+
+  /*
+   * (non-Javadoc)
+   * @see
+   * org.apache.hama.ml.math.DoubleVector#safeSubtract(org.apache.hama.ml.math
+   * .DoubleVector)
+   */
+  @Override
+  public DoubleVector subtract(DoubleVector vector) {
+    Preconditions.checkArgument(this.vector.length == vector.getDimension(),
+        "Dimensions of two vectors do not equal.");
+    return this.subtractUnsafe(vector);
+  }
+
+  /*
+   * (non-Javadoc)
+   * @see
+   * org.apache.hama.ml.math.DoubleVector#safeMultiplay(org.apache.hama.ml.math
+   * .DoubleVector)
+   */
+  @Override
+  public DoubleVector multiply(DoubleVector vector) {
+    Preconditions.checkArgument(this.vector.length == vector.getDimension(),
+        "Dimensions of two vectors do not equal.");
+    return this.multiplyUnsafe(vector);
+  }
+
+  /*
+   * (non-Javadoc)
+   * @see org.apache.hama.ml.math.DoubleVector#safeDot(org.apache.hama.ml.math.
+   * DoubleVector)
+   */
+  @Override
+  public double dot(DoubleVector vector) {
+    Preconditions.checkArgument(this.vector.length == vector.getDimension(),
+        "Dimensions of two vectors do not equal.");
+    return this.dotUnsafe(vector);
+  }
+
 }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java Sat Jul  6 02:12:01
2013
@@ -80,18 +80,47 @@ public interface DoubleMatrix {
 
   /**
    * Multiplies this matrix with the given other matrix.
+   * 
+   * @param other the other matrix.
+   * @return
+   */
+  public DoubleMatrix multiplyUnsafe(DoubleMatrix other);
+
+  /**
+   * Validates the input and multiplies this matrix with the given other matrix.
+   * 
+   * @param other the other matrix.
+   * @return
    */
   public DoubleMatrix multiply(DoubleMatrix other);
 
   /**
    * Multiplies this matrix per element with a given matrix.
    */
+  public DoubleMatrix multiplyElementWiseUnsafe(DoubleMatrix other);
+
+  /**
+   * Validates the input and multiplies this matrix per element with a given
+   * matrix.
+   * 
+   * @param other the other matrix
+   * @return
+   */
   public DoubleMatrix multiplyElementWise(DoubleMatrix other);
 
   /**
    * Multiplies this matrix with a given vector v. The returning vector contains
    * the sum of the rows.
    */
+  public DoubleVector multiplyVectorUnsafe(DoubleVector v);
+
+  /**
+   * Multiplies this matrix with a given vector v. The returning vector contains
+   * the sum of the rows.
+   * 
+   * @param v the vector
+   * @return
+   */
   public DoubleVector multiplyVector(DoubleVector v);
 
   /**
@@ -114,23 +143,58 @@ public interface DoubleMatrix {
   /**
    * Subtracts this matrix by the given other matrix.
    */
+  public DoubleMatrix subtractUnsafe(DoubleMatrix other);
+
+  /**
+   * Validates the input and subtracts this matrix by the given other matrix.
+   * 
+   * @param other
+   * @return
+   */
   public DoubleMatrix subtract(DoubleMatrix other);
 
   /**
    * Subtracts each element in a column by the related element in the given
    * vector.
    */
+  public DoubleMatrix subtractUnsafe(DoubleVector vec);
+
+  /**
+   * Validates and subtracts each element in a column by the related element in
+   * the given vector.
+   * 
+   * @param vec
+   * @return
+   */
   public DoubleMatrix subtract(DoubleVector vec);
 
   /**
    * Divides each element in a column by the related element in the given
    * vector.
    */
+  public DoubleMatrix divideUnsafe(DoubleVector vec);
+
+  /**
+   * Validates and divides each element in a column by the related element in
+   * the given vector.
+   * 
+   * @param vec
+   * @return
+   */
   public DoubleMatrix divide(DoubleVector vec);
 
   /**
    * Divides this matrix by the given other matrix. (Per element division).
    */
+  public DoubleMatrix divideUnsafe(DoubleMatrix other);
+
+  /**
+   * Validates and divides this matrix by the given other matrix. (Per element
+   * division).
+   * 
+   * @param other
+   * @return
+   */
   public DoubleMatrix divide(DoubleMatrix other);
 
   /**
@@ -203,6 +267,7 @@ public interface DoubleMatrix {
    * @param fun The function that takes two arguments.
    * @return The matrix itself, supply for chain operation.
    */
-  public DoubleMatrix applyToElements(DoubleMatrix other, DoubleDoubleFunction fun);
+  public DoubleMatrix applyToElements(DoubleMatrix other,
+      DoubleDoubleFunction fun);
 
 }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java Sat Jul  6 02:12:01
2013
@@ -58,7 +58,7 @@ public interface DoubleVector {
    * @param value the value at the index of the vector to set.
    */
   public void set(int index, double value);
-  
+
   /**
    * Apply a given {@link DoubleVectorFunction} to this vector and return a new
    * one.
@@ -68,7 +68,7 @@ public interface DoubleVector {
    */
   @Deprecated
   public DoubleVector apply(DoubleVectorFunction func);
-  
+
   /**
    * Apply a given {@link DoubleDoubleVectorFunction} to this vector and the
    * other given vector.
@@ -97,15 +97,24 @@ public interface DoubleVector {
    * @param func the function to apply on this and the other vector.
    * @return a new vector with the result of the function of the two vectors.
    */
-  public DoubleVector applyToElements(DoubleVector other, DoubleDoubleFunction func);
+  public DoubleVector applyToElements(DoubleVector other,
+      DoubleDoubleFunction func);
 
   /**
    * Adds the given {@link DoubleVector} to this vector.
    * 
-   * @param v the other vector.
+   * @param vector the other vector.
+   * @return a new vector with the sum of both vectors at each element index.
+   */
+  public DoubleVector addUnsafe(DoubleVector vector);
+
+  /**
+   * Validates the input and adds the given {@link DoubleVector} to this vector.
+   * 
+   * @param vector the other vector.
    * @return a new vector with the sum of both vectors at each element index.
    */
-  public DoubleVector add(DoubleVector v);
+  public DoubleVector add(DoubleVector vector);
 
   /**
    * Adds the given scalar to this vector.
@@ -118,10 +127,19 @@ public interface DoubleVector {
   /**
    * Subtracts this vector by the given {@link DoubleVector}.
    * 
-   * @param v the other vector.
+   * @param vector the other vector.
    * @return a new vector with the difference of both vectors.
    */
-  public DoubleVector subtract(DoubleVector v);
+  public DoubleVector subtractUnsafe(DoubleVector vector);
+
+  /**
+   * Validates the input and subtracts this vector by the given
+   * {@link DoubleVector}.
+   * 
+   * @param vector the other vector.
+   * @return a new vector with the difference of both vectors.
+   */
+  public DoubleVector subtract(DoubleVector vector);
 
   /**
    * Subtracts the given scalar to this vector. (vector - scalar).
@@ -153,6 +171,15 @@ public interface DoubleVector {
    * @param vector the other vector.
    * @return a new vector with the result of the operation.
    */
+  public DoubleVector multiplyUnsafe(DoubleVector vector);
+
+  /**
+   * Validates the input and multiplies the given {@link DoubleVector} with this
+   * vector.
+   * 
+   * @param vector the other vector.
+   * @return a new vector with the result of the operation.
+   */
   public DoubleVector multiply(DoubleVector vector);
 
   /**
@@ -201,10 +228,19 @@ public interface DoubleVector {
   /**
    * Calculates the dot product between this vector and the given vector.
    * 
-   * @param s the given vector s.
+   * @param vector the given vector.
+   * @return the dot product as a double.
+   */
+  public double dotUnsafe(DoubleVector vector);
+
+  /**
+   * Validates the input and calculates the dot product between this vector and
+   * the given vector.
+   * 
+   * @param vector the given vector.
    * @return the dot product as a double.
    */
-  public double dot(DoubleVector s);
+  public double dot(DoubleVector vector);
 
   /**
    * Slices this vector from index 0 to the given length.

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java Sat
Jul  6 02:12:01 2013
@@ -38,7 +38,7 @@ public class LinearRegressionModel imple
 
   @Override
   public double applyHypothesis(DoubleVector theta, DoubleVector x) {
-    return theta.dot(x);
+    return theta.dotUnsafe(x);
   }
 
   @Override

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
(original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
Sat Jul  6 02:12:01 2013
@@ -53,7 +53,7 @@ public class LogisticRegressionModel imp
       DoubleVector x) {
     return BigDecimal.valueOf(1).divide(
         BigDecimal.valueOf(1d).add(
-            BigDecimal.valueOf(Math.exp(-1d * theta.dot(x)))),
+            BigDecimal.valueOf(Math.exp(-1d * theta.dotUnsafe(x)))),
         MathContext.DECIMAL128);
   }
 

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java Sat Jul  6
02:12:01 2013
@@ -119,7 +119,7 @@ public final class VectorWritable implem
   }
 
   public static int compareVector(DoubleVector a, DoubleVector o) {
-    DoubleVector subtract = a.subtract(o);
+    DoubleVector subtract = a.subtractUnsafe(o);
     return (int) subtract.sum();
   }
 

Modified: hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java (original)
+++ hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java Sat Jul
 6 02:12:01 2013
@@ -19,6 +19,8 @@ package org.apache.hama.ml.math;
 
 import static org.junit.Assert.assertArrayEquals;
 
+import java.util.Arrays;
+
 import org.junit.Test;
 
 /**
@@ -57,12 +59,14 @@ public class TestDenseDoubleMatrix {
   @Test
   public void testDoubleDoubleFunction() {
     double[][] values1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } };
-    double[][] values2 = new double[][] { { 2, 3, 4 }, { 5, 6, 7 }, { 8, 9, 10 } };
-    double[][] result = new double[][] { {3, 5, 7}, {9, 11, 13}, {15, 17, 19}};
+    double[][] values2 = new double[][] { { 2, 3, 4 }, { 5, 6, 7 },
+        { 8, 9, 10 } };
+    double[][] result = new double[][] { { 3, 5, 7 }, { 9, 11, 13 },
+        { 15, 17, 19 } };
 
     DenseDoubleMatrix mat1 = new DenseDoubleMatrix(values1);
     DenseDoubleMatrix mat2 = new DenseDoubleMatrix(values2);
-    
+
     mat1.applyToElements(mat2, new DoubleDoubleFunction() {
 
       @Override
@@ -83,4 +87,153 @@ public class TestDenseDoubleMatrix {
     }
   }
 
+  @Test
+  public void testMultiplyNormal() {
+    double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+    double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 }, { 2, 1 } };
+    double[][] expMat = new double[][] { { 20, 14 }, { 56, 41 } };
+    DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+    DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+    DoubleMatrix actMatrix = matrix1.multiply(matrix2);
+    for (int r = 0; r < actMatrix.getRowCount(); ++r) {
+      assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(),
+          0.000001);
+    }
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testMultiplyAbnormal() {
+    double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+    double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } };
+    DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+    DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+    matrix1.multiply(matrix2);
+  }
+
+  @Test
+  public void testMultiplyElementWiseNormal() {
+    double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+    double[][] mat2 = new double[][] { { 6, 5, 4 }, { 3, 2, 1 } };
+    double[][] expMat = new double[][] { { 6, 10, 12 }, { 12, 10, 6 } };
+    DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+    DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+    DoubleMatrix actMatrix = matrix1.multiplyElementWise(matrix2);
+    for (int r = 0; r < actMatrix.getRowCount(); ++r) {
+      assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(),
+          0.000001);
+    }
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testMultiplyElementWiseAbnormal() {
+    double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+    double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } };
+    DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+    DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+    matrix1.multiplyElementWise(matrix2);
+  }
+
+  @Test
+  public void testMultiplyVectorNormal() {
+    double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+    double[] mat2 = new double[] { 6, 5, 4 };
+    double[] expVec = new double[] { 28, 73 };
+    DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+    DoubleVector vector2 = new DenseDoubleVector(mat2);
+    DoubleVector actVec = matrix1.multiplyVector(vector2);
+    assertArrayEquals(expVec, actVec.toArray(), 0.000001);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testMultiplyVectorAbnormal() {
+    double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+    double[] vec2 = new double[] { 6, 5 };
+    DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+    DoubleVector vector2 = new DenseDoubleVector(vec2);
+    matrix1.multiplyVector(vector2);
+  }
+
+  @Test
+  public void testSubtractNormal() {
+    double[][] mat1 = new double[][] {
+        {1, 2, 3},
+        {4, 5, 6}
+    };
+    double[][] mat2 = new double[][] {
+        {6, 5, 4},
+        {3, 2, 1}
+    };
+    double[][] expMat = new double[][] {
+        {-5, -3, -1},
+        {1, 3, 5}
+    };
+    DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+    DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+    DoubleMatrix actMatrix = matrix1.subtract(matrix2);
+    for (int r = 0; r < actMatrix.getRowCount(); ++r) {
+      assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(), 0.000001);
+    }
+  }
+  
+  @Test(expected = IllegalArgumentException.class)
+  public void testSubtractAbnormal() {
+    double[][] mat1 = new double[][] {
+        {1, 2, 3},
+        {4, 5, 6}
+    };
+    double[][] mat2 = new double[][] {
+        {6, 5},
+        {4, 3}
+    };
+    DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+    DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+    matrix1.subtract(matrix2);
+  }
+  
+  @Test
+  public void testDivideVectorNormal() {
+    double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+    double[] mat2 = new double[] { 6, 5, 4 };
+    double[][] expVec = new double[][] { {1.0 / 6, 2.0 / 5, 3.0 / 4}, {4.0 / 6, 5.0 / 5,
6.0 / 4} };
+    DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+    DoubleVector vector2 = new DenseDoubleVector(mat2);
+    DoubleMatrix expMat = new DenseDoubleMatrix(expVec);
+    DoubleMatrix actMat = matrix1.divide(vector2);
+    for (int r = 0; r < actMat.getRowCount(); ++r) {
+      assertArrayEquals(expMat.getRowVector(r).toArray(), actMat.getRowVector(r).toArray(),
0.000001);
+    }
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testDivideVectorAbnormal() {
+    double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+    double[] vec2 = new double[] { 6, 5 };
+    DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+    DoubleVector vector2 = new DenseDoubleVector(vec2);
+    matrix1.divide(vector2);
+  }
+  
+  @Test
+  public void testDivideNormal() {
+    double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+    double[][] mat2 = new double[][] { { 6, 5, 4 }, { 3, 2, 1 } };
+    double[][] expMat = new double[][] { { 1.0 / 6, 2.0 / 5, 3.0 / 4 }, { 4.0 / 3, 5.0 /
2, 6.0 / 1 } };
+    DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+    DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+    DoubleMatrix actMatrix = matrix1.divide(matrix2);
+    for (int r = 0; r < actMatrix.getRowCount(); ++r) {
+      assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(),
+          0.000001);
+    }
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testDivideAbnormal() {
+    double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+    double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } };
+    DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1);
+    DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2);
+    matrix1.divide(matrix2);
+  }
+  
 }

Modified: hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java?rev=1500189&r1=1500188&r2=1500189&view=diff
==============================================================================
--- hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java (original)
+++ hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java Sat Jul
 6 02:12:01 2013
@@ -18,8 +18,11 @@
 package org.apache.hama.ml.math;
 
 import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
 
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.ExpectedException;
 
 /**
  * Testcase for {@link DenseDoubleVector}
@@ -77,4 +80,79 @@ public class TestDenseDoubleVector {
     assertArrayEquals(result, vec1.toArray(), 0.0001);
     
   }
+  
+  @Test
+  public void testAddNormal() {
+    double[] arr1 = new double[] {1, 2, 3};
+    double[] arr2 = new double[] {4, 5, 6};
+    DoubleVector vec1 = new DenseDoubleVector(arr1);
+    DoubleVector vec2 = new DenseDoubleVector(arr2);
+    double[] arrExp = new double[] {5, 7, 9};
+    assertArrayEquals(arrExp, vec1.add(vec2).toArray(), 0.000001);
+  }
+  
+  @Test(expected = IllegalArgumentException.class)
+  public void testAddAbnormal() {
+    double[] arr1 = new double[] {1, 2, 3};
+    double[] arr2 = new double[] {4, 5};
+    DoubleVector vec1 = new DenseDoubleVector(arr1);
+    DoubleVector vec2 = new DenseDoubleVector(arr2);
+    vec1.add(vec2);
+  }
+  
+  @Test
+  public void testSubtractNormal() {
+    double[] arr1 = new double[] {1, 2, 3};
+    double[] arr2 = new double[] {4, 5, 6};
+    DoubleVector vec1 = new DenseDoubleVector(arr1);
+    DoubleVector vec2 = new DenseDoubleVector(arr2);
+    double[] arrExp = new double[] {-3, -3, -3};
+    assertArrayEquals(arrExp, vec1.subtract(vec2).toArray(), 0.000001);
+  }
+  
+  @Test(expected = IllegalArgumentException.class)
+  public void testSubtractAbnormal() {
+    double[] arr1 = new double[] {1, 2, 3};
+    double[] arr2 = new double[] {4, 5};
+    DoubleVector vec1 = new DenseDoubleVector(arr1);
+    DoubleVector vec2 = new DenseDoubleVector(arr2);
+    vec1.subtract(vec2);
+  }
+  
+  @Test
+  public void testMultiplyNormal() {
+    double[] arr1 = new double[] {1, 2, 3};
+    double[] arr2 = new double[] {4, 5, 6};
+    DoubleVector vec1 = new DenseDoubleVector(arr1);
+    DoubleVector vec2 = new DenseDoubleVector(arr2);
+    double[] arrExp = new double[] {4, 10, 18};
+    assertArrayEquals(arrExp, vec1.multiply(vec2).toArray(), 0.000001);
+  }
+  
+  @Test(expected = IllegalArgumentException.class)
+  public void testMultiplyAbnormal() {
+    double[] arr1 = new double[] {1, 2, 3};
+    double[] arr2 = new double[] {4, 5};
+    DoubleVector vec1 = new DenseDoubleVector(arr1);
+    DoubleVector vec2 = new DenseDoubleVector(arr2);
+    vec1.multiply(vec2);
+  }
+  
+  @Test
+  public void testDotNormal() {
+    double[] arr1 = new double[] {1, 2, 3};
+    double[] arr2 = new double[] {4, 5, 6};
+    DoubleVector vec1 = new DenseDoubleVector(arr1);
+    DoubleVector vec2 = new DenseDoubleVector(arr2);
+    assertEquals(32.0, vec1.dot(vec2), 0.000001);
+  }
+  
+  @Test(expected = IllegalArgumentException.class)
+  public void testDotAbnormal() {
+    double[] arr1 = new double[] {1, 2, 3};
+    double[] arr2 = new double[] {4, 5};
+    DoubleVector vec1 = new DenseDoubleVector(arr1);
+    DoubleVector vec2 = new DenseDoubleVector(arr2);
+    vec1.add(vec2);
+  }
 }



Mime
View raw message