commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From oe...@apache.org
Subject [1/2] [math] MATH-1258: check for equal array lengths in distance functions
Date Thu, 20 Aug 2015 16:01:36 GMT
Repository: commons-math
Updated Branches:
  refs/heads/MATH_3_X 9cb16d5b1 -> 7934bfea1
  refs/heads/master f70741c9b -> 5ca0a1c35


MATH-1258: check for equal array lengths in distance functions

Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/5ca0a1c3
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/5ca0a1c3
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/5ca0a1c3

Branch: refs/heads/master
Commit: 5ca0a1c3564d35293a5ecf03e5f32e6f0f6f445c
Parents: f70741c
Author: Otmar Ertl <otmar.ertl@gmail.com>
Authored: Thu Aug 20 17:29:02 2015 +0200
Committer: Otmar Ertl <otmar.ertl@gmail.com>
Committed: Thu Aug 20 17:29:02 2015 +0200

----------------------------------------------------------------------
 .../math4/ml/distance/CanberraDistance.java     |  6 +-
 .../math4/ml/distance/ChebyshevDistance.java    |  4 +-
 .../math4/ml/distance/DistanceMeasure.java      |  5 +-
 .../math4/ml/distance/EarthMoversDistance.java  |  6 +-
 .../math4/ml/distance/EuclideanDistance.java    |  4 +-
 .../math4/ml/distance/ManhattanDistance.java    |  4 +-
 .../apache/commons/math4/util/MathArrays.java   | 83 ++++++++++++++++----
 7 files changed, 89 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-math/blob/5ca0a1c3/src/main/java/org/apache/commons/math4/ml/distance/CanberraDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math4/ml/distance/CanberraDistance.java b/src/main/java/org/apache/commons/math4/ml/distance/CanberraDistance.java
index 0a3aaa3..59b8ff3 100644
--- a/src/main/java/org/apache/commons/math4/ml/distance/CanberraDistance.java
+++ b/src/main/java/org/apache/commons/math4/ml/distance/CanberraDistance.java
@@ -16,7 +16,9 @@
  */
 package org.apache.commons.math4.ml.distance;
 
+import org.apache.commons.math4.exception.DimensionMismatchException;
 import org.apache.commons.math4.util.FastMath;
+import org.apache.commons.math4.util.MathArrays;
 
 /**
  * Calculates the Canberra distance between two points.
@@ -30,7 +32,9 @@ public class CanberraDistance implements DistanceMeasure {
 
     /** {@inheritDoc} */
     @Override
-    public double compute(double[] a, double[] b) {
+    public double compute(double[] a, double[] b)
+    throws DimensionMismatchException {
+        MathArrays.checkEqualLength(a, b);
         double sum = 0;
         for (int i = 0; i < a.length; i++) {
             final double num = FastMath.abs(a[i] - b[i]);

http://git-wip-us.apache.org/repos/asf/commons-math/blob/5ca0a1c3/src/main/java/org/apache/commons/math4/ml/distance/ChebyshevDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math4/ml/distance/ChebyshevDistance.java b/src/main/java/org/apache/commons/math4/ml/distance/ChebyshevDistance.java
index 4b6cf28..9b3268e 100644
--- a/src/main/java/org/apache/commons/math4/ml/distance/ChebyshevDistance.java
+++ b/src/main/java/org/apache/commons/math4/ml/distance/ChebyshevDistance.java
@@ -16,6 +16,7 @@
  */
 package org.apache.commons.math4.ml.distance;
 
+import org.apache.commons.math4.exception.DimensionMismatchException;
 import org.apache.commons.math4.util.MathArrays;
 
 /**
@@ -30,7 +31,8 @@ public class ChebyshevDistance implements DistanceMeasure {
 
     /** {@inheritDoc} */
     @Override
-    public double compute(double[] a, double[] b) {
+    public double compute(double[] a, double[] b)
+    throws DimensionMismatchException {
         return MathArrays.distanceInf(a, b);
     }
 

http://git-wip-us.apache.org/repos/asf/commons-math/blob/5ca0a1c3/src/main/java/org/apache/commons/math4/ml/distance/DistanceMeasure.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math4/ml/distance/DistanceMeasure.java b/src/main/java/org/apache/commons/math4/ml/distance/DistanceMeasure.java
index 89d4672..a62ce80 100644
--- a/src/main/java/org/apache/commons/math4/ml/distance/DistanceMeasure.java
+++ b/src/main/java/org/apache/commons/math4/ml/distance/DistanceMeasure.java
@@ -18,6 +18,8 @@ package org.apache.commons.math4.ml.distance;
 
 import java.io.Serializable;
 
+import org.apache.commons.math4.exception.DimensionMismatchException;
+
 /**
  * Interface for distance measures of n-dimensional vectors.
  *
@@ -33,6 +35,7 @@ public interface DistanceMeasure extends Serializable {
      * @param a the first vector
      * @param b the second vector
      * @return the distance between the two vectors
+     * @throws DimensionMismatchException if the array lengths differ.
      */
-    double compute(double[] a, double[] b);
+    double compute(double[] a, double[] b) throws DimensionMismatchException;
 }

http://git-wip-us.apache.org/repos/asf/commons-math/blob/5ca0a1c3/src/main/java/org/apache/commons/math4/ml/distance/EarthMoversDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math4/ml/distance/EarthMoversDistance.java b/src/main/java/org/apache/commons/math4/ml/distance/EarthMoversDistance.java
index b5cdd2e..77b0479 100644
--- a/src/main/java/org/apache/commons/math4/ml/distance/EarthMoversDistance.java
+++ b/src/main/java/org/apache/commons/math4/ml/distance/EarthMoversDistance.java
@@ -16,7 +16,9 @@
  */
 package org.apache.commons.math4.ml.distance;
 
+import org.apache.commons.math4.exception.DimensionMismatchException;
 import org.apache.commons.math4.util.FastMath;
+import org.apache.commons.math4.util.MathArrays;
 
 /**
  * Calculates the Earh Mover's distance (also known as Wasserstein metric) between two distributions.
@@ -32,7 +34,9 @@ public class EarthMoversDistance implements DistanceMeasure {
 
     /** {@inheritDoc} */
     @Override
-    public double compute(double[] a, double[] b) {
+    public double compute(double[] a, double[] b)
+    throws DimensionMismatchException {
+        MathArrays.checkEqualLength(a, b);
         double lastDistance = 0;
         double totalDistance = 0;
         for (int i = 0; i < a.length; i++) {

http://git-wip-us.apache.org/repos/asf/commons-math/blob/5ca0a1c3/src/main/java/org/apache/commons/math4/ml/distance/EuclideanDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math4/ml/distance/EuclideanDistance.java b/src/main/java/org/apache/commons/math4/ml/distance/EuclideanDistance.java
index c42d02e..7567f61 100644
--- a/src/main/java/org/apache/commons/math4/ml/distance/EuclideanDistance.java
+++ b/src/main/java/org/apache/commons/math4/ml/distance/EuclideanDistance.java
@@ -16,6 +16,7 @@
  */
 package org.apache.commons.math4.ml.distance;
 
+import org.apache.commons.math4.exception.DimensionMismatchException;
 import org.apache.commons.math4.util.MathArrays;
 
 /**
@@ -30,7 +31,8 @@ public class EuclideanDistance implements DistanceMeasure {
 
     /** {@inheritDoc} */
     @Override
-    public double compute(double[] a, double[] b) {
+    public double compute(double[] a, double[] b)
+    throws DimensionMismatchException {
         return MathArrays.distance(a, b);
     }
 

http://git-wip-us.apache.org/repos/asf/commons-math/blob/5ca0a1c3/src/main/java/org/apache/commons/math4/ml/distance/ManhattanDistance.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math4/ml/distance/ManhattanDistance.java b/src/main/java/org/apache/commons/math4/ml/distance/ManhattanDistance.java
index b7e1938..a538570 100644
--- a/src/main/java/org/apache/commons/math4/ml/distance/ManhattanDistance.java
+++ b/src/main/java/org/apache/commons/math4/ml/distance/ManhattanDistance.java
@@ -16,6 +16,7 @@
  */
 package org.apache.commons.math4.ml.distance;
 
+import org.apache.commons.math4.exception.DimensionMismatchException;
 import org.apache.commons.math4.util.MathArrays;
 
 /**
@@ -30,7 +31,8 @@ public class ManhattanDistance implements DistanceMeasure {
 
     /** {@inheritDoc} */
     @Override
-    public double compute(double[] a, double[] b) {
+    public double compute(double[] a, double[] b)
+    throws DimensionMismatchException {
         return MathArrays.distance1(a, b);
     }
 

http://git-wip-us.apache.org/repos/asf/commons-math/blob/5ca0a1c3/src/main/java/org/apache/commons/math4/util/MathArrays.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math4/util/MathArrays.java b/src/main/java/org/apache/commons/math4/util/MathArrays.java
index 325f71b..86f67fd 100644
--- a/src/main/java/org/apache/commons/math4/util/MathArrays.java
+++ b/src/main/java/org/apache/commons/math4/util/MathArrays.java
@@ -194,8 +194,11 @@ public class MathArrays {
      * @param p1 the first point
      * @param p2 the second point
      * @return the L<sub>1</sub> distance between the two points
+     * @throws DimensionMismatchException if the array lengths differ.
      */
-    public static double distance1(double[] p1, double[] p2) {
+    public static double distance1(double[] p1, double[] p2)
+    throws DimensionMismatchException {
+        checkEqualLength(p1, p2);
         double sum = 0;
         for (int i = 0; i < p1.length; i++) {
             sum += FastMath.abs(p1[i] - p2[i]);
@@ -209,13 +212,16 @@ public class MathArrays {
      * @param p1 the first point
      * @param p2 the second point
      * @return the L<sub>1</sub> distance between the two points
+     * @throws DimensionMismatchException if the array lengths differ.
      */
-    public static int distance1(int[] p1, int[] p2) {
-      int sum = 0;
-      for (int i = 0; i < p1.length; i++) {
-          sum += FastMath.abs(p1[i] - p2[i]);
-      }
-      return sum;
+    public static int distance1(int[] p1, int[] p2)
+    throws DimensionMismatchException {
+        checkEqualLength(p1, p2);
+        int sum = 0;
+        for (int i = 0; i < p1.length; i++) {
+            sum += FastMath.abs(p1[i] - p2[i]);
+        }
+        return sum;
     }
 
     /**
@@ -224,8 +230,11 @@ public class MathArrays {
      * @param p1 the first point
      * @param p2 the second point
      * @return the L<sub>2</sub> distance between the two points
+     * @throws DimensionMismatchException if the array lengths differ.
      */
-    public static double distance(double[] p1, double[] p2) {
+    public static double distance(double[] p1, double[] p2)
+    throws DimensionMismatchException {
+        checkEqualLength(p1, p2);
         double sum = 0;
         for (int i = 0; i < p1.length; i++) {
             final double dp = p1[i] - p2[i];
@@ -251,8 +260,11 @@ public class MathArrays {
      * @param p1 the first point
      * @param p2 the second point
      * @return the L<sub>2</sub> distance between the two points
+     * @throws DimensionMismatchException if the array lengths differ.
      */
-    public static double distance(int[] p1, int[] p2) {
+    public static double distance(int[] p1, int[] p2)
+    throws DimensionMismatchException {
+      checkEqualLength(p1, p2);
       double sum = 0;
       for (int i = 0; i < p1.length; i++) {
           final double dp = p1[i] - p2[i];
@@ -267,8 +279,11 @@ public class MathArrays {
      * @param p1 the first point
      * @param p2 the second point
      * @return the L<sub>&infin;</sub> distance between the two points
+     * @throws DimensionMismatchException if the array lengths differ.
      */
-    public static double distanceInf(double[] p1, double[] p2) {
+    public static double distanceInf(double[] p1, double[] p2)
+    throws DimensionMismatchException {
+        checkEqualLength(p1, p2);
         double max = 0;
         for (int i = 0; i < p1.length; i++) {
             max = FastMath.max(max, FastMath.abs(p1[i] - p2[i]));
@@ -282,8 +297,11 @@ public class MathArrays {
      * @param p1 the first point
      * @param p2 the second point
      * @return the L<sub>&infin;</sub> distance between the two points
+     * @throws DimensionMismatchException if the array lengths differ.
      */
-    public static int distanceInf(int[] p1, int[] p2) {
+    public static int distanceInf(int[] p1, int[] p2)
+    throws DimensionMismatchException {
+        checkEqualLength(p1, p2);
         int max = 0;
         for (int i = 0; i < p1.length; i++) {
             max = FastMath.max(max, FastMath.abs(p1[i] - p2[i]));
@@ -399,6 +417,42 @@ public class MathArrays {
         checkEqualLength(a, b, true);
     }
 
+
+    /**
+     * Check that both arrays have the same length.
+     *
+     * @param a Array.
+     * @param b Array.
+     * @param abort Whether to throw an exception if the check fails.
+     * @return {@code true} if the arrays have the same length.
+     * @throws DimensionMismatchException if the lengths differ and
+     * {@code abort} is {@code true}.
+     */
+    public static boolean checkEqualLength(int[] a,
+                                           int[] b,
+                                           boolean abort) {
+        if (a.length == b.length) {
+            return true;
+        } else {
+            if (abort) {
+                throw new DimensionMismatchException(a.length, b.length);
+            }
+            return false;
+        }
+    }
+
+    /**
+     * Check that both arrays have the same length.
+     *
+     * @param a Array.
+     * @param b Array.
+     * @throws DimensionMismatchException if the lengths differ.
+     */
+    public static void checkEqualLength(int[] a,
+                                        int[] b) {
+        checkEqualLength(a, b, true);
+    }
+    
     /**
      * Check that the given array is sorted.
      *
@@ -886,10 +940,8 @@ public class MathArrays {
      */
     public static double linearCombination(final double[] a, final double[] b)
         throws DimensionMismatchException {
+        checkEqualLength(a, b);
         final int len = a.length;
-        if (len != b.length) {
-            throw new DimensionMismatchException(len, b.length);
-        }
 
         if (len == 1) {
             // Revert to scalar multiplication.
@@ -1764,9 +1816,6 @@ public class MathArrays {
         }
 
         checkEqualLength(weights, values);
-        if (weights.length != values.length) {
-            throw new DimensionMismatchException(weights.length, values.length);
-        }
 
         boolean containsPositiveWeight = false;
         for (int i = begin; i < begin + length; i++) {


Mime
View raw message