ignite-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sboi...@apache.org
Subject [21/30] ignite git commit: IGNITE-6872: Linear regression should implement Model API
Date Fri, 08 Dec 2017 08:34:52 GMT
http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
index acc5649..d0d1247 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
@@ -20,12 +20,14 @@ package org.apache.ignite.ml;
 import java.io.IOException;
 import java.nio.file.Files;
 import java.nio.file.Path;
-import java.nio.file.Paths;
+import java.util.function.Function;
 import org.apache.ignite.ml.clustering.KMeansLocalClusterer;
 import org.apache.ignite.ml.clustering.KMeansModel;
 import org.apache.ignite.ml.math.EuclideanDistance;
 import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.junit.After;
+import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionModel;
+import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionModelFormat;
+import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionTrainer;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -34,39 +36,68 @@ import org.junit.Test;
  */
 public class LocalModelsTest {
     /** */
-    private String mdlFilePath = "model.mlmod";
-
-    /**
-     *
-     */
-    @After
-    public void cleanUp() throws IOException {
-        Files.deleteIfExists(Paths.get(mdlFilePath));
+    @Test
+    public void importExportKMeansModelTest() throws IOException {
+        executeModelTest(mdlFilePath -> {
+            KMeansModel mdl = getClusterModel();
+
+            Exporter<KMeansModelFormat, String> exporter = new FileExporter<>();
+
+            mdl.saveModel(exporter, mdlFilePath);
+
+            KMeansModelFormat load = exporter.load(mdlFilePath);
+
+            Assert.assertNotNull(load);
+
+            KMeansModel importedMdl = new KMeansModel(load.getCenters(), load.getDistance());
+
+            Assert.assertTrue("", mdl.equals(importedMdl));
+
+            return null;
+        });
     }
 
-    /**
-     *
-     */
+    /** */
     @Test
-    public void importExportKMeansModelTest() {
-        Path mdlPath = Paths.get(mdlFilePath);
+    public void importExportOLSMultipleLinearRegressionModelTest() throws IOException {
+        executeModelTest(mdlFilePath -> {
+            OLSMultipleLinearRegressionModel mdl = getAbstractMultipleLinearRegressionModel();
+
+            Exporter<OLSMultipleLinearRegressionModelFormat, String> exporter = new
FileExporter<>();
 
-        KMeansModel mdl = getClusterModel();
+            mdl.saveModel(exporter, mdlFilePath);
 
-        Exporter<KMeansModelFormat, String> exporter = new FileExporter<>();
-        mdl.saveModel(exporter, mdlFilePath);
+            OLSMultipleLinearRegressionModelFormat load = exporter.load(mdlFilePath);
 
-        Assert.assertTrue(String.format("File %s not found.", mdlPath.toString()), Files.exists(mdlPath));
+            Assert.assertNotNull(load);
 
-        KMeansModelFormat load = exporter.load(mdlFilePath);
-        KMeansModel importedMdl = new KMeansModel(load.getCenters(), load.getDistance());
+            OLSMultipleLinearRegressionModel importedMdl = load.getOLSMultipleLinearRegressionModel();
 
-        Assert.assertTrue("", mdl.equals(importedMdl));
+            Assert.assertTrue("", mdl.equals(importedMdl));
+
+            return null;
+        });
     }
 
-    /**
-     *
-     */
+    /** */
+    private void executeModelTest(Function<String, Void> code) throws IOException {
+        Path mdlPath = Files.createTempFile(null, null);
+
+        Assert.assertNotNull(mdlPath);
+
+        try {
+            String mdlFilePath = mdlPath.toAbsolutePath().toString();
+
+            Assert.assertTrue(String.format("File %s not found.", mdlFilePath), Files.exists(mdlPath));
+
+            code.apply(mdlFilePath);
+        }
+        finally {
+            Files.deleteIfExists(mdlPath);
+        }
+    }
+
+    /** */
     private KMeansModel getClusterModel() {
         KMeansLocalClusterer clusterer = new KMeansLocalClusterer(new EuclideanDistance(),
1, 1L);
 
@@ -77,4 +108,22 @@ public class LocalModelsTest {
 
         return clusterer.cluster(points, 1);
     }
+
+    /** */
+    private OLSMultipleLinearRegressionModel getAbstractMultipleLinearRegressionModel() {
+        double[] data = new double[] {
+            0, 0, 0, 0, 0, 0, // IMPL NOTE values in this row are later replaced (with 1.0)
+            0, 2.0, 0, 0, 0, 0,
+            0, 0, 3.0, 0, 0, 0,
+            0, 0, 0, 4.0, 0, 0,
+            0, 0, 0, 0, 5.0, 0,
+            0, 0, 0, 0, 0, 6.0};
+
+        final int nobs = 6, nvars = 5;
+
+        OLSMultipleLinearRegressionTrainer trainer
+            = new OLSMultipleLinearRegressionTrainer(0, nobs, nvars, new DenseLocalOnHeapMatrix(1,
1));
+
+        return trainer.train(data);
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java
b/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java
index 216fd7b..af2154e 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java
@@ -20,6 +20,7 @@ package org.apache.ignite.ml.math;
 import org.apache.ignite.ml.math.decompositions.CholeskyDecompositionTest;
 import org.apache.ignite.ml.math.decompositions.EigenDecompositionTest;
 import org.apache.ignite.ml.math.decompositions.LUDecompositionTest;
+import org.apache.ignite.ml.math.decompositions.QRDSolverTest;
 import org.apache.ignite.ml.math.decompositions.QRDecompositionTest;
 import org.apache.ignite.ml.math.decompositions.SingularValueDecompositionTest;
 import org.apache.ignite.ml.math.impls.matrix.DenseLocalOffHeapMatrixConstructorTest;
@@ -116,6 +117,7 @@ import org.junit.runners.Suite;
     EigenDecompositionTest.class,
     CholeskyDecompositionTest.class,
     QRDecompositionTest.class,
+    QRDSolverTest.class,
     SingularValueDecompositionTest.class
 })
 public class MathImplLocalTestSuite {

http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/math/decompositions/QRDSolverTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/decompositions/QRDSolverTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/math/decompositions/QRDSolverTest.java
new file mode 100644
index 0000000..d3e8e76
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/decompositions/QRDSolverTest.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.decompositions;
+
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+/** */
+public class QRDSolverTest {
+    /** */
+    @Test
+    public void basicTest() {
+        Matrix m = new DenseLocalOnHeapMatrix(new double[][] {
+            {2.0d, -1.0d, 0.0d},
+            {-1.0d, 2.0d, -1.0d},
+            {0.0d, -1.0d, 2.0d}
+        });
+
+        QRDecomposition dec = new QRDecomposition(m);
+        assertTrue("Unexpected value for full rank in decomposition " + dec, dec.hasFullRank());
+
+        Matrix q = dec.getQ();
+        Matrix r = dec.getR();
+
+        assertNotNull("Matrix q is expected to be not null.", q);
+        assertNotNull("Matrix r is expected to be not null.", r);
+
+        Matrix qSafeCp = safeCopy(q);
+
+        Matrix expIdentity = qSafeCp.times(qSafeCp.transpose());
+
+        final double delta = 0.0001;
+
+        for (int row = 0; row < expIdentity.rowSize(); row++)
+            for (int col = 0; col < expIdentity.columnSize(); col++)
+                assertEquals("Unexpected identity matrix value at (" + row + "," + col +
").",
+                    row == col ? 1d : 0d, expIdentity.get(col, row), delta);
+
+        for (int row = 0; row < r.rowSize(); row++)
+            for (int col = 0; col < row - 1; col++)
+                assertEquals("Unexpected upper triangular matrix value at (" + row + ","
+ col + ").",
+                    0d, r.get(row, col), delta);
+
+        Matrix recomposed = qSafeCp.times(r);
+
+        for (int row = 0; row < m.rowSize(); row++)
+            for (int col = 0; col < m.columnSize(); col++)
+                assertEquals("Unexpected recomposed matrix value at (" + row + "," + col
+ ").",
+                    m.get(row, col), recomposed.get(row, col), delta);
+
+        Matrix sol = new QRDSolver(q, r).solve(new DenseLocalOnHeapMatrix(3, 10));
+        assertEquals("Unexpected rows in solution matrix.", 3, sol.rowSize());
+        assertEquals("Unexpected cols in solution matrix.", 10, sol.columnSize());
+
+        for (int row = 0; row < sol.rowSize(); row++)
+            for (int col = 0; col < sol.columnSize(); col++)
+                assertEquals("Unexpected solution matrix value at (" + row + "," + col +
").",
+                    0d, sol.get(row, col), delta);
+
+        dec.destroy();
+    }
+
+    /** */
+    private Matrix safeCopy(Matrix orig) {
+        return new DenseLocalOnHeapMatrix(orig.rowSize(), orig.columnSize()).assign(orig);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedBlockOLSMultipleLinearRegressionTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedBlockOLSMultipleLinearRegressionTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedBlockOLSMultipleLinearRegressionTest.java
index a482737..8c9d429 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedBlockOLSMultipleLinearRegressionTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedBlockOLSMultipleLinearRegressionTest.java
@@ -35,7 +35,6 @@ import org.junit.Assert;
 /**
  * Tests for {@link OLSMultipleLinearRegression}.
  */
-
 @GridCommonTest(group = "Distributed Models")
 public class DistributedBlockOLSMultipleLinearRegressionTest extends GridCommonAbstractTest
{
     /** */
@@ -95,7 +94,7 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends GridCommonA
     }
 
     /** */
-    protected OLSMultipleLinearRegression createRegression() {
+    private OLSMultipleLinearRegression createRegression() {
         OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
         regression.newSampleData(new SparseBlockDistributedVector(y), new SparseBlockDistributedMatrix(x));
         return regression;
@@ -243,7 +242,6 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends GridCommonA
         // Check R-Square statistics against R
         Assert.assertEquals(0.9999670130706, mdl.calculateRSquared(), 1E-12);
         Assert.assertEquals(0.999947220913, mdl.calculateAdjustedRSquared(), 1E-12);
-
     }
 
     /**
@@ -533,12 +531,12 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends
GridCommonA
 
         try {
             createRegression().newSampleData(null, new SparseBlockDistributedMatrix(new double[][]
{{1}}));
-            fail("NullArgumentException");
+            fail("Expected NullArgumentException was not caught.");
         }
         catch (NullArgumentException e) {
             return;
         }
-        fail("NullArgumentException");
+        fail("Expected NullArgumentException was not caught.");
     }
 
     /** */
@@ -547,13 +545,12 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends
GridCommonA
 
         try {
             createRegression().newSampleData(new SparseBlockDistributedVector(new double[]
{1}), null);
-            fail("NullArgumentException");
+            fail("Expected NullArgumentException was not caught.");
         }
         catch (NullArgumentException e) {
             return;
         }
-        fail("NullArgumentException");
-
+        fail("Expected NullArgumentException was not caught.");
     }
 
     /**
@@ -830,17 +827,16 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends
GridCommonA
     public void testSingularCalculateBeta() {
         IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
         OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(1e-15);
-        mdl.newSampleData(new double[] {1, 2, 3, 1, 2, 3, 1, 2, 3}, 3, 2, new SparseBlockDistributedMatrix());
 
         try {
+            mdl.newSampleData(new double[] {1, 2, 3, 1, 2, 3, 1, 2, 3}, 3, 2, new SparseBlockDistributedMatrix());
             mdl.calculateBeta();
-            fail("SingularMatrixException");
+            fail("Expected SingularMatrixException was not caught.");
         }
         catch (SingularMatrixException e) {
             return;
         }
-        fail("SingularMatrixException");
-
+        fail("Expected SingularMatrixException was not caught.");
     }
 
     /** */
@@ -850,13 +846,12 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends
GridCommonA
 
         try {
             mdl.calculateBeta();
-            fail("java.lang.NullPointerException");
+            fail("Expected NullPointerException was not caught.");
         }
         catch (NullPointerException e) {
             return;
         }
-        fail("java.lang.NullPointerException");
-
+        fail("Expected NullPointerException was not caught.");
     }
 
     /** */
@@ -866,12 +861,12 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends
GridCommonA
 
         try {
             mdl.calculateHat();
-            fail("java.lang.NullPointerException");
+            fail("Expected NullPointerException was not caught.");
         }
         catch (NullPointerException e) {
             return;
         }
-        fail("java.lang.NullPointerException");
+        fail("Expected NullPointerException was not caught.");
     }
 
     /** */
@@ -881,13 +876,12 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends
GridCommonA
 
         try {
             mdl.calculateTotalSumOfSquares();
-            fail("java.lang.NullPointerException");
+            fail("Expected NullPointerException was not caught.");
         }
         catch (NullPointerException e) {
             return;
         }
-        fail("java.lang.NullPointerException");
-
+        fail("Expected NullPointerException was not caught.");
     }
 
     /** */
@@ -897,11 +891,11 @@ public class DistributedBlockOLSMultipleLinearRegressionTest extends
GridCommonA
 
         try {
             mdl.validateSampleData(new SparseBlockDistributedMatrix(1, 2), new SparseBlockDistributedVector(1));
-            fail("MathIllegalArgumentException");
+            fail("Expected MathIllegalArgumentException was not caught.");
         }
         catch (MathIllegalArgumentException e) {
             return;
         }
-        fail("MathIllegalArgumentException");
+        fail("Expected MathIllegalArgumentException was not caught.");
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedOLSMultipleLinearRegressionTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedOLSMultipleLinearRegressionTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedOLSMultipleLinearRegressionTest.java
index a2d1e5f..f720406 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedOLSMultipleLinearRegressionTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/DistributedOLSMultipleLinearRegressionTest.java
@@ -35,7 +35,6 @@ import org.junit.Assert;
 /**
  * Tests for {@link OLSMultipleLinearRegression}.
  */
-
 @GridCommonTest(group = "Distributed Models")
 public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstractTest {
     /** */
@@ -58,9 +57,7 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
 
     /** */
     public DistributedOLSMultipleLinearRegressionTest() {
-
         super(false);
-
     }
 
     /** {@inheritDoc} */
@@ -97,7 +94,7 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
     }
 
     /** */
-    protected OLSMultipleLinearRegression createRegression() {
+    private OLSMultipleLinearRegression createRegression() {
         OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
         regression.newSampleData(new SparseDistributedVector(y), new SparseDistributedMatrix(x));
         return regression;
@@ -245,7 +242,6 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
         // Check R-Square statistics against R
         Assert.assertEquals(0.9999670130706, mdl.calculateRSquared(), 1E-12);
         Assert.assertEquals(0.999947220913, mdl.calculateAdjustedRSquared(), 1E-12);
-
     }
 
     /**
@@ -526,7 +522,6 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
         }
         for (int i = 0; i < combinedY.size(); i++)
             Assert.assertEquals(combinedY.get(i), regression.getY().get(i), PRECISION);
-
     }
 
     /** */
@@ -535,12 +530,12 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
 
         try {
             createRegression().newSampleData(null, new SparseDistributedMatrix(new double[][]
{{1}}));
-            fail("NullArgumentException");
+            fail("Expected NullArgumentException was not caught.");
         }
         catch (NullArgumentException e) {
             return;
         }
-        fail("NullArgumentException");
+        fail("Expected NullArgumentException was not caught.");
     }
 
     /** */
@@ -549,12 +544,12 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
 
         try {
             createRegression().newSampleData(new SparseDistributedVector(new double[] {1}),
null);
-            fail("NullArgumentException");
+            fail("Expected NullArgumentException was not caught.");
         }
         catch (NullArgumentException e) {
             return;
         }
-        fail("NullArgumentException");
+        fail("Expected NullArgumentException was not caught.");
 
     }
 
@@ -832,16 +827,16 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
     public void testSingularCalculateBeta() {
         IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
         OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(1e-15);
-        mdl.newSampleData(new double[] {1, 2, 3, 1, 2, 3, 1, 2, 3}, 3, 2, new SparseDistributedMatrix());
 
         try {
+            mdl.newSampleData(new double[] {1, 2, 3, 1, 2, 3, 1, 2, 3}, 3, 2, new SparseDistributedMatrix());
             mdl.calculateBeta();
-            fail("SingularMatrixException");
+            fail("Expected SingularMatrixException was not caught.");
         }
         catch (SingularMatrixException e) {
             return;
         }
-        fail("SingularMatrixException");
+        fail("Expected SingularMatrixException was not caught.");
 
     }
 
@@ -852,12 +847,12 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
 
         try {
             mdl.calculateBeta();
-            fail("java.lang.NullPointerException");
+            fail("Expected NullPointerException was not caught.");
         }
         catch (NullPointerException e) {
             return;
         }
-        fail("java.lang.NullPointerException");
+        fail("Expected NullPointerException was not caught.");
 
     }
 
@@ -868,12 +863,12 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
 
         try {
             mdl.calculateHat();
-            fail("java.lang.NullPointerException");
+            fail("Expected NullPointerException was not caught.");
         }
         catch (NullPointerException e) {
             return;
         }
-        fail("java.lang.NullPointerException");
+        fail("Expected NullPointerException was not caught.");
     }
 
     /** */
@@ -883,13 +878,12 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
 
         try {
             mdl.calculateTotalSumOfSquares();
-            fail("java.lang.NullPointerException");
+            fail("Expected NullPointerException was not caught.");
         }
         catch (NullPointerException e) {
             return;
         }
-        fail("java.lang.NullPointerException");
-
+        fail("Expected NullPointerException was not caught.");
     }
 
     /** */
@@ -899,11 +893,11 @@ public class DistributedOLSMultipleLinearRegressionTest extends GridCommonAbstra
 
         try {
             mdl.validateSampleData(new SparseDistributedMatrix(1, 2), new SparseDistributedVector(1));
-            fail("MathIllegalArgumentException");
+            fail("Expected MathIllegalArgumentException was not caught.");
         }
         catch (MathIllegalArgumentException e) {
             return;
         }
-        fail("MathIllegalArgumentException");
+        fail("Expected MathIllegalArgumentException was not caught.");
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelTest.java
new file mode 100644
index 0000000..37c972c
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelTest.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions;
+
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.junit.Test;
+
+/**
+ * Tests for {@link OLSMultipleLinearRegressionModel}.
+ */
+public class OLSMultipleLinearRegressionModelTest {
+    /** */
+    @Test
+    public void testPerfectFit() {
+        Vector val = new DenseLocalOnHeapVector(new double[] {11.0, 12.0, 13.0, 14.0, 15.0,
16.0});
+
+        double[] data = new double[] {
+            0, 0, 0, 0, 0, 0, // IMPL NOTE values in this row are later replaced (with 1.0)
+            0, 2.0, 0, 0, 0, 0,
+            0, 0, 3.0, 0, 0, 0,
+            0, 0, 0, 4.0, 0, 0,
+            0, 0, 0, 0, 5.0, 0,
+            0, 0, 0, 0, 0, 6.0};
+
+        final int nobs = 6, nvars = 5;
+
+        OLSMultipleLinearRegressionTrainer trainer
+            = new OLSMultipleLinearRegressionTrainer(0, nobs, nvars, new DenseLocalOnHeapMatrix(1,
1));
+
+        OLSMultipleLinearRegressionModel mdl = trainer.train(data);
+
+        TestUtils.assertEquals(new double[] {0d, 0d, 0d, 0d, 0d, 0d},
+            val.minus(mdl.predict(val)).getStorage().data(), 1e-13);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
index 2a0b111..be71934 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
@@ -25,7 +25,10 @@ import org.junit.runners.Suite;
  */
 @RunWith(Suite.class)
 @Suite.SuiteClasses({
-    OLSMultipleLinearRegressionTest.class, DistributedOLSMultipleLinearRegressionTest.class,
DistributedBlockOLSMultipleLinearRegressionTest.class
+    OLSMultipleLinearRegressionTest.class,
+    DistributedOLSMultipleLinearRegressionTest.class,
+    DistributedBlockOLSMultipleLinearRegressionTest.class,
+    OLSMultipleLinearRegressionModelTest.class
 })
 public class RegressionsTestSuite {
     // No-op.


Mime
View raw message