ignite-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ch...@apache.org
Subject ignite git commit: IGNITE-9064: [ML] Decision tree optimization
Date Fri, 03 Aug 2018 11:17:16 GMT
Repository: ignite
Updated Branches:
  refs/heads/master 25f2d1865 -> 44098bc6e


IGNITE-9064: [ML] Decision tree optimization

this closes #4436


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/44098bc6
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/44098bc6
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/44098bc6

Branch: refs/heads/master
Commit: 44098bc6e38ce9bbd4c191c3314b2123a60739d4
Parents: 25f2d18
Author: Alexey Platonov <aplatonovv@gmail.com>
Authored: Fri Aug 3 14:17:07 2018 +0300
Committer: Yury Babak <ybabak@gridgain.com>
Committed: Fri Aug 3 14:17:07 2018 +0300

----------------------------------------------------------------------
 .../ml/composition/boosting/GDBTrainer.java     |   2 +-
 .../dataset/impl/cache/CacheBasedDataset.java   |   1 +
 .../dataset/impl/cache/util/ComputeUtils.java   |  11 ++
 .../org/apache/ignite/ml/tree/DecisionTree.java |  13 +-
 .../tree/DecisionTreeClassificationTrainer.java |  13 +-
 .../ml/tree/DecisionTreeRegressionTrainer.java  |  14 +-
 .../GDBBinaryClassifierOnTreesTrainer.java      |  17 +-
 .../boosting/GDBRegressionOnTreesTrainer.java   |  17 +-
 .../ignite/ml/tree/data/DecisionTreeData.java   |  51 ++++-
 .../ml/tree/data/DecisionTreeDataBuilder.java   |   9 +-
 .../ignite/ml/tree/data/TreeDataIndex.java      | 184 +++++++++++++++++++
 .../impurity/ImpurityMeasureCalculator.java     |  67 ++++++-
 .../gini/GiniImpurityMeasureCalculator.java     |  67 ++++---
 .../mse/MSEImpurityMeasureCalculator.java       |  86 ++++++---
 .../RandomForestClassifierTrainer.java          |  13 +-
 .../RandomForestRegressionTrainer.java          |  13 +-
 .../tree/randomforest/RandomForestTrainer.java  |   4 +
 .../DecisionTreeClassificationTrainerTest.java  |  25 ++-
 .../tree/DecisionTreeRegressionTrainerTest.java |  18 +-
 .../ml/tree/data/DecisionTreeDataTest.java      |  21 ++-
 .../ignite/ml/tree/data/TreeDataIndexTest.java  | 159 ++++++++++++++++
 .../gini/GiniImpurityMeasureCalculatorTest.java |  27 ++-
 .../mse/MSEImpurityMeasureCalculatorTest.java   |  21 ++-
 23 files changed, 757 insertions(+), 96 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
index 6726892..8663d3d 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
@@ -165,7 +165,7 @@ abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, Double>
 
         try (Dataset<EmptyContext, DecisionTreeData> dataset = builder.build(
             new EmptyContextBuilder<>(),
-            new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor)
+            new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, false)
         )) {
             IgniteBiTuple<Double, Long> meanTuple = dataset.compute(
                 data -> {

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
index 67e0d56..e5eb483 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
@@ -144,6 +144,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
     /** {@inheritDoc} */
     @Override public void close() {
         datasetCache.destroy();
+        ComputeUtils.removeData(ignite, datasetId);
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
index 39b3703..a5cdd3b 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
@@ -184,6 +184,17 @@ public class ComputeUtils {
     }
 
     /**
+     * Remove data from local cache by Dataset ID.
+     *
+     * @param ignite Ingnite instance.
+     * @param datasetId Dataset ID.
+     */
+    public static void removeData(Ignite ignite, UUID datasetId) {
+        ignite.cluster().nodeLocalMap().remove(String.format(DATA_STORAGE_KEY_TEMPLATE, datasetId));
+    }
+
+
+    /**
      * Initializes partition {@code context} by loading it from a partition {@code upstream}.
      *
      * @param ignite Ignite instance.

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
index c1e3abf..270f14a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
@@ -53,6 +53,9 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset
     /** Decision tree leaf builder. */
     private final DecisionTreeLeafBuilder decisionTreeLeafBuilder;
 
+    /** Use index structure instead of using sorting while learning. */
+    protected boolean useIndex = true;
+
     /**
      * Constructs a new distributed decision tree trainer.
      *
@@ -74,7 +77,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
         try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(
             new EmptyContextBuilder<>(),
-            new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor)
+            new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIndex)
         )) {
             return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset));
         }
@@ -105,7 +108,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset
         if (deep >= maxDeep)
             return decisionTreeLeafBuilder.createLeafNode(dataset, filter);
 
-        StepFunction<T>[] criterionFunctions = calculateImpurityForAllColumns(dataset, filter, impurityCalc);
+        StepFunction<T>[] criterionFunctions = calculateImpurityForAllColumns(dataset, filter, impurityCalc, deep);
 
         if (criterionFunctions == null)
             return decisionTreeLeafBuilder.createLeafNode(dataset, filter);
@@ -132,14 +135,14 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset
      * @return Array of impurity measure functions for all columns.
      */
     private StepFunction<T>[] calculateImpurityForAllColumns(Dataset<EmptyContext, DecisionTreeData> dataset,
-        TreeFilter filter, ImpurityMeasureCalculator<T> impurityCalc) {
+        TreeFilter filter, ImpurityMeasureCalculator<T> impurityCalc, int depth) {
 
         StepFunction<T>[] result = dataset.compute(
             part -> {
                 if (compressor != null)
-                    return compressor.compress(impurityCalc.calculate(part.filter(filter)));
+                    return compressor.compress(impurityCalc.calculate(part, filter, depth));
                 else
-                    return impurityCalc.calculate(part.filter(filter));
+                    return impurityCalc.calculate(part, filter, depth);
             }, this::reduce
         );
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
index 71e387f..f371334 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
@@ -85,12 +85,13 @@ public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurity
     }
 
     /**
-     * Set up the step function compressor of decision tree.
-     * @param compressor The parameter value.
-     * @return Trainer with new compressor parameter value.
+     * Sets useIndex parameter and returns trainer instance.
+     *
+     * @param useIndex Use index.
+     * @return Decision tree trainer.
      */
-    public DecisionTreeClassificationTrainer withCompressor(StepFunctionCompressor compressor){
-        this.compressor = compressor;
+    public DecisionTreeClassificationTrainer withUseIndex(boolean useIndex) {
+        this.useIndex = useIndex;
         return this;
     }
 
@@ -126,6 +127,6 @@ public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurity
         for (Double lb : labels)
             encoder.put(lb, idx++);
 
-        return new GiniImpurityMeasureCalculator(encoder);
+        return new GiniImpurityMeasureCalculator(encoder, useIndex);
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
index 2bf09d3..7446237 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
@@ -52,9 +52,21 @@ public class DecisionTreeRegressionTrainer extends DecisionTree<MSEImpurityMeasu
         super(maxDeep, minImpurityDecrease, compressor, new MeanDecisionTreeLeafBuilder());
     }
 
+    /**
+     * Sets useIndex parameter and returns trainer instance.
+     *
+     * @param useIndex Use index.
+     * @return Decision tree trainer.
+     */
+    public DecisionTreeRegressionTrainer withUseIndex(boolean useIndex) {
+        this.useIndex = useIndex;
+        return this;
+    }
+
     /** {@inheritDoc} */
     @Override ImpurityMeasureCalculator<MSEImpurityMeasure> getImpurityMeasureCalculator(
         Dataset<EmptyContext, DecisionTreeData> dataset) {
-        return new MSEImpurityMeasureCalculator();
+
+        return new MSEImpurityMeasureCalculator(useIndex);
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java
index 3789588..631e848 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java
@@ -30,9 +30,13 @@ import org.jetbrains.annotations.NotNull;
 public class GDBBinaryClassifierOnTreesTrainer extends GDBBinaryClassifierTrainer {
     /** Max depth. */
     private final int maxDepth;
+
     /** Min impurity decrease. */
     private final double minImpurityDecrease;
 
+    /** Use index structure instead of using sorting while learning. */
+    private boolean useIndex = true;
+
     /**
      * Constructs instance of GDBBinaryClassifierOnTreesTrainer.
      *
@@ -51,6 +55,17 @@ public class GDBBinaryClassifierOnTreesTrainer extends GDBBinaryClassifierTraine
 
     /** {@inheritDoc} */
     @NotNull @Override protected DatasetTrainer<? extends Model<Vector, Double>, Double> buildBaseModelTrainer() {
-        return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease);
+        return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease).withUseIndex(useIndex);
+    }
+
+    /**
+     * Sets useIndex parameter and returns trainer instance.
+     *
+     * @param useIndex Use index.
+     * @return Decision tree trainer.
+     */
+    public GDBBinaryClassifierOnTreesTrainer withUseIndex(boolean useIndex) {
+        this.useIndex = useIndex;
+        return this;
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java
index 50c5f8d..450dae3 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java
@@ -30,9 +30,13 @@ import org.jetbrains.annotations.NotNull;
 public class GDBRegressionOnTreesTrainer extends GDBRegressionTrainer {
     /** Max depth. */
     private final int maxDepth;
+
     /** Min impurity decrease. */
     private final double minImpurityDecrease;
 
+    /** Use index structure instead of using sorting while learning. */
+    private boolean useIndex = true;
+
     /**
      * Constructs instance of GDBRegressionOnTreesTrainer.
      *
@@ -51,6 +55,17 @@ public class GDBRegressionOnTreesTrainer extends GDBRegressionTrainer {
 
     /** {@inheritDoc} */
     @NotNull @Override protected DatasetTrainer<? extends Model<Vector, Double>, Double> buildBaseModelTrainer() {
-        return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease);
+        return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease).withUseIndex(useIndex);
+    }
+
+    /**
+     * Sets useIndex parameter and returns trainer instance.
+     *
+     * @param useIndex Use index.
+     * @return Decision tree trainer.
+     */
+    public GDBRegressionOnTreesTrainer withUseIndex(boolean useIndex) {
+        this.useIndex = useIndex;
+        return this;
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
index 34deb46..c017e5c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
@@ -17,6 +17,8 @@
 
 package org.apache.ignite.ml.tree.data;
 
+import java.util.ArrayList;
+import java.util.List;
 import org.apache.ignite.ml.tree.TreeFilter;
 
 /**
@@ -29,17 +31,29 @@ public class DecisionTreeData implements AutoCloseable {
     /** Vector with labels. */
     private final double[] labels;
 
+    /** Indexes cache. */
+    private final List<TreeDataIndex> indexesCache;
+
+    /** Build index. */
+    private final boolean buildIndex;
+
     /**
      * Constructs a new instance of decision tree data.
      *
      * @param features Matrix with features.
      * @param labels Vector with labels.
+     * @param buildIdx Build index.
      */
-    public DecisionTreeData(double[][] features, double[] labels) {
+    public DecisionTreeData(double[][] features, double[] labels, boolean buildIdx) {
         assert features.length == labels.length : "Features and labels have to be the same length";
 
         this.features = features;
         this.labels = labels;
+        this.buildIndex = buildIdx;
+
+        indexesCache = new ArrayList<>();
+        if (buildIdx)
+            indexesCache.add(new TreeDataIndex(features, labels));
     }
 
     /**
@@ -69,7 +83,7 @@ public class DecisionTreeData implements AutoCloseable {
             }
         }
 
-        return new DecisionTreeData(newFeatures, newLabels);
+        return new DecisionTreeData(newFeatures, newLabels, buildIndex);
     }
 
     /**
@@ -89,8 +103,10 @@ public class DecisionTreeData implements AutoCloseable {
             int i = from, j = to;
 
             while (i <= j) {
-                while (features[i][col] < pivot) i++;
-                while (features[j][col] > pivot) j--;
+                while (features[i][col] < pivot)
+                    i++;
+                while (features[j][col] > pivot)
+                    j--;
 
                 if (i <= j) {
                     double[] tmpFeature = features[i];
@@ -125,4 +141,31 @@ public class DecisionTreeData implements AutoCloseable {
     @Override public void close() {
         // Do nothing, GC will clean up.
     }
+
+    /**
+     * Builds index in according to current tree depth and cached indexes in upper levels. Uses depth as key of cached
+     * index and replaces cached index with same key.
+     *
+     * @param depth Tree Depth.
+     * @param filter Filter.
+     */
+    public TreeDataIndex createIndexByFilter(int depth, TreeFilter filter) {
+        assert depth >= 0 && depth <= indexesCache.size();
+
+        if (depth > 0 && depth <= indexesCache.size() - 1) {
+            for (int i = indexesCache.size() - 1; i >= depth; i--)
+                indexesCache.remove(i);
+        }
+
+        if (depth == indexesCache.size()) {
+            if (depth == 0)
+                indexesCache.add(new TreeDataIndex(features, labels));
+            else {
+                TreeDataIndex lastIndex = indexesCache.get(depth - 1);
+                indexesCache.add(lastIndex.filter(filter));
+            }
+        }
+
+        return indexesCache.get(depth);
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java
index 0ff2012..6678218 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java
@@ -42,16 +42,21 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable>
     /** Function that extracts labels from an {@code upstream} data. */
     private final IgniteBiFunction<K, V, Double> lbExtractor;
 
+    /** Build index. */
+    private final boolean buildIndex;
+
     /**
      * Constructs a new instance of decision tree data builder.
      *
      * @param featureExtractor Function that extracts features from an {@code upstream} data.
      * @param lbExtractor Function that extracts labels from an {@code upstream} data.
+     * @param buildIdx Build index.
      */
     public DecisionTreeDataBuilder(IgniteBiFunction<K, V, Vector> featureExtractor,
-        IgniteBiFunction<K, V, Double> lbExtractor) {
+        IgniteBiFunction<K, V, Double> lbExtractor, boolean buildIdx) {
         this.featureExtractor = featureExtractor;
         this.lbExtractor = lbExtractor;
+        this.buildIndex = buildIdx;
     }
 
     /** {@inheritDoc} */
@@ -70,6 +75,6 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable>
             ptr++;
         }
 
-        return new DecisionTreeData(features, labels);
+        return new DecisionTreeData(features, labels, buildIndex);
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/TreeDataIndex.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/TreeDataIndex.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/TreeDataIndex.java
new file mode 100644
index 0000000..88ce190
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/TreeDataIndex.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.data;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.tree.TreeFilter;
+
+/**
+ * Index for representing sorted dataset rows for each features.
+ * It may be reused while decision tree learning at several levels through filter method.
+ */
+public class TreeDataIndex {
+    /** Index containing IDs of rows as if they is sorted by feature values. */
+    private final int[][] index;
+
+    /** Original features table. */
+    private final double[][] features;
+
+    /** Original labels. */
+    private final double[] labels;
+
+    /**
+     * Constructs an instance of TreeDataIndex.
+     *
+     * @param features Features.
+     * @param labels Labels.
+     */
+    public TreeDataIndex(double[][] features, double[] labels) {
+        this.features = features;
+        this.labels = labels;
+
+        int rows = features.length;
+        int cols = features.length == 0 ? 0 : features[0].length;
+
+        double[][] featuresCp = new double[rows][cols];
+        index = new int[rows][cols];
+        for (int row = 0; row < rows; row++) {
+            Arrays.fill(index[row], row);
+            featuresCp[row] = Arrays.copyOf(features[row], cols);
+        }
+
+        for (int col = 0; col < cols; col++)
+            sortIndex(featuresCp, col, 0, rows - 1);
+    }
+
+    /**
+     * Constructs an instance of TreeDataIndex
+     *
+     * @param indexProj Index projection.
+     * @param features Features.
+     * @param labels Labels.
+     */
+    private TreeDataIndex(int[][] indexProj, double[][] features, double[] labels) {
+        this.index = indexProj;
+        this.features = features;
+        this.labels = labels;
+    }
+
+    /**
+     * Returns label for kth order statistic for target feature.
+     *
+     * @param k K.
+     * @param featureId Feature id.
+     * @return Label value.
+     */
+    public double labelInSortedOrder(int k, int featureId) {
+        return labels[index[k][featureId]];
+    }
+
+    /**
+     * Returns vector of original features for kth order statistic for target feature.
+     *
+     * @param k K.
+     * @param featureId Feature id.
+     * @return Features vector.
+     */
+    public double[] featuresInSortedOrder(int k, int featureId) {
+        return features[index[k][featureId]];
+    }
+
+    /**
+     * Returns feature value for kth order statistic for target feature.
+     *
+     * @param k K.
+     * @param featureId Feature id.
+     * @return Feature value.
+     */
+    public double featureInSortedOrder(int k, int featureId) {
+        return featuresInSortedOrder(k, featureId)[featureId];
+    }
+
+    /**
+     * Creates projection of current index in according to {@link TreeFilter}.
+     *
+     * @param filter Filter.
+     * @return Projection of current index onto smaller index in according to rows filter.
+     */
+    public TreeDataIndex filter(TreeFilter filter) {
+        int projSize = 0;
+        for (int i = 0; i < rowsCount(); i++) {
+            if (filter.test(featuresInSortedOrder(i, 0)))
+                projSize++;
+        }
+
+        int[][] projection = new int[projSize][columnsCount()];
+        for(int feature = 0; feature < columnsCount(); feature++) {
+            int ptr = 0;
+            for(int row = 0; row < rowsCount(); row++) {
+                if(filter.test(featuresInSortedOrder(row, feature)))
+                    projection[ptr++][feature] = index[row][feature];
+            }
+        }
+
+        return new TreeDataIndex(projection, features, labels);
+    }
+
+    /**
+     * @return count of rows in current index.
+     */
+    public int rowsCount() {
+        return index.length;
+    }
+
+    /**
+     * @return count of columns in current index.
+     */
+    public int columnsCount() {
+        return rowsCount() == 0 ? 0 : index[0].length ;
+    }
+
+    /**
+     * Constructs index structure in according to features table.
+     *
+     * @param features Features.
+     * @param col Column.
+     * @param from From.
+     * @param to To.
+     */
+    private void sortIndex(double[][] features, int col, int from, int to) {
+        if (from < to) {
+            double pivot = features[(from + to) / 2][col];
+
+            int i = from, j = to;
+
+            while (i <= j) {
+                while (features[i][col] < pivot)
+                    i++;
+                while (features[j][col] > pivot)
+                    j--;
+
+                if (i <= j) {
+                    double tmpFeature = features[i][col];
+                    features[i][col] = features[j][col];
+                    features[j][col] = tmpFeature;
+
+                    int tmpLb = index[i][col];
+                    index[i][col] = index[j][col];
+                    index[j][col] = tmpLb;
+
+                    i++;
+                    j--;
+                }
+            }
+
+            sortIndex(features, col, from, j);
+            sortIndex(features, col, i, to);
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java
index 2b69356..709f68e 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java
@@ -18,7 +18,9 @@
 package org.apache.ignite.ml.tree.impurity;
 
 import java.io.Serializable;
+import org.apache.ignite.ml.tree.TreeFilter;
 import org.apache.ignite.ml.tree.data.DecisionTreeData;
+import org.apache.ignite.ml.tree.data.TreeDataIndex;
 import org.apache.ignite.ml.tree.impurity.util.StepFunction;
 
 /**
@@ -26,7 +28,19 @@ import org.apache.ignite.ml.tree.impurity.util.StepFunction;
  *
  * @param <T> Type of impurity measure.
  */
-public interface ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> extends Serializable {
+public abstract class ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> implements Serializable {
+    /** Use index structure instead of using sorting while learning. */
+    protected final boolean useIndex;
+
+    /**
+     * Constructs an instance of ImpurityMeasureCalculator.
+     *
+     * @param useIndex Use index.
+     */
+    public ImpurityMeasureCalculator(boolean useIndex) {
+        this.useIndex = useIndex;
+    }
+
     /**
      * Calculates all impurity measures required required to find a best split and returns them as an array of
      * {@link StepFunction} (for every column).
@@ -34,5 +48,54 @@ public interface ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> extends
      * @param data Features and labels.
      * @return Impurity measures as an array of {@link StepFunction} (for every column).
      */
-    public StepFunction<T>[] calculate(DecisionTreeData data);
+    public abstract StepFunction<T>[] calculate(DecisionTreeData data, TreeFilter filter, int depth);
+
+
+    /**
+     * Returns columns count in current dataset.
+     *
+     * @param data Data.
+     * @param idx Index.
+     * @return Columns count in current dataset.
+     */
+    protected int columnsCount(DecisionTreeData data, TreeDataIndex idx) {
+        return useIndex ? idx.columnsCount() : data.getFeatures()[0].length;
+    }
+
+    /**
+     * Returns rows count in current dataset.
+     *
+     * @param data Data.
+     * @param idx Index.
+     * @return rows count in current dataset
+     */
+    protected int rowsCount(DecisionTreeData data, TreeDataIndex idx) {
+        return useIndex ? idx.rowsCount() : data.getFeatures().length;
+    }
+
+    /**
+     * Returns label value in according to kth order statistic.
+     *
+     * @param data Data.
+     * @param idx Index.
+     * @param featureId Feature id.
+     * @param k K-th statistic.
+     * @return label value in according to kth order statistic
+     */
+    protected double getLabelValue(DecisionTreeData data, TreeDataIndex idx, int featureId, int k) {
+        return useIndex ? idx.labelInSortedOrder(k, featureId) : data.getLabels()[k];
+    }
+
+    /**
+     * Returns feature value in according to kth order statistic.
+     *
+     * @param data Data.
+     * @param idx Index.
+     * @param featureId Feature id.
+     * @param k K-th statistic.
+     * @return feature value in according to kth order statistic.
+     */
+    protected double getFeatureValue(DecisionTreeData data, TreeDataIndex idx, int featureId, int k) {
+        return useIndex ? idx.featureInSortedOrder(k, featureId) : data.getFeatures()[k][featureId];
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java
index 0dd0a10..38b3097 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java
@@ -19,14 +19,16 @@ package org.apache.ignite.ml.tree.impurity.gini;
 
 import java.util.Arrays;
 import java.util.Map;
+import org.apache.ignite.ml.tree.TreeFilter;
 import org.apache.ignite.ml.tree.data.DecisionTreeData;
+import org.apache.ignite.ml.tree.data.TreeDataIndex;
 import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
 import org.apache.ignite.ml.tree.impurity.util.StepFunction;
 
 /**
  * Gini impurity measure calculator.
  */
-public class GiniImpurityMeasureCalculator implements ImpurityMeasureCalculator<GiniImpurityMeasure> {
+public class GiniImpurityMeasureCalculator extends ImpurityMeasureCalculator<GiniImpurityMeasure> {
     /** */
     private static final long serialVersionUID = -522995134128519679L;
 
@@ -37,51 +39,70 @@ public class GiniImpurityMeasureCalculator implements ImpurityMeasureCalculator<
      * Constructs a new instance of Gini impurity measure calculator.
      *
      * @param lbEncoder Label encoder which defines integer value for every label class.
+     * @param useIndex Use index while calculate.
      */
-    public GiniImpurityMeasureCalculator(Map<Double, Integer> lbEncoder) {
+    public GiniImpurityMeasureCalculator(Map<Double, Integer> lbEncoder, boolean useIndex) {
+        super(useIndex);
         this.lbEncoder = lbEncoder;
     }
 
     /** {@inheritDoc} */
     @SuppressWarnings("unchecked")
-    @Override public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData data) {
-        double[][] features = data.getFeatures();
-        double[] labels = data.getLabels();
+    @Override public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) {
+        TreeDataIndex index = null;
+        boolean canCalculate = false;
 
-        if (features.length > 0) {
-            StepFunction<GiniImpurityMeasure>[] res = new StepFunction[features[0].length];
+        if (useIndex) {
+            index = data.createIndexByFilter(depth, filter);
+            canCalculate = index.rowsCount() > 0;
+        }
+        else {
+            data = data.filter(filter);
+            canCalculate = data.getFeatures().length > 0;
+        }
 
-            for (int col = 0; col < res.length; col++) {
-                data.sort(col);
+        if (canCalculate) {
+            int rowsCnt = rowsCount(data, index);
+            int colsCnt = columnsCount(data, index);
 
-                double[] x = new double[features.length + 1];
-                GiniImpurityMeasure[] y = new GiniImpurityMeasure[features.length + 1];
+            StepFunction<GiniImpurityMeasure>[] res = new StepFunction[colsCnt];
 
-                int xPtr = 0, yPtr = 0;
+            long right[] = new long[lbEncoder.size()];
+            for (int i = 0; i < rowsCnt; i++) {
+                double lb = getLabelValue(data, index, 0, i);
+                right[getLabelCode(lb)]++;
+            }
 
-                long[] left = new long[lbEncoder.size()];
-                long[] right = new long[lbEncoder.size()];
+            for (int col = 0; col < res.length; col++) {
+                if(!useIndex)
+                    data.sort(col);
 
-                for (int i = 0; i < labels.length; i++)
-                    right[getLabelCode(labels[i])]++;
+                double[] x = new double[rowsCnt + 1];
+                GiniImpurityMeasure[] y = new GiniImpurityMeasure[rowsCnt + 1];
 
+                long[] left = new long[lbEncoder.size()];
+                long[] rightCopy = Arrays.copyOf(right, right.length);
+
+                int xPtr = 0, yPtr = 0;
                 x[xPtr++] = Double.NEGATIVE_INFINITY;
                 y[yPtr++] = new GiniImpurityMeasure(
                     Arrays.copyOf(left, left.length),
-                    Arrays.copyOf(right, right.length)
+                    Arrays.copyOf(rightCopy, rightCopy.length)
                 );
 
-                for (int i = 0; i < features.length; i++) {
-                    left[getLabelCode(labels[i])]++;
-                    right[getLabelCode(labels[i])]--;
+                for (int i = 0; i < rowsCnt; i++) {
+                    double lb = getLabelValue(data, index, col, i);
+                    left[getLabelCode(lb)]++;
+                    rightCopy[getLabelCode(lb)]--;
 
-                    if (i < (features.length - 1) && features[i + 1][col] == features[i][col])
+                    double featureVal = getFeatureValue(data, index, col, i);
+                    if (i < (rowsCnt - 1) && getFeatureValue(data, index, col, i + 1) == featureVal)
                         continue;
 
-                    x[xPtr++] = features[i][col];
+                    x[xPtr++] = featureVal;
                     y[yPtr++] = new GiniImpurityMeasure(
                         Arrays.copyOf(left, left.length),
-                        Arrays.copyOf(right, right.length)
+                        Arrays.copyOf(rightCopy, rightCopy.length)
                     );
                 }
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java
index cb5019c..1788737 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java
@@ -17,56 +17,92 @@
 
 package org.apache.ignite.ml.tree.impurity.mse;
 
+import org.apache.ignite.ml.tree.TreeFilter;
 import org.apache.ignite.ml.tree.data.DecisionTreeData;
+import org.apache.ignite.ml.tree.data.TreeDataIndex;
 import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
 import org.apache.ignite.ml.tree.impurity.util.StepFunction;
 
 /**
  * Meas squared error (variance) impurity measure calculator.
  */
-public class MSEImpurityMeasureCalculator implements ImpurityMeasureCalculator<MSEImpurityMeasure> {
+public class MSEImpurityMeasureCalculator extends ImpurityMeasureCalculator<MSEImpurityMeasure> {
     /** */
     private static final long serialVersionUID = 288747414953756824L;
 
+    /**
+     * Constructs an instance of MSEImpurityMeasureCalculator.
+     *
+     * @param useIndex Use index while calculate.
+     */
+    public MSEImpurityMeasureCalculator(boolean useIndex) {
+        super(useIndex);
+    }
+
     /** {@inheritDoc} */
-    @SuppressWarnings("unchecked")
-    @Override public StepFunction<MSEImpurityMeasure>[] calculate(DecisionTreeData data) {
-        double[][] features = data.getFeatures();
-        double[] labels = data.getLabels();
+    @Override public StepFunction<MSEImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) {
+        TreeDataIndex index = null;
+        boolean canCalculate = false;
+
+        if (useIndex) {
+            index = data.createIndexByFilter(depth, filter);
+            canCalculate = index.rowsCount() > 0;
+        }
+        else {
+            data = data.filter(filter);
+            canCalculate = data.getFeatures().length > 0;
+        }
+
+        if (canCalculate) {
+            int rowsCnt = rowsCount(data, index);
+            int colsCnt = columnsCount(data, index);
 
-        if (features.length > 0) {
-            StepFunction<MSEImpurityMeasure>[] res = new StepFunction[features[0].length];
+            @SuppressWarnings("unchecked")
+            StepFunction<MSEImpurityMeasure>[] res = new StepFunction[colsCnt];
+
+            double rightYOriginal = 0;
+            double rightY2Original = 0;
+            for (int i = 0; i < rowsCnt; i++) {
+                double lbVal = getLabelValue(data, index, 0, i);
+
+                rightYOriginal += lbVal;
+                rightY2Original += Math.pow(lbVal, 2);
+            }
 
             for (int col = 0; col < res.length; col++) {
-                data.sort(col);
+                if (!useIndex)
+                    data.sort(col);
 
-                double[] x = new double[features.length + 1];
-                MSEImpurityMeasure[] y = new MSEImpurityMeasure[features.length + 1];
+                double[] x = new double[rowsCnt + 1];
+                MSEImpurityMeasure[] y = new MSEImpurityMeasure[rowsCnt + 1];
 
                 x[0] = Double.NEGATIVE_INFINITY;
 
-                for (int leftSize = 0; leftSize <= features.length; leftSize++) {
-                    double leftY = 0;
-                    double leftY2 = 0;
-                    double rightY = 0;
-                    double rightY2 = 0;
+                double leftY = 0;
+                double leftY2 = 0;
+                double rightY = rightYOriginal;
+                double rightY2 = rightY2Original;
 
-                    for (int i = 0; i < leftSize; i++) {
-                        leftY += labels[i];
-                        leftY2 += Math.pow(labels[i], 2);
-                    }
+                int leftSize = 0;
+                for (int i = 0; i <= rowsCnt; i++) {
+                    if (leftSize > 0) {
+                        double lblVal = getLabelValue(data, index, col, i - 1);
+
+                        leftY += lblVal;
+                        leftY2 += Math.pow(lblVal, 2);
 
-                    for (int i = leftSize; i < features.length; i++) {
-                        rightY += labels[i];
-                        rightY2 += Math.pow(labels[i], 2);
+                        rightY -= lblVal;
+                        rightY2 -= Math.pow(lblVal, 2);
                     }
 
-                    if (leftSize < features.length)
-                        x[leftSize + 1] = features[leftSize][col];
+                    if (leftSize < rowsCnt)
+                        x[leftSize + 1] = getFeatureValue(data, index, col, i);
 
                     y[leftSize] = new MSEImpurityMeasure(
-                        leftY, leftY2, leftSize, rightY, rightY2, features.length - leftSize
+                        leftY, leftY2, leftSize, rightY, rightY2, rowsCnt - leftSize
                     );
+
+                    leftSize++;
                 }
 
                 res[col] = new StepFunction<>(x, y);

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
index daba4fa..bbbb2a9 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
@@ -72,6 +72,17 @@ public class RandomForestClassifierTrainer extends RandomForestTrainer {
 
     /** {@inheritDoc} */
     @Override protected DatasetTrainer<DecisionTreeNode, Double> buildDatasetTrainerForModel() {
-        return new DecisionTreeClassificationTrainer(maxDeep, minImpurityDecrease);
+        return new DecisionTreeClassificationTrainer(maxDeep, minImpurityDecrease).withUseIndex(useIndex);
+    }
+
+    /**
+     * Sets useIndex parameter and returns trainer instance.
+     *
+     * @param useIndex Use index.
+     * @return Decision tree trainer.
+     */
+    public RandomForestClassifierTrainer withUseIndex(boolean useIndex) {
+        this.useIndex = useIndex;
+        return this;
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java
index 5b41b2c..009fff2 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java
@@ -73,6 +73,17 @@ public class RandomForestRegressionTrainer extends RandomForestTrainer {
 
     /** {@inheritDoc} */
     @Override protected DatasetTrainer<DecisionTreeNode, Double> buildDatasetTrainerForModel() {
-        return new DecisionTreeRegressionTrainer(maxDeep, minImpurityDecrease);
+        return new DecisionTreeRegressionTrainer(maxDeep, minImpurityDecrease).withUseIndex(useIndex);
+    }
+
+    /**
+     * Sets useIndex parameter and returns trainer instance.
+     *
+     * @param useIndex Use index.
+     * @return Decision tree trainer.
+     */
+    public RandomForestRegressionTrainer withUseIndex(boolean useIndex) {
+        this.useIndex = useIndex;
+        return this;
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
index b5ecaed..8608f09 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
@@ -26,9 +26,13 @@ import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggrega
 public abstract class RandomForestTrainer extends BaggingModelTrainer {
     /** Max decision tree deep. */
     protected final int maxDeep;
+
     /** Min impurity decrease. */
     protected final double minImpurityDecrease;
 
+    /** Use index structure instead of using sorting while decision tree learning. */
+    protected boolean useIndex = false;
+
     /**
      * Constructs new instance of BaggingModelTrainer.
      *

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
index de40b48..c84da12 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
@@ -40,15 +40,21 @@ public class DecisionTreeClassificationTrainerTest {
     private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7};
 
     /** Number of partitions. */
-    @Parameterized.Parameter
+    @Parameterized.Parameter(0)
     public int parts;
 
+    /** Use index [= 1 if true]. */
+    @Parameterized.Parameter(1)
+    public int useIndex;
 
-    @Parameterized.Parameters(name = "Data divided on {0} partitions")
+    /** Test parameters. */
+    @Parameterized.Parameters(name = "Data divided on {0} partitions. Use index = {1}.")
     public static Iterable<Integer[]> data() {
         List<Integer[]> res = new ArrayList<>();
-        for (int part : partsToBeTested)
-            res.add(new Integer[] {part});
+        for (int i = 0; i < 2; i++) {
+            for (int part : partsToBeTested)
+                res.add(new Integer[] {part, i});
+        }
 
         return res;
     }
@@ -63,10 +69,11 @@ public class DecisionTreeClassificationTrainerTest {
         Random rnd = new Random(0);
         for (int i = 0; i < size; i++) {
             double x = rnd.nextDouble() - 0.5;
-            data.put(i, new double[]{x, x > 0 ? 1 : 0});
+            data.put(i, new double[] {x, x > 0 ? 1 : 0});
         }
 
-        DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
+        DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0)
+            .withUseIndex(useIndex == 1);
 
         DecisionTreeNode tree = trainer.fit(
             data,
@@ -77,15 +84,15 @@ public class DecisionTreeClassificationTrainerTest {
 
         assertTrue(tree instanceof DecisionTreeConditionalNode);
 
-        DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree;
+        DecisionTreeConditionalNode node = (DecisionTreeConditionalNode)tree;
 
         assertEquals(0, node.getThreshold(), 1e-3);
 
         assertTrue(node.getThenNode() instanceof DecisionTreeLeafNode);
         assertTrue(node.getElseNode() instanceof DecisionTreeLeafNode);
 
-        DecisionTreeLeafNode thenNode = (DecisionTreeLeafNode) node.getThenNode();
-        DecisionTreeLeafNode elseNode = (DecisionTreeLeafNode) node.getElseNode();
+        DecisionTreeLeafNode thenNode = (DecisionTreeLeafNode)node.getThenNode();
+        DecisionTreeLeafNode elseNode = (DecisionTreeLeafNode)node.getElseNode();
 
         assertEquals(1, thenNode.getVal(), 1e-10);
         assertEquals(0, elseNode.getVal(), 1e-10);

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
index f69da4f..4e64925 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
@@ -40,14 +40,21 @@ public class DecisionTreeRegressionTrainerTest {
     private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7};
 
     /** Number of partitions. */
-    @Parameterized.Parameter
+    @Parameterized.Parameter(0)
     public int parts;
 
-    @Parameterized.Parameters(name = "Data divided on {0} partitions")
+    /** Use index [= 1 if true]. */
+    @Parameterized.Parameter(1)
+    public int useIndex;
+
+    /** Test parameters. */
+    @Parameterized.Parameters(name = "Data divided on {0} partitions. Use index = {1}.")
     public static Iterable<Integer[]> data() {
         List<Integer[]> res = new ArrayList<>();
-        for (int part : partsToBeTested)
-            res.add(new Integer[] {part});
+        for (int i = 0; i < 2; i++) {
+            for (int part : partsToBeTested)
+                res.add(new Integer[] {part, i});
+        }
 
         return res;
     }
@@ -65,7 +72,8 @@ public class DecisionTreeRegressionTrainerTest {
             data.put(i, new double[]{x, x > 0 ? 1 : 0});
         }
 
-        DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0);
+        DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0)
+            .withUseIndex(useIndex == 1);
 
         DecisionTreeNode tree = trainer.fit(
             data,

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java
index 0c89d4e..4ee717a 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java
@@ -17,21 +17,38 @@
 
 package org.apache.ignite.ml.tree.data;
 
+import java.util.Arrays;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 import static org.junit.Assert.assertArrayEquals;
 
 /**
  * Tests for {@link DecisionTreeData}.
  */
+@RunWith(Parameterized.class)
 public class DecisionTreeDataTest {
+    /** Parameters. */
+    @Parameterized.Parameters(name = "Use index {0}")
+    public static Iterable<Boolean[]> data() {
+        return Arrays.asList(
+            new Boolean[] {true},
+            new Boolean[] {false}
+        );
+    }
+
+    /** Use index. */
+    @Parameterized.Parameter
+    public boolean useIndex;
+
     /** */
     @Test
     public void testFilter() {
         double[][] features = new double[][]{{0}, {1}, {2}, {3}, {4}, {5}};
         double[] labels = new double[]{0, 1, 2, 3, 4, 5};
 
-        DecisionTreeData data = new DecisionTreeData(features, labels);
+        DecisionTreeData data = new DecisionTreeData(features, labels, useIndex);
         DecisionTreeData filteredData = data.filter(obj -> obj[0] > 2);
 
         assertArrayEquals(new double[][]{{3}, {4}, {5}}, filteredData.getFeatures());
@@ -44,7 +61,7 @@ public class DecisionTreeDataTest {
         double[][] features = new double[][]{{4, 1}, {3, 3}, {2, 0}, {1, 4}, {0, 2}};
         double[] labels = new double[]{0, 1, 2, 3, 4};
 
-        DecisionTreeData data = new DecisionTreeData(features, labels);
+        DecisionTreeData data = new DecisionTreeData(features, labels, useIndex);
 
         data.sort(0);
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/TreeDataIndexTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/TreeDataIndexTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/TreeDataIndexTest.java
new file mode 100644
index 0000000..78bdfdf
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/TreeDataIndexTest.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.data;
+
+import org.apache.ignite.ml.tree.TreeFilter;
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Test for {@link TreeDataIndex}.
+ */
+public class TreeDataIndexTest {
+    /**  */
+    private double[][] features = {
+        {1., 2., 3., 4.},
+        {2., 3., 4., 1.},
+        {3., 4., 1., 2.},
+        {4., 1., 2., 3.}
+    };
+
+    /** */
+    private double[] labels = {1., 2., 3, 4.};
+
+    /** */
+    private double[][] labelsInSortedOrder = {
+        {1., 4., 3., 2.},
+        {2., 1., 4., 3.},
+        {3., 2., 1., 4.},
+        {4., 3., 2., 1.}
+    };
+
+    /** */
+    private double[][][] featuresInSortedOrder = {
+        {
+            {1., 2., 3., 4.},
+            {4., 1., 2., 3.},
+            {3., 4., 1., 2.},
+            {2., 3., 4., 1.},
+        },
+        {
+            {2., 3., 4., 1.},
+            {1., 2., 3., 4.},
+            {4., 1., 2., 3.},
+            {3., 4., 1., 2.},
+        },
+        {
+            {3., 4., 1., 2.},
+            {2., 3., 4., 1.},
+            {1., 2., 3., 4.},
+            {4., 1., 2., 3.},
+        },
+        {
+            {4., 1., 2., 3.},
+            {3., 4., 1., 2.},
+            {2., 3., 4., 1.},
+            {1., 2., 3., 4.},
+        }
+    };
+
+    /** */
+    private TreeDataIndex index = new TreeDataIndex(features, labels);
+
+    /** */
+    @Test
+    public void labelInSortedOrderTest() {
+        assertEquals(features.length, index.rowsCount());
+        assertEquals(features[0].length, index.columnsCount());
+
+        for (int k = 0; k < index.rowsCount(); k++) {
+            for (int featureId = 0; featureId < index.columnsCount(); featureId++)
+                assertEquals(labelsInSortedOrder[k][featureId], index.labelInSortedOrder(k, featureId), 0.01);
+        }
+    }
+
+    /** */
+    @Test
+    public void featuresInSortedOrderTest() {
+        assertEquals(features.length, index.rowsCount());
+        assertEquals(features[0].length, index.columnsCount());
+
+        for (int k = 0; k < index.rowsCount(); k++) {
+            for (int featureId = 0; featureId < index.columnsCount(); featureId++)
+                assertArrayEquals(featuresInSortedOrder[k][featureId], index.featuresInSortedOrder(k, featureId), 0.01);
+        }
+    }
+
+    /** */
+    @Test
+    public void featureInSortedOrderTest() {
+        assertEquals(features.length, index.rowsCount());
+        assertEquals(features[0].length, index.columnsCount());
+
+        for (int k = 0; k < index.rowsCount(); k++) {
+            for (int featureId = 0; featureId < index.columnsCount(); featureId++)
+                assertEquals((double)k + 1, index.featureInSortedOrder(k, featureId), 0.01);
+        }
+    }
+
+    /** */
+    @Test
+    public void filterTest() {
+        TreeFilter filter1 = features -> features[0] > 2;
+        TreeFilter filter2 = features -> features[1] > 2;
+        TreeFilter filterAnd = filter1.and(features -> features[1] > 2);
+
+        TreeDataIndex filtered1 = index.filter(filter1);
+        TreeDataIndex filtered2 = filtered1.filter(filter2);
+        TreeDataIndex filtered3 = index.filter(filterAnd);
+
+        assertEquals(2, filtered1.rowsCount());
+        assertEquals(4, filtered1.columnsCount());
+        assertEquals(1, filtered2.rowsCount());
+        assertEquals(4, filtered2.columnsCount());
+        assertEquals(1, filtered3.rowsCount());
+        assertEquals(4, filtered3.columnsCount());
+
+        double[] obj1 = {3, 4, 1, 2};
+        double[] obj2 = {4, 1, 2, 3};
+        double[][] restObjs = new double[][] {obj1, obj2};
+        int[][] restObjIndxInSortedOrderPerFeatures = new int[][] {
+            {0, 1}, //feature 0
+            {1, 0}, //feature 1
+            {0, 1}, //feature 2
+            {0, 1}, //feature 3
+        };
+
+        for (int featureId = 0; featureId < filtered1.columnsCount(); featureId++) {
+            for (int k = 0; k < filtered1.rowsCount(); k++) {
+                int objId = restObjIndxInSortedOrderPerFeatures[featureId][k];
+                double[] obj = restObjs[objId];
+                assertArrayEquals(obj, filtered1.featuresInSortedOrder(k, featureId), 0.01);
+            }
+        }
+
+        for (int featureId = 0; featureId < filtered2.columnsCount(); featureId++) {
+            for (int k = 0; k < filtered2.rowsCount(); k++) {
+                assertArrayEquals(obj1, filtered2.featuresInSortedOrder(k, featureId), 0.01);
+                assertArrayEquals(obj1, filtered3.featuresInSortedOrder(k, featureId), 0.01);
+            }
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java
index afd81e8..a328bd7 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java
@@ -17,11 +17,14 @@
 
 package org.apache.ignite.ml.tree.impurity.gini;
 
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
 import org.apache.ignite.ml.tree.data.DecisionTreeData;
 import org.apache.ignite.ml.tree.impurity.util.StepFunction;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 import static junit.framework.TestCase.assertEquals;
 import static org.junit.Assert.assertArrayEquals;
@@ -29,7 +32,21 @@ import static org.junit.Assert.assertArrayEquals;
 /**
  * Tests for {@link GiniImpurityMeasureCalculator}.
  */
+@RunWith(Parameterized.class)
 public class GiniImpurityMeasureCalculatorTest {
+    /** Parameters. */
+    @Parameterized.Parameters(name = "Use index {0}")
+    public static Iterable<Boolean[]> data() {
+        return Arrays.asList(
+            new Boolean[] {true},
+            new Boolean[] {false}
+        );
+    }
+
+    /** Use index. */
+    @Parameterized.Parameter
+    public boolean useIndex;
+
     /** */
     @Test
     public void testCalculate() {
@@ -39,9 +56,9 @@ public class GiniImpurityMeasureCalculatorTest {
         Map<Double, Integer> encoder = new HashMap<>();
         encoder.put(0.0, 0);
         encoder.put(1.0, 1);
-        GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder);
+        GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIndex);
 
-        StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels));
+        StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIndex), fs -> true, 0);
 
         assertEquals(2, impurity.length);
 
@@ -71,9 +88,9 @@ public class GiniImpurityMeasureCalculatorTest {
         Map<Double, Integer> encoder = new HashMap<>();
         encoder.put(0.0, 0);
         encoder.put(1.0, 1);
-        GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder);
+        GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIndex);
 
-        StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels));
+        StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIndex), fs -> true, 0);
 
         assertEquals(1, impurity.length);
 
@@ -94,7 +111,7 @@ public class GiniImpurityMeasureCalculatorTest {
         encoder.put(1.0, 1);
         encoder.put(2.0, 2);
 
-        GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder);
+        GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIndex);
 
         assertEquals(0, calculator.getLabelCode(0.0));
         assertEquals(1, calculator.getLabelCode(1.0));

http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java
index 510c18f..82b3805 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java
@@ -17,9 +17,12 @@
 
 package org.apache.ignite.ml.tree.impurity.mse;
 
+import java.util.Arrays;
 import org.apache.ignite.ml.tree.data.DecisionTreeData;
 import org.apache.ignite.ml.tree.impurity.util.StepFunction;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 import static junit.framework.TestCase.assertEquals;
 import static org.junit.Assert.assertArrayEquals;
@@ -27,16 +30,30 @@ import static org.junit.Assert.assertArrayEquals;
 /**
  * Tests for {@link MSEImpurityMeasureCalculator}.
  */
+@RunWith(Parameterized.class)
 public class MSEImpurityMeasureCalculatorTest {
+    /** Parameters. */
+    @Parameterized.Parameters(name = "Use index {0}")
+    public static Iterable<Boolean[]> data() {
+        return Arrays.asList(
+            new Boolean[] {true},
+            new Boolean[] {false}
+        );
+    }
+
+    /** Use index. */
+    @Parameterized.Parameter
+    public boolean useIndex;
+
     /** */
     @Test
     public void testCalculate() {
         double[][] data = new double[][]{{0, 2}, {1, 1}, {2, 0}, {3, 3}};
         double[] labels = new double[]{1, 2, 2, 1};
 
-        MSEImpurityMeasureCalculator calculator = new MSEImpurityMeasureCalculator();
+        MSEImpurityMeasureCalculator calculator = new MSEImpurityMeasureCalculator(useIndex);
 
-        StepFunction<MSEImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels));
+        StepFunction<MSEImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIndex), fs -> true, 0);
 
         assertEquals(2, impurity.length);
 


Mime
View raw message