ignite-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ch...@apache.org
Subject ignite git commit: IGNITE-8667: Splitting of dataset to test and training sets
Date Tue, 05 Jun 2018 16:52:20 GMT
Repository: ignite
Updated Branches:
  refs/heads/master d61c0685c -> 2b4762a49


IGNITE-8667: Splitting of dataset to test and training sets

this closes #4124


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/2b4762a4
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/2b4762a4
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/2b4762a4

Branch: refs/heads/master
Commit: 2b4762a49c021aaa71c9e69b7d120153c0e5a61f
Parents: d61c068
Author: Anton Dmitriev <dmitrievanthony@gmail.com>
Authored: Tue Jun 5 19:52:08 2018 +0300
Committer: Yury Babak <ybabak@gridgain.com>
Committed: Tue Jun 5 19:52:08 2018 +0300

----------------------------------------------------------------------
 .../LinearRegressionLSQRTrainerExample.java     |   9 +-
 .../examples/ml/selection/package-info.java     |  22 +++
 .../split/TrainTestDatasetSplitterExample.java  | 175 +++++++++++++++++++
 .../ml/selection/split/package-info.java        |  22 +++
 .../dataset/impl/cache/util/ComputeUtils.java   |   4 +-
 .../ignite/ml/selection/package-info.java       |  22 +++
 .../split/TrainTestDatasetSplitter.java         | 122 +++++++++++++
 .../ml/selection/split/TrainTestSplit.java      |  59 +++++++
 .../split/mapper/SHA256UniformMapper.java       |  95 ++++++++++
 .../selection/split/mapper/UniformMapper.java   |  38 ++++
 .../ml/selection/split/mapper/package-info.java |  22 +++
 .../ignite/ml/selection/split/package-info.java |  22 +++
 .../ignite/ml/trainers/DatasetTrainer.java      |  51 +++++-
 .../linear/LinearRegressionLSQRTrainerTest.java |   7 +-
 .../split/TrainTestDatasetSplitterTest.java     |  46 +++++
 .../split/mapper/SHA256UniformMapperTest.java   |  70 ++++++++
 16 files changed, 771 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
index 04d1778..bfb4e0a 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
@@ -17,6 +17,9 @@
 
 package org.apache.ignite.examples.ml.regression.linear;
 
+import java.util.Arrays;
+import java.util.UUID;
+import javax.cache.Cache;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
@@ -29,10 +32,6 @@ import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
 import org.apache.ignite.thread.IgniteThread;
 
-import javax.cache.Cache;
-import java.util.Arrays;
-import java.util.UUID;
-
 /**
  * Run linear regression model over cached dataset.
  *
@@ -99,7 +98,7 @@ public class LinearRegressionLSQRTrainerExample {
     /** Run example. */
     public static void main(String[] args) throws InterruptedException {
         System.out.println();
-        System.out.println(">>> Linear regression model over sparse distributed
matrix API usage example started.");
+        System.out.println(">>> Linear regression model over cache based dataset
usage example started.");
         // Start ignite grid.
         try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
             System.out.println(">>> Ignite grid started.");

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/examples/src/main/java/org/apache/ignite/examples/ml/selection/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/package-info.java
b/examples/src/main/java/org/apache/ignite/examples/ml/selection/package-info.java
new file mode 100644
index 0000000..c3a264a
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * ML model selection examples.
+ */
+package org.apache.ignite.examples.ml.selection;

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
b/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
new file mode 100644
index 0000000..ebd899c
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.selection.split;
+
+import java.util.Arrays;
+import java.util.UUID;
+import javax.cache.Cache;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
+import org.apache.ignite.ml.selection.split.TrainTestSplit;
+import org.apache.ignite.thread.IgniteThread;
+
+/**
+ * Run linear regression model over dataset splitted on train and test subsets.
+ *
+ * @see TrainTestDatasetSplitter
+ */
+public class TrainTestDatasetSplitterExample {
+    /** */
+    private static final double[][] data = {
+        {8, 78, 284, 9.100000381, 109},
+        {9.300000191, 68, 433, 8.699999809, 144},
+        {7.5, 70, 739, 7.199999809, 113},
+        {8.899999619, 96, 1792, 8.899999619, 97},
+        {10.19999981, 74, 477, 8.300000191, 206},
+        {8.300000191, 111, 362, 10.89999962, 124},
+        {8.800000191, 77, 671, 10, 152},
+        {8.800000191, 168, 636, 9.100000381, 162},
+        {10.69999981, 82, 329, 8.699999809, 150},
+        {11.69999981, 89, 634, 7.599999905, 134},
+        {8.5, 149, 631, 10.80000019, 292},
+        {8.300000191, 60, 257, 9.5, 108},
+        {8.199999809, 96, 284, 8.800000191, 111},
+        {7.900000095, 83, 603, 9.5, 182},
+        {10.30000019, 130, 686, 8.699999809, 129},
+        {7.400000095, 145, 345, 11.19999981, 158},
+        {9.600000381, 112, 1357, 9.699999809, 186},
+        {9.300000191, 131, 544, 9.600000381, 177},
+        {10.60000038, 80, 205, 9.100000381, 127},
+        {9.699999809, 130, 1264, 9.199999809, 179},
+        {11.60000038, 140, 688, 8.300000191, 80},
+        {8.100000381, 154, 354, 8.399999619, 103},
+        {9.800000191, 118, 1632, 9.399999619, 101},
+        {7.400000095, 94, 348, 9.800000191, 117},
+        {9.399999619, 119, 370, 10.39999962, 88},
+        {11.19999981, 153, 648, 9.899999619, 78},
+        {9.100000381, 116, 366, 9.199999809, 102},
+        {10.5, 97, 540, 10.30000019, 95},
+        {11.89999962, 176, 680, 8.899999619, 80},
+        {8.399999619, 75, 345, 9.600000381, 92},
+        {5, 134, 525, 10.30000019, 126},
+        {9.800000191, 161, 870, 10.39999962, 108},
+        {9.800000191, 111, 669, 9.699999809, 77},
+        {10.80000019, 114, 452, 9.600000381, 60},
+        {10.10000038, 142, 430, 10.69999981, 71},
+        {10.89999962, 238, 822, 10.30000019, 86},
+        {9.199999809, 78, 190, 10.69999981, 93},
+        {8.300000191, 196, 867, 9.600000381, 106},
+        {7.300000191, 125, 969, 10.5, 162},
+        {9.399999619, 82, 499, 7.699999809, 95},
+        {9.399999619, 125, 925, 10.19999981, 91},
+        {9.800000191, 129, 353, 9.899999619, 52},
+        {3.599999905, 84, 288, 8.399999619, 110},
+        {8.399999619, 183, 718, 10.39999962, 69},
+        {10.80000019, 119, 540, 9.199999809, 57},
+        {10.10000038, 180, 668, 13, 106},
+        {9, 82, 347, 8.800000191, 40},
+        {10, 71, 345, 9.199999809, 50},
+        {11.30000019, 118, 463, 7.800000191, 35},
+        {11.30000019, 121, 728, 8.199999809, 86},
+        {12.80000019, 68, 383, 7.400000095, 57},
+        {10, 112, 316, 10.39999962, 57},
+        {6.699999809, 109, 388, 8.899999619, 94}
+    };
+
+    /** Run example. */
+    public static void main(String[] args) throws InterruptedException {
+        System.out.println();
+        System.out.println(">>> Linear regression model over cache based dataset
usage example started.");
+        // Start ignite grid.
+        try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
+                TrainTestDatasetSplitterExample.class.getSimpleName(), () -> {
+                IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
+
+                System.out.println(">>> Create new linear regression trainer object.");
+                LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
+
+                TrainTestSplit<Integer, double[]> split = new TrainTestDatasetSplitter<Integer,
double[]>()
+                    .split(0.75);
+
+                System.out.println(">>> Perform the training to get the model.");
+                LinearRegressionModel mdl = trainer.fit(
+                    ignite,
+                    dataCache,
+                    split.getTrainFilter(),
+                    (k, v) -> Arrays.copyOfRange(v, 1, v.length),
+                    (k, v) -> v[0]
+                );
+
+                System.out.println(">>> Linear regression model: " + mdl);
+
+                System.out.println(">>> ---------------------------------");
+                System.out.println(">>> | Prediction\t| Ground Truth\t|");
+                System.out.println(">>> ---------------------------------");
+
+                ScanQuery<Integer, double[]> qry = new ScanQuery<>();
+                qry.setFilter(split.getTestFilter());
+
+                try (QueryCursor<Cache.Entry<Integer, double[]>> observations
= dataCache.query(qry)) {
+                    for (Cache.Entry<Integer, double[]> observation : observations)
{
+                        double[] val = observation.getValue();
+                        double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+                        double groundTruth = val[0];
+
+                        double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
+
+                        System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction,
groundTruth);
+                    }
+                }
+
+                System.out.println(">>> ---------------------------------");
+            });
+
+            igniteThread.start();
+
+            igniteThread.join();
+        }
+    }
+
+    /**
+     * Fills cache with data and returns it.
+     *
+     * @param ignite Ignite instance.
+     * @return Filled Ignite Cache.
+     */
+    private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
+        CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
+        cacheConfiguration.setName("TEST_" + UUID.randomUUID());
+        cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 3));
+
+        IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
+
+        for (int i = 0; i < data.length; i++)
+            cache.put(i, data[i]);
+
+        return cache;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/package-info.java
b/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/package-info.java
new file mode 100644
index 0000000..55bfaeb
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * ML dataset split examples.
+ */
+package org.apache.ignite.examples.ml.selection.split;

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
index b235900..39b3703 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
@@ -173,7 +173,7 @@ public class ComputeUtils {
                     e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
 
                     Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(cursor.iterator(),
cnt,
-                        "Cache expected to be not modified during dataset data building");
+                        "Cache expected to be not modified during dataset data building [partition="
+ part + ']');
 
                     return partDataBuilder.build(iter, cnt, ctx);
                 }
@@ -215,7 +215,7 @@ public class ComputeUtils {
                 e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
 
                 Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(cursor.iterator(),
cnt,
-                    "Cache expected to be not modified during dataset context building");
+                    "Cache expected to be not modified during dataset data building [partition="
+ part + ']');
 
                 ctx = ctxBuilder.build(iter, cnt);
             }

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/modules/ml/src/main/java/org/apache/ignite/ml/selection/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/package-info.java
new file mode 100644
index 0000000..bfefe96
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Root package for dataset splitters, cross validation and search through parameters.
+ */
+package org.apache.ignite.ml.selection;
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/TrainTestDatasetSplitter.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/TrainTestDatasetSplitter.java
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/TrainTestDatasetSplitter.java
new file mode 100644
index 0000000..ae94e44
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/TrainTestDatasetSplitter.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.split;
+
+import java.io.Serializable;
+import org.apache.ignite.lang.IgniteBiPredicate;
+import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
+import org.apache.ignite.ml.selection.split.mapper.UniformMapper;
+
+/**
+ * Dataset splitter that splits dataset into train and test subsets.
+ *
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ */
+public class TrainTestDatasetSplitter<K, V> implements Serializable {
+    /** */
+    private static final long serialVersionUID = 3148338796945474491L;
+
+    /** Mapper used to map a key-value pair to a point on the segment (0, 1). */
+    private final UniformMapper<K, V> mapper;
+
+    /**
+     * Constructs a new instance of train test dataset splitter.
+     */
+    public TrainTestDatasetSplitter() {
+        this(new SHA256UniformMapper<>());
+    }
+
+    /**
+     * Constructs a new instance of train test dataset splitter.
+     *
+     * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1).
+     */
+    public TrainTestDatasetSplitter(UniformMapper<K, V> mapper) {
+        this.mapper = mapper;
+    }
+
+    /**
+     * Splits dataset into train and test subsets.
+     *
+     * @param trainSize The proportion of the dataset to include in the train split (should
be between 0 and 1).
+     * @return Split with two predicates for training and testing parts.
+     */
+    public TrainTestSplit<K, V> split(double trainSize) {
+        return split(trainSize, 1 - trainSize);
+    }
+
+    /**
+     * Splits dataset into train and test subsets.
+     *
+     * @param trainSize The proportion of the dataset to include in the train split (should
be between 0 and 1).
+     * @param testSize The proportion of the dataset to include in the test split (should
be a number between 0 and 1).
+     * @return Split with two predicates for training and testing parts.
+     */
+    public TrainTestSplit<K, V> split(double trainSize, double testSize) {
+        return new TrainTestSplit<>(
+            new DatasetSplitFilter(mapper, 0, trainSize),
+            new DatasetSplitFilter(mapper, trainSize, trainSize + testSize)
+        );
+    }
+
+    /**
+     * Dataset filter based on the uniform mapping and specified interval. It allows to specify
a mapper that maps key-value
+     * pair to a point on the segment (0, 1) and an interval inside that segment (for example
(0, 0.2)). After that this
+     * filter will pass all entries whose mappings lie in the specified interval.
+     */
+    class DatasetSplitFilter implements IgniteBiPredicate<K,V> {
+        /** */
+        private static final long serialVersionUID = 2247757751655582254L;
+
+        /** Mapper used to map a key-value pair to a point on the segment (0, 1). */
+        private final UniformMapper<K, V> mapper;
+
+        /** Left point of an interval. */
+        private final double from;
+
+        /** Right point of an interval. */
+        private final double to;
+
+        /**
+         * Constructs a new instance of dataset split filter.
+         *
+         * @param mapper Mapper used to map a key-value pair to a point on the segment (0,
1).
+         * @param from Left point of an interval.
+         * @param to Right point of an interval.
+         */
+        DatasetSplitFilter(UniformMapper<K, V> mapper, double from, double to) {
+            assert from >= 0 && from <= 1 : "Point 'from' should be in interval
(0, 1)";
+            assert to >= 0 && to <= 1: "Point 'to' should be in interval (0,
1)";
+            assert from <= to : "Point 'from' should be less of equal to point 'to'";
+
+            this.mapper = mapper;
+            this.from = from;
+            this.to = to;
+        }
+
+        /** {@inheritDoc} */
+        @Override public boolean apply(K key, V val) {
+            double pnt = mapper.map(key, val);
+
+            assert pnt >= 0 && pnt <= 1 : "Point should be in interval (0,
1)";
+
+            return pnt >= from && pnt < to;
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/TrainTestSplit.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/TrainTestSplit.java
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/TrainTestSplit.java
new file mode 100644
index 0000000..410d990
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/TrainTestSplit.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.split;
+
+import java.io.Serializable;
+import org.apache.ignite.lang.IgniteBiPredicate;
+
+/**
+ * Dataset split that encapsulates train and test subsets.
+ *
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ */
+public class TrainTestSplit<K, V> implements Serializable {
+    /** */
+    private static final long serialVersionUID = 2165934349492062372L;
+
+    /** Filter that selects train subset of the dataset. */
+    private final IgniteBiPredicate<K, V> trainFilter;
+
+    /** Filter that select test subset of the dataset. */
+    private final IgniteBiPredicate<K, V> testFilter;
+
+    /**
+     * Constructs a new instance of train test split.
+     *
+     * @param trainFilter Filter that passes train subset of the dataset.
+     * @param testFilter Filter that passes test subset of the dataset.
+     */
+    public TrainTestSplit(IgniteBiPredicate<K, V> trainFilter, IgniteBiPredicate<K,
V> testFilter) {
+        this.trainFilter = trainFilter;
+        this.testFilter = testFilter;
+    }
+
+    /** */
+    public IgniteBiPredicate<K, V> getTrainFilter() {
+        return trainFilter;
+    }
+
+    /** */
+    public IgniteBiPredicate<K, V> getTestFilter() {
+        return testFilter;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/SHA256UniformMapper.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/SHA256UniformMapper.java
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/SHA256UniformMapper.java
new file mode 100644
index 0000000..b0475ca
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/SHA256UniformMapper.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.split.mapper;
+
+import java.nio.charset.StandardCharsets;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Implementation of uniform mappers based on SHA-256 hashing algorithm.
+ *
+ * @param <K> Type of a key.
+ * @param <V> Type of a value.
+ */
+public class SHA256UniformMapper<K, V> implements UniformMapper<K,V> {
+    /** */
+    private static final long serialVersionUID = -8179630783617088803L;
+
+    /** Hashing algorithm. */
+    private static final String HASHING_ALGORITHM = "SHA-256";
+
+    /** Message digest. */
+    private static final ThreadLocal<MessageDigest> digest = new ThreadLocal<>();
+
+    /** Strategy that defines how bytes will be swapped after SHA-256. */
+    private final List<Integer> shuffleStgy = Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7);
+
+    /**
+     * Constructs a new instance of SHA-256 uniform mapper.
+     */
+    public SHA256UniformMapper() {}
+
+    /**
+     * Constructs a new instance of SHA-256 uniform mapper.
+     *
+     * @param random Random used to define shuffle strategy.
+     */
+    public SHA256UniformMapper(Random random) {
+        Collections.shuffle(shuffleStgy, random);
+    }
+
+    /** {@inheritDoc} */
+    @Override public double map(K key, V val) {
+        int h = key.hashCode();
+        String str = String.valueOf(key.hashCode());
+
+        byte[] hash = getDigest().digest(str.getBytes(StandardCharsets.UTF_8));
+
+        byte hashByte = hash[h % hash.length];
+
+        byte resByte = 0;
+
+        for (int i = 0; i < 8; i++)
+            resByte = (byte)(resByte << 1 | ((hashByte >> shuffleStgy.get(i))
& 0x1));
+
+        return  1.0 * (resByte & 0xFF) / 256;
+    }
+
+    /**
+     * Creates instance of digest in case it doesn't exist, otherwise returns existing instance.
+     *
+     * @return Instance of message digest.
+     */
+    private MessageDigest getDigest() {
+        if (digest.get() == null) {
+            try {
+                digest.set(MessageDigest.getInstance(HASHING_ALGORITHM));
+            }
+            catch (NoSuchAlgorithmException e) {
+                throw new RuntimeException(e);
+            }
+        }
+
+        return digest.get();
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/UniformMapper.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/UniformMapper.java
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/UniformMapper.java
new file mode 100644
index 0000000..fce31e8
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/UniformMapper.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.split.mapper;
+
+import java.io.Serializable;
+
+/**
+ * Interface for util mappers that maps a key-value pair to a point on the segment (0, 1).
+ *
+ * @param <K> Type of a key.
+ * @param <V> Type of a value.
+ */
+@FunctionalInterface
+public interface UniformMapper<K, V> extends Serializable {
+    /**
+     * Maps key-value pair to a point on the segment (0, 1).
+     *
+     * @param key Key.
+     * @param val Value.
+     * @return Point on the segment (0, 1).
+     */
+    public double map(K key, V val);
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/package-info.java
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/package-info.java
new file mode 100644
index 0000000..19d9de4
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Root package for mappers used in dataset splitters.
+ */
+package org.apache.ignite.ml.selection.split.mapper;
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/package-info.java
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/package-info.java
new file mode 100644
index 0000000..05b1331
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Root package for dataset splitters and cross validation.
+ */
+package org.apache.ignite.ml.selection.split;
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
index fcde3f5..4d7a262 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
@@ -17,16 +17,16 @@
 
 package org.apache.ignite.ml.trainers;
 
+import java.util.Map;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
+import org.apache.ignite.lang.IgniteBiPredicate;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 
-import java.util.Map;
-
 /**
  * Interface for trainers. Trainer is just a function which produces model from the data.
  *
@@ -58,8 +58,8 @@ public interface DatasetTrainer<M extends Model, L> {
      * @param <V> Type of a value in {@code upstream} data.
      * @return Model.
      */
-    public default <K, V> M fit(Ignite ignite, IgniteCache<K, V> cache, IgniteBiFunction<K,
V, double[]> featureExtractor,
-        IgniteBiFunction<K, V, L> lbExtractor) {
+    public default <K, V> M fit(Ignite ignite, IgniteCache<K, V> cache,
+        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V,
L> lbExtractor) {
         return fit(
             new CacheBasedDatasetBuilder<>(ignite, cache),
             featureExtractor,
@@ -70,6 +70,27 @@ public interface DatasetTrainer<M extends Model, L> {
     /**
      * Trains model based on the specified data.
      *
+     * @param ignite Ignite instance.
+     * @param cache Ignite cache.
+     * @param filter Filter for {@code upstream} data.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Model.
+     */
+    public default <K, V> M fit(Ignite ignite, IgniteCache<K, V> cache, IgniteBiPredicate<K,
V> filter,
+        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V,
L> lbExtractor) {
+        return fit(
+            new CacheBasedDatasetBuilder<>(ignite, cache, filter),
+            featureExtractor,
+            lbExtractor
+        );
+    }
+
+    /**
+     * Trains model based on the specified data.
+     *
      * @param data Data.
      * @param parts Number of partitions.
      * @param featureExtractor Feature extractor.
@@ -86,4 +107,26 @@ public interface DatasetTrainer<M extends Model, L> {
             lbExtractor
         );
     }
+
+    /**
+     * Trains model based on the specified data.
+     *
+     * @param data Data.
+     * @param filter Filter for {@code upstream} data.
+     * @param parts Number of partitions.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Model.
+     */
+    public default <K, V> M fit(Map<K, V> data, IgniteBiPredicate<K, V>
filter, int parts,
+        IgniteBiFunction<K, V, double[]> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor) {
+        return fit(
+            new LocalDatasetBuilder<>(data, filter, parts),
+            featureExtractor,
+            lbExtractor
+        );
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
index 2414236..ac0117d 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
@@ -17,14 +17,13 @@
 
 package org.apache.ignite.ml.regressions.linear;
 
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
-
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Random;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/modules/ml/src/test/java/org/apache/ignite/ml/selection/split/TrainTestDatasetSplitterTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/split/TrainTestDatasetSplitterTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/selection/split/TrainTestDatasetSplitterTest.java
new file mode 100644
index 0000000..25ac74e
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/split/TrainTestDatasetSplitterTest.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.split;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Tests for {@link TrainTestDatasetSplitter}.
+ */
+public class TrainTestDatasetSplitterTest {
+    /** */
+    @Test
+    public void testSplitWithSpecifiedTrainAndTestSize() {
+        TrainTestDatasetSplitter<Double, Double> splitter = new TrainTestDatasetSplitter<>((k,
v) -> k);
+
+        TrainTestSplit<Double, Double> split = splitter.split(0.4, 0.4);
+
+        assertTrue(split.getTrainFilter().apply(0.0, 0.0));
+        assertTrue(split.getTrainFilter().apply(0.2, 0.0));
+        assertFalse(split.getTrainFilter().apply(0.4, 0.0));
+        assertFalse(split.getTrainFilter().apply(0.6, 0.0));
+
+        assertFalse(split.getTestFilter().apply(0.0, 0.0));
+        assertFalse(split.getTestFilter().apply(0.2, 0.0));
+        assertTrue(split.getTestFilter().apply(0.4, 0.0));
+        assertTrue(split.getTestFilter().apply(0.6, 0.0));
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/2b4762a4/modules/ml/src/test/java/org/apache/ignite/ml/selection/split/mapper/SHA256UniformMapperTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/split/mapper/SHA256UniformMapperTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/selection/split/mapper/SHA256UniformMapperTest.java
new file mode 100644
index 0000000..f1f1774
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/split/mapper/SHA256UniformMapperTest.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.split.mapper;
+
+import java.util.Random;
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Test for {@link SHA256UniformMapper}.
+ */
+public class SHA256UniformMapperTest {
+    /** */
+    @Test
+    public void testMap() {
+        UniformMapper<Integer, Integer> mapper = new SHA256UniformMapper<>(new
Random(42));
+
+        int cnt = 0;
+
+        for (int i = 0; i < 100_000; i++) {
+            double pnt = mapper.map(i, i);
+
+            if (pnt < 0.2)
+                cnt++;
+        }
+
+        double err = 1.0 * Math.abs(cnt - 20_000) / 20_000;
+
+        // Hash function should provide a good distribution so that error should be less
that 2% in case 10^5 tests.
+        assertTrue(err < 0.02);
+    }
+
+    /** */
+    @Test
+    public void testMapAndMapAgain() {
+        UniformMapper<Integer, Integer> firstMapper = new SHA256UniformMapper<>(new
Random(42));
+        UniformMapper<Integer, Integer> secondMapper = new SHA256UniformMapper<>(new
Random(21));
+
+        int cnt = 0;
+
+        for (int i = 0; i < 100_000; i++) {
+            double firstPnt = firstMapper.map(i, i);
+            double secondPnt = secondMapper.map(i, i);
+
+            if (firstPnt < 0.5 && secondPnt < 0.5)
+                cnt++;
+        }
+
+        double err = 1.0 * Math.abs(cnt - 25_000) / 25_000;
+
+        // Hash function should provide a good distribution so that error should be less
that 2% in case 10^5 tests.
+        assertTrue(err < 0.02);
+    }
+}


Mime
View raw message