mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jman...@apache.org
Subject svn commit: r1188491 - in /mahout/trunk: core/src/main/java/org/apache/mahout/math/hadoop/solver/ core/src/test/java/org/apache/mahout/math/hadoop/solver/ math/src/main/java/org/apache/mahout/math/solver/ math/src/test/java/org/apache/mahout/math/solver/
Date Tue, 25 Oct 2011 02:00:00 GMT
Author: jmannix
Date: Tue Oct 25 01:59:58 2011
New Revision: 1188491

URL: http://svn.apache.org/viewvc?rev=1188491&view=rev
Log:
MAHOUT-672 on behalf of jtraupman

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/solver/
    mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java
    mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/
    mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolver.java
    mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolverCLI.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/
    mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/LSMR.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/Preconditioner.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/
    mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/LSMRTest.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/TestConjugateGradientSolver.java

Added: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,164 @@
+package org.apache.mahout.math.hadoop.solver;
+
+import java.io.IOException;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.apache.mahout.math.solver.ConjugateGradientSolver;
+import org.apache.mahout.math.solver.Preconditioner;
+
+/**
+ * 
+ * Distributed implementation of the conjugate gradient solver. More or less, this is just the standard solver
+ * but wrapped with some methods that make it easy to run it on a DistributedRowMatrix.
+ *  
+ */
+
+public class DistributedConjugateGradientSolver extends ConjugateGradientSolver implements Tool
+{
+  private Configuration conf; 
+  private Map<String, String> parsedArgs;
+
+  /**
+   * 
+   * Runs the distributed conjugate gradient solver programmatically to solve the system (A + lambda*I)x = b.
+   * 
+   * @param inputPath      Path to the matrix A
+   * @param tempPath       Path to scratch output path, deleted after the solver completes
+   * @param numRows        Number of rows in A
+   * @param numCols        Number of columns in A
+   * @param b              Vector b
+   * @param preconditioner Optional preconditioner for the system
+   * @param maxIterations  Maximum number of iterations to run, defaults to numCols
+   * @param maxError       Maximum error tolerated in the result. If the norm of the residual falls below this, then the 
+   *                       algorithm stops and returns. 
+
+   * @return               The vector that solves the system.
+   */
+  public Vector runJob(Path inputPath, 
+                       Path tempPath, 
+                       int numRows, 
+                       int numCols, 
+                       Vector b, 
+                       Preconditioner preconditioner, 
+                       int maxIterations, 
+                       double maxError) {
+    DistributedRowMatrix matrix = new DistributedRowMatrix(inputPath, tempPath, numRows, numCols);
+    matrix.setConf(conf);
+        
+    return solve(matrix, b, preconditioner, maxIterations, maxError);
+  }
+  
+  @Override
+  public Configuration getConf()
+  {
+    return conf;
+  }
+
+  @Override
+  public void setConf(Configuration conf)
+  {
+    this.conf = conf;    
+  }
+
+  @Override
+  public int run(String[] strings) throws Exception
+  {
+    Path inputPath = new Path(parsedArgs.get("--input"));
+    Path outputPath = new Path(parsedArgs.get("--output"));
+    Path tempPath = new Path(parsedArgs.get("--tempDir"));
+    Path vectorPath = new Path(parsedArgs.get("--vector"));
+    int numRows = Integer.parseInt(parsedArgs.get("--numRows"));
+    int numCols = Integer.parseInt(parsedArgs.get("--numCols"));
+    int maxIterations = parsedArgs.containsKey("--maxIter") ? Integer.parseInt(parsedArgs.get("--maxIter")) : numCols;
+    double maxError = parsedArgs.containsKey("--maxError") 
+        ? Double.parseDouble(parsedArgs.get("--maxError")) 
+        : ConjugateGradientSolver.DEFAULT_MAX_ERROR;
+
+    Vector b = loadInputVector(vectorPath);
+    Vector x = runJob(inputPath, tempPath, numRows, numCols, b, null, maxIterations, maxError);
+    saveOutputVector(outputPath, x);
+    tempPath.getFileSystem(conf).delete(tempPath, true);
+    
+    return 0;
+  }
+  
+  public DistributedConjugateGradientSolverJob job() {
+    return new DistributedConjugateGradientSolverJob();
+  }
+  
+  private Vector loadInputVector(Path path) throws IOException {
+    FileSystem fs = path.getFileSystem(conf);
+    SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
+    IntWritable key = new IntWritable();
+    VectorWritable value = new VectorWritable();
+    
+    try {
+      if (!reader.next(key, value)) {
+        throw new IOException("Input vector file is empty.");      
+      }
+      return value.get();
+    } finally {
+      reader.close();
+    }
+  }
+  
+  private void saveOutputVector(Path path, Vector v) throws IOException {
+    FileSystem fs = path.getFileSystem(conf);
+    SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class);
+    
+    try {
+      writer.append(new IntWritable(0), new VectorWritable(v));
+    } finally {
+      writer.close();
+    }
+  }
+  
+  public class DistributedConjugateGradientSolverJob extends AbstractJob {
+    @Override
+    public void setConf(Configuration conf) {
+      DistributedConjugateGradientSolver.this.setConf(conf);
+    }
+    
+    @Override
+    public Configuration getConf() {
+      return DistributedConjugateGradientSolver.this.getConf();
+    }
+    
+    @Override
+    public int run(String[] args) throws Exception
+    {
+      addInputOption();
+      addOutputOption();
+      addOption("numRows", "nr", "Number of rows in the input matrix", true);
+      addOption("numCols", "nc", "Number of columns in the input matrix", true);
+      addOption("vector", "b", "Vector to solve against", true);
+      addOption("lambda", "l", "Scalar in A + lambda * I [default = 0]", "0.0");
+      addOption("symmetric", "sym", "Is the input matrix square and symmetric?", "true");
+      addOption("maxIter", "x", "Maximum number of iterations to run");
+      addOption("maxError", "err", "Maximum residual error to allow before stopping");
+
+      DistributedConjugateGradientSolver.this.parsedArgs = parseArguments(args);
+      if (DistributedConjugateGradientSolver.this.parsedArgs == null) {
+        return -1;
+      } else {
+        DistributedConjugateGradientSolver.this.setConf(new Configuration());
+        return DistributedConjugateGradientSolver.this.run(args);
+      }
+    }    
+  }
+
+  public static void main(String[] args) throws Exception {
+    ToolRunner.run(new DistributedConjugateGradientSolver().job(), args);
+  }
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolver.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolver.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolver.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,43 @@
+package org.apache.mahout.math.hadoop.solver;
+
+import java.io.File;
+import java.util.Random;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MahoutTestCase;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.apache.mahout.math.hadoop.TestDistributedRowMatrix;
+import org.junit.Test;
+
+
+public class TestDistributedConjugateGradientSolver extends MahoutTestCase
+{
+  private Vector randomVector(int size, double entryMean) {
+    DenseVector v = new DenseVector(size);
+    Random r = new Random(1234L);
+    
+    for (int i = 0; i < size; ++i) {
+      v.setQuick(i, r.nextGaussian() * entryMean);
+    }
+    
+    return v;
+  }
+
+  @Test
+  public void testSolver() throws Exception {
+    File testData = getTestTempDir("testdata");
+    DistributedRowMatrix matrix = new TestDistributedRowMatrix().randomDistributedMatrix(
+        10, 10, 10, 10, 10.0, true, testData.getAbsolutePath());
+    matrix.setConf(new Configuration());
+    Vector vector = randomVector(matrix.numCols(), 10.0);
+    
+    DistributedConjugateGradientSolver solver = new DistributedConjugateGradientSolver();
+    Vector x = solver.solve(matrix, vector);
+
+    Vector solvedVector = matrix.times(x);    
+    double distance = Math.sqrt(vector.getDistanceSquared(solvedVector));
+    assertEquals(0.0, distance, EPSILON);
+  }
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolverCLI.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolverCLI.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolverCLI.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolverCLI.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,94 @@
+package org.apache.mahout.math.hadoop.solver;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Random;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.apache.mahout.math.hadoop.TestDistributedRowMatrix;
+import org.junit.Test;
+
+public class TestDistributedConjugateGradientSolverCLI extends MahoutTestCase
+{
+  private Vector randomVector(int size, double entryMean) {
+    DenseVector v = new DenseVector(size);
+    Random r = new Random(1234L);
+    
+    for (int i = 0; i < size; ++i) {
+      v.setQuick(i, r.nextGaussian() * entryMean);
+    }
+    
+    return v;
+  }
+
+  private Path saveVector(Configuration conf, Path path, Vector v) throws IOException {
+    FileSystem fs = path.getFileSystem(conf);
+    SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class);
+    
+    try {
+      writer.append(new IntWritable(0), new VectorWritable(v));
+    } finally {
+      writer.close();
+    }
+    return path;
+  }
+  
+  private Vector loadVector(Configuration conf, Path path) throws IOException {
+    FileSystem fs = path.getFileSystem(conf);
+    SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
+    IntWritable key = new IntWritable();
+    VectorWritable value = new VectorWritable();
+    
+    try {
+      if (!reader.next(key, value)) {
+        throw new IOException("Input vector file is empty.");      
+      }
+      return value.get();
+    } finally {
+      reader.close();
+    }
+  }
+
+  @Test
+  public void testSolver() throws Exception {
+    Configuration conf = new Configuration();
+    Path testData = getTestTempDirPath("testdata");
+    DistributedRowMatrix matrix = new TestDistributedRowMatrix().randomDistributedMatrix(
+        10, 10, 10, 10, 10.0, true, testData.toString());
+    matrix.setConf(conf);
+    Path output = getTestTempFilePath("output");
+    Path vectorPath = getTestTempFilePath("vector");
+    Path tempPath = getTestTempDirPath("tmp");
+
+    Vector vector = randomVector(matrix.numCols(), 10.0);
+    saveVector(conf, vectorPath, vector);
+        
+    String[] args = {
+        "-i", matrix.getRowPath().toString(),
+        "-o", output.toString(),
+        "--tempDir", tempPath.toString(),
+        "--vector", vectorPath.toString(),
+        "--numRows", "10",
+        "--numCols", "10",
+        "--symmetric", "true"        
+    };
+    
+    DistributedConjugateGradientSolver solver = new DistributedConjugateGradientSolver();
+    solver.job().run(args);
+    
+    Vector x = loadVector(conf, output);
+    
+    Vector solvedVector = matrix.times(x);    
+    double distance = Math.sqrt(vector.getDistanceSquared(solvedVector));
+    assertEquals(0.0, distance, EPSILON);
+  }
+}

Added: mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java (added)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,199 @@
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.PlusMult;
+import org.apache.mahout.math.function.TimesFunction;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <p>Implementation of a conjugate gradient iterative solver for linear systems. Implements both
+ * standard conjugate gradient and pre-conditioned conjugate gradient. 
+ * 
+ * <p>Conjugate gradient requires the matrix A in the linear system Ax = b to be symmetric and positive
+ * definite. For convenience, this implementation allows the input matrix to be be non-symmetric, in
+ * which case the system A'Ax = b is solved. Because this requires only one pass through the matrix A, it
+ * is faster than explictly computing A'A, then passing the results to the solver.
+ * 
+ * <p>For inputs that may be ill conditioned (often the case for highly sparse input), this solver
+ * also accepts a parameter, lambda, which adds a scaled identity to the matrix A, solving the system
+ * (A + lambda*I)x = b. This obviously changes the solution, but it will guarantee solvability. The
+ * ridge regression approach to linear regression is a common use of this feature.
+ * 
+ * <p>If only an approximate solution is required, the maximum number of iterations or the error threshold
+ * may be specified to end the algorithm early at the expense of accuracy. When the matrix A is ill conditioned,
+ * it may sometimes be necessary to increase the maximum number of iterations above the default of A.numCols()
+ * due to numerical issues.
+ * 
+ * <p>By default the solver will run a.numCols() iterations or until the residual falls below 1E-9.
+ * 
+ * <p>For more information on the conjugate gradient algorithm, see Golub & van Loan, "Matrix Computations", 
+ * sections 10.2 and 10.3 or the <a href="http://en.wikipedia.org/wiki/Conjugate_gradient">conjugate gradient
+ * wikipedia article</a>.
+ */
+
+public class ConjugateGradientSolver
+{
+  public static final double DEFAULT_MAX_ERROR = 1e-9;
+  
+  private static final Logger log = LoggerFactory.getLogger(ConjugateGradientSolver.class);
+  private static final PlusMult plusMult = new PlusMult(1.0);
+
+  private int iterations;
+  private double residualNormSquared;
+  
+  public ConjugateGradientSolver() {
+    this.iterations = 0;
+    this.residualNormSquared = Double.NaN;
+  }  
+
+  /**
+   * Solves the system Ax = b with default termination criteria. A must be symmetric, square, and positive definite.
+   * Only the squareness of a is checked, since testing for symmetry and positive definiteness are too expensive. If
+   * an invalid matrix is specified, then the algorithm may not yield a valid result.
+   *  
+   * @param a  The linear operator A.
+   * @param b  The vector b.
+   * @return The result x of solving the system.
+   * @throws IllegalArgumentException if a is not square or if the size of b is not equal to the number of columns of a.
+   * 
+   */
+  public Vector solve(VectorIterable a, Vector b) {
+    return solve(a, b, null, b.size(), DEFAULT_MAX_ERROR);
+  }
+  
+  /**
+   * Solves the system Ax = b with default termination criteria using the specified preconditioner. A must be 
+   * symmetric, square, and positive definite. Only the squareness of a is checked, since testing for symmetry 
+   * and positive definiteness are too expensive. If an invalid matrix is specified, then the algorithm may not 
+   * yield a valid result.
+   *  
+   * @param a  The linear operator A.
+   * @param b  The vector b.
+   * @param precond A preconditioner to use on A during the solution process.
+   * @return The result x of solving the system.
+   * @throws IllegalArgumentException if a is not square or if the size of b is not equal to the number of columns of a.
+   * 
+   */
+  public Vector solve(VectorIterable a, Vector b, Preconditioner precond) {
+    return solve(a, b, precond, b.size(), DEFAULT_MAX_ERROR);
+  }
+  
+
+  /**
+   * Solves the system Ax = b, where A is a linear operator and b is a vector. Uses the specified preconditioner
+   * to improve numeric stability and possibly speed convergence. This version of solve() allows control over the 
+   * termination and iteration parameters.
+   * 
+   * @param a  The matrix A.
+   * @param b  The vector b.
+   * @param preconditioner The preconditioner to apply.
+   * @param maxIterations The maximum number of iterations to run.
+   * @param maxError The maximum amount of residual error to tolerate. The algorithm will run until the residual falls 
+   * below this value or until maxIterations are completed.
+   * @return The result x of solving the system.
+   * @throws IllegalArgumentException if the matrix is not square, if the size of b is not equal to the number of 
+   * columns of A, if maxError is less than zero, or if maxIterations is not positive. 
+   */
+  
+  public Vector solve(VectorIterable a, 
+                      Vector b, 
+                      Preconditioner preconditioner, 
+                      int maxIterations, 
+                      double maxError) {
+
+    if (a.numRows() != a.numCols()) {
+      throw new IllegalArgumentException("Matrix must be square, symmetric and positive definite.");
+    }
+    
+    if (a.numCols() != b.size()) {
+      throw new CardinalityException(a.numCols(), b.size());
+    }
+
+    if (maxIterations <= 0) {
+      throw new IllegalArgumentException("Max iterations must be positive.");      
+    }
+    
+    if (maxError < 0.0) {
+      throw new IllegalArgumentException("Max error must be non-negative.");
+    }
+    
+    Vector x = new DenseVector(b.size());
+
+    iterations = 0;
+    Vector residual = b.minus(a.times(x));
+    residualNormSquared = residual.dot(residual);
+
+    double conditionedNormSqr;
+    double previousConditionedNormSqr = 0.0;
+
+    Vector updateDirection = null;
+    
+    log.info("Conjugate gradient initial residual norm = " + Math.sqrt(residualNormSquared));
+    while (Math.sqrt(residualNormSquared) > maxError && iterations < maxIterations) {
+      Vector conditionedResidual;
+      if (preconditioner == null) {
+        conditionedResidual = residual;
+        conditionedNormSqr = residualNormSquared;
+      } else {
+        conditionedResidual = preconditioner.precondition(residual);
+        conditionedNormSqr = residual.dot(conditionedResidual);
+      }      
+      
+      ++iterations;
+      
+      if (iterations == 1) {
+        updateDirection = new DenseVector(conditionedResidual);
+      } else {
+        double beta = conditionedNormSqr / previousConditionedNormSqr;
+        
+        // updateDirection = residual + beta * updateDirection
+        updateDirection.assign(Functions.MULT, beta);
+        updateDirection.assign(conditionedResidual, Functions.PLUS);
+      }
+      
+      Vector aTimesUpdate = a.times(updateDirection);
+      
+      double alpha = conditionedNormSqr / updateDirection.dot(aTimesUpdate);
+      
+      // x = x + alpha * updateDirection
+      plusMult.setMultiplicator(alpha);
+      x.assign(updateDirection, plusMult);
+
+      // residual = residual - alpha * A * updateDirection
+      plusMult.setMultiplicator(-alpha);
+      residual.assign(aTimesUpdate, plusMult);
+      
+      previousConditionedNormSqr = conditionedNormSqr;
+      residualNormSquared = residual.dot(residual);
+      
+      log.info("Conjugate gradient iteration " + iterations + " residual norm = " + Math.sqrt(residualNormSquared));
+    }
+    return x;
+  }
+
+  /**
+   * Returns the number of iterations run once the solver is complete.
+   * 
+   * @return The number of iterations run.
+   */
+  public int getIterations() {
+    return iterations;
+  }
+
+  /**
+   * Returns the norm of the residual at the completion of the solver. Usually this should be close to zero except in
+   * the case of a non positive definite matrix A, which results in an unsolvable system, or for ill conditioned A, in
+   * which case more iterations than the default may be needed.
+   * 
+   * @return The norm of the residual in the solution.
+   */
+  public double getResidualNorm() {
+    return Math.sqrt(residualNormSquared);
+  }  
+}

Added: mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java (added)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,33 @@
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+/**
+ * 
+ * Implements the Jacobi preconditioner for a matrix A. This is defined as inv(diag(A)).
+ *
+ */
+public class JacobiConditioner implements Preconditioner
+{
+  private DenseVector inverseDiagonal;
+
+  public JacobiConditioner(Matrix a) {
+    if (a.numCols() != a.numRows()) {
+      throw new IllegalArgumentException("Matrix must be square.");
+    }
+    
+    inverseDiagonal = new DenseVector(a.numCols());
+    for (int i = 0; i < a.numCols(); ++i) {
+      inverseDiagonal.setQuick(i, 1.0 / a.getQuick(i, i));
+    }
+  }
+  
+  @Override
+  public Vector precondition(Vector v)
+  {
+    return v.times(inverseDiagonal);
+  }
+
+}

Added: mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/LSMR.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/LSMR.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/LSMR.java (added)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/LSMR.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,582 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Solves sparse least-squares using the LSMR algorithm.
+ * <p/>
+ * LSMR solves the system of linear equations A * X = B. If the system is inconsistent, it solves
+ * the least-squares problem min ||b - Ax||_2. A is a rectangular matrix of dimension m-by-n, where
+ * all cases are allowed: m=n, m>n, or m&lt;n. B is a vector of length m. The matrix A may be dense
+ * or sparse (usually sparse).
+ * <p/>
+ * Some additional configurable properties adjust the behavior of the algorithm.
+ * <p/>
+ * If you set lambda to a non-zero value then LSMR solves the regularized least-squares problem min
+ * ||(B) - (   A    )X|| ||(0)   (lambda*I) ||_2 where LAMBDA is a scalar.  If LAMBDA is not set,
+ * the system is solved without regularization.
+ * <p/>
+ * You can also set aTolerance and bTolerance.  These cause LSMR to iterate until a certain backward
+ * error estimate is smaller than some quantity depending on ATOL and BTOL.  Let RES = B - A*X be
+ * the residual vector for the current approximate solution X.  If A*X = B seems to be consistent,
+ * LSMR terminates when NORM(RES) <= ATOL*NORM(A)*NORM(X) + BTOL*NORM(B). Otherwise, LSMR terminates
+ * when NORM(A'*RES) <= ATOL*NORM(A)*NORM(RES). If both tolerances are 1.0e-6 (say), the final
+ * NORM(RES) should be accurate to about 6 digits. (The final X will usually have fewer correct
+ * digits, depending on cond(A) and the size of LAMBDA.)
+ * <p/>
+ * The default value for ATOL and BTOL is 1e-6.
+ * <p/>
+ * Ideally, they should be estimates of the relative error in the entries of A and B respectively.
+ * For example, if the entries of A have 7 correct digits, set ATOL = 1e-7. This prevents the
+ * algorithm from doing unnecessary work beyond the uncertainty of the input data.
+ * <p/>
+ * You can also set conditionLimit.  In that case, LSMR terminates if an estimate of cond(A) exceeds
+ * conditionLimit. For compatible systems Ax = b, conditionLimit could be as large as 1.0e+12 (say).
+ * For least-squares problems, conditionLimit should be less than 1.0e+8. If conditionLimit is not
+ * set, the default value is 1e+8. Maximum precision can be obtained by setting aTolerance =
+ * bTolerance = conditionLimit = 0, but the number of iterations may then be excessive.
+ * <p/>
+ * Setting iterationLimit causes LSMR to terminate if the number of iterations reaches
+ * iterationLimit.  The default is iterationLimit = min(m,n).   For ill-conditioned systems, a
+ * larger value of ITNLIM may be needed.
+ * <p/>
+ * Setting localSize causes LSMR to run with rerorthogonalization on the last localSize v_k's.
+ * (v-vectors generated by Golub-Kahan bidiagonalization) If localSize is not set, LSMR runs without
+ * reorthogonalization. A localSize > max(n,m) performs reorthogonalization on all v_k's.
+ * Reorthgonalizing only u_k or both u_k and v_k are not an option here. Details are discussed in
+ * the SIAM paper.
+ * <p/>
+ * getTerminationReason() gives the reason for termination. ISTOP  = 0 means X=0 is a solution. = 1
+ * means X is an approximate solution to A*X = B, according to ATOL and BTOL. = 2 means X
+ * approximately solves the least-squares problem according to ATOL. = 3 means COND(A) seems to be
+ * greater than CONLIM. = 4 is the same as 1 with ATOL = BTOL = EPS. = 5 is the same as 2 with ATOL
+ * = EPS. = 6 is the same as 3 with CONLIM = 1/EPS. = 7 means ITN reached ITNLIM before the other
+ * stopping conditions were satisfied.
+ * <p/>
+ * getIterationCount() gives ITN = the number of LSMR iterations.
+ * <p/>
+ * getResidualNorm() gives an estimate of the residual norm: NORMR = norm(B-A*X).
+ * <p/>
+ * getNormalEquationResidual() gives an estimate of the residual for the normal equation: NORMAR =
+ * NORM(A'*(B-A*X)).
+ * <p/>
+ * getANorm() gives an estimate of the Frobenius norm of A.
+ * <p/>
+ * getCondition() gives an estimate of the condition number of A.
+ * <p/>
+ * getXNorm() gives an estimate of NORM(X).
+ * <p/>
+ * LSMR uses an iterative method. For further information, see D. C.-L. Fong and M. A. Saunders
+ * LSMR: An iterative algorithm for least-square problems Draft of 03 Apr 2010, to be submitted to
+ * SISC.
+ * <p/>
+ * David Chin-lung Fong            clfong@stanford.edu Institute for Computational and Mathematical
+ * Engineering Stanford University
+ * <p/>
+ * Michael Saunders                saunders@stanford.edu Systems Optimization Laboratory Dept of
+ * MS&E, Stanford University. -----------------------------------------------------------------------
+ */
+public class LSMR {
+  private Logger log = LoggerFactory.getLogger(LSMR.class);
+  private double lambda;
+  private int localSize;
+  private int iterationLimit;
+  private double conditionLimit;
+  private double bTolerance;
+  private double aTolerance;
+  private int localPointer;
+  private Vector v;
+  private boolean localVQueueFull;
+  private Vector[] localV;
+  private double residualNorm;
+  private double normalEquationResidual;
+  private double aNorm;
+  private double xNorm;
+  private int iteration;
+  private double normA;
+  private double condA;
+
+  public int getIterationCount() {
+    return iteration;
+  }
+
+  public double getResidualNorm() {
+    return residualNorm;
+  }
+
+  public double getNormalEquationResidual() {
+    return normalEquationResidual;
+  }
+
+  public double getANorm() {
+    return normA;
+  }
+
+  public double getCondition() {
+    return condA;
+  }
+
+  public double getXNorm() {
+    return xNorm;
+  }
+
+  /**
+   * LSMR uses an iterative method to solve a linear system. For further information, see D. C.-L.
+   * Fong and M. A. Saunders LSMR: An iterative algorithm for least-square problems Draft of 03 Apr
+   * 2010, to be submitted to SISC.
+   * <p/>
+   * 08 Dec 2009: First release version of LSMR. 09 Apr 2010: Updated documentation and default
+   * parameters. 14 Apr 2010: Updated documentation. 03 Jun 2010: LSMR with local
+   * reorthogonalization (full reorthogonalization is also implemented)
+   * <p/>
+   * David Chin-lung Fong            clfong@stanford.edu Institute for Computational and
+   * Mathematical Engineering Stanford University
+   * <p/>
+   * Michael Saunders                saunders@stanford.edu Systems Optimization Laboratory Dept of
+   * MS&E, Stanford University. -----------------------------------------------------------------------
+   */
+
+  public LSMR() {
+    // Set default parameters.
+    setLambda(0);
+    setAtolerance(1e-6);
+    setBtolerance(1e-6);
+    setConditionLimit(1e8);
+    setIterationLimit(-1);
+    setLocalSize(0);
+  }
+
+  public Vector solve(Matrix A, Vector b) {
+    /*
+        % Initialize.
+
+
+        hdg1 = '   itn      x(1)       norm r    norm A''r';
+        hdg2 = ' compatible   LS      norm A   cond A';
+        pfreq  = 20;   % print frequency (for repeating the heading)
+        pcount = 0;    % print counter
+
+        % Determine dimensions m and n, and
+        % form the first vectors u and v.
+        % These satisfy  beta*u = b,  alpha*v = A'u.
+    */
+    log.debug("   itn         x(1)     norm r   norm A'r");
+    log.debug("   compatible   LS      norm A   cond A");
+
+    Matrix transposedA = A.transpose();
+    Vector u = b;
+
+    double beta = u.norm(2);
+    if (beta > 0) {
+      u = u.divide(beta);
+    }
+
+    v = transposedA.times(u);
+    int m = A.numRows();
+    int n = A.numCols();
+
+    int minDim = Math.min(m, n);
+    if (iterationLimit == -1) {
+      setIterationLimit(minDim);
+    }
+
+    if (log.isDebugEnabled()) {
+      log.debug("LSMR - Least-squares solution of  Ax = b, based on Matlab Version 1.02, 14 Apr 2010, Mahout version {}",
+        this.getClass().getPackage().getImplementationVersion());
+      log.debug(String.format("The matrix A has %d rows  and %d cols, lambda = %.4g, atol = %g, btol = %g",
+        m, n, getLambda(), getAtolerance(), getBtolerance()));
+    }
+
+    double alpha = v.norm(2);
+    if (alpha > 0) {
+      v.assign(Functions.div(alpha));
+    }
+
+
+    // Initialization for local reorthogonalization
+    boolean localOrtho = false;
+    localPointer = 0;
+    localVQueueFull = false;
+
+    // Preallocate storage for storing the last few v_k. Since with
+    // orthogonal v_k's, Krylov subspace method would converge in not
+    // more iterations than the number of singular values, more
+    // space is not necessary.
+    localV = new Vector[Math.min(localSize, minDim)];
+    if (localSize > 0) {
+      localOrtho = true;
+      localV[0] = v;
+    }
+
+
+    // Initialize variables for 1st iteration.
+
+    iteration = 0;
+    double zetabar = alpha * beta;
+    double alphabar = alpha;
+    double rho = 1;
+    double rhobar = 1;
+    double cbar = 1;
+    double sbar = 0;
+
+    Vector h = v;
+    Vector hbar = zeros(n);
+    Vector x = zeros(n);
+
+    // Initialize variables for estimation of ||r||.
+
+    double betadd = beta;
+    double betad = 0;
+    double rhodold = 1;
+    double tautildeold = 0;
+    double thetatilde = 0;
+    double zeta = 0;
+    double d = 0;
+
+    // Initialize variables for estimation of ||A|| and cond(A)
+
+    aNorm = alpha * alpha;
+    double maxrbar = 0;
+    double minrbar = 1e+100;
+
+    // Items for use in stopping rules.
+    double normb = beta;
+
+    int istop = 0;
+    StopCode stop = StopCode.CONTINUE;
+
+    double ctol = 0;
+    if (conditionLimit > 0) {
+      ctol = 1 / conditionLimit;
+    }
+    residualNorm = beta;
+
+    // Exit if b=0 or A'b = 0.
+
+    normalEquationResidual = alpha * beta;
+    if (normalEquationResidual == 0) {
+      return x;
+    }
+
+    // Heading for iteration log.
+
+
+    if (log.isDebugEnabled()) {
+      double test1 = 1;
+      double test2 = alpha / beta;
+//      log.debug('{} {}', hdg1, hdg2);
+      log.debug("{} {}", iteration, x.get(0));
+      log.debug("{} {}", residualNorm, normalEquationResidual);
+      log.debug("{} {}", test1, test2);
+    }
+
+
+    //------------------------------------------------------------------
+    //     Main iteration loop.
+    //------------------------------------------------------------------
+    while (iteration <= iterationLimit && stop == StopCode.CONTINUE) {
+
+      iteration = iteration + 1;
+
+      // Perform the next step of the bidiagonalization to obtain the
+      // next beta, u, alpha, v.  These satisfy the relations
+      //      beta*u  =  A*v  - alpha*u,
+      //      alpha*v  =  A'*u - beta*v.
+
+      u = A.times(v).minus(u.times(alpha));
+      beta = u.norm(2);
+      if (beta > 0) {
+        u.assign(Functions.div(beta));
+
+        // store data for local-reorthogonalization of V
+        if (localOrtho) {
+          localVEnqueue(v);
+        }
+        v = transposedA.times(u).minus(v.times(beta));
+        // local-reorthogonalization of V
+        if (localOrtho) {
+          v = localVOrtho(v);
+        }
+        alpha = v.norm(2);
+        if (alpha > 0) {
+          v.assign(Functions.div(alpha));
+        }
+      }
+
+      // At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.
+
+      // Construct rotation Qhat_{k,2k+1}.
+
+      double alphahat = Math.hypot(alphabar, lambda);
+      double chat = alphabar / alphahat;
+      double shat = lambda / alphahat;
+
+      // Use a plane rotation (Q_i) to turn B_i to R_i
+
+      double rhoold = rho;
+      rho = Math.hypot(alphahat, beta);
+      double c = alphahat / rho;
+      double s = beta / rho;
+      double thetanew = s * alpha;
+      alphabar = c * alpha;
+
+      // Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar
+
+      double rhobarold = rhobar;
+      double zetaold = zeta;
+      double thetabar = sbar * rho;
+      double rhotemp = cbar * rho;
+      rhobar = Math.hypot(cbar * rho, thetanew);
+      cbar = cbar * rho / rhobar;
+      sbar = thetanew / rhobar;
+      zeta = cbar * zetabar;
+      zetabar = -sbar * zetabar;
+
+
+      // Update h, h_hat, x.
+
+      hbar = h.minus(hbar.times(thetabar * rho / (rhoold * rhobarold)));
+
+      x.assign(hbar.times(zeta / (rho * rhobar)), Functions.PLUS);
+      h = v.minus(h.times(thetanew / rho));
+
+      // Estimate of ||r||.
+
+      // Apply rotation Qhat_{k,2k+1}.
+      double betaacute = chat * betadd;
+      double betacheck = -shat * betadd;
+
+      // Apply rotation Q_{k,k+1}.
+      double betahat = c * betaacute;
+      betadd = -s * betaacute;
+
+      // Apply rotation Qtilde_{k-1}.
+      // betad = betad_{k-1} here.
+
+      double thetatildeold = thetatilde;
+      double rhotildeold = Math.hypot(rhodold, thetabar);
+      double ctildeold = rhodold / rhotildeold;
+      double stildeold = thetabar / rhotildeold;
+      thetatilde = stildeold * rhobar;
+      rhodold = ctildeold * rhobar;
+      betad = -stildeold * betad + ctildeold * betahat;
+
+      // betad   = betad_k here.
+      // rhodold = rhod_k  here.
+
+      tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold;
+      double taud = (zeta - thetatilde * tautildeold) / rhodold;
+      d = d + betacheck * betacheck;
+      residualNorm = Math.sqrt(d + (betad - taud) * (betad - taud) + betadd * betadd);
+
+      // Estimate ||A||.
+      aNorm = aNorm + beta * beta;
+      normA = Math.sqrt(aNorm);
+      aNorm = aNorm + alpha * alpha;
+
+      // Estimate cond(A).
+      maxrbar = Math.max(maxrbar, rhobarold);
+      if (iteration > 1) {
+        minrbar = Math.min(minrbar, rhobarold);
+      }
+      condA = Math.max(maxrbar, rhotemp) / Math.min(minrbar, rhotemp);
+
+      // Test for convergence.
+
+      // Compute norms for convergence testing.
+      normalEquationResidual = Math.abs(zetabar);
+      xNorm = x.norm(2);
+
+      // Now use these norms to estimate certain other quantities,
+      // some of which will be small near a solution.
+
+      double test1 = residualNorm / normb;
+      double test2 = normalEquationResidual / (normA * residualNorm);
+      double test3 = 1 / condA;
+      double t1 = test1 / (1 + normA * xNorm / normb);
+      double rtol = bTolerance + aTolerance * normA * xNorm / normb;
+
+      // The following tests guard against extremely small values of
+      // atol, btol or ctol.  (The user may have set any or all of
+      // the parameters atol, btol, conlim  to 0.)
+      // The effect is equivalent to the normAl tests using
+      // atol = eps,  btol = eps,  conlim = 1/eps.
+
+      if (iteration > iterationLimit) {
+        istop = 7;
+        stop = StopCode.ITERATION_LIMIT;
+      }
+      if (1 + test3 <= 1) {
+        istop = 6;
+        stop = StopCode.CONDITION_MACHINE_TOLERANCE;
+      }
+      if (1 + test2 <= 1) {
+        istop = 5;
+        stop = StopCode.LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE;
+      }
+      if (1 + t1 <= 1) {
+        istop = 4;
+        stop = StopCode.CONVERGED_MACHINE_TOLERANCE;
+      }
+
+      // Allow for tolerances set by the user.
+
+      if (test3 <= ctol) {
+        istop = 3;
+        stop = StopCode.CONDITION;
+      }
+      if (test2 <= aTolerance) {
+        istop = 2;
+        stop = StopCode.CONVERGED;
+      }
+      if (test1 <= rtol) {
+        istop = 1;
+        stop = StopCode.TRIVIAL;
+      }
+
+      if (stop != StopCode.CONTINUE && stop.ordinal() != istop + 1) {
+        throw new IllegalStateException(String.format("bad code match %d vs %d", istop, stop.ordinal()));
+      }
+
+      // See if it is time to print something.
+
+      if (log.isDebugEnabled()) {
+        if ((n <= 40) || (iteration <= 10) || (iteration >= iterationLimit - 10) || ((iteration % 10) == 0) || (test3 <= 1.1 * ctol) || (test2 <= 1.1 * aTolerance) || (test1 <= 1.1 * rtol) || (istop != 0)) {
+          statusDump(x, normA, condA, test1, test2);
+        }
+      }
+    } // iteration loop
+
+    // Print the stopping condition.
+    log.debug("Finished: {}", stop.getMessage());
+
+    return x;
+    /*
+
+
+    if show
+      fprintf('\n\nLSMR finished')
+      fprintf('\n%s', msg(istop+1,:))
+      fprintf('\nistop =%8g    normr =%8.1e'     , istop, normr )
+      fprintf('    normA =%8.1e    normAr =%8.1e', normA, normAr)
+      fprintf('\nitn   =%8g    condA =%8.1e'     , itn  , condA )
+      fprintf('    normx =%8.1e\n', normx)
+    end
+    */
+  }
+
+  private void statusDump(Vector x, double normA, double condA, double test1, double test2) {
+    log.debug("{} {}", residualNorm, normalEquationResidual);
+    log.debug("{} {}", iteration, x.get(0));
+    log.debug("{} {}", test1, test2);
+    log.debug("{} {}", normA, condA);
+  }
+
+  private Vector zeros(int n) {
+    return new DenseVector(n);
+  }
+
+  //-----------------------------------------------------------------------
+  // stores v into the circular buffer localV
+  //-----------------------------------------------------------------------
+
+  private void localVEnqueue(Vector v) {
+    if (localV.length > 0) {
+      localV[localPointer] = v;
+      localPointer = (localPointer + 1) % localV.length;
+    }
+  }
+
+  //-----------------------------------------------------------------------
+  // Perform local reorthogonalization of V
+  //-----------------------------------------------------------------------
+
+  private Vector localVOrtho(Vector v) {
+    for (Vector old : localV) {
+      if (old != null) {
+        double x = v.dot(old);
+        v = v.minus(old.times(x));
+      }
+    }
+    return v;
+  }
+
+  private enum StopCode {
+    CONTINUE("Not done"),
+    TRIVIAL("The exact solution is  x = 0"),
+    CONVERGED("Ax - b is small enough, given atol, btol"), LEAST_SQUARE_CONVERGED("The least-squares solution is good enough, given atol"),
+    CONDITION("The estimate of cond(Abar) has exceeded condition limit"),
+    CONVERGED_MACHINE_TOLERANCE("Ax - b is small enough for this machine"),
+    LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE("The least-squares solution is good enough for this machine"),
+    CONDITION_MACHINE_TOLERANCE("Cond(Abar) seems to be too large for this machine"),
+    ITERATION_LIMIT("The iteration limit has been reached");
+
+    private String message;
+
+    private StopCode(String message) {
+      this.message = message;
+    }
+
+    public String getMessage() {
+      return message;
+    }
+  }
+
+  public void setAtolerance(double aTolerance) {
+    this.aTolerance = aTolerance;
+  }
+
+  public void setBtolerance(double bTolerance) {
+    this.bTolerance = bTolerance;
+  }
+
+  public void setConditionLimit(double conditionLimit) {
+    this.conditionLimit = conditionLimit;
+  }
+
+  public void setIterationLimit(int iterationLimit) {
+    this.iterationLimit = iterationLimit;
+  }
+
+  public void setLocalSize(int localSize) {
+    this.localSize = localSize;
+  }
+
+  private void setLambda(double lambda) {
+    this.lambda = lambda;
+  }
+
+  public double getLambda() {
+    return lambda;
+  }
+
+  public double getAtolerance() {
+    return aTolerance;
+  }
+
+  public double getBtolerance() {
+    return bTolerance;
+  }
+}

Added: mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/Preconditioner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/Preconditioner.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/Preconditioner.java (added)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/solver/Preconditioner.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,20 @@
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * 
+ * <p>Interface for defining preconditioners used for improving the performance and/or stability of linear
+ * system solvers.
+ *
+ */
+public interface Preconditioner
+{
+  /**
+   * Preconditions the specified vector.
+   * 
+   * @param v The vector to precondition.
+   * @return The preconditioned vector.
+   */
+  public Vector precondition(Vector v);
+}

Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/LSMRTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/LSMRTest.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/LSMRTest.java (added)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/LSMRTest.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.solver.LSMR;
+import org.junit.Test;
+
+import java.util.Random;
+
+/**
+ * Created by IntelliJ IDEA. User: tdunning Date: Sep 15, 2010 Time: 7:32:27 PM To change this
+ * template use File | Settings | File Templates.
+ */
+public class LSMRTest extends MahoutTestCase {
+  @Test
+  public void basics() {
+    Matrix m = hilbert(5);
+
+    // make sure it is the hilbert matrix we know and love
+    assertEquals(1, m.get(0, 0), 0);
+    assertEquals(0.5, m.get(0, 1), 0);
+    assertEquals(1 / 6.0, m.get(2, 3), 1e-9);
+
+    Vector x = new DenseVector(new double[]{5, -120, 630, -1120, 630});
+
+    Vector b = new DenseVector(5);
+    b.assign(1);
+
+    assertEquals(0, m.times(x).minus(b).norm(2), 1e-9);
+
+    LSMR r = new LSMR();
+    Vector x1 = r.solve(m, b);
+
+    // the ideal solution is  [5  -120   630 -1120   630] but the 5x5 hilbert matrix
+    // has a condition number of almost 500,000 and the normal equation condition
+    // number is that squared.  This means that we don't get the exact answer with
+    // a fast iterative solution.
+    // Thus, we have to check the residuals rather than testing that the answer matched
+    // the ideal.
+    assertEquals(m.times(x1).minus(b).norm(2), 0, 1e-2);
+    assertEquals(0, m.transpose().times(m).times(x1).minus(m.transpose().times(b)).norm(2), 1e-7);
+
+    // and we need to check that the error estimates are pretty good.
+    assertEquals(m.times(x1).minus(b).norm(2), r.getResidualNorm(), 1e-5);
+    assertEquals(m.transpose().times(m).times(x1).minus(m.transpose().times(b)).norm(2), r.getNormalEquationResidual(), 1e-9);
+  }
+  
+  private Matrix hilbert(int n) {
+    Matrix r = new DenseMatrix(n, n);
+    for (int i = 0; i < n; i++) {
+      for (int j = 0; j < n; j++) {
+        r.set(i, j, 1.0 / (i + j + 1));
+      }
+    }
+    return r;
+  }
+
+  private Matrix overDetermined(int n) {
+    Random rand = RandomUtils.getRandom();
+    Matrix r = new DenseMatrix(2 * n, n);
+    for (int i = 0; i < 2 * n; i++) {
+      for (int j = 0; j < n; j++) {
+        r.set(i, j, rand.nextGaussian());
+      }
+    }
+    return r;
+  }
+}

Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/TestConjugateGradientSolver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/TestConjugateGradientSolver.java?rev=1188491&view=auto
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/TestConjugateGradientSolver.java (added)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/solver/TestConjugateGradientSolver.java Tue Oct 25 01:59:58 2011
@@ -0,0 +1,212 @@
+package org.apache.mahout.math.solver;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public class TestConjugateGradientSolver extends MahoutTestCase
+{
+  @Test
+  public void testConjugateGradientSolver() {
+    Matrix a = getA();
+    Vector b = getB();
+    
+    ConjugateGradientSolver solver = new ConjugateGradientSolver();
+    Vector x = solver.solve(a, b);
+    
+    assertEquals(0.0, Math.sqrt(a.times(x).getDistanceSquared(b)), EPSILON);    
+    assertEquals(0.0, solver.getResidualNorm(), ConjugateGradientSolver.DEFAULT_MAX_ERROR);
+    assertEquals(10, solver.getIterations());
+  }
+  
+  @Test
+  public void testConditionedConjugateGradientSolver() {
+    Matrix a = getIllConditionedMatrix();
+    Vector b = getB();
+    Preconditioner conditioner = new JacobiConditioner(a);
+    ConjugateGradientSolver solver = new ConjugateGradientSolver();
+    
+    Vector x = solver.solve(a, b, null, 100, ConjugateGradientSolver.DEFAULT_MAX_ERROR);
+
+    double distance = Math.sqrt(a.times(x).getDistanceSquared(b));
+    assertEquals(0.0, distance, EPSILON);
+    assertEquals(0.0, solver.getResidualNorm(), ConjugateGradientSolver.DEFAULT_MAX_ERROR);
+    assertEquals(16, solver.getIterations());
+    
+    Vector x2 = solver.solve(a, b, conditioner, 100, ConjugateGradientSolver.DEFAULT_MAX_ERROR);
+
+    // the Jacobi preconditioner isn't very good, but it does result in one less iteration to converge
+    distance = Math.sqrt(a.times(x2).getDistanceSquared(b));
+    assertEquals(0.0, distance, EPSILON);
+    assertEquals(0.0, solver.getResidualNorm(), ConjugateGradientSolver.DEFAULT_MAX_ERROR);
+    assertEquals(15, solver.getIterations());
+  }
+    
+  @Test
+  public void testEarlyStop() {
+    Matrix a = getA();
+    Vector b = getB();    
+    ConjugateGradientSolver solver = new ConjugateGradientSolver();
+    
+    // specifying a looser max error will result in few iterations but less accurate results
+    Vector x = solver.solve(a, b, null, 10, 0.1);
+    double distance = Math.sqrt(a.times(x).getDistanceSquared(b));
+    assertTrue(distance > EPSILON);
+    assertEquals(0.0, distance, 0.1); // should be equal to within the error specified
+    assertEquals(7, solver.getIterations()); // should have taken fewer iterations
+    
+    // can get a similar effect by bounding the number of iterations
+    x = solver.solve(a, b, null, 7, ConjugateGradientSolver.DEFAULT_MAX_ERROR);    
+    distance = Math.sqrt(a.times(x).getDistanceSquared(b));
+    assertTrue(distance > EPSILON);
+    assertEquals(0.0, distance, 0.1);
+    assertEquals(7, solver.getIterations()); 
+  }  
+  
+  private static Matrix getA() {
+    return reshape(new double[] {
+        11.7155649822793997, -0.7125253363083646, 4.6473613961860183,  1.6020939468348456, -4.6789817799137134,
+        -0.8140416763434970, -4.5995617505618345, -1.1749070042775340, -1.6747995811678336, 3.1922255171058342,
+        -0.7125253363083646, 12.3400579683994867, -2.6498099427000645, 0.5264507222630669, 0.3783428369189767,
+        -2.1170186159188811, 2.3695134252190528, 3.8182131490333013, 6.5285942298270347, 2.8564814419366353,
+         4.6473613961860183, -2.6498099427000645, 16.1317933921668484, -0.0409475448061225, 1.4805687075608227,
+        -2.9958076484628950, -2.5288893025027264, -0.9614557539842487, -2.2974738351519077, -1.5516184284572598,
+         1.6020939468348456, 0.5264507222630669, -0.0409475448061225, 4.1946802122694482, -2.5210038046912198,
+         0.6634899962909317, 0.4036187419205338, -0.2829211393003727, -0.2283091172980954, 1.1253516563552464,
+        -4.6789817799137134, 0.3783428369189767, 1.4805687075608227, -2.5210038046912198, 19.4307361862733430,
+        -2.5200132222091787, 2.3748511971444510, 11.6426598443305522, -0.1508136510863874, 4.3471343888063512,
+        -0.8140416763434970, -2.1170186159188811, -2.9958076484628950, 0.6634899962909317, -2.5200132222091787,
+         7.6712334419700747, -3.8687773629502851, -3.0453418711591529, -0.1155580876143619, -2.4025459467422121,
+        -4.5995617505618345, 2.3695134252190528, -2.5288893025027264, 0.4036187419205338, 2.3748511971444510,
+        -3.8687773629502851, 10.4681666057470082, 1.6527180866171229, 2.9341795819365384, -2.1708176372763099,
+        -1.1749070042775340, 3.8182131490333013, -0.9614557539842487, -0.2829211393003727, 11.6426598443305522,
+        -3.0453418711591529, 1.6527180866171229, 16.0050616934176233, 1.1689747208793086, 1.6665090945954870,
+        -1.6747995811678336, 6.5285942298270347, -2.2974738351519077, -0.2283091172980954, -0.1508136510863874,
+        -0.1155580876143619, 2.9341795819365384, 1.1689747208793086, 6.4794329751637481, -1.9197339981871877,
+         3.1922255171058342, 2.8564814419366353, -1.5516184284572598, 1.1253516563552464, 4.3471343888063512,
+        -2.4025459467422121, -2.1708176372763099, 1.6665090945954870, -1.9197339981871877, 18.9149021356344598
+    }, 10, 10);
+  }
+    
+  private static Vector getB() {
+    return new DenseVector(new double[] { 
+        -0.552252, 0.038430, 0.058392, -1.234496, 1.240369, 0.373649, 0.505113, 0.503723, 1.215340, -0.391908
+    });
+  }
+  
+  private static Matrix getIllConditionedMatrix() {
+    return reshape(new double[] {
+        0.00695278043678842, 0.09911830022078683, 0.01309584636255063, 0.00652917453032394, 0.04337631487735064,
+        0.14232165273321387, 0.05808722912361313, -0.06591965049732287, 0.06055771542862332, 0.00577423310349649,
+        0.09911830022078683, 1.50071402418061428, 0.14988743575884242, 0.07195514527480981, 0.63747362341752722,
+        1.30711819020414688, 0.82151609385115953, -0.72616125524587938, 1.03490136002022948, 0.12800239664439328,
+        0.01309584636255063, 0.14988743575884242, 0.04068462583124965, 0.02147022047006482, 0.07388113580146650,
+        0.58070223915076002, 0.11280336266257514, -0.21690068430020618, 0.04065087561300068, -0.00876895259593769,
+        0.00652917453032394, 0.07195514527480981, 0.02147022047006482, 0.01140105250542524, 0.03624164348693958,
+        0.31291554581393255, 0.05648457235205666, -0.11507583016077780, 0.01475756130709823, -0.00584453679519805,
+        0.04337631487735064, 0.63747362341752722, 0.07388113580146649, 0.03624164348693959, 0.27491543200760571,
+        0.73410543168748121, 0.36120630002843257, -0.36583546331208316, 0.41472509341940017, 0.04581458758255480,
+        0.14232165273321387, 1.30711819020414666, 0.58070223915076002, 0.31291554581393255, 0.73410543168748121,
+        9.02536073121807014, 1.25426385582883104, -3.16186335125594642, -0.19740140818905436, -0.26613760880058035,
+        0.05808722912361314, 0.82151609385115953, 0.11280336266257514, 0.05648457235205667, 0.36120630002843257,
+        1.25426385582883126, 0.48661058451606820, -0.57030511336562195, 0.49151280464818098, 0.04428280690189127,
+       -0.06591965049732286, -0.72616125524587938, -0.21690068430020618, -0.11507583016077781, -0.36583546331208316,
+       -3.16186335125594642, -0.57030511336562195, 1.16270815038078945, -0.14837898963724327, 0.05917203395002889,
+        0.06055771542862331, 1.03490136002022926, 0.04065087561300068, 0.01475756130709823, 0.41472509341940023,
+       -0.19740140818905436, 0.49151280464818103, -0.14837898963724327, 0.86693820682049716, 0.14089688752570340,
+        0.00577423310349649, 0.12800239664439328, -0.00876895259593769, -0.00584453679519805, 0.04581458758255480,
+       -0.26613760880058035, 0.04428280690189126, 0.05917203395002889, 0.14089688752570340, 0.02901858439788401
+    }, 10, 10);    
+  }
+  
+  private static Matrix getAsymmetricMatrix() {
+    return reshape(new double[] {
+        0.1586493402398226, -0.8668244036239467, 0.4335233711065471, -1.1025223577469705, 1.1344100191664601,
+         -0.1399944083742454, 0.8879750333144295, -1.2139664527957903, 0.7154591081557057, -0.6320890356949669,
+        -2.4546945723009581, 0.6354748667295935, -0.1931993736354496, -0.1210449542073575, -1.0668745874463414,
+         0.6539061600017384, 2.4045520271091063,-0.3387572116155693, 0.1575188740437142, 1.1791073500243496,
+        -0.6418745429181755, 0.6836410530720005, -1.2447493564334062, -1.8840081252627843, 0.5663864914859502,
+         0.0819203791124956, 0.2004407540793239, 0.7350145066687849, 1.6525377683305262, -0.3156915229969668,
+        -0.1866701463141060, -0.3929673444397022, -0.4440946700501859, 0.1366803303987421, -0.2138101381625466,
+         0.5399874351478779, -1.0088091882703056, 0.0978023083150833, 1.8795777615527958, 0.3782417618354363,
+        -0.4564752186043173, 0.4014814252832269, 1.9691150950571501, 0.2424686682362568, 1.0965758964799504,
+         0.2751725463132324, -0.6652756564294597, -0.6256564536463288, 1.0332457212107204, -0.0330851504958215,
+        -1.0402096493279287, -0.6850389655533707, -1.8896839974451625, 1.1533231017445102, -0.5387306882127710,
+         0.0181850207098213, -0.2416652193929706, -0.9868171673047287, -1.5872573189377035, -0.8492253650362955,
+         1.1949977792951225, 0.7901168665120927, 0.9832676055718492, -0.0752834029327588, 1.0555006468941126,
+         0.6842531633106009, 0.2589700378872499, 0.3565253337268334, 0.1869608474650344, -0.1696524825242293,
+         0.6919898638809949, -1.4937187919435133, 1.0039151841775080, -0.2580993333173019, 0.1243386429912411,
+         1.3945380460721688,  0.3078165489952902, 1.1248734111054359,  0.5613308856003306, -0.9013329415656699,
+        -0.9197179846787753,  0.1167372728291174, -0.7807620712716467,  0.2210918047063067, -0.4813869727362010,
+         0.3870067788770671,  1.1974416632199159, 2.4676804711420330,  1.8492990765211168, -1.3089887830472471,
+        -0.7587845769668021, -1.0354138253278353, -0.3907902473275445, -2.1292895670916168, -0.7544686049709807,
+        -0.3431317172534703, 1.4959721683724390, 0.6004852467523584, 1.2140230344223786, 0.1279148299232956
+    }, 20, 5);
+  }
+  
+  private static Vector getSmallB() {
+    return new DenseVector(new double[] {    
+        0.114065955249272,
+        0.953981568944476,
+        -2.611106316607759,
+        0.652190962446307,
+        1.298055218126384,
+    });
+  }
+  
+  private static Matrix getLowrankSymmetricMatrix() {
+    Matrix m = new DenseMatrix(5,5);
+    Vector u = new DenseVector(new double[] {  
+        -0.0364638798936962,
+        1.0219291133418171,
+        -0.5649933120375343,
+        -1.0050553315595800,
+        -0.5264178580727512
+    });
+    Vector v = new DenseVector(new double[] {
+        -1.345847117891187,
+        0.553386426498032,
+        1.912020072696648,
+        -0.820959934779948,
+        1.223358044171859
+    });
+
+    return m.plus(u.cross(u)).plus(v.cross(v));
+  }
+
+  private static Matrix getLowrankAsymmetricMatrix() {
+    Matrix m = new DenseMatrix(20,5);
+    Vector u = new DenseVector(new double[] {  
+        -0.0364638798936962,
+        1.0219291133418171,
+        -0.5649933120375343,
+        -1.0050553315595800,
+        -0.5264178580727512
+    });
+    Vector v = new DenseVector(new double[] {
+        -1.345847117891187,
+        0.553386426498032,
+        1.912020072696648,
+        -0.820959934779948,
+        1.223358044171859
+    });
+
+    m.assignRow(0, u);
+    m.assignRow(0, v);
+    
+    return m;
+  }
+  
+  private static Matrix reshape(double[] values, int rows, int columns) {
+    Matrix m = new DenseMatrix(rows, columns);
+    int i = 0;
+    for (double v : values) {
+      m.set(i % rows, i / rows, v);
+      i++;
+    }
+    return m;
+  }
+}



Mime
View raw message