hama-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1531245 - in /hama/trunk: examples/src/test/resources/ ml/src/main/java/org/apache/hama/ml/math/ ml/src/main/java/org/apache/hama/ml/regression/ ml/src/test/java/org/apache/hama/ml/regression/
Date Fri, 11 Oct 2013 09:46:29 GMT
Author: tommaso
Date: Fri Oct 11 09:46:29 2013
New Revision: 1531245

URL: http://svn.apache.org/r1531245
Log:
HAMA-809 - applied draft patch to solve LR cost function problems

Modified:
    hama/trunk/examples/src/test/resources/logistic_regression_sample.txt
    hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/HypothesisFunction.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java
    hama/trunk/ml/src/test/java/org/apache/hama/ml/regression/LinearRegressionModelTest.java
    hama/trunk/ml/src/test/java/org/apache/hama/ml/regression/LogisticRegressionModelTest.java

Modified: hama/trunk/examples/src/test/resources/logistic_regression_sample.txt
URL: http://svn.apache.org/viewvc/hama/trunk/examples/src/test/resources/logistic_regression_sample.txt?rev=1531245&r1=1531244&r2=1531245&view=diff
==============================================================================
--- hama/trunk/examples/src/test/resources/logistic_regression_sample.txt (original)
+++ hama/trunk/examples/src/test/resources/logistic_regression_sample.txt Fri Oct 11 09:46:29
2013
@@ -1,6 +1,8 @@
-2>1 9 2 4 5 6 7
+0>1 9 2 4 5 6 7
 1>3 4 5 6 9 1 3
 1>1 1 3 1 1 1 1
-2>2 4 1 1 4 1 8
+0>2 4 1 1 4 1 8
 1>3 4 5 6 7 8 9
 1>1 3 4 1 4 5 1
+0>1 10 3 2 1 6 1
+0>1 2 2 2 2 6 2
\ No newline at end of file

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java?rev=1531245&r1=1531244&r2=1531245&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java Fri Oct 11
09:46:29 2013
@@ -17,6 +17,7 @@
  */
 package org.apache.hama.ml.math;
 
+import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -357,11 +358,11 @@ public final class DenseDoubleVector imp
    */
   @Override
   public double dotUnsafe(DoubleVector vector) {
-    double dotProduct = 0.0d;
+    BigDecimal dotProduct = BigDecimal.valueOf(0.0d);
     for (int i = 0; i < getLength(); i++) {
-      dotProduct += this.get(i) * vector.get(i);
+      dotProduct = dotProduct.add(BigDecimal.valueOf(this.get(i)).multiply(BigDecimal.valueOf(vector.get(i))));
     }
-    return dotProduct;
+    return dotProduct.doubleValue();
   }
 
   /*

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java?rev=1531245&r1=1531244&r2=1531245&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java Fri Oct 11
09:46:29 2013
@@ -17,6 +17,8 @@
  */
 package org.apache.hama.ml.regression;
 
+import java.math.BigDecimal;
+
 import org.apache.hama.ml.math.DoubleVector;
 
 /**
@@ -33,9 +35,9 @@ public interface CostFunction {
    * @param m the number of existing items
    * @param theta the parameters vector theta
    * @param hypothesis the hypothesis function to model the problem
-   * @return the calculated cost for input x and output y
+   * @return the calculated cost for input x and output y as a <code>BigDecimal</code>
    */
-  public double calculateCostForItem(DoubleVector x, double y, int m,
+  public BigDecimal calculateCostForItem(DoubleVector x, double y, int m,
       DoubleVector theta, HypothesisFunction hypothesis);
 
 }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1531245&r1=1531244&r2=1531245&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java Fri
Oct 11 09:46:29 2013
@@ -17,6 +17,10 @@
  */
 package org.apache.hama.ml.regression;
 
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.util.Arrays;
+
 import org.apache.hadoop.io.DoubleWritable;
 import org.apache.hama.bsp.BSP;
 import org.apache.hama.bsp.BSPPeer;
@@ -28,9 +32,6 @@ import org.apache.hama.util.KeyValuePair
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.IOException;
-import java.util.Arrays;
-
 /**
  * A gradient descent (see
  * <code>http://en.wikipedia.org/wiki/Gradient_descent</code>) BSP based
@@ -133,9 +134,7 @@ public class GradientDescentBSP
       updateTheta(newTheta);
 
       if (log.isDebugEnabled()) {
-        log.debug(new StringBuilder(peer.getPeerName())
-            .append(": new theta for cost ").append(cost).append(" is ")
-            .append(theta.toString()).toString());
+        log.debug("{}: new theta for cost {} is {}", new Object[]{peer.getPeerName(), cost,
theta});
       }
       // master writes down the output
       if (master) {
@@ -206,8 +205,7 @@ public class GradientDescentBSP
     } else {
       cost = totalCost;
       if (log.isDebugEnabled()) {
-        log.debug(new StringBuilder(peer.getPeerName())
-            .append(": current cost is ").append(cost).toString());
+        log.debug("{}: current cost is {}", peer.getPeerName(), cost);
       }
       return false;
     }
@@ -224,7 +222,7 @@ public class GradientDescentBSP
       // calculate cost for given input
       double y = kvp.getValue().get();
       DoubleVector x = kvp.getKey().getVector();
-      double costForX = regressionModel.calculateCostForItem(x, y, m, theta);
+      double costForX = regressionModel.calculateCostForItem(x, y, m, theta).doubleValue();
 
       // adds to local cost
       localCost += costForX;
@@ -250,9 +248,9 @@ public class GradientDescentBSP
     while ((kvp = peer.readNext()) != null) {
       DoubleVector x = kvp.getKey().getVector();
       double y = kvp.getValue().get();
-      double difference = regressionModel.applyHypothesis(theta, x) - y;
+      BigDecimal difference = regressionModel.applyHypothesis(theta, x).subtract(BigDecimal.valueOf(y));
       for (int j = 0; j < theta.getLength(); j++) {
-        thetaDelta[j] += difference * x.get(j);
+        thetaDelta[j] += difference.multiply(BigDecimal.valueOf(x.get(j))).doubleValue();
       }
     }
     return thetaDelta;
@@ -266,9 +264,7 @@ public class GradientDescentBSP
     if (master) {
       peer.write(new VectorWritable(theta), new DoubleWritable(cost));
       if (log.isInfoEnabled()) {
-        log.info(new StringBuilder(peer.getPeerName())
-            .append(":computation finished with cost ").append(cost)
-            .append(" for theta ").append(theta).toString());
+        log.info("{}:computation finished with cost {} and theta {}", new Object[]{peer.getPeerName(),
cost, theta});
       }
     }
   }
@@ -283,14 +279,12 @@ public class GradientDescentBSP
             INITIAL_THETA_VALUES, 1));
         broadcastVector(peer, theta.toArray());
         if (log.isDebugEnabled()) {
-          log.debug(new StringBuilder(peer.getPeerName()).append(
-              ": sending theta").toString());
+          log.debug("{}: sending theta", peer.getPeerName());
         }
         peer.sync();
       } else {
         if (log.isDebugEnabled()) {
-          log.debug(new StringBuilder(peer.getPeerName()).append(
-              ": getting theta").toString());
+          log.debug("{}: getting theta", peer.getPeerName());
         }
         peer.sync();
         VectorWritable vectorWritable = peer.getCurrentMessage();

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/HypothesisFunction.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/HypothesisFunction.java?rev=1531245&r1=1531244&r2=1531245&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/HypothesisFunction.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/HypothesisFunction.java Fri
Oct 11 09:46:29 2013
@@ -17,6 +17,8 @@
  */
 package org.apache.hama.ml.regression;
 
+import java.math.BigDecimal;
+
 import org.apache.hama.ml.math.DoubleVector;
 
 /**
@@ -25,13 +27,13 @@ import org.apache.hama.ml.math.DoubleVec
 public interface HypothesisFunction {
 
   /**
-   * Applies the applyHypothesis given a set of parameters theta to a given
+   * Applies this <code>HypothesisFunction</code> to given a set of parameters
theta and 
    * input x
    * 
    * @param theta the parameters vector
    * @param x the input
-   * @return a <code>double</code> number
+   * @return a <code>BigDecimal</code> representing the number
    */
-  public double applyHypothesis(DoubleVector theta, DoubleVector x);
+  public BigDecimal applyHypothesis(DoubleVector theta, DoubleVector x);
 
 }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java?rev=1531245&r1=1531244&r2=1531245&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java Fri
Oct 11 09:46:29 2013
@@ -17,6 +17,8 @@
  */
 package org.apache.hama.ml.regression;
 
+import java.math.BigDecimal;
+
 import org.apache.hama.ml.math.DoubleVector;
 
 /**
@@ -29,20 +31,20 @@ public class LinearRegressionModel imple
   public LinearRegressionModel() {
     costFunction = new CostFunction() {
       @Override
-      public double calculateCostForItem(DoubleVector x, double y, int m,
+      public BigDecimal calculateCostForItem(DoubleVector x, double y, int m,
           DoubleVector theta, HypothesisFunction hypothesis) {
-        return y * Math.pow(applyHypothesis(theta, x) - y, 2) / (2 * m);
+        return BigDecimal.valueOf(y * Math.pow(applyHypothesis(theta, x).doubleValue() -
y, 2) / (2 * m));
       }
     };
   }
 
   @Override
-  public double applyHypothesis(DoubleVector theta, DoubleVector x) {
-    return theta.dotUnsafe(x);
+  public BigDecimal applyHypothesis(DoubleVector theta, DoubleVector x) {
+    return BigDecimal.valueOf(theta.dotUnsafe(x));
   }
 
   @Override
-  public double calculateCostForItem(DoubleVector x, double y, int m,
+  public BigDecimal calculateCostForItem(DoubleVector x, double y, int m,
       DoubleVector theta) {
     return costFunction.calculateCostForItem(x, y, m, theta, this);
   }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java?rev=1531245&r1=1531244&r2=1531245&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
(original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
Fri Oct 11 09:46:29 2013
@@ -27,59 +27,71 @@ import org.apache.hama.ml.math.DoubleVec
  */
 public class LogisticRegressionModel implements RegressionModel {
 
+  private static final MathContext DEFAULT_PRECISION = MathContext.DECIMAL128;
+
   private final CostFunction costFunction;
 
   public LogisticRegressionModel() {
     costFunction = new CostFunction() {
       @Override
-      public double calculateCostForItem(DoubleVector x, double y, int m,
-          DoubleVector theta, HypothesisFunction hypothesis) {
-        return (-1d * y
-            * ln(applyHypothesisWithPrecision(theta, x)).doubleValue() + (1d - y)
-            * ln(
-                applyHypothesisWithPrecision(theta, x).subtract(
-                    BigDecimal.valueOf(1))).doubleValue())
-            / m;
+      public BigDecimal calculateCostForItem(DoubleVector x, double y, int m, DoubleVector
theta,
+              HypothesisFunction hypothesis) {
+        // -1/m*(y*ln(hx) + (1-y)*ln(1-hx))
+        BigDecimal hx = applyHypothesisWithPrecision(theta, x);
+        BigDecimal first = BigDecimal.valueOf(y).multiply(ln(hx));
+        BigDecimal logarg = BigDecimal.valueOf(1).subtract(hx, DEFAULT_PRECISION);
+        BigDecimal ln = ln(logarg);
+        BigDecimal second = BigDecimal.valueOf(1d - y).multiply(ln);
+        BigDecimal num = first.add(second);
+        BigDecimal den = BigDecimal.valueOf(-1*m);
+        BigDecimal res = num.divide(den, DEFAULT_PRECISION);
+        return res;
       }
     };
   }
 
   @Override
-  public double applyHypothesis(DoubleVector theta, DoubleVector x) {
-    return applyHypothesisWithPrecision(theta, x).doubleValue();
+  public BigDecimal applyHypothesis(DoubleVector theta, DoubleVector x) {
+    return applyHypothesisWithPrecision(theta, x);
   }
 
-  private BigDecimal applyHypothesisWithPrecision(DoubleVector theta,
-      DoubleVector x) {
-    return BigDecimal.valueOf(1).divide(
-        BigDecimal.valueOf(1d).add(
-            BigDecimal.valueOf(Math.exp(-1d * theta.dotUnsafe(x)))),
-        MathContext.DECIMAL128);
+  private BigDecimal applyHypothesisWithPrecision(DoubleVector theta, DoubleVector x) {
+    // 1 / (1 + (e^(-theta'x)))
+    double dotUnsafe = theta.multiply(-1d).dotUnsafe(x);
+    double d = Math.exp(dotUnsafe);
+    BigDecimal exp = BigDecimal.valueOf(d);
+    BigDecimal den = BigDecimal.valueOf(1d).add(exp);
+    BigDecimal remainder = BigDecimal.valueOf(1).subtract(den, DEFAULT_PRECISION);
+    BigDecimal res = BigDecimal.valueOf(1).divide(den, DEFAULT_PRECISION);
+    if (res.doubleValue() == 1 && remainder.doubleValue() < 0) {
+      res = res.add(remainder);
+    }
+    return res;
   }
 
   private BigDecimal ln(BigDecimal x) {
-    if (x.equals(BigDecimal.ONE)) {
-      return BigDecimal.ZERO;
-    }
-    x = x.subtract(BigDecimal.ONE);
-    int iterations = 1000;
-    BigDecimal ret = new BigDecimal(iterations + 1);
-    for (long i = iterations; i >= 0; i--) {
-      BigDecimal N = new BigDecimal(i / 2 + 1).pow(2);
-      N = N.multiply(x, MathContext.DECIMAL128);
-      ret = N.divide(ret, MathContext.DECIMAL128);
-
-      N = new BigDecimal(i + 1);
-      ret = ret.add(N, MathContext.DECIMAL128);
-
-    }
-    ret = x.divide(ret, MathContext.DECIMAL128);
-    return ret;
+//    if (x.equals(BigDecimal.ONE)) {
+//      return BigDecimal.ZERO;
+//    }
+//    x = x.subtract(BigDecimal.ONE);
+//    int iterations = 10000000;
+//    BigDecimal ret = new BigDecimal(iterations + 1);
+//    for (long i = iterations; i >= 0; i--) {
+//      BigDecimal N = new BigDecimal(i / 2 + 1).pow(2);
+//      N = N.multiply(x, DEFAULT_PRECISION);
+//      ret = N.divide(ret, DEFAULT_PRECISION);
+//
+//      N = new BigDecimal(i + 1);
+//      ret = ret.add(N, DEFAULT_PRECISION);
+//
+//    }
+//    ret = x.divide(ret, DEFAULT_PRECISION);
+//    return ret;
+    return BigDecimal.valueOf(Math.log(x.doubleValue()));
   }
 
   @Override
-  public double calculateCostForItem(DoubleVector x, double y, int m,
-      DoubleVector theta) {
+  public BigDecimal calculateCostForItem(DoubleVector x, double y, int m, DoubleVector theta)
{
     return costFunction.calculateCostForItem(x, y, m, theta, this);
   }
 }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java?rev=1531245&r1=1531244&r2=1531245&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java Fri Oct
11 09:46:29 2013
@@ -17,6 +17,8 @@
  */
 package org.apache.hama.ml.regression;
 
+import java.math.BigDecimal;
+
 import org.apache.hama.ml.math.DoubleVector;
 
 /**
@@ -32,9 +34,9 @@ public interface RegressionModel extends
    * @param y the learned output for x
    * @param m the total number of existing items
    * @param theta the parameters vector theta
-   * @return the calculated cost for input x and output y
+   * @return the calculated cost for input x and output y as a <code>BigDecimal</code>
    */
-  public double calculateCostForItem(DoubleVector x, double y, int m,
+  public BigDecimal calculateCostForItem(DoubleVector x, double y, int m,
       DoubleVector theta);
 
 }

Modified: hama/trunk/ml/src/test/java/org/apache/hama/ml/regression/LinearRegressionModelTest.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/regression/LinearRegressionModelTest.java?rev=1531245&r1=1531244&r2=1531245&view=diff
==============================================================================
--- hama/trunk/ml/src/test/java/org/apache/hama/ml/regression/LinearRegressionModelTest.java
(original)
+++ hama/trunk/ml/src/test/java/org/apache/hama/ml/regression/LinearRegressionModelTest.java
Fri Oct 11 09:46:29 2013
@@ -17,12 +17,14 @@
  */
 package org.apache.hama.ml.regression;
 
+import static org.junit.Assert.assertEquals;
+
+import java.math.BigDecimal;
+
 import org.apache.hama.ml.math.DenseDoubleVector;
 import org.apache.hama.ml.math.DoubleVector;
 import org.junit.Test;
 
-import static org.junit.Assert.assertEquals;
-
 /**
  * Testcase for {@link LinearRegressionModel}
  */
@@ -34,15 +36,15 @@ public class LinearRegressionModelTest {
     DoubleVector x = new DenseDoubleVector(new double[]{2, 3, 4});
     double y = 1;
     DoubleVector theta = new DenseDoubleVector(new double[]{1, 1, 1});
-    Double cost = linearRegressionModel.calculateCostForItem(x, y, 2, theta);
-    assertEquals("wrong cost calculation for linear regression", Double.valueOf(16d), cost);
+    BigDecimal cost = linearRegressionModel.calculateCostForItem(x, y, 2, theta);
+    assertEquals("wrong cost calculation for linear regression", BigDecimal.valueOf(16d),
cost);
   }
 
   @Test
   public void testCorrectHypothesisCalculation() throws Exception {
     LinearRegressionModel linearRegressionModel = new LinearRegressionModel();
-    Double hypothesisValue = linearRegressionModel.applyHypothesis(new DenseDoubleVector(new
double[]{1, 1, 1}),
+    BigDecimal hypothesisValue = linearRegressionModel.applyHypothesis(new DenseDoubleVector(new
double[]{1, 1, 1}),
             new DenseDoubleVector(new double[]{2, 3, 4}));
-    assertEquals("wrong hypothesis value for linear regression", Double.valueOf(9), hypothesisValue);
+    assertEquals("wrong hypothesis value for linear regression", BigDecimal.valueOf(9d),
hypothesisValue);
   }
 }

Modified: hama/trunk/ml/src/test/java/org/apache/hama/ml/regression/LogisticRegressionModelTest.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/regression/LogisticRegressionModelTest.java?rev=1531245&r1=1531244&r2=1531245&view=diff
==============================================================================
--- hama/trunk/ml/src/test/java/org/apache/hama/ml/regression/LogisticRegressionModelTest.java
(original)
+++ hama/trunk/ml/src/test/java/org/apache/hama/ml/regression/LogisticRegressionModelTest.java
Fri Oct 11 09:46:29 2013
@@ -17,12 +17,15 @@
  */
 package org.apache.hama.ml.regression;
 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+
+import java.math.BigDecimal;
+
 import org.apache.hama.ml.math.DenseDoubleVector;
 import org.apache.hama.ml.math.DoubleVector;
 import org.junit.Test;
 
-import static org.junit.Assert.assertEquals;
-
 /**
  * Testcase for {@link LogisticRegressionModel}
  */
@@ -34,15 +37,40 @@ public class LogisticRegressionModelTest
     DoubleVector x = new DenseDoubleVector(new double[]{2, 3, 4});
     double y = 1;
     DoubleVector theta = new DenseDoubleVector(new double[]{1, 1, 1});
-    Double cost = logisticRegressionModel.calculateCostForItem(x, y, 2, theta);
-    assertEquals("wrong cost calculation for logistic regression", Double.valueOf(6.170109486162941E-5),
cost);
+    BigDecimal cost = logisticRegressionModel.calculateCostForItem(x, y, 2, theta);
+    assertEquals("wrong cost calculation for logistic regression", 6.170109486162941E-5d,
cost.doubleValue(), 0.000001);
   }
 
   @Test
   public void testCorrectHypothesisCalculation() throws Exception {
     LogisticRegressionModel logisticRegressionModel = new LogisticRegressionModel();
-    Double hypothesisValue = logisticRegressionModel.applyHypothesis(new DenseDoubleVector(new
double[]{1, 1, 1}),
+    BigDecimal hypothesisValue = logisticRegressionModel.applyHypothesis(new DenseDoubleVector(new
double[]{1, 1, 1}),
             new DenseDoubleVector(new double[]{2, 3, 4}));
-    assertEquals("wrong hypothesis value for logistic regression", Double.valueOf(0.9998766054240138),
hypothesisValue);
+    assertEquals("wrong hypothesis value for logistic regression", 0.9998766054240137682597533152954043d,
hypothesisValue.doubleValue(), 0.000001);
+  }
+  
+  @Test
+  public void testMultipleCostCalculation() throws Exception {
+    LogisticRegressionModel logisticRegressionModel = new LogisticRegressionModel();
+    double[] theta1 = new double[] { 10.010000000474975, 10.050000002374873, 10.01600000075996,
+        10.018000000854954, 10.024000001139939, 10.038000001804903, 10.036000001709908 };
+    double[] theta2 = new double[] { 13.000000142492354, 25.00000071246177, 14.800000227987766,
+        15.400000256486237, 17.20000034198165, 21.400000541470945, 20.800000512972474 };
+
+    DenseDoubleVector theta1Vector = new DenseDoubleVector(theta1);
+    DenseDoubleVector theta2Vector = new DenseDoubleVector(theta2);
+
+    DenseDoubleVector x = new DenseDoubleVector(new double[] { 1, 10, 3, 2, 1, 6, 1 });
+
+    BigDecimal res1 = logisticRegressionModel.applyHypothesis(theta1Vector, x);
+    BigDecimal res2 = logisticRegressionModel.applyHypothesis(theta2Vector, x);
+
+    assertFalse(res1 + " shouldn't be the same as " + res2, res1.equals(res2));
+
+    BigDecimal itemCost1 = logisticRegressionModel.calculateCostForItem(x, 2, 8, theta1Vector);
+    BigDecimal itemCost2 = logisticRegressionModel.calculateCostForItem(x, 2, 8, theta2Vector);
+
+    assertFalse(itemCost1 + " shouldn't be the same as " + itemCost2, itemCost1.equals(itemCost2));
+
   }
 }



Mime
View raw message