ignite-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ch...@apache.org
Subject [1/2] ignite git commit: IGNITE-9261: [ML] Add ANN algorithm based on ACD concept
Date Wed, 15 Aug 2018 16:09:46 GMT
Repository: ignite
Updated Branches:
  refs/heads/master d57c3d6e0 -> 82131d23a


http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java
deleted file mode 100644
index 2440587..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java
+++ /dev/null
@@ -1,220 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.structures;
-
-import org.apache.ignite.ml.math.exceptions.CardinalityException;
-import org.apache.ignite.ml.math.exceptions.NoDataException;
-import org.apache.ignite.ml.math.exceptions.knn.NoLabelVectorException;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
-
-/**
- * Class for set of labeled vectors.
- */
-public class LabeledDataset<L, Row extends LabeledVector> extends Dataset<Row> implements AutoCloseable {
-    /**
-     * Default constructor (required by Externalizable).
-     */
-    public LabeledDataset() {
-        super();
-    }
-
-    /**
-     * Creates new Labeled Dataset and initialized with empty data structure.
-     *
-     * @param rowSize Amount of instances. Should be > 0.
-     * @param colSize Amount of attributes. Should be > 0.
-     * @param isDistributed Use distributed data structures to keep data.
-     */
-    public LabeledDataset(int rowSize, int colSize,  boolean isDistributed){
-        this(rowSize, colSize, null, isDistributed);
-    }
-
-    /**
-     * Creates new local Labeled Dataset and initialized with empty data structure.
-     *
-     * @param rowSize Amount of instances. Should be > 0.
-     * @param colSize Amount of attributes. Should be > 0.
-     */
-    public LabeledDataset(int rowSize, int colSize){
-        this(rowSize, colSize, null, false);
-    }
-
-    /**
-     * Creates new Labeled Dataset and initialized with empty data structure.
-     *
-     * @param rowSize Amount of instances. Should be > 0.
-     * @param colSize Amount of attributes. Should be > 0
-     * @param featureNames Column names.
-     * @param isDistributed Use distributed data structures to keep data.
-     */
-    public LabeledDataset(int rowSize, int colSize, String[] featureNames, boolean isDistributed){
-        super(rowSize, colSize, featureNames, isDistributed);
-
-        initializeDataWithLabeledVectors();
-    }
-
-    /**
-     * Creates new Labeled Dataset by given data.
-     *
-     * @param data Should be initialized with one vector at least.
-     */
-    public LabeledDataset(Row[] data) {
-        super(data);
-    }
-
-    /** */
-    private void initializeDataWithLabeledVectors() {
-        data = (Row[])new LabeledVector[rowSize];
-        for (int i = 0; i < rowSize; i++)
-            data[i] = (Row)new LabeledVector(emptyVector(colSize, isDistributed), null);
-    }
-
-    /**
-     * Creates new Labeled Dataset by given data.
-     *
-     * @param data Should be initialized with one vector at least.
-     * @param colSize Amount of observed attributes in each vector.
-     */
-    public LabeledDataset(Row[] data, int colSize) {
-        super(data, colSize);
-    }
-
-
-    /**
-     * Creates new local Labeled Dataset by matrix and vector of labels.
-     *
-     * @param mtx Given matrix with rows as observations.
-     * @param lbs Labels of observations.
-     */
-    public LabeledDataset(double[][] mtx, double[] lbs) {
-       this(mtx, lbs, null, false);
-    }
-
-    /**
-     * Creates new Labeled Dataset by matrix and vector of labels.
-     *
-     * @param mtx Given matrix with rows as observations.
-     * @param lbs Labels of observations.
-     * @param featureNames Column names.
-     * @param isDistributed Use distributed data structures to keep data.
-     */
-    public LabeledDataset(double[][] mtx, double[] lbs, String[] featureNames, boolean isDistributed) {
-        super();
-        assert mtx != null;
-        assert lbs != null;
-
-        if(mtx.length != lbs.length)
-            throw new CardinalityException(lbs.length, mtx.length);
-
-        if(mtx[0] == null)
-            throw new NoDataException("Pass filled array, the first vector is empty");
-
-        this.rowSize = lbs.length;
-        this.colSize = mtx[0].length;
-
-        if(featureNames == null)
-            generateFeatureNames();
-        else {
-            assert colSize == featureNames.length;
-            convertStringNamesToFeatureMetadata(featureNames);
-        }
-
-        data = (Row[])new LabeledVector[rowSize];
-        for (int i = 0; i < rowSize; i++){
-
-            data[i] = (Row)new LabeledVector(emptyVector(colSize, isDistributed), lbs[i]);
-            for (int j = 0; j < colSize; j++) {
-                try {
-                    data[i].features().set(j, mtx[i][j]);
-                } catch (ArrayIndexOutOfBoundsException e) {
-                    throw new NoDataException("No data in given matrix by coordinates (" + i + "," + j + ")");
-                }
-            }
-        }
-    }
-
-    /**
-     * Returns label if label is attached or null if label is missed.
-     *
-     * @param idx Index of observation.
-     * @return Label.
-     */
-    public double label(int idx) {
-        LabeledVector labeledVector = data[idx];
-
-        if(labeledVector!=null)
-            return (double)labeledVector.label();
-        else
-            return Double.NaN;
-    }
-
-    /**
-     * Returns new copy of labels of all labeled vectors NOTE: This method is useful for copying labels from test
-     * dataset.
-     *
-     * @return Copy of labels.
-     */
-    public double[] labels() {
-        assert data != null;
-        assert data.length > 0;
-
-        double[] labels = new double[data.length];
-
-        for (int i = 0; i < data.length; i++)
-            labels[i] = (double)data[i].label();
-
-        return labels;
-    }
-
-    /**
-     * Fill the label with given value.
-     *
-     * @param idx Index of observation.
-     * @param lb The given label.
-     */
-    public void setLabel(int idx, double lb) {
-        LabeledVector<Vector, Double> labeledVector = data[idx];
-
-        if(labeledVector != null)
-            labeledVector.setLabel(lb);
-        else
-            throw new NoLabelVectorException(idx);
-    }
-
-    /** */
-    public static Vector emptyVector(int size, boolean isDistributed) {
-            return new DenseVector(size);
-    }
-
-    /** Makes copy with new Label objects and old features and Metadata objects. */
-    public LabeledDataset copy(){
-        LabeledDataset res = new LabeledDataset(this.data, this.colSize);
-        res.isDistributed = this.isDistributed;
-        res.meta = this.meta;
-        for (int i = 0; i < rowSize; i++)
-            res.setLabel(i, this.label(i));
-
-        return res;
-    }
-
-    /** Closes LabeledDataset. */
-    @Override public void close() throws Exception {
-
-    }
-}

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java
deleted file mode 100644
index f362fbc..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.structures;
-
-import java.io.Serializable;
-import java.util.Map;
-import java.util.Random;
-import java.util.TreeMap;
-import java.util.TreeSet;
-import org.jetbrains.annotations.NotNull;
-
-/**
- * Class for splitting Labeled Dataset on train and test sets.
- */
-public class LabeledDatasetTestTrainPair implements Serializable {
-    /** Data to keep train set. */
-    private LabeledDataset train;
-
-    /** Data to keep test set. */
-    private LabeledDataset test;
-
-    /**
-     * Creates two subsets of given dataset.
-     * <p>
-     * NOTE: This method uses next algorithm with O(n log n) by calculations and O(n) by memory.
-     * </p>
-     * @param dataset The dataset to split on train and test subsets.
-     * @param testPercentage The percentage of the test subset.
-     */
-    public LabeledDatasetTestTrainPair(LabeledDataset dataset, double testPercentage) {
-        assert testPercentage > 0.0;
-        assert testPercentage < 1.0;
-        final int datasetSize = dataset.rowSize();
-        assert datasetSize > 2;
-
-        final int testSize = (int)Math.floor(datasetSize * testPercentage);
-        final int trainSize = datasetSize - testSize;
-
-        final TreeSet<Integer> sortedTestIndices = getSortedIndices(datasetSize, testSize);
-
-        LabeledVector[] testVectors = new LabeledVector[testSize];
-        LabeledVector[] trainVectors = new LabeledVector[trainSize];
-
-        int datasetCntr = 0;
-        int trainCntr = 0;
-        int testCntr = 0;
-
-        for (Integer idx: sortedTestIndices){ // guarantee order as iterator
-            testVectors[testCntr] = (LabeledVector)dataset.getRow(idx);
-            testCntr++;
-
-            for (int i = datasetCntr; i < idx; i++) {
-                trainVectors[trainCntr] = (LabeledVector)dataset.getRow(i);
-                trainCntr++;
-            }
-
-            datasetCntr = idx + 1;
-        }
-        if(datasetCntr < datasetSize){
-            for (int i = datasetCntr; i < datasetSize; i++) {
-                trainVectors[trainCntr] = (LabeledVector)dataset.getRow(i);
-                trainCntr++;
-            }
-        }
-
-        test = new LabeledDataset(testVectors, dataset.colSize());
-        train = new LabeledDataset(trainVectors, dataset.colSize());
-    }
-
-    /** This method generates "random double, integer" pairs, sort them, gets first "testSize" elements and returns appropriate indices */
-    @NotNull private TreeSet<Integer> getSortedIndices(int datasetSize, int testSize) {
-        Random rnd = new Random();
-        TreeMap<Double, Integer> randomIdxPairs = new TreeMap<>();
-        for (int i = 0; i < datasetSize; i++)
-            randomIdxPairs.put(rnd.nextDouble(), i);
-
-        final TreeMap<Double, Integer> testIdxPairs = randomIdxPairs.entrySet().stream()
-            .limit(testSize)
-            .collect(TreeMap::new, (m, e) -> m.put(e.getKey(), e.getValue()), Map::putAll);
-
-        return new TreeSet<>(testIdxPairs.values());
-    }
-
-    /**
-     * Train subset of the whole dataset.
-     * @return Train subset.
-     */
-    public LabeledDataset train() {
-        return train;
-    }
-
-    /**
-     * Test subset of the whole dataset.
-     * @return Test subset.
-     */
-    public LabeledDataset test() {
-        return test;
-    }
-}

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSet.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSet.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSet.java
new file mode 100644
index 0000000..e98d793
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSet.java
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.structures;
+
+import org.apache.ignite.ml.math.exceptions.CardinalityException;
+import org.apache.ignite.ml.math.exceptions.NoDataException;
+import org.apache.ignite.ml.math.exceptions.knn.NoLabelVectorException;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+
+/**
+ * The set of labeled vectors used in local partition calculations.
+ */
+public class LabeledVectorSet<L, Row extends LabeledVector> extends Dataset<Row> implements AutoCloseable {
+    /**
+     * Default constructor (required by Externalizable).
+     */
+    public LabeledVectorSet() {
+        super();
+    }
+
+    /**
+     * Creates new Labeled Dataset and initialized with empty data structure.
+     *
+     * @param rowSize Amount of instances. Should be > 0.
+     * @param colSize Amount of attributes. Should be > 0.
+     * @param isDistributed Use distributed data structures to keep data.
+     */
+    public LabeledVectorSet(int rowSize, int colSize, boolean isDistributed){
+        this(rowSize, colSize, null, isDistributed);
+    }
+
+    /**
+     * Creates new local Labeled Dataset and initialized with empty data structure.
+     *
+     * @param rowSize Amount of instances. Should be > 0.
+     * @param colSize Amount of attributes. Should be > 0.
+     */
+    public LabeledVectorSet(int rowSize, int colSize){
+        this(rowSize, colSize, null, false);
+    }
+
+    /**
+     * Creates new Labeled Dataset and initialized with empty data structure.
+     *
+     * @param rowSize Amount of instances. Should be > 0.
+     * @param colSize Amount of attributes. Should be > 0
+     * @param featureNames Column names.
+     * @param isDistributed Use distributed data structures to keep data.
+     */
+    public LabeledVectorSet(int rowSize, int colSize, String[] featureNames, boolean isDistributed){
+        super(rowSize, colSize, featureNames, isDistributed);
+
+        initializeDataWithLabeledVectors();
+    }
+
+    /**
+     * Creates new Labeled Dataset by given data.
+     *
+     * @param data Should be initialized with one vector at least.
+     */
+    public LabeledVectorSet(Row[] data) {
+        super(data);
+    }
+
+    /** */
+    private void initializeDataWithLabeledVectors() {
+        data = (Row[])new LabeledVector[rowSize];
+        for (int i = 0; i < rowSize; i++)
+            data[i] = (Row)new LabeledVector(emptyVector(colSize, isDistributed), null);
+    }
+
+    /**
+     * Creates new Labeled Dataset by given data.
+     *
+     * @param data Should be initialized with one vector at least.
+     * @param colSize Amount of observed attributes in each vector.
+     */
+    public LabeledVectorSet(Row[] data, int colSize) {
+        super(data, colSize);
+    }
+
+
+    /**
+     * Creates new local Labeled Dataset by matrix and vector of labels.
+     *
+     * @param mtx Given matrix with rows as observations.
+     * @param lbs Labels of observations.
+     */
+    public LabeledVectorSet(double[][] mtx, double[] lbs) {
+       this(mtx, lbs, null, false);
+    }
+
+    /**
+     * Creates new Labeled Dataset by matrix and vector of labels.
+     *
+     * @param mtx Given matrix with rows as observations.
+     * @param lbs Labels of observations.
+     * @param featureNames Column names.
+     * @param isDistributed Use distributed data structures to keep data.
+     */
+    public LabeledVectorSet(double[][] mtx, double[] lbs, String[] featureNames, boolean isDistributed) {
+        super();
+        assert mtx != null;
+        assert lbs != null;
+
+        if(mtx.length != lbs.length)
+            throw new CardinalityException(lbs.length, mtx.length);
+
+        if(mtx[0] == null)
+            throw new NoDataException("Pass filled array, the first vector is empty");
+
+        this.rowSize = lbs.length;
+        this.colSize = mtx[0].length;
+
+        if(featureNames == null)
+            generateFeatureNames();
+        else {
+            assert colSize == featureNames.length;
+            convertStringNamesToFeatureMetadata(featureNames);
+        }
+
+        data = (Row[])new LabeledVector[rowSize];
+        for (int i = 0; i < rowSize; i++){
+
+            data[i] = (Row)new LabeledVector(emptyVector(colSize, isDistributed), lbs[i]);
+            for (int j = 0; j < colSize; j++) {
+                try {
+                    data[i].features().set(j, mtx[i][j]);
+                } catch (ArrayIndexOutOfBoundsException e) {
+                    throw new NoDataException("No data in given matrix by coordinates (" + i + "," + j + ")");
+                }
+            }
+        }
+    }
+
+    /**
+     * Returns label if label is attached or null if label is missed.
+     *
+     * @param idx Index of observation.
+     * @return Label.
+     */
+    public double label(int idx) {
+        LabeledVector labeledVector = data[idx];
+
+        if(labeledVector!=null)
+            return (double)labeledVector.label();
+        else
+            return Double.NaN;
+    }
+
+    /**
+     * Returns new copy of labels of all labeled vectors NOTE: This method is useful for copying labels from test
+     * dataset.
+     *
+     * @return Copy of labels.
+     */
+    public double[] labels() {
+        assert data != null;
+        assert data.length > 0;
+
+        double[] labels = new double[data.length];
+
+        for (int i = 0; i < data.length; i++)
+            labels[i] = (double)data[i].label();
+
+        return labels;
+    }
+
+    /**
+     * Fill the label with given value.
+     *
+     * @param idx Index of observation.
+     * @param lb The given label.
+     */
+    public void setLabel(int idx, double lb) {
+        LabeledVector<Vector, Double> labeledVector = data[idx];
+
+        if(labeledVector != null)
+            labeledVector.setLabel(lb);
+        else
+            throw new NoLabelVectorException(idx);
+    }
+
+    /** */
+    public static Vector emptyVector(int size, boolean isDistributed) {
+            return new DenseVector(size);
+    }
+
+    /** Makes copy with new Label objects and old features and Metadata objects. */
+    public LabeledVectorSet copy(){
+        LabeledVectorSet res = new LabeledVectorSet(this.data, this.colSize);
+        res.isDistributed = this.isDistributed;
+        res.meta = this.meta;
+        for (int i = 0; i < rowSize; i++)
+            res.setLabel(i, this.label(i));
+
+        return res;
+    }
+
+    /** Closes LabeledDataset. */
+    @Override public void close() throws Exception {
+
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSetTestTrainPair.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSetTestTrainPair.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSetTestTrainPair.java
new file mode 100644
index 0000000..d06dfd0
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSetTestTrainPair.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.structures;
+
+import java.io.Serializable;
+import java.util.Map;
+import java.util.Random;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * Class for splitting Labeled Dataset on train and test sets.
+ */
+public class LabeledVectorSetTestTrainPair implements Serializable {
+    /** Data to keep train set. */
+    private LabeledVectorSet train;
+
+    /** Data to keep test set. */
+    private LabeledVectorSet test;
+
+    /**
+     * Creates two subsets of given dataset.
+     * <p>
+     * NOTE: This method uses next algorithm with O(n log n) by calculations and O(n) by memory.
+     * </p>
+     * @param dataset The dataset to split on train and test subsets.
+     * @param testPercentage The percentage of the test subset.
+     */
+    public LabeledVectorSetTestTrainPair(LabeledVectorSet dataset, double testPercentage) {
+        assert testPercentage > 0.0;
+        assert testPercentage < 1.0;
+        final int datasetSize = dataset.rowSize();
+        assert datasetSize > 2;
+
+        final int testSize = (int)Math.floor(datasetSize * testPercentage);
+        final int trainSize = datasetSize - testSize;
+
+        final TreeSet<Integer> sortedTestIndices = getSortedIndices(datasetSize, testSize);
+
+        LabeledVector[] testVectors = new LabeledVector[testSize];
+        LabeledVector[] trainVectors = new LabeledVector[trainSize];
+
+        int datasetCntr = 0;
+        int trainCntr = 0;
+        int testCntr = 0;
+
+        for (Integer idx: sortedTestIndices){ // guarantee order as iterator
+            testVectors[testCntr] = (LabeledVector)dataset.getRow(idx);
+            testCntr++;
+
+            for (int i = datasetCntr; i < idx; i++) {
+                trainVectors[trainCntr] = (LabeledVector)dataset.getRow(i);
+                trainCntr++;
+            }
+
+            datasetCntr = idx + 1;
+        }
+        if(datasetCntr < datasetSize){
+            for (int i = datasetCntr; i < datasetSize; i++) {
+                trainVectors[trainCntr] = (LabeledVector)dataset.getRow(i);
+                trainCntr++;
+            }
+        }
+
+        test = new LabeledVectorSet(testVectors, dataset.colSize());
+        train = new LabeledVectorSet(trainVectors, dataset.colSize());
+    }
+
+    /** This method generates "random double, integer" pairs, sort them, gets first "testSize" elements and returns appropriate indices */
+    @NotNull private TreeSet<Integer> getSortedIndices(int datasetSize, int testSize) {
+        Random rnd = new Random();
+        TreeMap<Double, Integer> randomIdxPairs = new TreeMap<>();
+        for (int i = 0; i < datasetSize; i++)
+            randomIdxPairs.put(rnd.nextDouble(), i);
+
+        final TreeMap<Double, Integer> testIdxPairs = randomIdxPairs.entrySet().stream()
+            .limit(testSize)
+            .collect(TreeMap::new, (m, e) -> m.put(e.getKey(), e.getValue()), Map::putAll);
+
+        return new TreeSet<>(testIdxPairs.values());
+    }
+
+    /**
+     * Train subset of the whole dataset.
+     * @return Train subset.
+     */
+    public LabeledVectorSet train() {
+        return train;
+    }
+
+    /**
+     * Test subset of the whole dataset.
+     * @return Test subset.
+     */
+    public LabeledVectorSet test() {
+        return test;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java
index b4e552b..0351037 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java
@@ -23,18 +23,18 @@ import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.LabeledDataset;
 import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 
 /**
- * Partition data builder that builds {@link LabeledDataset}.
+ * Partition data builder that builds {@link LabeledVectorSet}.
  *
  * @param <K> Type of a key in <tt>upstream</tt> data.
  * @param <V> Type of a value in <tt>upstream</tt> data.
  * @param <C> Type of a partition <tt>context</tt>.
  */
 public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializable>
-    implements PartitionDataBuilder<K, V, C, LabeledDataset<Double, LabeledVector>> {
+    implements PartitionDataBuilder<K, V, C, LabeledVectorSet<Double, LabeledVector>> {
     /** */
     private static final long serialVersionUID = -7820760153954269227L;
 
@@ -57,8 +57,8 @@ public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializab
     }
 
     /** {@inheritDoc} */
-    @Override public LabeledDataset<Double, LabeledVector> build(Iterator<UpstreamEntry<K, V>> upstreamData,
-        long upstreamDataSize, C ctx) {
+    @Override public LabeledVectorSet<Double, LabeledVector> build(Iterator<UpstreamEntry<K, V>> upstreamData,
+                                                                   long upstreamDataSize, C ctx) {
         int xCols = -1;
         double[][] x = null;
         double[] y = new double[Math.toIntExact(upstreamDataSize)];
@@ -82,6 +82,6 @@ public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializab
 
             ptr++;
         }
-        return new LabeledDataset<>(x, y);
+        return new LabeledVectorSet<>(x, y);
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java
index 5c20d9c..f370cbd 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java
@@ -28,8 +28,8 @@ import org.apache.ignite.ml.math.exceptions.NoDataException;
 import org.apache.ignite.ml.math.exceptions.knn.EmptyFileException;
 import org.apache.ignite.ml.math.exceptions.knn.FileParsingException;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.LabeledDataset;
 import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.jetbrains.annotations.NotNull;
 
 /** Data pre-processing step which loads data from different file types. */
@@ -43,8 +43,8 @@ public class LabeledDatasetLoader {
      * @param isFallOnBadData Fall on incorrect data if true.
      * @return Labeled Dataset parsed from file.
      */
-    public static LabeledDataset loadFromTxtFile(Path pathToFile, String separator, boolean isDistributed,
-        boolean isFallOnBadData) throws IOException {
+    public static LabeledVectorSet loadFromTxtFile(Path pathToFile, String separator, boolean isDistributed,
+                                                   boolean isFallOnBadData) throws IOException {
         Stream<String> stream = Files.lines(pathToFile);
         List<String> list = new ArrayList<>();
         stream.forEach(list::add);
@@ -81,7 +81,7 @@ public class LabeledDatasetLoader {
                 for (int i = 0; i < vectors.size(); i++)
                     data[i] = new LabeledVector(vectors.get(i), labels.get(i));
 
-                return new LabeledDataset(data, colSize);
+                return new LabeledVectorSet(data, colSize);
             }
             else
                 throw new NoDataException("File should contain first row with data");
@@ -93,7 +93,7 @@ public class LabeledDatasetLoader {
     /** */
     @NotNull private static Vector parseFeatures(Path pathToFile, boolean isDistributed, boolean isFallOnBadData,
         int colSize, int rowIdx, String[] rowData) {
-        final Vector vec = LabeledDataset.emptyVector(colSize, isDistributed);
+        final Vector vec = LabeledVectorSet.emptyVector(colSize, isDistributed);
 
         if (isFallOnBadData && rowData.length != colSize + 1)
             throw new CardinalityException(colSize + 1, rowData.length);

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
index 1ae896f..4f11318 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
@@ -25,8 +25,8 @@ import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
-import org.apache.ignite.ml.structures.LabeledDataset;
 import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
 import org.jetbrains.annotations.NotNull;
@@ -60,14 +60,14 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
 
         assert datasetBuilder != null;
 
-        PartitionDataBuilder<K, V, EmptyContext, LabeledDataset<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
+        PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
             featureExtractor,
             lbExtractor
         );
 
         Vector weights;
 
-        try(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset = datasetBuilder.build(
+        try(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build(
             (upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
@@ -91,7 +91,7 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
     }
 
     /** */
-    private Vector calculateUpdates(Vector weights, Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset) {
+    private Vector calculateUpdates(Vector weights, Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
         return dataset.compute(data -> {
             Vector copiedWeights = weights.copy();
             Vector deltaWeights = initializeWeightsWithZeros(weights.size());
@@ -116,8 +116,8 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
     }
 
     /** */
-    private Deltas getDeltas(LabeledDataset data, Vector copiedWeights, int amountOfObservation, Vector tmpAlphas,
-        int randomIdx) {
+    private Deltas getDeltas(LabeledVectorSet data, Vector copiedWeights, int amountOfObservation, Vector tmpAlphas,
+                             int randomIdx) {
         LabeledVector row = (LabeledVector)data.getRow(randomIdx);
         Double lb = (Double)row.label();
         Vector v = makeVectorWithInterceptElement(row);

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
index 3e3bab5..42f5dec 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
@@ -28,13 +28,20 @@ import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
 import org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat;
 import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.knn.NNClassificationModel;
+import org.apache.ignite.ml.knn.ann.ANNClassificationModel;
+import org.apache.ignite.ml.knn.ann.ANNModelFormat;
+import org.apache.ignite.ml.knn.ann.ProbableLabel;
 import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
 import org.apache.ignite.ml.knn.classification.KNNModelFormat;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.distances.ManhattanDistance;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
 import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
 import org.junit.Assert;
@@ -165,10 +172,10 @@ public class LocalModelsTest {
     @Test
     public void importExportKNNModelTest() throws IOException {
         executeModelTest(mdlFilePath -> {
-            KNNClassificationModel mdl = new KNNClassificationModel(null)
+            NNClassificationModel mdl = new KNNClassificationModel(null)
                 .withK(3)
                 .withDistanceMeasure(new EuclideanDistance())
-                .withStrategy(KNNStrategy.SIMPLE);
+                .withStrategy(NNStrategy.SIMPLE);
 
             Exporter<KNNModelFormat, String> exporter = new FileExporter<>();
             mdl.saveModel(exporter, mdlFilePath);
@@ -177,7 +184,37 @@ public class LocalModelsTest {
 
             Assert.assertNotNull(load);
 
-            KNNClassificationModel importedMdl = new KNNClassificationModel(null)
+            NNClassificationModel importedMdl = new KNNClassificationModel(null)
+                .withK(load.getK())
+                .withDistanceMeasure(load.getDistanceMeasure())
+                .withStrategy(load.getStgy());
+
+            Assert.assertTrue("", mdl.equals(importedMdl));
+
+            return null;
+        });
+    }
+
+    /** */
+    @Test
+    public void importExportANNModelTest() throws IOException {
+        executeModelTest(mdlFilePath -> {
+            final LabeledVectorSet<ProbableLabel, LabeledVector> centers = new LabeledVectorSet<>();
+
+            NNClassificationModel mdl = new ANNClassificationModel(centers)
+                .withK(4)
+                .withDistanceMeasure(new ManhattanDistance())
+                .withStrategy(NNStrategy.WEIGHTED);
+
+            Exporter<KNNModelFormat, String> exporter = new FileExporter<>();
+            mdl.saveModel(exporter, mdlFilePath);
+
+            ANNModelFormat load = (ANNModelFormat) exporter.load(mdlFilePath);
+
+            Assert.assertNotNull(load);
+
+
+            NNClassificationModel importedMdl = new ANNClassificationModel(load.getCandidates())
                 .withK(load.getK())
                 .withDistanceMeasure(load.getDistanceMeasure())
                 .withStrategy(load.getStgy());

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
index c4d896c..552c478 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
@@ -23,7 +23,7 @@ import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
 import org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat;
 import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
 import org.apache.ignite.ml.knn.classification.KNNModelFormat;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
 import org.apache.ignite.ml.math.distances.HammingDistance;
 import org.apache.ignite.ml.math.distances.ManhattanDistance;
@@ -83,8 +83,8 @@ public class CollectionsTest {
         test(new KMeansModel(new Vector[] {}, new ManhattanDistance()),
             new KMeansModel(new Vector[] {}, new HammingDistance()));
 
-        test(new KNNModelFormat(1, new ManhattanDistance(), KNNStrategy.SIMPLE),
-            new KNNModelFormat(2, new ManhattanDistance(), KNNStrategy.SIMPLE));
+        test(new KNNModelFormat(1, new ManhattanDistance(), NNStrategy.SIMPLE),
+            new KNNModelFormat(2, new ManhattanDistance(), NNStrategy.SIMPLE));
 
         test(new KNNClassificationModel(null).withK(1), new KNNClassificationModel(null).withK(2));
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
new file mode 100644
index 0000000..ea602cd
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
+import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+/** Tests behaviour of ANNClassificationTest. */
+@RunWith(Parameterized.class)
+public class ANNClassificationTest {
+    /** Number of parts to be tested. */
+    private static final int[] partsToBeTested = new int[]{1, 2, 3, 4, 5, 7, 100};
+
+    /** Fixed size of Dataset. */
+    private static final int AMOUNT_OF_OBSERVATIONS = 1000;
+
+    /** Fixed size of columns in Dataset. */
+    private static final int AMOUNT_OF_FEATURES = 2;
+
+    /** Precision in test checks. */
+    private static final double PRECISION = 1e-2;
+
+    /** Number of partitions. */
+    @Parameterized.Parameter
+    public int parts;
+
+    /** Parameters. */
+    @Parameterized.Parameters(name = "Data divided on {0} partitions, training with batch size {1}")
+    public static Iterable<Integer[]> data() {
+        List<Integer[]> res = new ArrayList<>();
+
+        for (int part : partsToBeTested)
+            res.add(new Integer[]{part});
+
+        return res;
+    }
+
+    /** */
+    @Test
+    public void testBinaryClassificationTest() {
+        Map<Integer, double[]> data = new HashMap<>();
+
+        ThreadLocalRandom rndX = ThreadLocalRandom.current();
+        ThreadLocalRandom rndY = ThreadLocalRandom.current();
+
+        for (int i = 0; i < AMOUNT_OF_OBSERVATIONS; i++) {
+            double x = rndX.nextDouble(500, 600);
+            double y = rndY.nextDouble(500, 600);
+            double[] vec = new double[AMOUNT_OF_FEATURES + 1];
+            vec[0] = 0; // assign label.
+            vec[1] = x;
+            vec[2] = y;
+            data.put(i, vec);
+        }
+
+        for (int i = AMOUNT_OF_OBSERVATIONS; i < AMOUNT_OF_OBSERVATIONS * 2; i++) {
+            double x = rndX.nextDouble(-600, -500);
+            double y = rndY.nextDouble(-600, -500);
+            double[] vec = new double[AMOUNT_OF_FEATURES + 1];
+            vec[0] = 1; // assign label.
+            vec[1] = x;
+            vec[2] = y;
+            data.put(i, vec);
+        }
+
+        ANNClassificationTrainer trainer = new ANNClassificationTrainer()
+            .withK(10);
+
+        NNClassificationModel mdl = trainer.fit(
+            data,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        ).withK(3)
+            .withDistanceMeasure(new EuclideanDistance())
+            .withStrategy(NNStrategy.SIMPLE);
+
+        TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(550, 550)), PRECISION);
+        TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(-550, -550)), PRECISION);
+    }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
index aeb2414..c176682 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
@@ -22,9 +22,8 @@ import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
 import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
@@ -70,14 +69,14 @@ public class KNNClassificationTest {
 
         KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
-        KNNClassificationModel knnMdl = trainer.fit(
+        NNClassificationModel knnMdl = trainer.fit(
             data,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
         ).withK(3)
             .withDistanceMeasure(new EuclideanDistance())
-            .withStrategy(KNNStrategy.SIMPLE);
+            .withStrategy(NNStrategy.SIMPLE);
 
         assertTrue(knnMdl.toString().length() > 0);
         assertTrue(knnMdl.toString(true).length() > 0);
@@ -102,14 +101,14 @@ public class KNNClassificationTest {
 
         KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
-        KNNClassificationModel knnMdl = trainer.fit(
+        NNClassificationModel knnMdl = trainer.fit(
             data,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
         ).withK(1)
             .withDistanceMeasure(new EuclideanDistance())
-            .withStrategy(KNNStrategy.SIMPLE);
+            .withStrategy(NNStrategy.SIMPLE);
 
         Vector firstVector = new DenseVector(new double[] {2.0, 2.0});
         assertEquals(knnMdl.apply(firstVector), 1.0);
@@ -130,14 +129,14 @@ public class KNNClassificationTest {
 
         KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
-        KNNClassificationModel knnMdl = trainer.fit(
+        NNClassificationModel knnMdl = trainer.fit(
             data,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
         ).withK(3)
             .withDistanceMeasure(new EuclideanDistance())
-            .withStrategy(KNNStrategy.SIMPLE);
+            .withStrategy(NNStrategy.SIMPLE);
 
         Vector vector = new DenseVector(new double[] {-1.01, -1.01});
         assertEquals(knnMdl.apply(vector), 2.0);
@@ -156,14 +155,14 @@ public class KNNClassificationTest {
 
         KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
-        KNNClassificationModel knnMdl = trainer.fit(
+        NNClassificationModel knnMdl = trainer.fit(
             data,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
         ).withK(3)
             .withDistanceMeasure(new EuclideanDistance())
-            .withStrategy(KNNStrategy.WEIGHTED);
+            .withStrategy(NNStrategy.WEIGHTED);
 
         Vector vector = new DenseVector(new double[] {-1.01, -1.01});
         assertEquals(knnMdl.apply(vector), 1.0);

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
index 7d57ec9..e05903e 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
@@ -23,7 +23,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
 import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
@@ -77,7 +77,7 @@ public class KNNRegressionTest {
             (k, v) -> v[0]
         ).withK(1)
             .withDistanceMeasure(new EuclideanDistance())
-            .withStrategy(KNNStrategy.SIMPLE);
+            .withStrategy(NNStrategy.SIMPLE);
 
         Vector vector = new DenseVector(new double[] {0, 0, 0, 5.0, 0.0});
         System.out.println(knnMdl.apply(vector));
@@ -87,17 +87,17 @@ public class KNNRegressionTest {
     /** */
     @Test
     public void testLongly() {
-        testLongly(KNNStrategy.SIMPLE);
+        testLongly(NNStrategy.SIMPLE);
     }
 
     /** */
     @Test
     public void testLonglyWithWeightedStrategy() {
-        testLongly(KNNStrategy.WEIGHTED);
+        testLongly(NNStrategy.WEIGHTED);
     }
 
     /** */
-    private void testLongly(KNNStrategy stgy) {
+    private void testLongly(NNStrategy stgy) {
         Map<Integer, double[]> data = new HashMap<>();
         data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947});
         data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948});
@@ -123,16 +123,12 @@ public class KNNRegressionTest {
             (k, v) -> v[0]
         ).withK(3)
             .withDistanceMeasure(new EuclideanDistance())
-            .withStrategy(stgy);
+            .withStrategy(NNStrategy.SIMPLE);
 
         Vector vector = new DenseVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956});
 
         Assert.assertNotNull(knnMdl.apply(vector));
 
         Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
-
-        Assert.assertTrue(knnMdl.toString().contains(stgy.name()));
-        Assert.assertTrue(knnMdl.toString(true).contains(stgy.name()));
-        Assert.assertTrue(knnMdl.toString(false).contains(stgy.name()));
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java
index 55ef24e..0303d26 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java
@@ -25,9 +25,10 @@ import org.junit.runners.Suite;
  */
 @RunWith(Suite.class)
 @Suite.SuiteClasses({
+    ANNClassificationTest.class,
     KNNClassificationTest.class,
     KNNRegressionTest.class,
-    LabeledDatasetTest.class
+    LabeledVectorSetTest.class
 })
 public class KNNTestSuite {
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
index dbcdb99..f3b8b3a 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
@@ -21,7 +21,7 @@ import java.io.IOException;
 import java.net.URISyntaxException;
 import java.nio.file.Path;
 import java.nio.file.Paths;
-import org.apache.ignite.ml.structures.LabeledDataset;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
 
 /**
@@ -37,7 +37,7 @@ public class LabeledDatasetHelper {
      * @param rsrcPath path to dataset.
      * @return null if path is incorrect.
      */
-    public static LabeledDataset loadDatasetFromTxt(String rsrcPath, boolean isFallOnBadData) {
+    public static LabeledVectorSet loadDatasetFromTxt(String rsrcPath, boolean isFallOnBadData) {
         try {
             Path path = Paths.get(LabeledDatasetHelper.class.getClassLoader().getResource(rsrcPath).toURI());
             try {

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
deleted file mode 100644
index 9867fbe..0000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
+++ /dev/null
@@ -1,294 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.knn;
-
-import java.io.IOException;
-import java.net.URISyntaxException;
-import java.nio.file.Path;
-import java.nio.file.Paths;
-import java.util.Objects;
-import org.apache.ignite.ml.math.ExternalizableTest;
-import org.apache.ignite.ml.math.exceptions.CardinalityException;
-import org.apache.ignite.ml.math.exceptions.NoDataException;
-import org.apache.ignite.ml.math.exceptions.knn.EmptyFileException;
-import org.apache.ignite.ml.math.exceptions.knn.FileParsingException;
-import org.apache.ignite.ml.math.exceptions.knn.NoLabelVectorException;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair;
-import org.apache.ignite.ml.structures.LabeledVector;
-import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
-import org.junit.Test;
-
-import static junit.framework.TestCase.assertEquals;
-import static junit.framework.TestCase.fail;
-
-/** Tests behaviour of LabeledDataset. */
-public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
-    /** */
-    private static final String KNN_IRIS_TXT = "datasets/knn/iris.txt";
-
-    /** */
-    private static final String NO_DATA_TXT = "datasets/knn/no_data.txt";
-
-    /** */
-    private static final String EMPTY_TXT = "datasets/knn/empty.txt";
-
-    /** */
-    private static final String IRIS_INCORRECT_TXT = "datasets/knn/iris_incorrect.txt";
-
-    /** */
-    private static final String IRIS_MISSED_DATA = "datasets/knn/missed_data.txt";
-
-    /** */
-    @Test
-    public void testFeatureNames() {
-        double[][] mtx =
-            new double[][] {
-                {1.0, 1.0},
-                {1.0, 2.0},
-                {2.0, 1.0},
-                {-1.0, -1.0},
-                {-1.0, -2.0},
-                {-2.0, -1.0}};
-        double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
-
-        String[] featureNames = new String[] {"x", "y"};
-        final LabeledDataset dataset = new LabeledDataset(mtx, lbs, featureNames, false);
-
-        assertEquals(dataset.getFeatureName(0), "x");
-    }
-
-    /** */
-    @Test
-    public void testAccessMethods() {
-        double[][] mtx =
-            new double[][] {
-                {1.0, 1.0},
-                {1.0, 2.0},
-                {2.0, 1.0},
-                {-1.0, -1.0},
-                {-1.0, -2.0},
-                {-2.0, -1.0}};
-        double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
-
-        final LabeledDataset dataset = new LabeledDataset(mtx, lbs, null, false);
-
-        assertEquals(dataset.colSize(), 2);
-        assertEquals(dataset.rowSize(), 6);
-
-        assertEquals(dataset.label(0), lbs[0], 0);
-
-        assertEquals(dataset.copy().colSize(), 2);
-
-        @SuppressWarnings("unchecked")
-        final LabeledVector<Vector, Double> row = (LabeledVector<Vector, Double>)dataset.getRow(0);
-
-        assertEquals(row.features().get(0), 1.0);
-        assertEquals(row.label(), 1.0);
-        dataset.setLabel(0, 2.0);
-        assertEquals(row.label(), 2.0);
-
-        assertEquals(0, new LabeledDataset().rowSize());
-        assertEquals(1, new LabeledDataset(1, 2).rowSize());
-        assertEquals(1, new LabeledDataset(1, 2, true).rowSize());
-        assertEquals(1, new LabeledDataset(1, 2, null, true).rowSize());
-    }
-
-    /** */
-    @Test
-    public void testFailOnYNull() {
-        double[][] mtx =
-            new double[][] {
-                {1.0, 1.0},
-                {1.0, 2.0},
-                {2.0, 1.0},
-                {-1.0, -1.0},
-                {-1.0, -2.0},
-                {-2.0, -1.0}};
-        double[] lbs = new double[] {};
-
-        try {
-            new LabeledDataset(mtx, lbs);
-            fail("CardinalityException");
-        }
-        catch (CardinalityException e) {
-            return;
-        }
-        fail("CardinalityException");
-    }
-
-    /** */
-    @Test
-    public void testFailOnXNull() {
-        double[][] mtx =
-            new double[][] {};
-        double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
-
-        try {
-            new LabeledDataset(mtx, lbs);
-            fail("CardinalityException");
-        }
-        catch (CardinalityException e) {
-            return;
-        }
-        fail("CardinalityException");
-    }
-
-    /** */
-    @Test
-    public void testLoadingCorrectTxtFile() {
-        LabeledDataset training = LabeledDatasetHelper.loadDatasetFromTxt(KNN_IRIS_TXT, false);
-        assertEquals(Objects.requireNonNull(training).rowSize(), 150);
-    }
-
-    /** */
-    @Test
-    public void testLoadingEmptyFile() {
-        try {
-            LabeledDatasetHelper.loadDatasetFromTxt(EMPTY_TXT, false);
-            fail("EmptyFileException");
-        }
-        catch (EmptyFileException e) {
-            return;
-        }
-        fail("EmptyFileException");
-    }
-
-    /** */
-    @Test
-    public void testLoadingFileWithFirstEmptyRow() {
-        try {
-            LabeledDatasetHelper.loadDatasetFromTxt(NO_DATA_TXT, false);
-            fail("NoDataException");
-        }
-        catch (NoDataException e) {
-            return;
-        }
-        fail("NoDataException");
-    }
-
-    /** */
-    @Test
-    public void testLoadingFileWithIncorrectData() {
-        LabeledDataset training = LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, false);
-        assertEquals(149, Objects.requireNonNull(training).rowSize());
-    }
-
-    /** */
-    @Test
-    public void testFailOnLoadingFileWithIncorrectData() {
-        try {
-            LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, true);
-            fail("FileParsingException");
-        }
-        catch (FileParsingException e) {
-            return;
-        }
-        fail("FileParsingException");
-
-    }
-
-    /** */
-    @Test
-    public void testLoadingFileWithMissedData() throws URISyntaxException, IOException {
-        Path path = Paths.get(Objects.requireNonNull(this.getClass().getClassLoader().getResource(IRIS_MISSED_DATA)).toURI());
-
-        LabeledDataset training = LabeledDatasetLoader.loadFromTxtFile(path, ",", false, false);
-
-        assertEquals(training.features(2).get(1), 0.0);
-    }
-
-    /** */
-    @Test
-    public void testSplitting() {
-        double[][] mtx =
-            new double[][] {
-                {1.0, 1.0},
-                {1.0, 2.0},
-                {2.0, 1.0},
-                {-1.0, -1.0},
-                {-1.0, -2.0},
-                {-2.0, -1.0}};
-        double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
-
-        LabeledDataset training = new LabeledDataset(mtx, lbs);
-
-        LabeledDatasetTestTrainPair split1 = new LabeledDatasetTestTrainPair(training, 0.67);
-
-        assertEquals(4, split1.test().rowSize());
-        assertEquals(2, split1.train().rowSize());
-
-        LabeledDatasetTestTrainPair split2 = new LabeledDatasetTestTrainPair(training, 0.65);
-
-        assertEquals(3, split2.test().rowSize());
-        assertEquals(3, split2.train().rowSize());
-
-        LabeledDatasetTestTrainPair split3 = new LabeledDatasetTestTrainPair(training, 0.4);
-
-        assertEquals(2, split3.test().rowSize());
-        assertEquals(4, split3.train().rowSize());
-
-        LabeledDatasetTestTrainPair split4 = new LabeledDatasetTestTrainPair(training, 0.3);
-
-        assertEquals(1, split4.test().rowSize());
-        assertEquals(5, split4.train().rowSize());
-    }
-
-    /** */
-    @Test
-    public void testLabels() {
-        double[][] mtx =
-            new double[][] {
-                {1.0, 1.0},
-                {1.0, 2.0},
-                {2.0, 1.0},
-                {-1.0, -1.0},
-                {-1.0, -2.0},
-                {-2.0, -1.0}};
-        double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
-
-        LabeledDataset dataset = new LabeledDataset(mtx, lbs);
-        final double[] labels = dataset.labels();
-        for (int i = 0; i < lbs.length; i++)
-            assertEquals(lbs[i], labels[i]);
-    }
-
-    /** */
-    @Test(expected = NoLabelVectorException.class)
-    @SuppressWarnings("unchecked")
-    public void testSetLabelInvalid() {
-        new LabeledDataset(new LabeledVector[1]).setLabel(0, 2.0);
-    }
-
-    /** */
-    @Override public void testExternalization() {
-        double[][] mtx =
-            new double[][] {
-                {1.0, 1.0},
-                {1.0, 2.0},
-                {2.0, 1.0},
-                {-1.0, -1.0},
-                {-1.0, -2.0},
-                {-2.0, -1.0}};
-        double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
-
-        LabeledDataset dataset = new LabeledDataset(mtx, lbs);
-        this.externalizeTest(dataset);
-    }
-}

http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledVectorSetTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledVectorSetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledVectorSetTest.java
new file mode 100644
index 0000000..2303e96
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledVectorSetTest.java
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Objects;
+import org.apache.ignite.ml.math.ExternalizableTest;
+import org.apache.ignite.ml.math.exceptions.CardinalityException;
+import org.apache.ignite.ml.math.exceptions.NoDataException;
+import org.apache.ignite.ml.math.exceptions.knn.EmptyFileException;
+import org.apache.ignite.ml.math.exceptions.knn.FileParsingException;
+import org.apache.ignite.ml.math.exceptions.knn.NoLabelVectorException;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
+import org.apache.ignite.ml.structures.LabeledVectorSetTestTrainPair;
+import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
+import org.junit.Test;
+
+import static junit.framework.TestCase.assertEquals;
+import static junit.framework.TestCase.fail;
+
+/** Tests behaviour of KNNClassificationTest. */
+public class LabeledVectorSetTest implements ExternalizableTest<LabeledVectorSet> {
+    /** */
+    private static final String KNN_IRIS_TXT = "datasets/knn/iris.txt";
+
+    /** */
+    private static final String NO_DATA_TXT = "datasets/knn/no_data.txt";
+
+    /** */
+    private static final String EMPTY_TXT = "datasets/knn/empty.txt";
+
+    /** */
+    private static final String IRIS_INCORRECT_TXT = "datasets/knn/iris_incorrect.txt";
+
+    /** */
+    private static final String IRIS_MISSED_DATA = "datasets/knn/missed_data.txt";
+
+    /** */
+    @Test
+    public void testFeatureNames() {
+        double[][] mtx =
+            new double[][] {
+                {1.0, 1.0},
+                {1.0, 2.0},
+                {2.0, 1.0},
+                {-1.0, -1.0},
+                {-1.0, -2.0},
+                {-2.0, -1.0}};
+        double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+
+        String[] featureNames = new String[] {"x", "y"};
+        final LabeledVectorSet dataset = new LabeledVectorSet(mtx, lbs, featureNames, false);
+
+        assertEquals(dataset.getFeatureName(0), "x");
+    }
+
+    /** */
+    @Test
+    public void testAccessMethods() {
+        double[][] mtx =
+            new double[][] {
+                {1.0, 1.0},
+                {1.0, 2.0},
+                {2.0, 1.0},
+                {-1.0, -1.0},
+                {-1.0, -2.0},
+                {-2.0, -1.0}};
+        double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+
+        final LabeledVectorSet dataset = new LabeledVectorSet(mtx, lbs, null, false);
+
+        assertEquals(dataset.colSize(), 2);
+        assertEquals(dataset.rowSize(), 6);
+
+        assertEquals(dataset.label(0), lbs[0], 0);
+
+        assertEquals(dataset.copy().colSize(), 2);
+
+        @SuppressWarnings("unchecked")
+        final LabeledVector<Vector, Double> row = (LabeledVector<Vector, Double>)dataset.getRow(0);
+
+        assertEquals(row.features().get(0), 1.0);
+        assertEquals(row.label(), 1.0);
+        dataset.setLabel(0, 2.0);
+        assertEquals(row.label(), 2.0);
+
+        assertEquals(0, new LabeledVectorSet().rowSize());
+        assertEquals(1, new LabeledVectorSet(1, 2).rowSize());
+        assertEquals(1, new LabeledVectorSet(1, 2, true).rowSize());
+        assertEquals(1, new LabeledVectorSet(1, 2, null, true).rowSize());
+    }
+
+    /** */
+    @Test
+    public void testFailOnYNull() {
+        double[][] mtx =
+            new double[][] {
+                {1.0, 1.0},
+                {1.0, 2.0},
+                {2.0, 1.0},
+                {-1.0, -1.0},
+                {-1.0, -2.0},
+                {-2.0, -1.0}};
+        double[] lbs = new double[] {};
+
+        try {
+            new LabeledVectorSet(mtx, lbs);
+            fail("CardinalityException");
+        }
+        catch (CardinalityException e) {
+            return;
+        }
+        fail("CardinalityException");
+    }
+
+    /** */
+    @Test
+    public void testFailOnXNull() {
+        double[][] mtx =
+            new double[][] {};
+        double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+
+        try {
+            new LabeledVectorSet(mtx, lbs);
+            fail("CardinalityException");
+        }
+        catch (CardinalityException e) {
+            return;
+        }
+        fail("CardinalityException");
+    }
+
+    /** */
+    @Test
+    public void testLoadingCorrectTxtFile() {
+        LabeledVectorSet training = LabeledDatasetHelper.loadDatasetFromTxt(KNN_IRIS_TXT, false);
+        assertEquals(training.rowSize(), 150);
+    }
+
+    /** */
+    @Test
+    public void testLoadingEmptyFile() {
+        try {
+            LabeledDatasetHelper.loadDatasetFromTxt(EMPTY_TXT, false);
+            fail("EmptyFileException");
+        }
+        catch (EmptyFileException e) {
+            return;
+        }
+        fail("EmptyFileException");
+    }
+
+    /** */
+    @Test
+    public void testLoadingFileWithFirstEmptyRow() {
+        try {
+            LabeledDatasetHelper.loadDatasetFromTxt(NO_DATA_TXT, false);
+            fail("NoDataException");
+        }
+        catch (NoDataException e) {
+            return;
+        }
+        fail("NoDataException");
+    }
+
+    /** */
+    @Test
+    public void testLoadingFileWithIncorrectData() {
+        LabeledVectorSet training = LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, false);
+        assertEquals(149, training.rowSize());
+    }
+
+    /** */
+    @Test
+    public void testFailOnLoadingFileWithIncorrectData() {
+        try {
+            LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, true);
+            fail("FileParsingException");
+        }
+        catch (FileParsingException e) {
+            return;
+        }
+        fail("FileParsingException");
+
+    }
+
+    /** */
+    @Test
+    public void testLoadingFileWithMissedData() throws URISyntaxException, IOException {
+        Path path = Paths.get(Objects.requireNonNull(this.getClass().getClassLoader().getResource(IRIS_MISSED_DATA)).toURI());
+
+        LabeledVectorSet training = LabeledDatasetLoader.loadFromTxtFile(path, ",", false, false);
+
+        assertEquals(training.features(2).get(1), 0.0);
+    }
+
+    /** */
+    @Test
+    public void testSplitting() {
+        double[][] mtx =
+            new double[][] {
+                {1.0, 1.0},
+                {1.0, 2.0},
+                {2.0, 1.0},
+                {-1.0, -1.0},
+                {-1.0, -2.0},
+                {-2.0, -1.0}};
+        double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+
+        LabeledVectorSet training = new LabeledVectorSet(mtx, lbs);
+
+        LabeledVectorSetTestTrainPair split1 = new LabeledVectorSetTestTrainPair(training, 0.67);
+
+        assertEquals(4, split1.test().rowSize());
+        assertEquals(2, split1.train().rowSize());
+
+        LabeledVectorSetTestTrainPair split2 = new LabeledVectorSetTestTrainPair(training, 0.65);
+
+        assertEquals(3, split2.test().rowSize());
+        assertEquals(3, split2.train().rowSize());
+
+        LabeledVectorSetTestTrainPair split3 = new LabeledVectorSetTestTrainPair(training, 0.4);
+
+        assertEquals(2, split3.test().rowSize());
+        assertEquals(4, split3.train().rowSize());
+
+        LabeledVectorSetTestTrainPair split4 = new LabeledVectorSetTestTrainPair(training, 0.3);
+
+        assertEquals(1, split4.test().rowSize());
+        assertEquals(5, split4.train().rowSize());
+    }
+
+    /** */
+    @Test
+    public void testLabels() {
+        double[][] mtx =
+            new double[][] {
+                {1.0, 1.0},
+                {1.0, 2.0},
+                {2.0, 1.0},
+                {-1.0, -1.0},
+                {-1.0, -2.0},
+                {-2.0, -1.0}};
+        double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+
+        LabeledVectorSet dataset = new LabeledVectorSet(mtx, lbs);
+        final double[] labels = dataset.labels();
+        for (int i = 0; i < lbs.length; i++)
+            assertEquals(lbs[i], labels[i]);
+    }
+
+    /** */
+    @Test(expected = NoLabelVectorException.class)
+    @SuppressWarnings("unchecked")
+    public void testSetLabelInvalid() {
+        new LabeledVectorSet(new LabeledVector[1]).setLabel(0, 2.0);
+    }
+
+    /** */
+    @Override public void testExternalization() {
+        double[][] mtx =
+            new double[][] {
+                {1.0, 1.0},
+                {1.0, 2.0},
+                {2.0, 1.0},
+                {-1.0, -1.0},
+                {-1.0, -2.0},
+                {-2.0, -1.0}};
+        double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+
+        LabeledVectorSet dataset = new LabeledVectorSet(mtx, lbs);
+        this.externalizeTest(dataset);
+    }
+}


Mime
View raw message