mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeast...@apache.org
Subject svn commit: r774521 - in /lucene/mahout/trunk/core/src: main/java/org/apache/mahout/utils/CosineDistanceMeasure.java test/java/org/apache/mahout/utils/CosineDistanceMeasureTest.java
Date Wed, 13 May 2009 20:42:22 GMT
Author: jeastman
Date: Wed May 13 20:42:22 2009
New Revision: 774521

URL: http://svn.apache.org/viewvc?rev=774521&view=rev
Log:
- committing MAHOUT-109, CosineDistanceMeasure with one change:
   - removed abstract from test class definition

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/CosineDistanceMeasure.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/utils/CosineDistanceMeasureTest.java

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/CosineDistanceMeasure.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/CosineDistanceMeasure.java?rev=774521&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/CosineDistanceMeasure.java
(added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/CosineDistanceMeasure.java
Wed May 13 20:42:22 2009
@@ -0,0 +1,87 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils;
+
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.mahout.matrix.CardinalityException;
+import org.apache.mahout.matrix.Vector;
+import org.apache.mahout.utils.parameters.Parameter;
+
+import java.util.Collection;
+import java.util.Collections;
+
+/**
+ * This class implements a cosine distance metric by dividing the dot product
+ * of two vectors by the product of their lengths
+ */
+public class CosineDistanceMeasure implements DistanceMeasure {
+
+  @Override
+  public void configure(JobConf job) {
+    // nothing to do
+  }
+
+  @Override
+  public Collection<Parameter<?>> getParameters() {
+    return Collections.emptyList();
+  }
+
+  @Override
+  public void createParameters(String prefix, JobConf jobConf) {
+    // nothing to do
+  }
+
+  public static double distance(double[] p1, double[] p2) {
+    double dotProduct = 0.0;
+    double lengthSquaredp1 = 0.0;
+    double lengthSquaredp2 = 0.0;
+    for (int i = 0; i < p1.length; i++) { 
+      lengthSquaredp1 += p1[i] * p1[i];
+      lengthSquaredp2 += p2[i] * p2[i];
+      dotProduct += p1[i] * p2[i];
+    }
+    double denominator = Math.sqrt(lengthSquaredp1) * Math.sqrt(lengthSquaredp2);
+    
+    // correct for floating-point rounding errors
+    if(denominator < dotProduct)
+      denominator = dotProduct;
+    
+    return 1.0 - (dotProduct / denominator);
+  }
+
+  @Override
+  public double distance(Vector v1, Vector v2) {
+    if (v1.cardinality() != v2.cardinality())
+      throw new CardinalityException();
+	  double lengthSquaredv1 = 0.0;
+	  double lengthSquaredv2 = 0.0;
+	  for (int i = 0; i < v1.cardinality(); i++) {
+	    lengthSquaredv1 += v1.getQuick(i) * v1.getQuick(i);
+	    lengthSquaredv2 += v2.getQuick(i) * v2.getQuick(i);
+	  }
+	  double dotProduct = v1.dot(v2);
+	  double denominator = Math.sqrt(lengthSquaredv1) * Math.sqrt(lengthSquaredv2);
+	  
+	  // correct for floating-point rounding errors
+    if(denominator < dotProduct)
+      denominator = dotProduct;
+
+	  return 1.0 - (dotProduct / denominator);
+  }
+  
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/utils/CosineDistanceMeasureTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/utils/CosineDistanceMeasureTest.java?rev=774521&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/utils/CosineDistanceMeasureTest.java
(added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/utils/CosineDistanceMeasureTest.java
Wed May 13 20:42:22 2009
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils;
+
+import junit.framework.TestCase;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+
+
+public class CosineDistanceMeasureTest extends TestCase {
+
+  public void testMeasure() {
+
+    DistanceMeasure distanceMeasure = new CosineDistanceMeasure();
+
+    Vector[] vectors = {
+        new DenseVector(new double[]{1, 0, 0, 0, 0, 0}),
+        new DenseVector(new double[]{1, 1, 1, 0, 0, 0}),
+        new DenseVector(new double[]{1, 1, 1, 1, 1, 1})
+    };
+
+    double[][] distanceMatrix = new double[3][3];
+
+    for (int a = 0; a < 3; a++) {
+      for (int b = 0; b < 3; b++) {
+        distanceMatrix[a][b] = distanceMeasure.distance(vectors[a], vectors[b]);
+      }
+    }
+
+    assertEquals(0.0, distanceMatrix[0][0]);
+    assertTrue(distanceMatrix[0][0] < distanceMatrix[0][1]);
+    assertTrue(distanceMatrix[0][1] < distanceMatrix[0][2]);
+
+    assertEquals(0.0, distanceMatrix[1][1]);
+    assertTrue(distanceMatrix[1][0] > distanceMatrix[1][1]);
+    assertTrue(distanceMatrix[1][2] < distanceMatrix[1][0]);
+
+    assertEquals(0.0, distanceMatrix[2][2]);
+    assertTrue(distanceMatrix[2][0] > distanceMatrix[2][1]);
+    assertTrue(distanceMatrix[2][1] > distanceMatrix[2][2]);
+
+
+  }
+
+}



Mime
View raw message