mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tdunn...@apache.org
Subject svn commit: r1164337 - in /mahout/trunk/core/src: main/java/org/apache/mahout/classifier/discriminative/ main/java/org/apache/mahout/math/hadoop/stochasticsvd/ test/java/org/apache/mahout/math/hadoop/stochasticsvd/
Date Fri, 02 Sep 2011 03:05:32 GMT
Author: tdunning
Date: Fri Sep  2 03:05:32 2011
New Revision: 1164337

URL: http://svn.apache.org/viewvc?rev=1164337&view=rev
Log:
MAHOUT-790 - Iterator may return zeros.  Upper triangular rewrite.  Ignore iterator test for
now.

MAHOUT-790 - Clean javadoc

MAHOUT-790 - Narrow exception declaration.

Added:
    mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/stochasticsvd/UpperTriangularTest.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/GivensThinSolver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/UpperTriangular.java
    mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/stochasticsvd/LocalSSVDSolverDenseTest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java?rev=1164337&r1=1164336&r2=1164337&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java
Fri Sep  2 03:05:32 2011
@@ -71,7 +71,9 @@ public class WinnowTrainer extends Linea
       Iterator<Vector.Element> iter = updateVector.iterateNonZero();
       while (iter.hasNext()) {
         Vector.Element element = iter.next();
-        model.timesDelta(element.index(), element.get());
+        if (element.get() != 0) {
+          model.timesDelta(element.index(), element.get());
+        }
       }
     } else {
       // case two
@@ -80,7 +82,9 @@ public class WinnowTrainer extends Linea
       Iterator<Vector.Element> iter = updateVector.iterateNonZero();
       while (iter.hasNext()) {
         Vector.Element element = iter.next();
-        model.timesDelta(element.index(), element.get());
+        if (element.get() != 0) {
+          model.timesDelta(element.index(), element.get());
+        }
       }
     }
     log.info(model.toString());

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/GivensThinSolver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/GivensThinSolver.java?rev=1164337&r1=1164336&r2=1164337&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/GivensThinSolver.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/GivensThinSolver.java
Fri Sep  2 03:05:32 2011
@@ -219,7 +219,8 @@ public class GivensThinSolver {
   }
 
   private double[] getRRow(int row) {
-    return mR[(row += rStartRow) >= n ? row - n : row];
+    row += rStartRow;
+    return mR[row >= n ? row - n : row];
   }
 
   private void setRRow(int row, double[] rrow) {
@@ -235,7 +236,7 @@ public class GivensThinSolver {
   public UpperTriangular getRTilde() {
     UpperTriangular packedR = new UpperTriangular(n);
     for (int i = 0; i < n; i++) {
-      packedR.assignRow(i, getRRow(i));
+      packedR.assignNonZeroElementsInRow(i, getRRow(i));
     }
     return packedR;
   }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/UpperTriangular.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/UpperTriangular.java?rev=1164337&r1=1164336&r2=1164337&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/UpperTriangular.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/UpperTriangular.java
Fri Sep  2 03:05:32 2011
@@ -18,7 +18,10 @@
 package org.apache.mahout.math.hadoop.stochasticsvd;
 
 import org.apache.mahout.math.AbstractMatrix;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.IndexException;
 import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixView;
 import org.apache.mahout.math.Vector;
 
 /**
@@ -48,18 +51,19 @@ public class UpperTriangular extends Abs
     this.n = n;
   }
 
-  public UpperTriangular(Vector data) {
-    this((int) Math.round((-1 + Math.sqrt(1 + 8 * data.size())) / 2), data);
-  }
-
   public UpperTriangular(double[] data, boolean shallow) {
-    this((int) Math.round((-1 + Math.sqrt(1 + 8 * data.length)) / 2), data, shallow);
+    this(data != null ? data.length : 0, elementsToMatrixSize(data != null ? data.length
: 0));
+    if (data == null) {
+      throw new IllegalArgumentException("data");
+    }
+    values = shallow ? data : data.clone();
   }
 
-  private UpperTriangular(int rows, Vector data) {
-    super(rows, rows);
+  public UpperTriangular(Vector data) {
+    this(data.size(), elementsToMatrixSize(data.size()));
+
     values = new double[n * (n + 1) / 2];
-    int n = data.size();
+    n = data.size();
     // if ( data instanceof DenseVector )
     // ((DenseVector)data).
     // system.arraycopy would've been much faster, but this way it's a drag
@@ -69,12 +73,13 @@ public class UpperTriangular extends Abs
     }
   }
 
-  private UpperTriangular(int rows, double[] data, boolean shallow) {
+  private UpperTriangular(int n, int rows) {
     super(rows, rows);
-    if (data == null) {
-      throw new IllegalArgumentException("data");
-    }
-    values = shallow ? data : data.clone();
+    this.n = n;
+  }
+
+  private static int elementsToMatrixSize(int size) {
+    return (int) Math.round((-1 + Math.sqrt(1 + 8 * size)) / 2);
   }
 
   // copy-constructor
@@ -84,13 +89,25 @@ public class UpperTriangular extends Abs
 
   @Override
   public Matrix assignColumn(int column, Vector other) {
-    throw new UnsupportedOperationException();
+    if (columnSize() != other.size()) {
+      throw new IndexException(columnSize(), other.size());
+    }
+    if (other.viewPart(column + 1, other.size() - column - 1).norm(1) > 1e-14) {
+      throw new IllegalArgumentException("Cannot set lower portion of triangular matrix to
non-zero");
+    }
+    for (Vector.Element element : other.viewPart(0, column)) {
+      setQuick(element.index(), column, element.get());
+    }
+    return this;
   }
 
   @Override
   public Matrix assignRow(int row, Vector other) {
+    if (columnSize() != other.size()) {
+      throw new IndexException(numCols(), other.size());
+    }
     for (int i = 0; i < row; i++) {
-      if (other.getQuick(i) > EPSILON) {
+      if (Math.abs(other.getQuick(i)) > EPSILON) {
         throw new IllegalArgumentException("non-triangular source");
       }
     }
@@ -100,7 +117,7 @@ public class UpperTriangular extends Abs
     return this;
   }
 
-  public Matrix assignRow(int row, double[] other) {
+  public Matrix assignNonZeroElementsInRow(int row, double[] other) {
     System.arraycopy(other, row, values, getL(row, row), n - row);
     return this;
   }
@@ -110,21 +127,24 @@ public class UpperTriangular extends Abs
     if (row > column) {
       return 0;
     }
-    return values[getL(row, column)];
+    int i = getL(row, column);
+    return values[i];
   }
 
   private int getL(int row, int col) {
-    return (((n << 1) - row + 1) * row >> 1) + col - row;
+    // each row starts with some zero elements that we don't store.
+    // this accumulates an offset of (row+1)*row/2
+    return col + row * numCols() - (row + 1) * row / 2;
   }
 
   @Override
   public Matrix like() {
-    throw new UnsupportedOperationException();
+    return like(rowSize(), columnSize());
   }
 
   @Override
   public Matrix like(int rows, int columns) {
-    throw new UnsupportedOperationException();
+    return new DenseMatrix(rows, columns);
   }
 
   @Override
@@ -139,7 +159,7 @@ public class UpperTriangular extends Abs
 
   @Override
   public Matrix viewPart(int[] offset, int[] size) {
-    throw new UnsupportedOperationException();
+    return new MatrixView(this, offset, size);
   }
 
   double[] getData() {

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/stochasticsvd/LocalSSVDSolverDenseTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/stochasticsvd/LocalSSVDSolverDenseTest.java?rev=1164337&r1=1164336&r2=1164337&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/stochasticsvd/LocalSSVDSolverDenseTest.java
(original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/stochasticsvd/LocalSSVDSolverDenseTest.java
Fri Sep  2 03:05:32 2011
@@ -19,6 +19,7 @@ package org.apache.mahout.math.hadoop.st
 
 import java.io.Closeable;
 import java.io.File;
+import java.io.IOException;
 import java.util.Deque;
 import java.util.LinkedList;
 import java.util.Random;
@@ -57,7 +58,7 @@ public class LocalSSVDSolverDenseTest ex
   private static final double s_epsilon = 1.0E-10d;
 
   @Test
-  public void testSSVDSolver() throws Exception {
+  public void testSSVDSolver() throws IOException {
 
     Configuration conf = new Configuration();
     conf.set("mapred.job.tracker", "local");

Added: mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/stochasticsvd/UpperTriangularTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/stochasticsvd/UpperTriangularTest.java?rev=1164337&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/stochasticsvd/UpperTriangularTest.java
(added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/stochasticsvd/UpperTriangularTest.java
Fri Sep  2 03:05:32 2011
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.function.Functions;
+import org.junit.Test;
+
+public class UpperTriangularTest extends MahoutTestCase {
+  @Test
+  public void testBasics() {
+    Matrix a = new UpperTriangular(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, false);
+    assertEquals(0, a.viewDiagonal().minus(new DenseVector(new double[]{1, 5, 8, 10})).norm(1),
1e-10);
+    assertEquals(0, a.viewPart(0, 3, 1, 3).viewDiagonal().minus(new DenseVector(new double[]{2,
6, 9})).norm(1), 1e-10);
+    assertEquals(4, a.get(0, 3), 1e-10);
+    print(a);
+    Matrix m = new DenseMatrix(4, 4).assign(a);
+    assertEquals(0, m.minus(a).aggregate(Functions.PLUS, Functions.ABS), 1e-10);
+    print(m);
+
+    assertEquals(0, m.transpose().times(m).minus(a.transpose().times(a)).aggregate(Functions.PLUS,
Functions.ABS), 1e-10);
+    assertEquals(0, m.plus(m).minus(a.plus(a)).aggregate(Functions.PLUS, Functions.ABS),
1e-10);
+  }
+
+  private void print(Matrix m) {
+    for (int i = 0; i < m.rowSize(); i++) {
+      for (int j = 0; j < m.columnSize(); j++) {
+        if (Math.abs(m.get(i, j)) > 1e-10) {
+          System.out.printf("%10.3f ", m.get(i, j));
+        } else {
+          System.out.printf("%10s ", (i + j) % 3 == 0 ? "." : "");
+        }
+      }
+      System.out.printf("\n");
+    }
+    System.out.printf("\n");
+  }
+}



Mime
View raw message