ignite-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From a.@apache.org
Subject [43/50] [abbrv] ignite git commit: IGNITE-8741: [ML] Make a tutorial for data preprocessing
Date Wed, 27 Jun 2018 14:50:10 GMT
http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/resources/datasets/titanic_10_rows.csv
----------------------------------------------------------------------
diff --git a/examples/src/main/resources/datasets/titanic_10_rows.csv b/examples/src/main/resources/datasets/titanic_10_rows.csv
new file mode 100644
index 0000000..8146db3
--- /dev/null
+++ b/examples/src/main/resources/datasets/titanic_10_rows.csv
@@ -0,0 +1,11 @@
+´╗┐pclass;survived;name;sex;age;sibsp;parch;ticket;fare;cabin;embarked;boat;body;homedest
+1;1;Allen, Miss. Elisabeth Walton;;29;;;24160;211,3375;B5;;2;;St Louis, MO
+1;1;Allison, Master. Hudson Trevor;male;0,9167;1;2;113781;151,55;C22 C26;S;11;;Montreal,
PQ / Chesterville, ON
+1;0;Allison, Miss. Helen Loraine;female;2;1;2;113781;151,55;C22 C26;S;;;Montreal, PQ / Chesterville,
ON
+1;0;Allison, Mr. Hudson Joshua Creighton;male;30;1;2;113781;151,55;C22 C26;S;;135;Montreal,
PQ / Chesterville, ON
+1;0;Allison, Mrs. Hudson J C (Bessie Waldo Daniels);female;25;1;2;113781;151,55;C22 C26;S;;;Montreal,
PQ / Chesterville, ON
+1;1;Anderson, Mr. Harry;male;48;0;0;19952;26,55;E12;S;3;;New York, NY
+1;1;Andrews, Miss. Kornelia Theodosia;female;63;1;0;13502;77,9583;D7;S;10;;Hudson, NY
+1;0;Andrews, Mr. Thomas Jr;male;39;0;0;112050;0;A36;S;;;Belfast, NI
+1;1;Appleton, Mrs. Edward Dale (Charlotte Lamson);female;53;2;0;11769;51,4792;C101;S;D;;Bayside,
Queens, NY
+1;0;Artagaveytia, Mr. Ramon;male;71;0;0;PC 17609;49,5042;;C;;22;Montevideo, Uruguay

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
b/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
index f721d53..3525feb 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
@@ -29,6 +29,7 @@ import java.util.stream.IntStream;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.math.Vector;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
 import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
@@ -161,7 +162,7 @@ public abstract class BaggingModelTrainer implements DatasetTrainer<ModelsCompos
         Map<Integer, Integer> featuresMapping = createFeaturesMapping(featureExtractorSeed,
featureVectorSize);
 
         //TODO: IGNITE-8867 Need to implement bootstrapping algorithm
-        Model<double[], Double> mdl = buildDatasetTrainerForModel().fit(
+        Model<Vector, Double> mdl = buildDatasetTrainerForModel().fit(
             datasetBuilder.withFilter((features, answer) -> sampleFilter.map(features,
answer) < samplePartSizePerMdl),
             wrapFeatureExtractor(featureExtractor, featuresMapping),
             lbExtractor);
@@ -188,7 +189,7 @@ public abstract class BaggingModelTrainer implements DatasetTrainer<ModelsCompos
     /**
      * Creates trainer specific to ensemble.
      */
-    protected abstract DatasetTrainer<? extends Model<double[], Double>, Double>
buildDatasetTrainerForModel();
+    protected abstract DatasetTrainer<? extends Model<Vector, Double>, Double>
buildDatasetTrainerForModel();
 
     /**
      * Wraps the original feature extractor with features subspace mapping applying.

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java
b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java
index 1de82e3..9077338 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java
@@ -22,11 +22,13 @@ import java.util.List;
 import java.util.Map;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
 
 /**
  * Model consisting of several models and prediction aggregation strategy.
  */
-public class ModelsComposition implements Model<double[], Double> {
+public class ModelsComposition implements Model<Vector, Double> {
     /**
      * Predictions aggregator.
      */
@@ -53,7 +55,7 @@ public class ModelsComposition implements Model<double[], Double>
{
      * @param features Features vector.
      * @return Estimation.
      */
-    @Override public Double apply(double[] features) {
+    @Override public Double apply(Vector features) {
         double[] predictions = new double[models.size()];
 
         for (int i = 0; i < models.size(); i++)
@@ -79,7 +81,7 @@ public class ModelsComposition implements Model<double[], Double>
{
     /**
      * Model trained on a features subspace with mapping from original features space to
subspace.
      */
-    public static class ModelOnFeaturesSubspace implements Model<double[], Double>
{
+    public static class ModelOnFeaturesSubspace implements Model<Vector, Double> {
         /**
          * Features mapping to subspace.
          */
@@ -87,7 +89,7 @@ public class ModelsComposition implements Model<double[], Double>
{
         /**
          * Trained model of features subspace.
          */
-        private final Model<double[], Double> model;
+        private final Model<Vector, Double> mdl;
 
         /**
          * Constructs new instance of ModelOnFeaturesSubspace.
@@ -95,9 +97,9 @@ public class ModelsComposition implements Model<double[], Double>
{
          * @param featuresMapping Features mapping to subspace.
          * @param mdl Learned model.
          */
-        ModelOnFeaturesSubspace(Map<Integer, Integer> featuresMapping, Model<double[],
Double> mdl) {
+        ModelOnFeaturesSubspace(Map<Integer, Integer> featuresMapping, Model<Vector,
Double> mdl) {
             this.featuresMapping = Collections.unmodifiableMap(featuresMapping);
-            this.model = mdl;
+            this.mdl = mdl;
         }
 
         /**
@@ -106,10 +108,10 @@ public class ModelsComposition implements Model<double[], Double>
{
          * @param features Features vector.
          * @return Estimation.
          */
-        @Override public Double apply(double[] features) {
+        @Override public Double apply(Vector features) {
             double[] newFeatures = new double[featuresMapping.size()];
-            featuresMapping.forEach((localId, featureVectorId) -> newFeatures[localId]
= features[featureVectorId]);
-            return model.apply(newFeatures);
+            featuresMapping.forEach((localId, featureVectorId) -> newFeatures[localId]
= features.get(featureVectorId));
+            return mdl.apply(new DenseLocalOnHeapVector(newFeatures));
         }
 
         /**
@@ -122,8 +124,8 @@ public class ModelsComposition implements Model<double[], Double>
{
         /**
          * Returns model.
          */
-        public Model<double[], Double> getModel() {
-            return model;
+        public Model<Vector, Double> getMdl() {
+            return mdl;
         }
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java
b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java
index 4b21e67..275de13 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java
@@ -18,6 +18,7 @@
 package org.apache.ignite.ml.preprocessing.encoding.stringencoder;
 
 import java.util.Map;
+import java.util.Set;
 import org.apache.ignite.ml.math.exceptions.preprocessing.UnknownStringValue;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 
@@ -30,20 +31,27 @@ import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 public class StringEncoderPreprocessor<K, V> implements IgniteBiFunction<K, V, double[]>
{
     /** */
     private static final long serialVersionUID = 6237812226382623469L;
+    /** */
+    private static final String KEY_FOR_NULL_VALUES = "";
 
     /** Filling values. */
     private final Map<String, Integer>[] encodingValues;
 
     /** Base preprocessor. */
-    private final IgniteBiFunction<K, V, String[]> basePreprocessor;
+    private final IgniteBiFunction<K, V, Object[]> basePreprocessor;
+
+    /** Feature indices to apply encoder.*/
+    private final Set<Integer> handledIndices;
 
     /**
      * Constructs a new instance of String Encoder preprocessor.
      *
      * @param basePreprocessor Base preprocessor.
+     * @param handledIndices Handled indices.
      */
     public StringEncoderPreprocessor(Map<String, Integer>[] encodingValues,
-        IgniteBiFunction<K, V, String[]> basePreprocessor) {
+        IgniteBiFunction<K, V, Object[]> basePreprocessor, Set<Integer> handledIndices)
{
+        this.handledIndices = handledIndices;
         this.encodingValues = encodingValues;
         this.basePreprocessor = basePreprocessor;
     }
@@ -56,14 +64,20 @@ public class StringEncoderPreprocessor<K, V> implements IgniteBiFunction<K,
V, d
      * @return Preprocessed row.
      */
     @Override public double[] apply(K k, V v) {
-        String[] tmp = basePreprocessor.apply(k, v);
+        Object[] tmp = basePreprocessor.apply(k, v);
         double[] res = new double[tmp.length];
 
         for (int i = 0; i < res.length; i++) {
-            if (encodingValues[i].containsKey(tmp[i]))
-                res[i] = encodingValues[i].get(tmp[i]);
-            else
-                throw new UnknownStringValue(tmp[i]);
+            Object tmpObj = tmp[i];
+            if(handledIndices.contains(i)){
+                if(tmpObj.equals(Double.NaN) && encodingValues[i].containsKey(KEY_FOR_NULL_VALUES))
+                    res[i] = encodingValues[i].get(KEY_FOR_NULL_VALUES);
+                else if (encodingValues[i].containsKey(tmpObj))
+                    res[i] = encodingValues[i].get(tmpObj);
+                else
+                    throw new UnknownStringValue(tmpObj.toString());
+            } else
+                res[i] = (double)tmpObj;
         }
         return res;
     }

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java
b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java
index 5a4d090..8ed073c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java
@@ -18,8 +18,10 @@
 package org.apache.ignite.ml.preprocessing.encoding.stringencoder;
 
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.Map;
+import java.util.Set;
 import java.util.stream.Collectors;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -27,6 +29,7 @@ import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
+import org.jetbrains.annotations.NotNull;
 
 /**
  * Trainer of the String Encoder preprocessor.
@@ -36,20 +39,26 @@ import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
  * @param <K> Type of a key in {@code upstream} data.
  * @param <V> Type of a value in {@code upstream} data.
  */
-public class StringEncoderTrainer<K, V> implements PreprocessingTrainer<K, V, String[],
double[]> {
+public class StringEncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[],
double[]> {
+    /** Indices of features which should be encoded. */
+    private Set<Integer> handledIndices = new HashSet<>();
+
     /** {@inheritDoc} */
     @Override public StringEncoderPreprocessor<K, V> fit(DatasetBuilder<K, V>
datasetBuilder,
-        IgniteBiFunction<K, V, String[]> basePreprocessor) {
+        IgniteBiFunction<K, V, Object[]> basePreprocessor) {
+        if(handledIndices.isEmpty())
+            throw new RuntimeException("Add indices of handled features");
+
         try (Dataset<EmptyContext, StringEncoderPartitionData> dataset = datasetBuilder.build(
             (upstream, upstreamSize) -> new EmptyContext(),
             (upstream, upstreamSize, ctx) -> {
+                // This array will contain not null values for handled indices
                 Map<String, Integer>[] categoryFrequencies = null;
 
                 while (upstream.hasNext()) {
                     UpstreamEntry<K, V> entity = upstream.next();
-                    String[] row = basePreprocessor.apply(entity.getKey(), entity.getValue());
+                    Object[] row = basePreprocessor.apply(entity.getKey(), entity.getValue());
                     categoryFrequencies = calculateFrequencies(row, categoryFrequencies);
-
                 }
                 return new StringEncoderPartitionData()
                     .withCategoryFrequencies(categoryFrequencies);
@@ -57,7 +66,7 @@ public class StringEncoderTrainer<K, V> implements PreprocessingTrainer<K,
V, St
         )) {
             Map<String, Integer>[] encodingValues = calculateEncodingValuesByFrequencies(dataset);
 
-            return new StringEncoderPreprocessor<>(encodingValues, basePreprocessor);
+            return new StringEncoderPreprocessor<>(encodingValues, basePreprocessor,
handledIndices);
         }
         catch (Exception e) {
             throw new RuntimeException(e);
@@ -84,8 +93,10 @@ public class StringEncoderTrainer<K, V> implements PreprocessingTrainer<K,
V, St
                 assert a.length == b.length;
 
                 for (int i = 0; i < a.length; i++) {
-                    int finalI = i;
-                    a[i].forEach((k, v) -> b[finalI].merge(k, v, (f1, f2) -> f1 + f2));
+                    if(handledIndices.contains(i)){
+                        int finalI = i;
+                        a[i].forEach((k, v) -> b[finalI].merge(k, v, (f1, f2) -> f1
+ f2));
+                    }
                 }
                 return b;
             }
@@ -94,7 +105,8 @@ public class StringEncoderTrainer<K, V> implements PreprocessingTrainer<K,
V, St
         Map<String, Integer>[] res = new HashMap[frequencies.length];
 
         for (int i = 0; i < frequencies.length; i++)
-            res[i] = transformFrequenciesToEncodingValues(frequencies[i]);
+            if(handledIndices.contains(i))
+                res[i] = transformFrequenciesToEncodingValues(frequencies[i]);
 
         return res;
     }
@@ -127,26 +139,57 @@ public class StringEncoderTrainer<K, V> implements PreprocessingTrainer<K,
V, St
      * @param categoryFrequencies Holds the frequencies of categories by values and features.
      * @return Updated frequencies by values and features.
      */
-    private Map<String, Integer>[] calculateFrequencies(String[] row, Map<String,
Integer>[] categoryFrequencies) {
-        if (categoryFrequencies == null) {
-            categoryFrequencies = new HashMap[row.length];
-            for (int i = 0; i < categoryFrequencies.length; i++)
-                categoryFrequencies[i] = new HashMap<>();
-        }
+    private Map<String, Integer>[] calculateFrequencies(Object[] row, Map<String,
Integer>[] categoryFrequencies) {
+        if (categoryFrequencies == null)
+            categoryFrequencies = initializeCategoryFrequencies(row);
         else
-            assert categoryFrequencies.length == row.length : "Base preprocessor must return
exactly " + categoryFrequencies.length
-                + " features";
+            assert categoryFrequencies.length == row.length : "Base preprocessor must return
exactly "
+                + categoryFrequencies.length + " features";
 
         for (int i = 0; i < categoryFrequencies.length; i++) {
-            String s = row[i];
-            Map<String, Integer> map = categoryFrequencies[i];
+            if(handledIndices.contains(i)){
+                String strVal;
+                Object featureVal = row[i];
 
-            if (map.containsKey(s))
-                map.put(s, (map.get(s)) + 1);
-            else
-                map.put(s, 1);
+                if(featureVal.equals(Double.NaN)) {
+                    strVal = "";
+                    row[i] = strVal;
+                }
+                else strVal = (String)featureVal;
+
+                Map<String, Integer> map = categoryFrequencies[i];
+
+                if (map.containsKey(strVal))
+                    map.put(strVal, (map.get(strVal)) + 1);
+                else
+                    map.put(strVal, 1);
+            }
         }
         return categoryFrequencies;
     }
 
+    /**
+     * Initialize frequencies for handled indices only.
+     * @param row Feature vector.
+     * @return The array contains not null values for handled indices.
+     */
+    @NotNull private Map<String, Integer>[] initializeCategoryFrequencies(Object[]
row) {
+        Map<String, Integer>[] categoryFrequencies = new HashMap[row.length];
+
+        for (int i = 0; i < categoryFrequencies.length; i++)
+            if(handledIndices.contains(i))
+                categoryFrequencies[i] = new HashMap<>();
+
+        return categoryFrequencies;
+    }
+
+    /**
+     * Add the index of encoded feature.
+     * @param idx The index of encoded feature.
+     * @return The changed trainer.
+     */
+    public StringEncoderTrainer<K, V> encodeFeature(int idx){
+        handledIndices.add(idx);
+        return this;
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java
index 8ea1490..a5f4607 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionModel.java
@@ -17,14 +17,13 @@
 
 package org.apache.ignite.ml.regressions.logistic.binomial;
 
+import java.io.Serializable;
+import java.util.Objects;
 import org.apache.ignite.ml.Exportable;
 import org.apache.ignite.ml.Exporter;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.math.Vector;
 
-import java.io.Serializable;
-import java.util.Objects;
-
 /**
  * Logistic regression (logit model) is a generalized linear model used for binomial regression.
  */
@@ -132,6 +131,7 @@ public class LogisticRegressionModel implements Model<Vector, Double>,
Exportabl
 
     /** {@inheritDoc} */
     @Override public Double apply(Vector input) {
+
         final double res = sigmoid(input.dot(weights) + intercept);
 
         if (isKeepingRawLabels)

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculator.java
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculator.java
index f885c3e..6aaac81 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculator.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculator.java
@@ -27,6 +27,7 @@ import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.Vector;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.selection.score.ScoreCalculator;
 import org.apache.ignite.ml.selection.score.util.CacheBasedTruthWithPredictionCursor;
@@ -51,7 +52,7 @@ import org.apache.ignite.ml.trainers.DatasetTrainer;
  * @param <K> Type of a key in {@code upstream} data.
  * @param <V> Type of a value in {@code upstream} data.
  */
-public class CrossValidationScoreCalculator<M extends Model<double[], L>, L, K,
V> {
+public class CrossValidationScoreCalculator<M extends Model<Vector, L>, L, K, V>
{
     /**
      * Computes cross-validated metrics.
      *
@@ -109,6 +110,7 @@ public class CrossValidationScoreCalculator<M extends Model<double[],
L>, L, K,
         Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V>
filter,
         IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V,
L> lbExtractor,
         UniformMapper<K, V> mapper, int cv) {
+
         return score(
             trainer,
             predicate -> new CacheBasedDatasetBuilder<>(

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursor.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursor.java
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursor.java
index 862c7ab..7cf6274 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursor.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursor.java
@@ -24,7 +24,9 @@ import org.apache.ignite.cache.query.QueryCursor;
 import org.apache.ignite.cache.query.ScanQuery;
 import org.apache.ignite.lang.IgniteBiPredicate;
 import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.Vector;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
 import org.apache.ignite.ml.selection.score.TruthWithPrediction;
 import org.jetbrains.annotations.NotNull;
 
@@ -46,7 +48,7 @@ public class CacheBasedTruthWithPredictionCursor<L, K, V> implements
TruthWithPr
     private final IgniteBiFunction<K, V, L> lbExtractor;
 
     /** Model for inference. */
-    private final Model<double[], L> mdl;
+    private final Model<Vector, L> mdl;
 
     /**
      * Constructs a new instance of cache based truth with prediction cursor.
@@ -59,7 +61,7 @@ public class CacheBasedTruthWithPredictionCursor<L, K, V> implements
TruthWithPr
      */
     public CacheBasedTruthWithPredictionCursor(IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K,
V> filter,
         IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V,
L> lbExtractor,
-        Model<double[], L> mdl) {
+        Model<Vector, L> mdl) {
         this.cursor = query(upstreamCache, filter);
         this.featureExtractor = featureExtractor;
         this.lbExtractor = lbExtractor;
@@ -118,7 +120,7 @@ public class CacheBasedTruthWithPredictionCursor<L, K, V> implements
TruthWithPr
             double[] features = featureExtractor.apply(entry.getKey(), entry.getValue());
             L lb = lbExtractor.apply(entry.getKey(), entry.getValue());
 
-            return new TruthWithPrediction<>(lb, mdl.apply(features));
+            return new TruthWithPrediction<>(lb, mdl.apply(new DenseLocalOnHeapVector(features)));
         }
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursor.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursor.java
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursor.java
index 093c6ed..50ca2bd 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursor.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursor.java
@@ -22,7 +22,9 @@ import java.util.Map;
 import java.util.NoSuchElementException;
 import org.apache.ignite.lang.IgniteBiPredicate;
 import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.Vector;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
 import org.apache.ignite.ml.selection.score.TruthWithPrediction;
 import org.jetbrains.annotations.NotNull;
 
@@ -33,7 +35,7 @@ import org.jetbrains.annotations.NotNull;
  * @param <K> Type of a key in {@code upstream} data.
  * @param <V> Type of a value in {@code upstream} data.
  */
-public class LocalTruthWithPredictionCursor<L, K, V> implements TruthWithPredictionCursor<L>
{
+public class LocalTruthWithPredictionCursor<L, K, V, T> implements TruthWithPredictionCursor<L>
{
     /** Map with {@code upstream} data. */
     private final Map<K, V> upstreamMap;
 
@@ -47,7 +49,7 @@ public class LocalTruthWithPredictionCursor<L, K, V> implements TruthWithPredict
     private final IgniteBiFunction<K, V, L> lbExtractor;
 
     /** Model for inference. */
-    private final Model<double[], L> mdl;
+    private final Model<Vector, L> mdl;
 
     /**
      * Constructs a new instance of local truth with prediction cursor.
@@ -60,7 +62,7 @@ public class LocalTruthWithPredictionCursor<L, K, V> implements TruthWithPredict
      */
     public LocalTruthWithPredictionCursor(Map<K, V> upstreamMap, IgniteBiPredicate<K,
V> filter,
         IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V,
L> lbExtractor,
-        Model<double[], L> mdl) {
+        Model<Vector, L> mdl) {
         this.upstreamMap = upstreamMap;
         this.filter = filter;
         this.featureExtractor = featureExtractor;
@@ -117,7 +119,7 @@ public class LocalTruthWithPredictionCursor<L, K, V> implements
TruthWithPredict
 
             nextEntry = null;
 
-            return new TruthWithPrediction<>(lb, mdl.apply(features));
+            return new TruthWithPrediction<>(lb, mdl.apply(new DenseLocalOnHeapVector(features)));
         }
 
         /**

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java
index 9818239..fcb134b 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java
@@ -17,6 +17,8 @@
 
 package org.apache.ignite.ml.tree;
 
+import org.apache.ignite.ml.math.Vector;
+
 /**
  * Decision tree conditional (non-leaf) node.
  */
@@ -52,8 +54,8 @@ public class DecisionTreeConditionalNode implements DecisionTreeNode {
     }
 
     /** {@inheritDoc} */
-    @Override public Double apply(double[] features) {
-        return features[col] > threshold ? thenNode.apply(features) : elseNode.apply(features);
+    @Override public Double apply(Vector features) {
+        return features.get(col) > threshold ? thenNode.apply(features) : elseNode.apply(features);
     }
 
     /** */

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java
index 4c6369d..b3645dd 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java
@@ -17,6 +17,8 @@
 
 package org.apache.ignite.ml.tree;
 
+import org.apache.ignite.ml.math.Vector;
+
 /**
  * Decision tree leaf node which contains value.
  */
@@ -37,7 +39,7 @@ public class DecisionTreeLeafNode implements DecisionTreeNode {
     }
 
     /** {@inheritDoc} */
-    @Override public Double apply(double[] doubles) {
+    @Override public Double apply(Vector doubles) {
         return val;
     }
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java
index 94878eb..55afc52 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java
@@ -18,9 +18,10 @@
 package org.apache.ignite.ml.tree;
 
 import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.Vector;
 
 /**
  * Base interface for decision tree nodes.
  */
-public interface DecisionTreeNode extends Model<double[], Double> {
+public interface DecisionTreeNode extends Model<Vector, Double> {
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java
index 67109ae..819af2b 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java
@@ -63,6 +63,7 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable>
             UpstreamEntry<K, V> entry = upstreamData.next();
 
             features[ptr] = featureExtractor.apply(entry.getKey(), entry.getValue());
+
             labels[ptr] = lbExtractor.apply(entry.getKey(), entry.getValue());
 
             ptr++;

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java
index d74b923..d8c3aa0 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java
@@ -18,6 +18,7 @@
 package org.apache.ignite.ml.preprocessing.encoding;
 
 import java.util.HashMap;
+import java.util.HashSet;
 import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderPreprocessor;
 import org.junit.Test;
 
@@ -52,8 +53,14 @@ public class StringEncoderPreprocessorTest {
                     put("B", 0);
                 }
             }},
-            (k, v) -> v
-        );
+            (k, v) -> v,
+            new HashSet() {
+                {
+                    add(0);
+                    add(1);
+                    add(2);
+                }
+            });
 
         double[][] postProcessedData = new double[][]{
             {1.0, 0.0, 1.0},

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java
index aa17beb..cc79584 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java
@@ -66,7 +66,9 @@ public class StringEncoderTrainerTest {
 
         DatasetBuilder<Integer, String[]> datasetBuilder = new LocalDatasetBuilder<>(data,
parts);
 
-        StringEncoderTrainer<Integer, String[]> strEncoderTrainer = new StringEncoderTrainer<>();
+        StringEncoderTrainer<Integer, String[]> strEncoderTrainer = new StringEncoderTrainer<Integer,
String[]>()
+            .encodeFeature(0)
+            .encodeFeature(1);
 
         StringEncoderPreprocessor<Integer, String[]> preprocessor = strEncoderTrainer.fit(
             datasetBuilder,

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursorTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursorTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursorTest.java
index 7eba10f..220f8d8 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursorTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursorTest.java
@@ -65,7 +65,7 @@ public class CacheBasedTruthWithPredictionCursorTest extends GridCommonAbstractT
             (k, v) -> v % 2 == 0,
             (k, v) -> new double[]{v},
             (k, v) -> v,
-            arr -> (int)arr[0]
+            vec -> (int)vec.get(0)
         );
 
         int cnt = 0;

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursorTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursorTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursorTest.java
index 3fc3c83..66a8fcd 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursorTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursorTest.java
@@ -41,7 +41,7 @@ public class LocalTruthWithPredictionCursorTest {
             (k, v) -> v % 2 == 0,
             (k, v) -> new double[]{v},
             (k, v) -> v,
-            arr -> (int)arr[0]
+            vec -> (int)vec.get(0)
         );
 
         int cnt = 0;

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java
index 35f805e..f83ae7c 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java
@@ -17,11 +17,13 @@
 
 package org.apache.ignite.ml.tree.performance;
 
+import java.io.IOException;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
 import org.apache.ignite.ml.nn.performance.MnistMLPTestUtil;
 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeNode;
@@ -29,8 +31,6 @@ import org.apache.ignite.ml.tree.impurity.util.SimpleStepFunctionCompressor;
 import org.apache.ignite.ml.util.MnistUtils;
 import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
 
-import java.io.IOException;
-
 /**
  * Tests {@link DecisionTreeClassificationTrainer} on the MNIST dataset that require to start
the whole Ignite
  * infrastructure. For manual run.
@@ -91,7 +91,7 @@ public class DecisionTreeMNISTIntegrationTest extends GridCommonAbstractTest
{
         int incorrectAnswers = 0;
 
         for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTestSet(10_000)) {
-            double res = mdl.apply(e.getPixels());
+            double res = mdl.apply(new DenseLocalOnHeapVector(e.getPixels()));
 
             if (res == e.getLabel())
                 correctAnswers++;

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java
index b40c7ac..c9e9fb2 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java
@@ -17,6 +17,10 @@
 
 package org.apache.ignite.ml.tree.performance;
 
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
 import org.apache.ignite.ml.nn.performance.MnistMLPTestUtil;
 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeNode;
@@ -24,10 +28,6 @@ import org.apache.ignite.ml.tree.impurity.util.SimpleStepFunctionCompressor;
 import org.apache.ignite.ml.util.MnistUtils;
 import org.junit.Test;
 
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
-
 import static junit.framework.TestCase.assertTrue;
 
 /**
@@ -60,7 +60,7 @@ public class DecisionTreeMNISTTest {
         int incorrectAnswers = 0;
 
         for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTestSet(10_000)) {
-            double res = mdl.apply(e.getPixels());
+            double res = mdl.apply(new DenseLocalOnHeapVector(e.getPixels()));
 
             if (res == e.getLabel())
                 correctAnswers++;

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
index d581d6d..0494249 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
@@ -73,7 +73,7 @@ public class RandomForestClassifierTrainerTest {
         assertEquals(5, model.getModels().size());
 
         for (ModelsComposition.ModelOnFeaturesSubspace tree : model.getModels()) {
-            assertTrue(tree.getModel() instanceof DecisionTreeConditionalNode);
+            assertTrue(tree.getMdl() instanceof DecisionTreeConditionalNode);
             assertEquals(3, tree.getFeaturesMapping().size());
         }
     }

http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
index f7594a3..418a98c 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
@@ -73,7 +73,7 @@ public class RandomForestRegressionTrainerTest {
         assertEquals(5, model.getModels().size());
 
         for (ModelsComposition.ModelOnFeaturesSubspace tree : model.getModels()) {
-            assertTrue(tree.getModel() instanceof DecisionTreeConditionalNode);
+            assertTrue(tree.getMdl() instanceof DecisionTreeConditionalNode);
             assertEquals(3, tree.getFeaturesMapping().size());
         }
     }


Mime
View raw message