mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeast...@apache.org
Subject svn commit: r786675 - in /lucene/mahout/trunk/core/src: main/java/org/apache/mahout/matrix/ test/java/org/apache/mahout/matrix/
Date Fri, 19 Jun 2009 21:45:46 GMT
Author: jeastman
Date: Fri Jun 19 21:45:46 2009
New Revision: 786675

URL: http://svn.apache.org/viewvc?rev=786675&view=rev
Log:
- MAHOUT-65: made Matrix be Writable and implemented all methods
- added unit tests thereof; all run

Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/DenseMatrix.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/MatrixView.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseColumnMatrix.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseRowMatrix.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/VectorView.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestMatrixView.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestVectorWritable.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java?rev=786675&r1=786674&r2=786675&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java Fri
Jun 19 21:45:46 2009
@@ -17,6 +17,9 @@
 
 package org.apache.mahout.matrix;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
 import java.lang.reflect.Type;
 import java.util.HashMap;
 import java.util.Map;
@@ -32,9 +35,9 @@
 public abstract class AbstractMatrix implements Matrix {
 
   private Map<String, Integer> columnLabelBindings;
-  
+
   private Map<String, Integer> rowLabelBindings;
-  
+
   @Override
   public double get(String rowLabel, String columnLabel) throws IndexException,
       UnboundLabelException {
@@ -97,7 +100,7 @@
     if (columnLabelBindings == null)
       columnLabelBindings = new HashMap<String, Integer>();
     columnLabelBindings.put(columnLabel, column);
-    
+
     set(row, column, value);
   }
 
@@ -359,4 +362,67 @@
     return result;
   }
 
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    // read the label bindings
+    int colSize = in.readInt();
+    if (colSize > 0) {
+      columnLabelBindings = new HashMap<String, Integer>();
+      for (int i = 0; i < colSize; i++)
+        columnLabelBindings.put(in.readUTF(), in.readInt());
+    }
+    int rowSize = in.readInt();
+    if (rowSize > 0) {
+      rowLabelBindings = new HashMap<String, Integer>();
+      for (int i = 0; i < rowSize; i++)
+        rowLabelBindings.put(in.readUTF(), in.readInt());
+    }
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    // write the label bindings
+    if (columnLabelBindings == null)
+      out.writeInt(0);
+    else {
+      out.writeInt(columnLabelBindings.size());
+      for (String key : columnLabelBindings.keySet()) {
+        out.writeUTF(key);
+        out.writeInt(columnLabelBindings.get(key));
+      }
+    }
+    if (rowLabelBindings == null)
+      out.writeInt(0);
+    else {
+      out.writeInt(rowLabelBindings.size());
+      for (String key : rowLabelBindings.keySet()) {
+        out.writeUTF(key);
+        out.writeInt(rowLabelBindings.get(key));
+      }
+    }
+  }
+
+  protected static Matrix readMatrix(DataInput in) throws IOException {
+    String matrixClassName = in.readUTF();
+    Matrix matrix;
+    try {
+      matrix = Class.forName(matrixClassName).asSubclass(Matrix.class)
+          .newInstance();
+    } catch (ClassNotFoundException e) {
+      throw new RuntimeException(e);
+    } catch (IllegalAccessException e) {
+      throw new RuntimeException(e);
+    } catch (InstantiationException e) {
+      throw new RuntimeException(e);
+    }
+    matrix.readFields(in);
+    return matrix;
+  }
+
+  protected static void writeMatrix(DataOutput out, Matrix matrix)
+      throws IOException {
+    out.writeUTF(matrix.getClass().getName());
+    matrix.write(out);
+  }
+
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java?rev=786675&r1=786674&r2=786675&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractVector.java Fri
Jun 19 21:45:46 2009
@@ -17,6 +17,9 @@
 
 package org.apache.mahout.matrix;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
 import java.lang.reflect.Type;
 import java.util.HashMap;
 import java.util.Map;
@@ -37,8 +40,8 @@
    * transient so that it will not be serialized with each vector instance.
    */
   private transient Map<String, Integer> bindings;
-  protected String name;
 
+  protected String name;
 
   protected AbstractVector() {
   }
@@ -336,10 +339,13 @@
   public void setName(String name) {
     this.name = name;
 
-  }/* (non-Javadoc)
-   * @see org.apache.mahout.matrix.Vector#asFormatString()
-   */
-  public String asFormatString(){
+  }/*
+    * (non-Javadoc)
+    * 
+    * @see org.apache.mahout.matrix.Vector#asFormatString()
+    */
+
+  public String asFormatString() {
     Type vectorType = new TypeToken<Vector>() {
     }.getType();
     GsonBuilder builder = new GsonBuilder();
@@ -349,9 +355,9 @@
   }
 
   /**
-   * Compare whether two Vector implementations have the same elements, regardless of the
-   * implementation and name. Two Vectors are equivalent if they have the same cardinality
-   * and all of their values are the same.
+   * Compare whether two Vector implementations have the same elements,
+   * regardless of the implementation and name. Two Vectors are equivalent if
+   * they have the same cardinality and all of their values are the same.
    * <p/>
    * Does not compare {@link Vector#getName()}.
    * 
@@ -360,12 +366,13 @@
    * @param right The right hand Vector
    * @return true if the two Vectors have the same cardinality and the same
    *         values
-   *
+   * 
    * @see #strictEquivalence(Vector, Vector)
-   * @see Vector#equals(Object) 
+   * @see Vector#equals(Object)
    */
   public static boolean equivalent(Vector left, Vector right) {
-    if (left == right) return true;
+    if (left == right)
+      return true;
     boolean result = true;
     int leftCardinality = left.size();
     if (leftCardinality == right.size()) {
@@ -383,26 +390,29 @@
 
   /**
    * Compare whether two Vector implementations are the same, including the
-   * underlying implementation. Two Vectors are the same if they have the same cardinality,
same name
-   * and all of their values are the same.
-   *
-   *
+   * underlying implementation. Two Vectors are the same if they have the same
+   * cardinality, same name and all of their values are the same.
+   * 
+   * 
    * @param left The left hand Vector to compare
    * @param right The right hand Vector
    * @return true if the two Vectors have the same cardinality and the same
    *         values
    */
   public static boolean strictEquivalence(Vector left, Vector right) {
-    if (left == right) return true;
-    if (!(left.getClass().equals(right.getClass()))) return false;
+    if (left == right)
+      return true;
+    if (!(left.getClass().equals(right.getClass())))
+      return false;
     String leftName = left.getName();
     String rightName = right.getName();
-    if (leftName != null && rightName != null && !leftName.equals(rightName)){
+    if (leftName != null && rightName != null && !leftName.equals(rightName))
{
       return false;
-    } else if ((leftName != null && rightName == null) || (rightName != null &&
leftName == null)){
+    } else if ((leftName != null && rightName == null)
+        || (rightName != null && leftName == null)) {
       return false;
     }
-    
+
     boolean result = true;
     int leftCardinality = left.size();
     if (leftCardinality == right.size()) {
@@ -418,8 +428,9 @@
     return result;
   }
 
-
-  /* (non-Javadoc)
+  /*
+   * (non-Javadoc)
+   * 
    * @see org.apache.mahout.matrix.Vector#get(java.lang.String)
    */
   @Override
@@ -432,7 +443,9 @@
     return get(index);
   }
 
-  /* (non-Javadoc)
+  /*
+   * (non-Javadoc)
+   * 
    * @see org.apache.mahout.matrix.Vector#getLabelBindings()
    */
   @Override
@@ -440,7 +453,9 @@
     return bindings;
   }
 
-  /* (non-Javadoc)
+  /*
+   * (non-Javadoc)
+   * 
    * @see org.apache.mahout.matrix.Vector#set(java.lang.String, double)
    */
   @Override
@@ -454,7 +469,9 @@
     set(index, value);
   }
 
-  /* (non-Javadoc)
+  /*
+   * (non-Javadoc)
+   * 
    * @see org.apache.mahout.matrix.Vector#setLabelBindings(java.util.Map)
    */
   @Override
@@ -462,7 +479,9 @@
     this.bindings = bindings;
   }
 
-  /* (non-Javadoc)
+  /*
+   * (non-Javadoc)
+   * 
    * @see org.apache.mahout.matrix.Vector#set(java.lang.String, int, double)
    */
   @Override
@@ -473,4 +492,43 @@
     set(index, value);
   }
 
+  /**
+   * Read and return a vector from the input
+   * 
+   * @param in
+   * @return
+   * @throws IOException
+   */
+  protected static Vector readVector(DataInput in) throws IOException {
+    String vectorClassName = in.readUTF();
+    Vector vector;
+    try {
+      vector = Class.forName(vectorClassName).asSubclass(Vector.class)
+          .newInstance();
+    } catch (ClassNotFoundException e) {
+      throw new RuntimeException(e);
+    } catch (IllegalAccessException e) {
+      throw new RuntimeException(e);
+    } catch (InstantiationException e) {
+      throw new RuntimeException(e);
+    }
+    vector.readFields(in);
+    return vector;
+  }
+
+  /**
+   * Write the vector to the output
+   * 
+   * @param out
+   * @param vector
+   * @throws IOException
+   */
+  protected static void writeVector(DataOutput out, Vector vector)
+      throws IOException {
+    String vectorClassName = vector.getClass().getName();
+    out.writeUTF(vectorClassName);
+    vector.write(out);
+
+  }
+
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/DenseMatrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/DenseMatrix.java?rev=786675&r1=786674&r2=786675&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/DenseMatrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/DenseMatrix.java Fri Jun
19 21:45:46 2009
@@ -17,6 +17,9 @@
 
 package org.apache.mahout.matrix;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
 
 /**
  * Matrix of doubles implemented using a 2-d array
@@ -108,12 +111,6 @@
   }
 
   @Override
-  public double[][] toArray() {
-    DenseMatrix result = new DenseMatrix(values);
-    return result.values;
-  }
-
-  @Override
   public Matrix viewPart(int[] offset, int[] size) {
     if (size[ROW] > rowSize() || size[COL] > columnSize())
       throw new CardinalityException();
@@ -158,4 +155,25 @@
     return new DenseVector(values[row]);
   }
 
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    super.readFields(in);
+    int rows = in.readInt();
+    int columns = in.readInt();
+    this.values = new double[rows][columns];
+    for (int row = 0; row < rows; row++)
+      for (int column = 0; column < columns; column++)
+        this.values[row][column] = in.readDouble();
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    super.write(out);
+    out.writeInt(rowSize());
+    out.writeInt(columnSize());
+    for (double[] row : values)
+      for (double value : row)
+        out.writeDouble(value);
+  }
+
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java?rev=786675&r1=786674&r2=786675&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java Fri Jun 19
21:45:46 2009
@@ -19,10 +19,12 @@
 
 import java.util.Map;
 
+import org.apache.hadoop.io.Writable;
+
 /**
  * The basic interface including numerous convenience functions
  */
-public interface Matrix extends Cloneable {
+public interface Matrix extends Cloneable, Writable {
 
   /**
    * @return a formatted String suitable for output
@@ -270,13 +272,6 @@
   Matrix transpose();
 
   /**
-   * Return the element of the recipient as a double[]
-   * 
-   * @return a double[][]
-   */
-  double[][] toArray();
-
-  /**
    * Return a new matrix containing the subset of the recipient
    * 
    * @param offset an int[2] offset into the receiver

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/MatrixView.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/MatrixView.java?rev=786675&r1=786674&r2=786675&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/MatrixView.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/MatrixView.java Fri Jun
19 21:45:46 2009
@@ -17,6 +17,9 @@
 
 package org.apache.mahout.matrix;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
 
 /**
  * Implements subset view of a Matrix
@@ -38,12 +41,9 @@
   /**
    * Construct a view of the matrix with given offset and cardinality
    * 
-   * @param matrix
-   *            an underlying Matrix
-   * @param offset
-   *            the int[2] offset into the underlying matrix
-   * @param cardinality
-   *            the int[2] cardinality of the view
+   * @param matrix an underlying Matrix
+   * @param offset the int[2] offset into the underlying matrix
+   * @param cardinality the int[2] cardinality of the view
    */
   public MatrixView(Matrix matrix, int[] offset, int[] cardinality) {
     this.matrix = matrix;
@@ -88,16 +88,6 @@
   }
 
   @Override
-  public double[][] toArray() {
-    double[][] result = new double[cardinality[ROW]][cardinality[COL]];
-    for (int row = ROW; row < cardinality[ROW]; row++)
-      for (int col = ROW; col < cardinality[COL]; col++)
-        result[row][col] = matrix
-            .getQuick(offset[ROW] + row, offset[COL] + col);
-    return result;
-  }
-
-  @Override
   public Matrix viewPart(int[] offset, int[] size) {
     if (size[ROW] > cardinality[ROW] || size[COL] > cardinality[COL])
       throw new CardinalityException();
@@ -153,4 +143,24 @@
     return new VectorView(matrix.getRow(row + offset[ROW]), offset[COL],
         cardinality[COL]);
   }
+
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    super.readFields(in);
+    int[] o = { in.readInt(), in.readInt() };
+    this.offset = o;
+    int[] c = { in.readInt(), in.readInt() };
+    this.cardinality = c;
+    this.matrix = readMatrix(in);
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    super.write(out);
+    out.writeInt(offset[ROW]);
+    out.writeInt(offset[COL]);
+    out.writeInt(cardinality[ROW]);
+    out.writeInt(cardinality[COL]);
+    writeMatrix(out, this.matrix);
+  }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseColumnMatrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseColumnMatrix.java?rev=786675&r1=786674&r2=786675&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseColumnMatrix.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseColumnMatrix.java
Fri Jun 19 21:45:46 2009
@@ -17,6 +17,9 @@
 
 package org.apache.mahout.matrix;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
 
 /**
  * sparse matrix with general element values whose columns are accessible
@@ -110,16 +113,8 @@
     int[] result = new int[2];
     result[COL] = columns.length;
     for (int col = 0; col < cardinality[COL]; col++)
-      result[ROW] = Math.max(result[ROW], columns[col].getNumNondefaultElements());
-    return result;
-  }
-
-  @Override
-  public double[][] toArray() {
-    double[][] result = new double[cardinality[ROW]][cardinality[COL]];
-    for (int row = 0; row < cardinality[ROW]; row++)
-      for (int col = 0; col < cardinality[COL]; col++)
-        result[row][col] = getQuick(row, col);
+      result[ROW] = Math.max(result[ROW], columns[col]
+          .getNumNondefaultElements());
     return result;
   }
 
@@ -128,8 +123,7 @@
     if (size[COL] > columns.length || size[ROW] > columns[COL].size())
       throw new CardinalityException();
     if (offset[COL] < 0 || offset[COL] + size[COL] > columns.length
-        || offset[ROW] < 0
-        || offset[ROW] + size[ROW] > columns[COL].size())
+        || offset[ROW] < 0 || offset[ROW] + size[ROW] > columns[COL].size())
       throw new IndexException();
     return new MatrixView(this, offset, size);
   }
@@ -168,4 +162,27 @@
     return new DenseVector(d);
   }
 
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    super.readFields(in);
+    int[] card = { in.readInt(), in.readInt() };
+    this.cardinality = card;
+    int colSize = in.readInt();
+    this.columns = new Vector[colSize];
+    for (int col = 0; col < colSize; col++) {
+      columns[col] = AbstractVector.readVector(in);
+    }
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    super.write(out);
+    out.writeInt(cardinality[ROW]);
+    out.writeInt(cardinality[COL]);
+    out.writeInt(columns.length);
+    for (Vector col : columns) {
+      AbstractVector.writeVector(out, col);
+    }
+  }
+
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java?rev=786675&r1=786674&r2=786675&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java Fri
Jun 19 21:45:46 2009
@@ -17,6 +17,9 @@
 
 package org.apache.mahout.matrix;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -112,16 +115,8 @@
     int[] result = new int[2];
     result[ROW] = rows.size();
     for (Map.Entry<Integer, Vector> integerVectorEntry : rows.entrySet())
-      result[COL] = Math.max(result[COL], integerVectorEntry.getValue().getNumNondefaultElements());
-    return result;
-  }
-
-  @Override
-  public double[][] toArray() {
-    double[][] result = new double[cardinality[ROW]][cardinality[COL]];
-    for (int row = 0; row < cardinality[ROW]; row++)
-      for (int col = 0; col < cardinality[COL]; col++)
-        result[row][col] = getQuick(row, col);
+      result[COL] = Math.max(result[COL], integerVectorEntry.getValue()
+          .getNumNondefaultElements());
     return result;
   }
 
@@ -182,4 +177,29 @@
     return res;
   }
 
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    super.readFields(in);
+    int[] card = { in.readInt(), in.readInt() };
+    this.cardinality = card;
+    int rowsize = in.readInt();
+    this.rows = new HashMap<Integer, Vector>();
+    for (int row = 0; row < rowsize; row++) {
+      int key = in.readInt();
+      rows.put(key, AbstractVector.readVector(in));
+    }
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    super.write(out);
+    out.writeInt(cardinality[ROW]);
+    out.writeInt(cardinality[COL]);
+    out.writeInt(rows.size());
+    for (Integer row : rows.keySet()) {
+      out.writeInt(row);
+      AbstractVector.writeVector(out, rows.get(row));
+    }
+  }
+
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseRowMatrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseRowMatrix.java?rev=786675&r1=786674&r2=786675&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseRowMatrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseRowMatrix.java Fri
Jun 19 21:45:46 2009
@@ -17,6 +17,9 @@
 
 package org.apache.mahout.matrix;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
 
 /**
  * sparse matrix with general element values whose rows are accessible quickly.
@@ -35,10 +38,8 @@
   /**
    * Construct a matrix of the given cardinality with the given rows
    * 
-   * @param cardinality
-   *            the int[2] cardinality desired
-   * @param rows
-   *            a SparseVector[] array of rows
+   * @param cardinality the int[2] cardinality desired
+   * @param rows a SparseVector[] array of rows
    */
   public SparseRowMatrix(int[] cardinality, SparseVector[] rows) {
     this.cardinality = cardinality.clone();
@@ -50,8 +51,7 @@
   /**
    * Construct a matrix of the given cardinality
    * 
-   * @param cardinality
-   *            the int[2] cardinality desired
+   * @param cardinality the int[2] cardinality desired
    */
   public SparseRowMatrix(int[] cardinality) {
     this.cardinality = cardinality.clone();
@@ -116,15 +116,6 @@
   }
 
   @Override
-  public double[][] toArray() {
-    double[][] result = new double[cardinality[ROW]][cardinality[COL]];
-    for (int row = 0; row < cardinality[ROW]; row++)
-      for (int col = 0; col < cardinality[COL]; col++)
-        result[row][col] = getQuick(row, col);
-    return result;
-  }
-
-  @Override
   public Matrix viewPart(int[] offset, int[] size) {
     if (size[ROW] > rows.length || size[COL] > rows[ROW].size())
       throw new CardinalityException();
@@ -168,4 +159,27 @@
     return rows[row];
   }
 
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    super.readFields(in);
+    int[] card = { in.readInt(), in.readInt() };
+    this.cardinality = card;
+    int rowsize = in.readInt();
+    this.rows = new Vector[rowsize];
+    for (int row = 0; row < rowsize; row++) {
+      rows[row] = AbstractVector.readVector(in);
+    }
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    super.write(out);
+    out.writeInt(cardinality[ROW]);
+    out.writeInt(cardinality[COL]);
+    out.writeInt(rows.length);
+    for (Vector row : rows) {
+      AbstractVector.writeVector(out, row);
+    }
+  }
+
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/VectorView.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/VectorView.java?rev=786675&r1=786674&r2=786675&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/VectorView.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/VectorView.java Fri Jun
19 21:45:46 2009
@@ -20,7 +20,6 @@
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
-import java.nio.charset.Charset;
 import java.util.Iterator;
 import java.util.NoSuchElementException;
 
@@ -118,6 +117,7 @@
 
   public class ViewIterator implements Iterator<Vector.Element> {
     private final Iterator<Vector.Element> it;
+
     private Vector.Element el;
 
     public ViewIterator() {
@@ -168,8 +168,8 @@
     }
 
     /**
-     * @throws UnsupportedOperationException
-     *             all the time. method not implemented.
+     * @throws UnsupportedOperationException all the time. method not
+     *         implemented.
      */
     @Override
     public void remove() {
@@ -177,46 +177,26 @@
     }
   }
 
-
   @Override
   public void write(DataOutput dataOutput) throws IOException {
-    dataOutput.writeUTF(this.name==null? "": this.name);
+    dataOutput.writeUTF(this.name == null ? "" : this.name);
     dataOutput.writeInt(offset);
     dataOutput.writeInt(cardinality);
-    String vectorClassName = vector.getClass().getName();
-    dataOutput.writeInt(vectorClassName.length() * 2);
-    dataOutput.write(vectorClassName.getBytes());
-    vector.write(dataOutput);
+    writeVector(dataOutput, vector);
   }
 
   @Override
   public void readFields(DataInput dataInput) throws IOException {
     this.name = dataInput.readUTF();
-    int offset = dataInput.readInt();
-    int cardinality = dataInput.readInt();
-    byte[] buf = new byte[dataInput.readInt()];
-    dataInput.readFully(buf);
-    String vectorClassName = new String(buf, Charset.forName("UTF-8"));
-    Vector vector;
-    try {
-      vector = Class.forName(vectorClassName).asSubclass(Vector.class).newInstance();
-    } catch (ClassNotFoundException e) {
-      throw new RuntimeException(e);
-    } catch (IllegalAccessException e) {
-      throw new RuntimeException(e);
-    } catch (InstantiationException e) {
-      throw new RuntimeException(e);
-    }
-    vector.readFields(dataInput);
-
-    this.offset = offset;
-    this.cardinality = cardinality;
-    this.vector = vector;
+    this.offset = dataInput.readInt();
+    this.cardinality = dataInput.readInt();
+    this.vector = readVector(dataInput);
   }
 
   @Override
   public boolean equals(Object o) {
-    if (this == o) return true;
+    if (this == o)
+      return true;
     return o instanceof Vector && equivalent(this, (Vector) o);
 
   }

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java?rev=786675&r1=786674&r2=786675&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java Fri Jun
19 21:45:46 2009
@@ -17,11 +17,16 @@
 
 package org.apache.mahout.matrix;
 
+import java.io.ByteArrayInputStream;
+import java.io.DataInputStream;
+import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
 
 import junit.framework.TestCase;
 
+import org.apache.hadoop.io.DataOutputBuffer;
+
 public abstract class MatrixTest extends TestCase {
 
   protected static final int ROW = AbstractMatrix.ROW;
@@ -104,15 +109,6 @@
     assertEquals("col size", values[0].length, c[COL]);
   }
 
-  public void testToArray() {
-    double[][] array = test.toArray();
-    int[] c = test.size();
-    for (int row = 0; row < c[ROW]; row++)
-      for (int col = 0; col < c[COL]; col++)
-        assertEquals("value[" + row + "][" + col + ']', values[row][col],
-            array[row][col]);
-  }
-
   public void testViewPart() {
     int[] offset = { 1, 1 };
     int[] size = { 2, 1 };
@@ -539,8 +535,8 @@
       assertTrue(true);
     }
   }
-  
-  public void testLabelBindingSerialization(){
+
+  public void testLabelBindingSerialization() {
     Matrix m = matrixFactory(new double[][] { { 1, 3, 4 }, { 5, 2, 3 },
         { 1, 4, 2 } });
     assertNull("row bindings", m.getRowLabelBindings());
@@ -560,4 +556,20 @@
     Matrix mm = AbstractMatrix.decodeMatrix(json);
     assertEquals("Fee", m.get(0, 1), mm.get("Fee", "Bar"));
   }
+
+  public void testMatrixWritable() throws IOException {
+    Matrix m = matrixFactory(new double[][] { { 1, 3, 4 }, { 5, 2, 3 },
+        { 1, 4, 2 } });
+    DataOutputBuffer out = new DataOutputBuffer();
+    m.write(out);
+    out.close();
+
+    DataInputStream in = new DataInputStream(new ByteArrayInputStream(out
+        .getData()));
+    Matrix m2 = m.like();
+    m2.readFields(in);
+    in.close();
+    assertEquals("row size", m.size()[ROW], m2.size()[ROW]);
+    assertEquals("col size", m.size()[COL], m2.size()[COL]);
+  }
 }

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestMatrixView.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestMatrixView.java?rev=786675&r1=786674&r2=786675&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestMatrixView.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestMatrixView.java Fri
Jun 19 21:45:46 2009
@@ -17,6 +17,14 @@
 
 package org.apache.mahout.matrix;
 
+import java.io.ByteArrayInputStream;
+import java.io.DataInputStream;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.hadoop.io.DataOutputBuffer;
+
 import junit.framework.TestCase;
 
 public class TestMatrixView extends TestCase {
@@ -109,15 +117,6 @@
     assertEquals("col size", values[0].length - 1, c[COL]);
   }
 
-  public void testToArray() {
-    double[][] array = test.toArray();
-    int[] c = test.size();
-    for (int row = 0; row < c[ROW]; row++)
-      for (int col = 0; col < c[COL]; col++)
-        assertEquals("value[" + row + "][" + col + ']',
-            values[row + 1][col + 1], array[row][col]);
-  }
-
   public void testViewPart() throws Exception {
     int[] offset = { 1, 1 };
     int[] size = { 2, 1 };
@@ -481,4 +480,76 @@
     }
   }
 
+  public void testMatrixWritable() throws IOException {
+    DataOutputBuffer out = new DataOutputBuffer();
+    test.write(out);
+    out.close();
+
+    DataInputStream in = new DataInputStream(new ByteArrayInputStream(out
+        .getData()));
+    Matrix m2 = test.clone();
+    m2.readFields(in);
+    in.close();
+    assertEquals("row size", test.size()[ROW], m2.size()[ROW]);
+    assertEquals("col size", test.size()[COL], m2.size()[COL]);
+  }
+
+  public void testLabelBindings() {
+    assertNull("row bindings", test.getRowLabelBindings());
+    assertNull("col bindings", test.getColumnLabelBindings());
+    Map<String, Integer> rowBindings = new HashMap<String, Integer>();
+    rowBindings.put("Fee", 0);
+    rowBindings.put("Fie", 1);
+    test.setRowLabelBindings(rowBindings);
+    assertEquals("row", rowBindings, test.getRowLabelBindings());
+    Map<String, Integer> colBindings = new HashMap<String, Integer>();
+    colBindings.put("Foo", 0);
+    colBindings.put("Bar", 1);
+    test.setColumnLabelBindings(colBindings);
+    assertEquals("row", rowBindings, test.getRowLabelBindings());
+    assertEquals("Fee", test.get(0, 1), test.get("Fee", "Bar"));
+
+    double[] newrow = { 9, 8 };
+    test.set("Fie", newrow);
+    assertEquals("FeeBar", test.get(0, 1), test.get("Fee", "Bar"));
+  }
+
+  public void testSettingLabelBindings() {
+    assertNull("row bindings", test.getRowLabelBindings());
+    assertNull("col bindings", test.getColumnLabelBindings());
+    test.set("Fee", "Foo", 1, 1, 9);
+    assertNotNull("row", test.getRowLabelBindings());
+    assertNotNull("row", test.getRowLabelBindings());
+    assertEquals("Fee", 1, test.getRowLabelBindings().get("Fee").intValue());
+    assertEquals("Foo", 1, test.getColumnLabelBindings().get("Foo").intValue());
+    assertEquals("FeeFoo", test.get(1, 1), test.get("Fee", "Foo"));
+    try {
+      test.get("Fie", "Foe");
+      fail("Expected UnboundLabelException");
+    } catch (IndexException e) {
+      fail("Expected UnboundLabelException");
+    } catch (UnboundLabelException e) {
+      assertTrue(true);
+    }
+  }
+
+  public void testLabelBindingSerialization() {
+    assertNull("row bindings", test.getRowLabelBindings());
+    assertNull("col bindings", test.getColumnLabelBindings());
+    Map<String, Integer> rowBindings = new HashMap<String, Integer>();
+    rowBindings.put("Fee", 0);
+    rowBindings.put("Fie", 1);
+    rowBindings.put("Foe", 2);
+    test.setRowLabelBindings(rowBindings);
+    assertEquals("row", rowBindings, test.getRowLabelBindings());
+    Map<String, Integer> colBindings = new HashMap<String, Integer>();
+    colBindings.put("Foo", 0);
+    colBindings.put("Bar", 1);
+    colBindings.put("Baz", 2);
+    test.setColumnLabelBindings(colBindings);
+    String json = test.asFormatString();
+    Matrix mm = AbstractMatrix.decodeMatrix(json);
+    assertEquals("Fee", test.get(0, 1), mm.get("Fee", "Bar"));
+  }
+
 }

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestVectorWritable.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestVectorWritable.java?rev=786675&r1=786674&r2=786675&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestVectorWritable.java
(original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestVectorWritable.java
Fri Jun 19 21:45:46 2009
@@ -35,13 +35,14 @@
     writable.write(out);
     out.close();
 
-    DataInputStream in = new DataInputStream(new ByteArrayInputStream(out.getData()));
+    DataInputStream in = new DataInputStream(new ByteArrayInputStream(out
+        .getData()));
     writable.readFields(in);
     in.close();
 
     assertEquals(cardinality, writable.size());
     for (int i = 0; i < cardinality; i++) {
-      assertEquals((double)i, writable.get(i));
+      assertEquals((double) i, writable.get(i));
     }
 
     in = new DataInputStream(new ByteArrayInputStream(out.getData()));
@@ -50,12 +51,14 @@
 
     assertEquals(cardinality, writable.size());
     for (int i = 0; i < cardinality; i++) {
-      assertEquals((double)i, writable.get(i));
+      assertEquals((double) i, writable.get(i));
     }
   }
 
   public void testVectors() throws Exception {
     doTest(new SparseVector(cardinality));
     doTest(new DenseVector(cardinality));
+    doTest(new VectorView(new SparseVector(cardinality + 1), 1, cardinality));
+    doTest(new VectorView(new DenseVector(cardinality + 1), 1, cardinality));
   }
 }



Mime
View raw message