ignite-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ch...@apache.org
Subject [1/2] ignite git commit: IGNITE-8176: Integrate gradient descent linear regression with partition based dataset
Date Thu, 12 Apr 2018 08:16:41 GMT
Repository: ignite
Updated Branches:
  refs/heads/master 67023a88b -> df6356d5d


http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
new file mode 100644
index 0000000..fa8fac4
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.linear;
+
+import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
+import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link LinearRegressionSGDTrainer}.
+ */
+@RunWith(Parameterized.class)
+public class LinearRegressionSGDTrainerTest {
+    /** Parameters. */
+    @Parameterized.Parameters(name = "Data divided on {0} partitions")
+    public static Iterable<Integer[]> data() {
+        return Arrays.asList(
+            new Integer[] {1},
+            new Integer[] {2},
+            new Integer[] {3},
+            new Integer[] {5},
+            new Integer[] {7},
+            new Integer[] {100}
+        );
+    }
+
+    /** Number of partitions. */
+    @Parameterized.Parameter
+    public int parts;
+
+    /**
+     * Tests {@code fit()} method on a simple small dataset.
+     */
+    @Test
+    public void testSmallDataFit() {
+        Map<Integer, double[]> data = new HashMap<>();
+        data.put(0, new double[] {-1.0915526, 1.81983527, -0.91409478, 0.70890712, -24.55724107});
+        data.put(1, new double[] {-0.61072904, 0.37545517, 0.21705352, 0.09516495, -26.57226867});
+        data.put(2, new double[] {0.05485406, 0.88219898, -0.80584547, 0.94668307, 61.80919728});
+        data.put(3, new double[] {-0.24835094, -0.34000053, -1.69984651, -1.45902635, -161.65525991});
+        data.put(4, new double[] {0.63675392, 0.31675535, 0.38837437, -1.1221971, -14.46432611});
+        data.put(5, new double[] {0.14194017, 2.18158997, -0.28397346, -0.62090588, -3.2122197});
+        data.put(6, new double[] {-0.53487507, 1.4454797, 0.21570443, -0.54161422, -46.5469012});
+        data.put(7, new double[] {-1.58812173, -0.73216803, -2.15670676, -1.03195988, -247.23559889});
+        data.put(8, new double[] {0.20702671, 0.92864654, 0.32721202, -0.09047503, 31.61484949});
+        data.put(9, new double[] {-0.37890345, -0.04846179, -0.84122753, -1.14667474, -124.92598583});
+
+        LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new
UpdatesStrategy<>(
+            new RPropUpdateCalculator(),
+            RPropParameterUpdate::sumLocal,
+            RPropParameterUpdate::avg
+        ), 100000,  10, 100, 123L);
+
+        LinearRegressionModel mdl = trainer.fit(
+            data,
+            parts,
+            (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
+            (k, v) -> v[4]
+        );
+
+        assertArrayEquals(
+            new double[] {72.26948107, 15.95144674, 24.07403921, 66.73038781},
+            mdl.getWeights().getStorage().data(),
+            1e-1
+        );
+
+        assertEquals(2.8421709430404007e-14, mdl.getIntercept(), 1e-1);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionSGDTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionSGDTrainerTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionSGDTrainerTest.java
deleted file mode 100644
index bea164d..0000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionSGDTrainerTest.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.regressions.linear;
-
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-
-/**
- * Tests for {@link LinearRegressionSGDTrainer} on {@link DenseLocalOnHeapMatrix}.
- */
-public class LocalLinearRegressionSGDTrainerTest extends GenericLinearRegressionTrainerTest
{
-    /** */
-    public LocalLinearRegressionSGDTrainerTest() {
-        super(
-            new LinearRegressionSGDTrainer(100_000, 1e-12),
-            DenseLocalOnHeapMatrix::new,
-            DenseLocalOnHeapVector::new,
-            1e-2);
-    }
-}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
index 26ba2fb..0befd9b 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
@@ -17,14 +17,14 @@
 
 package org.apache.ignite.ml.svm;
 
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.junit.Test;
+
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.ThreadLocalRandom;
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.junit.Test;
 
 /**
  * Tests for {@link SVMLinearBinaryClassificationTrainer}.
@@ -62,7 +62,8 @@ public class SVMBinaryTrainerTest {
         SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer();
 
         SVMLinearBinaryClassificationModel mdl = trainer.fit(
-            new LocalDatasetBuilder<>(data, 10),
+            data,
+            10,
             (k, v) -> Arrays.copyOfRange(v, 1, v.length),
             (k, v) -> v[0]
         );

http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
index ad95eb4..31ab4d7 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
@@ -17,14 +17,14 @@
 
 package org.apache.ignite.ml.svm;
 
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.junit.Test;
+
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.ThreadLocalRandom;
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.junit.Test;
 
 /**
  * Tests for {@link SVMLinearBinaryClassificationTrainer}.
@@ -65,7 +65,8 @@ public class SVMMultiClassTrainerTest {
             .withAmountOfIterations(20);
 
         SVMLinearMultiClassClassificationModel mdl = trainer.fit(
-            new LocalDatasetBuilder<>(data, 10),
+            data,
+            10,
             (k, v) -> Arrays.copyOfRange(v, 1, v.length),
             (k, v) -> v[0]
         );

http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
index 94bca3f..d5b0b86 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
@@ -17,16 +17,16 @@
 
 package org.apache.ignite.ml.tree;
 
-import java.util.Arrays;
-import java.util.Random;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
 
+import java.util.Arrays;
+import java.util.Random;
+
 /**
  * Tests for {@link DecisionTreeClassificationTrainer} that require to start the whole Ignite
infrastructure.
  */
@@ -77,7 +77,8 @@ public class DecisionTreeClassificationTrainerIntegrationTest extends GridCommon
         DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1,
0);
 
         DecisionTreeNode tree = trainer.fit(
-            new CacheBasedDatasetBuilder<>(ignite, data),
+            ignite,
+            data,
             (k, v) -> Arrays.copyOf(v, v.length - 1),
             (k, v) -> v[v.length - 1]
         );

http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
index 2599bfe..12ef698 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
@@ -17,17 +17,12 @@
 
 package org.apache.ignite.ml.tree;
 
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
+import java.util.*;
+
 import static junit.framework.TestCase.assertEquals;
 import static junit.framework.TestCase.assertTrue;
 
@@ -68,7 +63,8 @@ public class DecisionTreeClassificationTrainerTest {
         DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1,
0);
 
         DecisionTreeNode tree = trainer.fit(
-            new LocalDatasetBuilder<>(data, parts),
+            data,
+            parts,
             (k, v) -> Arrays.copyOf(v, v.length - 1),
             (k, v) -> v[v.length - 1]
         );

http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
index 754ff20..c2a4638 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
@@ -17,16 +17,16 @@
 
 package org.apache.ignite.ml.tree;
 
-import java.util.Arrays;
-import java.util.Random;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
 
+import java.util.Arrays;
+import java.util.Random;
+
 /**
  * Tests for {@link DecisionTreeRegressionTrainer} that require to start the whole Ignite
infrastructure.
  */
@@ -77,7 +77,8 @@ public class DecisionTreeRegressionTrainerIntegrationTest extends GridCommonAbst
         DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0);
 
         DecisionTreeNode tree = trainer.fit(
-            new CacheBasedDatasetBuilder<>(ignite, data),
+            ignite,
+            data,
             (k, v) -> Arrays.copyOf(v, v.length - 1),
             (k, v) -> v[v.length - 1]
         );

http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
index 3bdbf60..bcfb53f 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
@@ -17,17 +17,12 @@
 
 package org.apache.ignite.ml.tree;
 
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
+import java.util.*;
+
 import static junit.framework.TestCase.assertEquals;
 import static junit.framework.TestCase.assertTrue;
 
@@ -68,7 +63,8 @@ public class DecisionTreeRegressionTrainerTest {
         DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0);
 
         DecisionTreeNode tree = trainer.fit(
-            new LocalDatasetBuilder<>(data, parts),
+            data,
+            parts,
             (k, v) -> Arrays.copyOf(v, v.length - 1),
             (k, v) -> v[v.length - 1]
         );

http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java
index b259ec9..35f805e 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java
@@ -17,13 +17,11 @@
 
 package org.apache.ignite.ml.tree.performance;
 
-import java.io.IOException;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.nn.performance.MnistMLPTestUtil;
 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeNode;
@@ -31,6 +29,8 @@ import org.apache.ignite.ml.tree.impurity.util.SimpleStepFunctionCompressor;
 import org.apache.ignite.ml.util.MnistUtils;
 import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
 
+import java.io.IOException;
+
 /**
  * Tests {@link DecisionTreeClassificationTrainer} on the MNIST dataset that require to start
the whole Ignite
  * infrastructure. For manual run.
@@ -81,7 +81,8 @@ public class DecisionTreeMNISTIntegrationTest extends GridCommonAbstractTest
{
             new SimpleStepFunctionCompressor<>());
 
         DecisionTreeNode mdl = trainer.fit(
-            new CacheBasedDatasetBuilder<>(ignite, trainingSet),
+            ignite,
+            trainingSet,
             (k, v) -> v.getPixels(),
             (k, v) -> (double) v.getLabel()
         );

http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java
index 6dbd44c..b40c7ac 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java
@@ -17,10 +17,6 @@
 
 package org.apache.ignite.ml.tree.performance;
 
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.nn.performance.MnistMLPTestUtil;
 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeNode;
@@ -28,6 +24,10 @@ import org.apache.ignite.ml.tree.impurity.util.SimpleStepFunctionCompressor;
 import org.apache.ignite.ml.util.MnistUtils;
 import org.junit.Test;
 
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
 import static junit.framework.TestCase.assertTrue;
 
 /**
@@ -50,7 +50,8 @@ public class DecisionTreeMNISTTest {
             new SimpleStepFunctionCompressor<>());
 
         DecisionTreeNode mdl = trainer.fit(
-            new LocalDatasetBuilder<>(trainingSet, 10),
+            trainingSet,
+            10,
             (k, v) -> v.getPixels(),
             (k, v) -> (double) v.getLabel()
         );


Mime
View raw message