flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From chiwanp...@apache.org
Subject [1/3] flink git commit: [FLINK-1745] [ml] Use QuadTree to speed up exact k-nearest-neighbor join
Date Mon, 30 May 2016 11:12:33 GMT
Repository: flink
Updated Branches:
  refs/heads/master 1212b6d3f -> 035f62969


[FLINK-1745] [ml] Use QuadTree to speed up exact k-nearest-neighbor join


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/4a5af42c
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/4a5af42c
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/4a5af42c

Branch: refs/heads/master
Commit: 4a5af42c678a0437aa5614741280e8c5465b8cec
Parents: 858ca14
Author: danielblazevski <daniel.blazevski@gmail.com>
Authored: Tue Sep 15 17:49:05 2015 -0400
Committer: Chiwan Park <chiwanpark@apache.org>
Committed: Mon May 30 19:32:26 2016 +0900

----------------------------------------------------------------------
 docs/libs/ml/knn.md                             | 145 ++++++++
 .../main/scala/org/apache/flink/ml/nn/KNN.scala | 353 +++++++++++++++++++
 .../scala/org/apache/flink/ml/nn/QuadTree.scala | 350 ++++++++++++++++++
 .../org/apache/flink/ml/nn/KNNITSuite.scala     | 108 ++++++
 .../org/apache/flink/ml/nn/QuadTreeSuite.scala  | 107 ++++++
 .../main/scala/org/apache/flink/ml/nn/KNN.scala | 207 +++++++++--
 .../scala/org/apache/flink/ml/nn/QuadTree.scala | 344 ++++++++++++++++++
 .../org/apache/flink/ml/nn/KNNITSuite.scala     |   7 +-
 .../org/apache/flink/ml/nn/QuadTreeSuite.scala  | 106 ++++++
 9 files changed, 1686 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/docs/libs/ml/knn.md
----------------------------------------------------------------------
diff --git a/docs/libs/ml/knn.md b/docs/libs/ml/knn.md
new file mode 100644
index 0000000..c9a7e03
--- /dev/null
+++ b/docs/libs/ml/knn.md
@@ -0,0 +1,145 @@
+---
+mathjax: include
+htmlTitle: FlinkML - k-nearest neighbors
+title: <a href="../ml">FlinkML</a> - knn
+---
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+* This will be replaced by the TOC
+{:toc}
+
+## Description
+Implements an exact k-nearest neighbors algorithm.  Given a training set $A$ and a testing set $B$, the algorithm returns
+
+$$
+KNN(A,B, k) = \{ \left( b, KNN(b,A, k) \right) where b \in B and KNN(b, A, k) are the k-nearest points to b in A \}
+$$
+
+The brute-force approach is to compute the distance between every training and testing point.  To ease the brute-force computation of computing the distance between every traning point a quadtree is used.  The quadtree scales well in the number of training points, though poorly in the spatial dimension.  The algorithm will automatically choose whether or not to use the quadtree, though the user can override that decision by setting a parameter to force use or not use a quadtree. 
+
+##Operations
+
+`KNN` is a `Predictor`. 
+As such, it supports the `fit` and `predict` operation.
+
+### Fit
+
+KNN is trained given a set of `LabeledVector`:
+
+* `fit: DataSet[LabeledVector] => Unit`
+
+### Predict
+
+KNN predicts for all subtypes of FlinkML's `Vector` the corresponding class label:
+
+* `predict[T <: Vector]: DataSet[T] => DataSet[(T, Array[Vector])]`, where the `(T, Array[Vector])` tuple
+  corresponds to (testPoint, K-nearest training points)
+
+## Paremeters
+The KNN implementation can be controlled by the following parameters:
+
+   <table class="table table-bordered">
+    <thead>
+      <tr>
+        <th class="text-left" style="width: 20%">Parameters</th>
+        <th class="text-center">Description</th>
+      </tr>
+    </thead>
+
+    <tbody>
+      <tr>
+        <td><strong>K</strong></td>
+        <td>
+          <p>
+            Defines the number of nearest-neighbors to search for.  That is, for each test point, the algorithm finds the K-nearest neighbors in the training set
+            (Default value: <strong>5</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>DistanceMetric</strong></td>
+        <td>
+          <p>
+            Sets the distance metric we use to calculate the distance between two points. If no metric is specified, then [[org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric]] is used.
+            (Default value: <strong>EuclideanDistanceMetric</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>Blocks</strong></td>
+        <td>
+          <p>
+            Sets the number of blocks into which the input data will be split. This number should be set
+            at least to the degree of parallelism. If no value is specified, then the parallelism of the
+            input [[DataSet]] is used as the number of blocks.
+            (Default value: <strong>None</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>UseQuadTreeParam</strong></td>
+        <td>
+          <p>
+             A boolean variable that whether or not to use a Quadtree to partition the training set to potentially simplify the KNN search.  If no value is specified, the code will automatically decide whether or not to use a Quadtree.  Use of a Quadtree scales well with the number of training and testing points, though poorly with the dimension.
+            (Default value: <strong>None</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>SizeHint</strong></td>
+        <td>
+          <p>Specifies whether the training set or test set is small to optimize the cross product operation needed for the KNN search.  If the training set is small this should be `CrossHint.FIRST_IS_SMALL` and set to `CrossHint.SECOND_IS_SMALL` if the test set is small.
+             (Default value: <strong>None</strong>)
+          </p>
+        </td>
+      </tr>
+    </tbody>
+  </table>
+
+## Examples
+
+{% highlight scala %}
+import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
+import org.apache.flink.api.scala._
+import org.apache.flink.ml.classification.Classification
+import org.apache.flink.ml.math.DenseVector
+import org.apache.flink.ml.metrics.distances.SquaredEuclideanDistanceMetric
+
+  val env = ExecutionEnvironment.getExecutionEnvironment
+
+  // prepare data
+  val trainingSet = env.fromCollection(Classification.trainingData).map(_.vector)
+  val testingSet = env.fromElements(DenseVector(0.0, 0.0))
+
+ val knn = KNN()
+    .setK(3)
+    .setBlocks(10)
+    .setDistanceMetric(SquaredEuclideanDistanceMetric())
+    .setUseQuadTree(false)
+    .setSizeHint(CrossHint.SECOND_IS_SMALL)
+
+  // run knn join
+  knn.fit(trainingSet)
+  val result = knn.predict(testingSet).collect()
+
+{% endhighlight %}
+
+For more details on the computing KNN with and without and quadtree, here is a presentation:
+http://danielblazevski.github.io/

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
new file mode 100644
index 0000000..82f4b88
--- /dev/null
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
@@ -0,0 +1,353 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.nn
+
+import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.scala.utils._
+import org.apache.flink.api.scala._
+import org.apache.flink.ml.common._
+import org.apache.flink.ml.math.{Vector => FlinkVector, DenseVector}
+import org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric, DistanceMetric,
+EuclideanDistanceMetric}
+import org.apache.flink.ml.pipeline.{FitOperation, PredictDataSetOperation, Predictor}
+import org.apache.flink.util.Collector
+import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
+
+import scala.collection.immutable.Vector
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+/** Implements a k-nearest neighbor join.
+  *
+  * Calculates the `k`-nearest neighbor points in the training set for each point in the test set.
+  *
+  * @example
+  * {{{
+  *         val trainingDS: DataSet[Vector] = ...
+  *         val testingDS: DataSet[Vector] = ...
+  *
+  *         val knn = KNN()
+  *           .setK(10)
+  *           .setBlocks(5)
+  *           .setDistanceMetric(EuclideanDistanceMetric())
+  *
+  *         knn.fit(trainingDS)
+  *
+  *         val predictionDS: DataSet[(Vector, Array[Vector])] = knn.predict(testingDS)
+  * }}}
+  *
+  * =Parameters=
+  *
+  * - [[org.apache.flink.ml.nn.KNN.K]]
+  * Sets the K which is the number of selected points as neighbors. (Default value: '''5''')
+  *
+  * - [[org.apache.flink.ml.nn.KNN.DistanceMetric]]
+  * Sets the distance metric we use to calculate the distance between two points. If no metric is
+  * specified, then [[org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric]] is used.
+  * (Default value: '''EuclideanDistanceMetric()''')
+  *
+  * - [[org.apache.flink.ml.nn.KNN.Blocks]]
+  * Sets the number of blocks into which the input data will be split. This number should be set
+  * at least to the degree of parallelism. If no value is specified, then the parallelism of the
+  * input [[DataSet]] is used as the number of blocks. (Default value: '''None''')
+  *
+  * - [[org.apache.flink.ml.nn.KNN.UseQuadTreeParam]]
+  * A boolean variable that whether or not to use a Quadtree to partition the training set
+  * to potentially simplify the KNN search.  If no value is specified, the code will
+  * automatically decide whether or not to use a Quadtree.  Use of a Quadtree scales well
+  * with the number of training and testing points, though poorly with the dimension.
+  * (Default value:  ```None```)
+  *
+  * - [[org.apache.flink.ml.nn.KNN.SizeHint]]
+  * Specifies whether the training set or test set is small to optimize the cross
+  * product operation needed for the KNN search.  If the training set is small
+  * this should be `CrossHint.FIRST_IS_SMALL` and set to `CrossHint.SECOND_IS_SMALL`
+  * if the test set is small.
+  * (Default value:  ```None```)
+  *
+  */
+
+class KNN extends Predictor[KNN] {
+
+  import KNN._
+
+  var trainingSet: Option[DataSet[Block[FlinkVector]]] = None
+
+  /** Sets K
+    * @param k the number of selected points as neighbors
+    */
+  def setK(k: Int): KNN = {
+    require(k > 0, "K must be positive.")
+    parameters.add(K, k)
+    this
+  }
+
+  /** Sets the distance metric
+    * @param metric the distance metric to calculate distance between two points
+    */
+  def setDistanceMetric(metric: DistanceMetric): KNN = {
+    parameters.add(DistanceMetric, metric)
+    this
+  }
+
+  /** Sets the number of data blocks/partitions
+    * @param n the number of data blocks
+    */
+  def setBlocks(n: Int): KNN = {
+    require(n > 0, "Number of blocks must be positive.")
+    parameters.add(Blocks, n)
+    this
+  }
+
+  /**
+    * Sets the Boolean variable that decides whether to use the QuadTree or not
+    */
+  def setUseQuadTree(useQuadTree: Boolean): KNN = {
+    if (useQuadTree) {
+      require(parameters(DistanceMetric).isInstanceOf[SquaredEuclideanDistanceMetric] ||
+        parameters(DistanceMetric).isInstanceOf[EuclideanDistanceMetric])
+    }
+    parameters.add(UseQuadTreeParam, useQuadTree)
+    this
+  }
+
+  /**
+    * Parameter a user can specify if one of the training or test sets are small
+    * @param sizeHint
+    * @return
+    */
+  def setSizeHint(sizeHint: CrossHint): KNN = {
+    parameters.add(SizeHint, sizeHint)
+    this
+  }
+
+}
+
+object KNN {
+
+  case object K extends Parameter[Int] {
+    val defaultValue: Option[Int] = Some(5)
+  }
+
+  case object DistanceMetric extends Parameter[DistanceMetric] {
+    val defaultValue: Option[DistanceMetric] = Some(EuclideanDistanceMetric())
+  }
+
+  case object Blocks extends Parameter[Int] {
+    val defaultValue: Option[Int] = None
+  }
+
+  case object UseQuadTreeParam extends Parameter[Boolean] {
+    val defaultValue: Option[Boolean] = None
+  }
+
+  case object SizeHint extends Parameter[CrossHint] {
+    val defaultValue: Option[CrossHint] = None
+  }
+
+  def apply(): KNN = {
+    new KNN()
+  }
+
+  /** [[FitOperation]] which trains a KNN based on the given training data set.
+    * @tparam T Subtype of [[org.apache.flink.ml.math.Vector]]
+    */
+  implicit def fitKNN[T <: FlinkVector : TypeInformation] = new FitOperation[KNN, T] {
+    override def fit(
+                      instance: KNN,
+                      fitParameters: ParameterMap,
+                      input: DataSet[T]): Unit = {
+      val resultParameters = instance.parameters ++ fitParameters
+
+      require(resultParameters.get(K).isDefined, "K is needed for calculation")
+
+      val blocks = resultParameters.get(Blocks).getOrElse(input.getParallelism)
+      val partitioner = FlinkMLTools.ModuloKeyPartitioner
+      val inputAsVector = input.asInstanceOf[DataSet[FlinkVector]]
+
+      instance.trainingSet = Some(FlinkMLTools.block(inputAsVector, blocks, Some(partitioner)))
+    }
+  }
+
+  /** [[PredictDataSetOperation]] which calculates k-nearest neighbors of the given testing data
+    * set.
+    * @tparam T Subtype of [[Vector]]
+    * @return The given testing data set with k-nearest neighbors
+    */
+  implicit def predictValues[T <: FlinkVector : ClassTag : TypeInformation] = {
+    new PredictDataSetOperation[KNN, T, (FlinkVector, Array[FlinkVector])] {
+      override def predictDataSet(
+                                   instance: KNN,
+                                   predictParameters: ParameterMap,
+                                   input: DataSet[T]): DataSet[(FlinkVector,
+        Array[FlinkVector])] = {
+        val resultParameters = instance.parameters ++ predictParameters
+
+        instance.trainingSet match {
+          case Some(trainingSet) =>
+            val k = resultParameters.get(K).get
+            val blocks = resultParameters.get(Blocks).getOrElse(input.getParallelism)
+            val metric = resultParameters.get(DistanceMetric).get
+            val partitioner = FlinkMLTools.ModuloKeyPartitioner
+
+            // attach unique id for each data
+            val inputWithId: DataSet[(Long, T)] = input.zipWithUniqueId
+
+            // split data into multiple blocks
+            val inputSplit = FlinkMLTools.block(inputWithId, blocks, Some(partitioner))
+
+            val sizeHint = resultParameters.get(SizeHint)
+            val crossTuned = sizeHint match {
+              case Some(hint) if hint == CrossHint.FIRST_IS_SMALL =>
+                trainingSet.crossWithHuge(inputSplit)
+              case Some(hint) if hint == CrossHint.SECOND_IS_SMALL =>
+                trainingSet.crossWithTiny(inputSplit)
+              case _ => trainingSet.cross(inputSplit)
+            }
+
+            // join input and training set
+            val crossed = crossTuned.mapPartition {
+              (iter, out: Collector[(FlinkVector, FlinkVector, Long, Double)]) => {
+                for ((training, testing) <- iter) {
+                  // use a quadtree if (4^dim)Ntest*log(Ntrain)
+                  // < Ntest*Ntrain, and distance is Euclidean
+                  val useQuadTree = resultParameters.get(UseQuadTreeParam).getOrElse(
+                    math.log(4.0) * training.values.head.size + math.log(math.log(training.values.length))
+                      < math.log(training.values.length) &&
+                      (metric.isInstanceOf[EuclideanDistanceMetric] ||
+                        metric.isInstanceOf[SquaredEuclideanDistanceMetric]))
+
+                  if (useQuadTree) {
+                    knnQueryWithQuadTree(training.values, testing.values, k, metric, out)
+                  } else {
+                    knnQueryBasic(training.values, testing.values, k, metric, out)
+                  }
+                }
+              }
+            }
+
+            // group by input vector id and pick k nearest neighbor for each group
+            val result = crossed.groupBy(2).sortGroup(3, Order.ASCENDING).reduceGroup {
+              (iter, out: Collector[(FlinkVector, Array[FlinkVector])]) => {
+                if (iter.hasNext) {
+                  val head = iter.next()
+                  val key = head._2
+                  val neighbors: ArrayBuffer[FlinkVector] = ArrayBuffer(head._1)
+
+                  for ((vector, _, _, _) <- iter.take(k - 1)) {
+                    // we already took a first element
+                    neighbors += vector
+                  }
+
+                  out.collect(key, neighbors.toArray)
+                }
+              }
+            }
+
+            result
+          case None => throw new RuntimeException("The KNN model has not been trained." +
+            "Call first fit before calling the predict operation.")
+
+        }
+      }
+    }
+  }
+
+  def knnQueryWithQuadTree[T <: FlinkVector](
+                                              training: Vector[T],
+                                              testing: Vector[(Long, T)],
+                                              k: Int, metric: DistanceMetric,
+                                              out: Collector[(FlinkVector,
+                                                FlinkVector, Long, Double)]) {
+    /// find a bounding box
+    val MinArr = Array.tabulate(training.head.size)(x => x)
+    val MaxArr = Array.tabulate(training.head.size)(x => x)
+
+    val minVecTrain = MinArr.map(i => training.map(x => x(i)).min - 0.01)
+    val minVecTest = MinArr.map(i => testing.map(x => x._2(i)).min - 0.01)
+    val maxVecTrain = MaxArr.map(i => training.map(x => x(i)).max + 0.01)
+    val maxVecTest = MaxArr.map(i => testing.map(x => x._2(i)).max + 0.01)
+
+    val MinVec = DenseVector(MinArr.map(i => math.min(minVecTrain(i), minVecTest(i))))
+    val MaxVec = DenseVector(MinArr.map(i => math.max(maxVecTrain(i), maxVecTest(i))))
+
+    //default value of max elements/box is set to max(20,k)
+    val maxPerBox = math.max(k, 20)
+    val trainingQuadTree = new QuadTree(MinVec, MaxVec, metric, maxPerBox)
+
+    val queue = mutable.PriorityQueue[(FlinkVector, FlinkVector, Long, Double)]()(
+      Ordering.by(_._4))
+
+    for (v <- training) {
+      trainingQuadTree.insert(v)
+    }
+
+    for ((id, vector) <- testing) {
+      //  Find siblings' objects and do local kNN there
+      val siblingObjects =
+        trainingQuadTree.searchNeighborsSiblingQueue(vector)
+
+      // do KNN query on siblingObjects and get max distance of kNN
+      // then rad is good choice for a neighborhood to do a refined
+      // local kNN search
+      val knnSiblings = siblingObjects.map(v => metric.distance(vector, v)
+      ).sortWith(_ < _).take(k)
+
+      val rad = knnSiblings.last
+      val trainingFiltered = trainingQuadTree.searchNeighbors(vector, rad)
+
+      for (b <- trainingFiltered) {
+        // (training vector, input vector, input key, distance)
+        queue.enqueue((b, vector, id, metric.distance(b, vector)))
+        if (queue.size > k) {
+          queue.dequeue()
+        }
+      }
+      for (v <- queue) {
+        out.collect(v)
+      }
+    }
+  }
+
+  def knnQueryBasic[T <: FlinkVector](
+                                       training: Vector[T],
+                                       testing: Vector[(Long, T)],
+                                       k: Int, metric: DistanceMetric,
+                                       out: Collector[(FlinkVector, FlinkVector, Long, Double)]) {
+
+    val queue = mutable.PriorityQueue[(FlinkVector, FlinkVector, Long, Double)]()(
+      Ordering.by(_._4))
+    
+    for ((id, vector) <- testing) {
+      for (b <- training) {
+        // (training vector, input vector, input key, distance)
+        queue.enqueue((b, vector, id, metric.distance(b, vector)))
+        if (queue.size > k) {
+          queue.dequeue()
+        }
+      }
+      for (v <- queue) {
+        out.collect(v)
+      }
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
new file mode 100644
index 0000000..d08dcdd
--- /dev/null
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
@@ -0,0 +1,350 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.nn
+
+import org.apache.flink.ml.math.{Breeze, Vector}
+import Breeze._
+
+import org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric,
+EuclideanDistanceMetric, DistanceMetric}
+
+import scala.collection.mutable.ListBuffer
+import scala.collection.mutable.PriorityQueue
+
+/**
+ * n-dimensional QuadTree data structure; partitions
+ * spatial data for faster queries (e.g. KNN query)
+ * The skeleton of the data structure was initially
+ * based off of the 2D Quadtree found here:
+ * http://www.cs.trinity.edu/~mlewis/CSCI1321-F11/Code/src/util/Quadtree.scala
+ *
+ * Many additional methods were added to the class both for
+ * efficient KNN queries and generalizing to n-dim.
+ *
+ * @param minVec vector of the corner of the bounding box with smallest coordinates
+ * @param maxVec vector of the corner of the bounding box with smallest coordinates
+ * @param distMetric metric, must be Euclidean or squareEuclidean
+ * @param maxPerBox threshold for number of points in each box before slitting a box
+ */
+class QuadTree(
+  minVec: Vector,
+  maxVec: Vector,
+  distMetric: DistanceMetric,
+  maxPerBox: Int) {
+
+  class Node(
+    center: Vector,
+    width: Vector,
+    var children: Seq[Node]) {
+
+    val nodeElements = new ListBuffer[Vector]
+
+    /** for testing purposes only; used in QuadTreeSuite.scala
+      *
+      * @return center and width of the box
+      */
+    def getCenterWidth(): (Vector, Vector) = {
+      (center, width)
+    }
+
+    /** Tests whether the queryPoint is in the node, or a child of that node
+      *
+      * @param queryPoint
+      * @return
+      */
+    def contains(queryPoint: Vector): Boolean = {
+      overlap(queryPoint, 0.0)
+    }
+
+    /** Tests if queryPoint is within a radius of the node
+      *
+      * @param queryPoint
+      * @param radius
+      * @return
+      */
+    def overlap(
+      queryPoint: Vector,
+      radius: Double): Boolean = {
+      (0 until queryPoint.size).forall{ i =>
+          (queryPoint(i) - radius < center(i) + width(i) / 2) &&
+            (queryPoint(i) + radius > center(i) - width(i) / 2)
+      }
+    }
+
+    /** Tests if queryPoint is near a node
+      *
+      * @param queryPoint
+      * @param radius
+      * @return
+      */
+    def isNear(
+      queryPoint: Vector,
+      radius: Double): Boolean = {
+      minDist(queryPoint) < radius
+    }
+
+    /**
+     * minDist is defined so that every point in the box
+     * has distance to queryPoint greater than minDist
+     * (minDist adopted from "Nearest Neighbors Queries" by N. Roussopoulos et al.)
+     *
+     * @param queryPoint
+     * @return
+     */
+    def minDist(queryPoint: Vector): Double = {
+      val minDist = (0 until queryPoint.size).map { i =>
+        if (queryPoint(i) < center(i) - width(i) / 2) {
+          math.pow(queryPoint(i) - center(i) + width(i) / 2, 2)
+        } else if (queryPoint(i) > center(i) + width(i) / 2) {
+          math.pow(queryPoint(i) - center(i) - width(i) / 2, 2)
+        } else {
+          0
+        }
+      }.sum
+
+      distMetric match {
+        case _: SquaredEuclideanDistanceMetric => minDist
+        case _: EuclideanDistanceMetric => math.sqrt(minDist)
+        case _ => throw new IllegalArgumentException(s" Error: metric must be" +
+          s" Euclidean or SquaredEuclidean!")
+      }
+    }
+
+    /**
+     * Finds which child queryPoint lies in.  node.children is a Seq[Node], and
+     * whichChild finds the appropriate index of that Seq.
+     * @param queryPoint
+     * @return
+     */
+    def whichChild(queryPoint: Vector): Int = {
+      (0 until queryPoint.size).map { i =>
+        if (queryPoint(i) > center(i)) {
+          scala.math.pow(2, queryPoint.size - 1 - i).toInt
+        } else {
+          0
+        }
+      }.sum
+    }
+
+    /** Makes children nodes by partitioning the box into equal sub-boxes
+      * and adding a node for each sub-box
+      */
+    def makeChildren() {
+      val centerClone = center.copy
+      val cPart = partitionBox(centerClone, width)
+      val mappedWidth = 0.5 * width.asBreeze
+      children = cPart.map(p => new Node(p, mappedWidth.fromBreeze, null))
+    }
+
+    /**
+     * Recursive function that partitions a n-dim box by taking the (n-1) dimensional
+     * plane through the center of the box keeping the n-th coordinate fixed,
+     * then shifting it in the n-th direction up and down
+     * and recursively applying partitionBox to the two shifted (n-1) dimensional planes.
+     *
+     * @param center the center of the box
+     * @param width a vector of lengths of each dimension of the box
+     * @return
+     */
+    def partitionBox(
+      center: Vector,
+      width: Vector): Seq[Vector] = {
+      def partitionHelper(
+        box: Seq[Vector],
+        dim: Int): Seq[Vector] = {
+        if (dim >= width.size) {
+          box
+        } else {
+          val newBox = box.flatMap {
+            vector =>
+              val (up, down) = (vector.copy, vector)
+              up.update(dim, up(dim) - width(dim) / 4)
+              down.update(dim, down(dim) + width(dim) / 4)
+
+              Seq(up, down)
+          }
+          partitionHelper(newBox, dim + 1)
+        }
+      }
+      partitionHelper(Seq(center), 0)
+    }
+  }
+
+
+  val root = new Node(((minVec.asBreeze + maxVec.asBreeze) * 0.5).fromBreeze,
+    (maxVec.asBreeze - minVec.asBreeze).fromBreeze, null)
+
+  /**
+   * simple printing of tree for testing/debugging
+   */
+  def printTree(): Unit = {
+    printTreeRecur(root)
+  }
+
+  def printTreeRecur(node: Node) {
+    if (node.children != null) {
+      for (c <- node.children) {
+        printTreeRecur(c)
+      }
+    } else {
+      println("printing tree: n.nodeElements " + node.nodeElements)
+    }
+  }
+
+  /**
+   * Recursively adds an object to the tree
+   * @param queryPoint
+   */
+  def insert(queryPoint: Vector) {
+    insertRecur(queryPoint, root)
+  }
+
+  private def insertRecur(
+    queryPoint: Vector,
+    node: Node) {
+    if (node.children == null) {
+      if (node.nodeElements.length < maxPerBox) {
+        node.nodeElements += queryPoint
+      } else {
+        node.makeChildren()
+        for (o <- node.nodeElements) {
+          insertRecur(o, node.children(node.whichChild(o)))
+        }
+        node.nodeElements.clear()
+        insertRecur(queryPoint, node.children(node.whichChild(queryPoint)))
+      }
+    } else {
+      insertRecur(queryPoint, node.children(node.whichChild(queryPoint)))
+    }
+  }
+
+  /**
+   * Used to zoom in on a region near a test point for a fast KNN query.
+   * This capability is used in the KNN query to find k "near" neighbors n_1,...,n_k, from
+   * which one computes the max distance D_s to queryPoint.  D_s is then used during the
+   * kNN query to find all points within a radius D_s of queryPoint using searchNeighbors.
+   * To find the "near" neighbors, a min-heap is defined on the leaf nodes of the leaf
+   * nodes of the minimal bounding box of the queryPoint. The priority of a leaf node
+   * is an appropriate notion of the distance between the test point and the node,
+   * which is defined by minDist(queryPoint),
+   *
+   * @param queryPoint a test point for which the method finds the minimal bounding
+   *                   box that queryPoint lies in and returns elements in that boxes
+   *                   siblings' leaf nodes
+   * @return
+   */
+  def searchNeighborsSiblingQueue(queryPoint: Vector): ListBuffer[Vector] = {
+    val ret = new ListBuffer[Vector]
+    // edge case when the main box has not been partitioned at all
+    if (root.children == null) {
+      root.nodeElements.clone()
+    } else {
+      val nodeQueue = new PriorityQueue[(Double, Node)]()(Ordering.by(x => x._1))
+      searchRecurSiblingQueue(queryPoint, root, nodeQueue)
+
+      var count = 0
+      while (count < maxPerBox) {
+        val dq = nodeQueue.dequeue()
+        if (dq._2.nodeElements.nonEmpty) {
+          ret ++= dq._2.nodeElements
+          count += dq._2.nodeElements.length
+        }
+      }
+      ret
+    }
+  }
+
+  /**
+   *
+   * @param queryPoint point under consideration
+   * @param node node that queryPoint lies in
+   * @param nodeQueue defined in searchSiblingQueue, this stores nodes based on their
+   *                  distance to node as defined by minDist
+   */
+  private def searchRecurSiblingQueue(
+    queryPoint: Vector,
+    node: Node,
+    nodeQueue: PriorityQueue[(Double, Node)]) {
+    if (node.children != null) {
+      for (child <- node.children; if child.contains(queryPoint)) {
+        if (child.children == null) {
+          for (c <- node.children) {
+            minNodes(queryPoint, c, nodeQueue)
+          }
+        } else {
+          searchRecurSiblingQueue(queryPoint, child, nodeQueue)
+        }
+      }
+    }
+  }
+
+  /**
+   * Goes down to minimal bounding box of queryPoint, and add elements to nodeQueue
+   *
+   * @param queryPoint point under consideration
+   * @param node node that queryPoint lies in
+   * @param nodeQueue PriorityQueue that stores all points in minimal bounding box of queryPoint
+   */
+  private def minNodes(
+    queryPoint: Vector,
+    node: Node,
+    nodeQueue: PriorityQueue[(Double, Node)]) {
+    if (node.children == null) {
+      nodeQueue += ((-node.minDist(queryPoint), node))
+    } else {
+      for (c <- node.children) {
+        minNodes(queryPoint, c, nodeQueue)
+      }
+    }
+  }
+
+  /** Finds all objects within a neighborhood of queryPoint of a specified radius
+    * scope is modified from original 2D version in:
+    * http://www.cs.trinity.edu/~mlewis/CSCI1321-F11/Code/src/util/Quadtree.scala
+    *
+    * original version only looks in minimal box; for the KNN Query, we look at
+    * all nearby boxes. The radius is determined from searchNeighborsSiblingQueue
+    * by defining a min-heap on the leaf nodes
+    *
+    * @param queryPoint
+    * @param radius
+    * @return all points within queryPoint with given radius
+    */
+  def searchNeighbors(
+    queryPoint: Vector,
+    radius: Double): ListBuffer[Vector] = {
+    val ret = new ListBuffer[Vector]
+    searchRecur(queryPoint, radius, root, ret)
+    ret
+  }
+
+  private def searchRecur(
+    queryPoint: Vector,
+    radius: Double,
+    node: Node,
+    ret: ListBuffer[Vector]) {
+    if (node.children == null) {
+      ret ++= node.nodeElements
+    } else {
+      for (child <- node.children; if child.isNear(queryPoint, radius)) {
+        searchRecur(queryPoint, radius, child, ret)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
new file mode 100644
index 0000000..63e412a
--- /dev/null
+++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.nn
+
+import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
+import org.apache.flink.api.scala._
+import org.apache.flink.ml.classification.Classification
+import org.apache.flink.ml.math.DenseVector
+import org.apache.flink.ml.metrics.distances.{ManhattanDistanceMetric,
+SquaredEuclideanDistanceMetric}
+import org.apache.flink.test.util.FlinkTestBase
+import org.scalatest.{FlatSpec, Matchers}
+
+class KNNITSuite extends FlatSpec with Matchers with FlinkTestBase {
+  behavior of "The KNN Join Implementation"
+
+  it should "throw an exception when the given K is not valid" in {
+    intercept[IllegalArgumentException] {
+      KNN().setK(0)
+    }
+  }
+
+  it should "throw an exception when the given count of blocks is not valid" in {
+    intercept[IllegalArgumentException] {
+      KNN().setBlocks(0)
+    }
+  }
+
+  val env = ExecutionEnvironment.getExecutionEnvironment
+
+  // prepare data
+  val trainingSet = env.fromCollection(Classification.trainingData).map(_.vector)
+  val testingSet = env.fromElements(DenseVector(0.0, 0.0))
+
+  // calculate answer
+  val answer = Classification.trainingData.map {
+    v => (v.vector, SquaredEuclideanDistanceMetric().distance(DenseVector(0.0, 0.0), v.vector))
+  }.sortBy(_._2).take(3).map(_._1).toArray
+
+  it should "calculate kNN join correctly without using a Quadtree" in {
+
+    val knn = KNN()
+      .setK(3)
+      .setBlocks(10)
+      .setDistanceMetric(SquaredEuclideanDistanceMetric())
+      .setUseQuadTree(false)
+      .setSizeHint(CrossHint.SECOND_IS_SMALL)
+
+    // run knn join
+    knn.fit(trainingSet)
+    val result = knn.predict(testingSet).collect()
+
+    result.size should be(1)
+    result.head._1 should be(DenseVector(0.0, 0.0))
+    result.head._2 should be(answer)
+  }
+
+  it should "calculate kNN join correctly with a Quadtree" in {
+
+    val knn = KNN()
+      .setK(3)
+      .setBlocks(2) // blocks set to 2 to make sure initial quadtree box is partitioned
+      .setDistanceMetric(SquaredEuclideanDistanceMetric())
+      .setUseQuadTree(true)
+      .setSizeHint(CrossHint.SECOND_IS_SMALL)
+
+    // run knn join
+    knn.fit(trainingSet)
+    val result = knn.predict(testingSet).collect()
+
+    result.size should be(1)
+    result.head._1 should be(DenseVector(0.0, 0.0))
+    result.head._2 should be(answer)
+  }
+
+  it should "throw an exception when using a Quadtree with an incompatible metric" in {
+    intercept[IllegalArgumentException] {
+      val knn = KNN()
+        .setK(3)
+        .setBlocks(10)
+        .setDistanceMetric(ManhattanDistanceMetric())
+        .setUseQuadTree(true)
+
+      // run knn join
+      knn.fit(trainingSet)
+      val result = knn.predict(testingSet).collect()
+
+    }
+  }
+
+}
+

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
new file mode 100644
index 0000000..8be5c6e
--- /dev/null
+++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.nn
+
+import org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric
+import org.apache.flink.test.util.FlinkTestBase
+import org.apache.flink.ml.math.{Vector, DenseVector}
+
+import org.scalatest.{Matchers, FlatSpec}
+
+/** Test of Quadtree class
+  * Constructor for the Quadtree class:
+  * class QuadTree(minVec:ListBuffer[Double], maxVec:ListBuffer[Double])
+  *
+  */
+
+class QuadTreeSuite extends FlatSpec with Matchers with FlinkTestBase {
+  behavior of "The QuadTree Class"
+
+  it should "partition into equal size sub-boxes and search for nearby objects properly" in {
+
+    val minVec = DenseVector(-1.0, -0.5)
+    val maxVec = DenseVector(1.0, 0.5)
+
+    val myTree = new QuadTree(minVec, maxVec, EuclideanDistanceMetric(), 3)
+
+    myTree.insert(DenseVector(-0.25, 0.3).asInstanceOf[Vector])
+    myTree.insert(DenseVector(-0.20, 0.31).asInstanceOf[Vector])
+    myTree.insert(DenseVector(-0.21, 0.29).asInstanceOf[Vector])
+
+    var a = myTree.root.getCenterWidth()
+
+    /** Tree will partition once the 4th point is added
+      */
+
+    myTree.insert(DenseVector(0.2, 0.27).asInstanceOf[Vector])
+    myTree.insert(DenseVector(0.2, 0.26).asInstanceOf[Vector])
+
+    myTree.insert(DenseVector(-0.21, 0.289).asInstanceOf[Vector])
+    myTree.insert(DenseVector(-0.1, 0.289).asInstanceOf[Vector])
+
+    myTree.insert(DenseVector(0.7, 0.45).asInstanceOf[Vector])
+
+    /**
+     * Exact values of (centers,dimensions) of root + children nodes, to test
+     * partitionBox and makeChildren methods; exact values are given to avoid
+     * essentially copying and pasting the code to automatically generate them
+     * from minVec/maxVec
+     */
+
+    val knownCentersLengths = Set((DenseVector(0.0, 0.0), DenseVector(2.0, 1.0)),
+      (DenseVector(-0.5, -0.25), DenseVector(1.0, 0.5)),
+      (DenseVector(-0.5, 0.25), DenseVector(1.0, 0.5)),
+      (DenseVector(0.5, -0.25), DenseVector(1.0, 0.5)),
+      (DenseVector(0.5, 0.25), DenseVector(1.0, 0.5))
+    )
+
+    /**
+     * (centers,dimensions) computed from QuadTree.makeChildren
+     */
+
+    var computedCentersLength = Set((DenseVector(0.0, 0.0), DenseVector(2.0, 1.0)))
+    for (child <- myTree.root.children) {
+      computedCentersLength += child.getCenterWidth().asInstanceOf[(DenseVector, DenseVector)]
+    }
+
+
+    /**
+     * Tests search for nearby neighbors, make sure the right object is contained in neighbor
+      * search the neighbor search will contain more points
+     */
+    val neighborsComputed = myTree.searchNeighbors(DenseVector(0.7001, 0.45001), 0.001)
+    val isNeighborInSearch = neighborsComputed.contains(DenseVector(0.7, 0.45))
+
+    /**
+     * Test ability to get all objects in minimal bounding box + objects in siblings' block method
+     * In this case, drawing a picture of the QuadTree shows that
+     * (-0.2, 0.31), (-0.21, 0.29), (-0.21, 0.289)
+     * are objects near (-0.2001, 0.31001)
+     */
+
+    val siblingsObjectsComputed = myTree.searchNeighborsSiblingQueue(DenseVector(-0.2001, 0.31001))
+    val isSiblingsInSearch = siblingsObjectsComputed.contains(DenseVector(-0.2, 0.31)) &&
+      siblingsObjectsComputed.contains(DenseVector(-0.21, 0.29)) &&
+      siblingsObjectsComputed.contains(DenseVector(-0.21, 0.289))
+
+    computedCentersLength should be(knownCentersLengths)
+    isNeighborInSearch should be(true)
+    isSiblingsInSearch should be(true)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
index 35073b6..6d563e9 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala
@@ -20,14 +20,20 @@ package org.apache.flink.ml.nn
 
 import org.apache.flink.api.common.operators.Order
 import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.scala.DataSetUtils._
+//import org.apache.flink.api.scala.DataSetUtils._
+import org.apache.flink.api.scala.utils._
 import org.apache.flink.api.scala._
 import org.apache.flink.ml.common._
-import org.apache.flink.ml.math.Vector
-import org.apache.flink.ml.metrics.distances.{DistanceMetric, EuclideanDistanceMetric}
+import org.apache.flink.ml.math.{Vector => FlinkVector, DenseVector}
+import org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric,
+DistanceMetric, EuclideanDistanceMetric}
 import org.apache.flink.ml.pipeline.{FitOperation, PredictDataSetOperation, Predictor}
 import org.apache.flink.util.Collector
+import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
 
+import org.apache.flink.ml.nn.util.QuadTree
+
+import scala.collection.immutable.Vector
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.reflect.ClassTag
@@ -38,17 +44,17 @@ import scala.reflect.ClassTag
   *
   * @example
   * {{{
-  *     val trainingDS: DataSet[Vector] = ...
-  *     val testingDS: DataSet[Vector] = ...
+  *       val trainingDS: DataSet[Vector] = ...
+  *       val testingDS: DataSet[Vector] = ...
   *
-  *     val knn = KNN()
-  *       .setK(10)
-  *       .setBlocks(5)
-  *       .setDistanceMetric(EuclideanDistanceMetric())
+  *       val knn = KNN()
+  *         .setK(10)
+  *         .setBlocks(5)
+  *         .setDistanceMetric(EuclideanDistanceMetric())
   *
-  *     knn.fit(trainingDS)
+  *       knn.fit(trainingDS)
   *
-  *     val predictionDS: DataSet[(Vector, Array[Vector])] = knn.predict(testingDS)
+  *       val predictionDS: DataSet[(Vector, Array[Vector])] = knn.predict(testingDS)
   * }}}
   *
   * =Parameters=
@@ -67,11 +73,12 @@ import scala.reflect.ClassTag
   * (Default value: '''EuclideanDistanceMetric()''')
   *
   */
+
 class KNN extends Predictor[KNN] {
 
   import KNN._
 
-  var trainingSet: Option[DataSet[Block[Vector]]] = None
+  var trainingSet: Option[DataSet[Block[FlinkVector]]] = None
 
   /** Sets K
     * @param k the number of selected points as neighbors
@@ -98,6 +105,25 @@ class KNN extends Predictor[KNN] {
     parameters.add(Blocks, n)
     this
   }
+
+  /**
+   * Sets the Boolean variable that decides whether to use the QuadTree or not
+   */
+  def setUseQuadTree(UseQuadTree: Boolean): KNN = {
+    parameters.add(UseQuadTreeParam, UseQuadTree)
+    this
+  }
+
+  /**
+   * Parameter a user can specify if one of the training or test sets are small
+   * @param sizeHint
+   * @return
+   */
+  def setSizeHint(sizeHint: CrossHint): KNN = {
+    parameters.add(SizeHint, sizeHint)
+    this
+  }
+
 }
 
 object KNN {
@@ -114,6 +140,14 @@ object KNN {
     val defaultValue: Option[Int] = None
   }
 
+  case object UseQuadTreeParam extends Parameter[Boolean] {
+    val defaultValue: Option[Boolean] = None
+  }
+
+  case object SizeHint extends Parameter[CrossHint] {
+    val defaultValue: Option[CrossHint] = None
+  }
+
   def apply(): KNN = {
     new KNN()
   }
@@ -121,18 +155,18 @@ object KNN {
   /** [[FitOperation]] which trains a KNN based on the given training data set.
     * @tparam T Subtype of [[org.apache.flink.ml.math.Vector]]
     */
-  implicit def fitKNN[T <: Vector : TypeInformation] = new FitOperation[KNN, T] {
+  implicit def fitKNN[T <: FlinkVector : TypeInformation] = new FitOperation[KNN, T] {
     override def fit(
-        instance: KNN,
-        fitParameters: ParameterMap,
-        input: DataSet[T]): Unit = {
+      instance: KNN,
+      fitParameters: ParameterMap,
+      input: DataSet[T]): Unit = {
       val resultParameters = instance.parameters ++ fitParameters
 
       require(resultParameters.get(K).isDefined, "K is needed for calculation")
 
       val blocks = resultParameters.get(Blocks).getOrElse(input.getParallelism)
       val partitioner = FlinkMLTools.ModuloKeyPartitioner
-      val inputAsVector = input.asInstanceOf[DataSet[Vector]]
+      val inputAsVector = input.asInstanceOf[DataSet[FlinkVector]]
 
       instance.trainingSet = Some(FlinkMLTools.block(inputAsVector, blocks, Some(partitioner)))
     }
@@ -143,12 +177,13 @@ object KNN {
     * @tparam T Subtype of [[Vector]]
     * @return The given testing data set with k-nearest neighbors
     */
-  implicit def predictValues[T <: Vector : ClassTag : TypeInformation] = {
-    new PredictDataSetOperation[KNN, T, (Vector, Array[Vector])] {
+  implicit def predictValues[T <: FlinkVector : ClassTag : TypeInformation] = {
+    new PredictDataSetOperation[KNN, T, (FlinkVector, Array[FlinkVector])] {
       override def predictDataSet(
-          instance: KNN,
-          predictParameters: ParameterMap,
-          input: DataSet[T]): DataSet[(Vector, Array[Vector])] = {
+        instance: KNN,
+        predictParameters: ParameterMap,
+        input: DataSet[T]): DataSet[(FlinkVector,
+        Array[FlinkVector])] = {
         val resultParameters = instance.parameters ++ predictParameters
 
         instance.trainingSet match {
@@ -164,24 +199,40 @@ object KNN {
             // split data into multiple blocks
             val inputSplit = FlinkMLTools.block(inputWithId, blocks, Some(partitioner))
 
+            val sizeHint = resultParameters.get(SizeHint)
+            val crossTuned = sizeHint match {
+              case Some(hint) if hint == CrossHint.FIRST_IS_SMALL =>
+                trainingSet.crossWithHuge(inputSplit)
+              case Some(hint) if hint == CrossHint.SECOND_IS_SMALL =>
+                trainingSet.crossWithTiny(inputSplit)
+              case _ => trainingSet.cross(inputSplit)
+            }
+
             // join input and training set
-            val crossed = trainingSet.cross(inputSplit).mapPartition {
-              (iter, out: Collector[(Vector, Vector, Long, Double)]) => {
+            val crossed = crossTuned.mapPartition {
+              (iter, out: Collector[(FlinkVector, FlinkVector, Long, Double)]) => {
                 for ((training, testing) <- iter) {
-                  val queue = mutable.PriorityQueue[(Vector, Vector, Long, Double)]()(
+                  val queue = mutable.PriorityQueue[(FlinkVector, FlinkVector, Long, Double)]()(
                     Ordering.by(_._4))
 
-                  for (a <- testing.values; b <- training.values) {
-                    // (training vector, input vector, input key, distance)
-                    queue.enqueue((b, a._2, a._1, metric.distance(b, a._2)))
+                  // use a quadtree if (4^dim)Ntest*log(Ntrain)
+                  // < Ntest*Ntrain, and distance is Euclidean
+                  val useQuadTree = resultParameters.get(UseQuadTreeParam).getOrElse(
+                    training.values.head.size + math.log(math.log(training.values.length) /
+                      math.log(4.0)) < math.log(training.values.length) / math.log(4.0) &&
+                      (metric.isInstanceOf[EuclideanDistanceMetric] ||
+                        metric.isInstanceOf[SquaredEuclideanDistanceMetric]))
 
-                    if (queue.size > k) {
-                      queue.dequeue()
+                  if (useQuadTree) {
+                    if (metric.isInstanceOf[EuclideanDistanceMetric] ||
+                      metric.isInstanceOf[SquaredEuclideanDistanceMetric]){
+                      knnQueryWithQuadTree(training.values, testing.values, k, metric, queue, out)
+                    } else {
+                      throw new IllegalArgumentException(s" Error: metric must be" +
+                        s" Euclidean or SquaredEuclidean!")
                     }
-                  }
-
-                  for (v <- queue) {
-                    out.collect(v)
+                  } else {
+                    knnQueryBasic(training.values, testing.values, k, metric, queue, out)
                   }
                 }
               }
@@ -189,13 +240,14 @@ object KNN {
 
             // group by input vector id and pick k nearest neighbor for each group
             val result = crossed.groupBy(2).sortGroup(3, Order.ASCENDING).reduceGroup {
-              (iter, out: Collector[(Vector, Array[Vector])]) => {
+              (iter, out: Collector[(FlinkVector, Array[FlinkVector])]) => {
                 if (iter.hasNext) {
                   val head = iter.next()
                   val key = head._2
-                  val neighbors: ArrayBuffer[Vector] = ArrayBuffer(head._1)
+                  val neighbors: ArrayBuffer[FlinkVector] = ArrayBuffer(head._1)
 
-                  for ((vector, _, _, _) <- iter.take(k - 1)) { // we already took a first element
+                  for ((vector, _, _, _) <- iter.take(k - 1)) {
+                    // we already took a first element
                     neighbors += vector
                   }
 
@@ -206,9 +258,88 @@ object KNN {
 
             result
           case None => throw new RuntimeException("The KNN model has not been trained." +
-              "Call first fit before calling the predict operation.")
+            "Call first fit before calling the predict operation.")
+
         }
       }
     }
   }
+
+  def knnQueryWithQuadTree[T <: FlinkVector](
+    training: Vector[T],
+    testing: Vector[(Long, T)],
+    k: Int, metric: DistanceMetric,
+    queue: mutable.PriorityQueue[(FlinkVector,
+      FlinkVector, Long, Double)],
+    out: Collector[(FlinkVector,
+      FlinkVector, Long, Double)]) {
+    /// find a bounding box
+    val MinArr = Array.tabulate(training.head.size)(x => x)
+    val MaxArr = Array.tabulate(training.head.size)(x => x)
+
+    val minVecTrain = MinArr.map(i => training.map(x => x(i)).min - 0.01)
+    val minVecTest = MinArr.map(i => testing.map(x => x._2(i)).min - 0.01)
+    val maxVecTrain = MaxArr.map(i => training.map(x => x(i)).min + 0.01)
+    val maxVecTest = MaxArr.map(i => testing.map(x => x._2(i)).min + 0.01)
+
+    val MinVec = DenseVector(MinArr.map(i => Array(minVecTrain(i), minVecTest(i)).min))
+    val MaxVec = DenseVector(MinArr.map(i => Array(maxVecTrain(i), maxVecTest(i)).max))
+
+    //default value of max elements/box is set to max(20,k)
+    val maxPerBox = Array(k, 20).max
+    val trainingQuadTree = new QuadTree(MinVec, MaxVec, metric, maxPerBox)
+
+    for (v <- training) {
+      trainingQuadTree.insert(v)
+    }
+
+    for ((id, vector) <- testing) {
+      //  Find siblings' objects and do local kNN there
+      val siblingObjects =
+        trainingQuadTree.searchNeighborsSiblingQueue(vector)
+
+      // do KNN query on siblingObjects and get max distance of kNN
+      // then rad is good choice for a neighborhood to do a refined
+      // local kNN search
+      val knnSiblings = siblingObjects.map(v => metric.distance(vector, v)
+      ).sortWith(_ < _).take(k)
+
+      val rad = knnSiblings.last
+      val trainingFiltered = trainingQuadTree.searchNeighbors(vector, rad)
+
+      for (b <- trainingFiltered) {
+        // (training vector, input vector, input key, distance)
+        queue.enqueue((b, vector, id, metric.distance(b, vector)))
+        if (queue.size > k) {
+          queue.dequeue()
+        }
+      }
+      for (v <- queue) {
+        out.collect(v)
+      }
+    }
+  }
+
+  def knnQueryBasic[T <: FlinkVector](
+    training: Vector[T],
+    testing: Vector[(Long, T)],
+    k: Int, metric: DistanceMetric,
+    queue: mutable.PriorityQueue[(FlinkVector,
+      FlinkVector, Long, Double)],
+    out: Collector[(FlinkVector, FlinkVector, Long, Double)]) {
+
+    for ((id, vector) <- testing) {
+      for (b <- training) {
+        // (training vector, input vector, input key, distance)
+        queue.enqueue((b, vector, id, metric.distance(b, vector)))
+        if (queue.size > k) {
+          queue.dequeue()
+        }
+      }
+      for (v <- queue) {
+        out.collect(v)
+      }
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
new file mode 100644
index 0000000..0b37313
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.nn.util
+
+import org.apache.flink.ml.math.{Breeze, Vector}
+import Breeze._
+
+import org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric,
+EuclideanDistanceMetric, DistanceMetric}
+
+import scala.collection.mutable.ListBuffer
+import scala.collection.mutable.PriorityQueue
+
+/**
+ * n-dimensional QuadTree data structure; partitions
+ * spatial data for faster queries (e.g. KNN query)
+ * The skeleton of the data structure was initially
+ * based off of the 2D Quadtree found here:
+ * http://www.cs.trinity.edu/~mlewis/CSCI1321-F11/Code/src/util/Quadtree.scala
+ *
+ * Many additional methods were added to the class both for
+ * efficient KNN queries and generalizing to n-dim.
+ *
+ * @param minVec vector of the corner of the bounding box with smallest coordinates
+ * @param maxVec vector of the corner of the bounding box with smallest coordinates
+ * @param distMetric metric, must be Euclidean or squareEuclidean
+ * @param maxPerBox threshold for number of points in each box before slitting a box
+ */
+class QuadTree(
+  minVec: Vector,
+  maxVec: Vector,
+  distMetric: DistanceMetric,
+  maxPerBox: Int) {
+
+  class Node(
+    center: Vector,
+    width: Vector,
+    var children: Seq[Node]) {
+
+    val nodeElements = new ListBuffer[Vector]
+
+    /** for testing purposes only; used in QuadTreeSuite.scala
+      *
+      * @return center and width of the box
+      */
+    def getCenterWidth(): (Vector, Vector) = {
+      (center, width)
+    }
+
+    def contains(queryPoint: Vector): Boolean = {
+      overlap(queryPoint, 0.0)
+    }
+
+    /** Tests if queryPoint is within a radius of the node
+      *
+      * @param queryPoint
+      * @param radius
+      * @return
+      */
+    def overlap(
+      queryPoint: Vector,
+      radius: Double): Boolean = {
+      val count = (0 until queryPoint.size).filter { i =>
+        (queryPoint(i) - radius < center(i) + width(i) / 2) &&
+          (queryPoint(i) + radius > center(i) - width(i) / 2)
+      }.size
+
+      count == queryPoint.size
+    }
+
+    /** Tests if queryPoint is near a node
+      *
+      * @param queryPoint
+      * @param radius
+      * @return
+      */
+    def isNear(
+      queryPoint: Vector,
+      radius: Double): Boolean = {
+      minDist(queryPoint) < radius
+    }
+
+    /**
+     * minDist is defined so that every point in the box
+     * has distance to queryPoint greater than minDist
+     * (minDist adopted from "Nearest Neighbors Queries" by N. Roussopoulos et al.)
+     *
+     * @param queryPoint
+     * @return
+     */
+    def minDist(queryPoint: Vector): Double = {
+      val minDist = (0 until queryPoint.size).map { i =>
+        if (queryPoint(i) < center(i) - width(i) / 2) {
+          math.pow(queryPoint(i) - center(i) + width(i) / 2, 2)
+        } else if (queryPoint(i) > center(i) + width(i) / 2) {
+          math.pow(queryPoint(i) - center(i) - width(i) / 2, 2)
+        } else {
+          0
+        }
+      }.sum
+
+      distMetric match {
+        case _: SquaredEuclideanDistanceMetric => minDist
+        case _: EuclideanDistanceMetric => math.sqrt(minDist)
+        case _ => throw new IllegalArgumentException(s" Error: metric must be" +
+          s" Euclidean or SquaredEuclidean!")
+      }
+    }
+
+    /**
+     * Finds which child queryPoint lies in.  node.children is a Seq[Node], and
+     * whichChild finds the appropriate index of that Seq.
+     * @param queryPoint
+     * @return
+     */
+    def whichChild(queryPoint: Vector): Int = {
+      (0 until queryPoint.size).map { i =>
+        if (queryPoint(i) > center(i)) {
+          Math.pow(2, queryPoint.size - 1 - i).toInt
+        } else {
+          0
+        }
+      }.sum
+    }
+
+    def makeChildren() {
+      val centerClone = center.copy
+      val cPart = partitionBox(centerClone, width)
+      val mappedWidth = 0.5 * width.asBreeze
+      children = cPart.map(p => new Node(p, mappedWidth.fromBreeze, null))
+    }
+
+    /**
+     * Recursive function that partitions a n-dim box by taking the (n-1) dimensional
+     * plane through the center of the box keeping the n-th coordinate fixed,
+     * then shifting it in the n-th direction up and down
+     * and recursively applying partitionBox to the two shifted (n-1) dimensional planes.
+     *
+     * @param center the center of the box
+     * @param width a vector of lengths of each dimension of the box
+     * @return
+     */
+    def partitionBox(
+      center: Vector,
+      width: Vector): Seq[Vector] = {
+      def partitionHelper(
+        box: Seq[Vector],
+        dim: Int): Seq[Vector] = {
+        if (dim >= width.size) {
+          box
+        } else {
+          val newBox = box.flatMap {
+            vector =>
+              val (up, down) = (vector.copy, vector)
+              up.update(dim, up(dim) - width(dim) / 4)
+              down.update(dim, down(dim) + width(dim) / 4)
+
+              Seq(up, down)
+          }
+          partitionHelper(newBox, dim + 1)
+        }
+      }
+      partitionHelper(Seq(center), 0)
+    }
+  }
+
+
+  val root = new Node(((minVec.asBreeze + maxVec.asBreeze) * 0.5).fromBreeze,
+    (maxVec.asBreeze - minVec.asBreeze).fromBreeze, null)
+
+  /**
+   * simple printing of tree for testing/debugging
+   */
+  def printTree(): Unit = {
+    printTreeRecur(root)
+  }
+
+  def printTreeRecur(node: Node) {
+    if (node.children != null) {
+      for (c <- node.children) {
+        printTreeRecur(c)
+      }
+    } else {
+      println("printing tree: n.nodeElements " + node.nodeElements)
+    }
+  }
+
+  /**
+   * Recursively adds an object to the tree
+   * @param queryPoint
+   */
+  def insert(queryPoint: Vector) {
+    insertRecur(queryPoint, root)
+  }
+
+  private def insertRecur(
+    queryPoint: Vector,
+    node: Node) {
+    if (node.children == null) {
+      if (node.nodeElements.length < maxPerBox) {
+        node.nodeElements += queryPoint
+      } else {
+        node.makeChildren()
+        for (o <- node.nodeElements) {
+          insertRecur(o, node.children(node.whichChild(o)))
+        }
+        node.nodeElements.clear()
+        insertRecur(queryPoint, node.children(node.whichChild(queryPoint)))
+      }
+    } else {
+      insertRecur(queryPoint, node.children(node.whichChild(queryPoint)))
+    }
+  }
+
+  /**
+   * Used to zoom in on a region near a test point for a fast KNN query.
+   * This capability is used in the KNN query to find k "near" neighbors n_1,...,n_k, from
+   * which one computes the max distance D_s to queryPoint.  D_s is then used during the
+   * kNN query to find all points within a radius D_s of queryPoint using searchNeighbors.
+   * To find the "near" neighbors, a min-heap is defined on the leaf nodes of the leaf
+   * nodes of the minimal bounding box of the queryPoint. The priority of a leaf node
+   * is an appropriate notion of the distance between the test point and the node,
+   * which is defined by minDist(queryPoint),
+   *
+   * @param queryPoint a test point for which the method finds the minimal bounding
+   *                   box that queryPoint lies in and returns elements in that boxes
+   *                   siblings' leaf nodes
+   * @return
+   */
+  def searchNeighborsSiblingQueue(queryPoint: Vector): ListBuffer[Vector] = {
+    val ret = new ListBuffer[Vector]
+    // edge case when the main box has not been partitioned at all
+    if (root.children == null) {
+      root.nodeElements.clone()
+    } else {
+      val nodeQueue = new PriorityQueue[(Double, Node)]()(Ordering.by(x => x._1))
+      searchRecurSiblingQueue(queryPoint, root, nodeQueue)
+
+      var count = 0
+      while (count < maxPerBox) {
+        val dq = nodeQueue.dequeue()
+        if (dq._2.nodeElements.nonEmpty) {
+          ret ++= dq._2.nodeElements
+          count += dq._2.nodeElements.length
+        }
+      }
+      ret
+    }
+  }
+
+  /**
+   *
+   * @param queryPoint point under consideration
+   * @param node node that queryPoint lies in
+   * @param nodeQueue defined in searchSiblingQueue, this stores nodes based on their
+   *                  distance to node as defined by minDist
+   */
+  private def searchRecurSiblingQueue(
+    queryPoint: Vector,
+    node: Node,
+    nodeQueue: PriorityQueue[(Double, Node)]) {
+    if (node.children != null) {
+      for (child <- node.children; if child.contains(queryPoint)) {
+        if (child.children == null) {
+          for (c <- node.children) {
+            minNodes(queryPoint, c, nodeQueue)
+          }
+        } else {
+          searchRecurSiblingQueue(queryPoint, child, nodeQueue)
+        }
+      }
+    }
+  }
+
+  /**
+   * Goes down to minimal bounding box of queryPoint, and add elements to nodeQueue
+   *
+   * @param queryPoint point under consideration
+   * @param node node that queryPoint lies in
+   * @param nodeQueue PriorityQueue that stores all points in minimal bounding box of queryPoint
+   */
+  private def minNodes(
+    queryPoint: Vector,
+    node: Node,
+    nodeQueue: PriorityQueue[(Double, Node)]) {
+    if (node.children == null) {
+      nodeQueue += ((-node.minDist(queryPoint), node))
+    } else {
+      for (c <- node.children) {
+        minNodes(queryPoint, c, nodeQueue)
+      }
+    }
+  }
+
+  /** Finds all objects within a neigiborhood of queryPoint of a specified radius
+    * scope is modified from original 2D version in:
+    * http://www.cs.trinity.edu/~mlewis/CSCI1321-F11/Code/src/util/Quadtree.scala
+    *
+    * original version only looks in minimal box; for the KNN Query, we look at
+    * all nearby boxes. The radius is determined from searchNeighborsSiblingQueue
+    * by defining a min-heap on the leaf nodes
+    *
+    * @param queryPoint
+    * @param radius
+    * @return all points within queryPoint with given radius
+    */
+  def searchNeighbors(
+    queryPoint: Vector,
+    radius: Double): ListBuffer[Vector] = {
+    val ret = new ListBuffer[Vector]
+    searchRecur(queryPoint, radius, root, ret)
+    ret
+  }
+
+  private def searchRecur(
+    queryPoint: Vector,
+    radius: Double,
+    node: Node,
+    ret: ListBuffer[Vector]) {
+    if (node.children == null) {
+      ret ++= node.nodeElements
+    } else {
+      for (child <- node.children; if child.isNear(queryPoint, radius)) {
+        searchRecur(queryPoint, radius, child, ret)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
index 107724b..350af95 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/KNNITSuite.scala
@@ -53,9 +53,10 @@ class KNNITSuite extends FlatSpec with Matchers with FlinkTestBase {
     }.sortBy(_._2).take(3).map(_._1).toArray
 
     val knn = KNN()
-        .setK(3)
-        .setBlocks(10)
-        .setDistanceMetric(SquaredEuclideanDistanceMetric())
+      .setK(3)
+      .setBlocks(10)
+      .setDistanceMetric(SquaredEuclideanDistanceMetric())
+      .setUseQuadTree(true)
 
     // run knn join
     knn.fit(trainingSet)

http://git-wip-us.apache.org/repos/asf/flink/blob/4a5af42c/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
new file mode 100644
index 0000000..9b84a80
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/nn/QuadTreeSuite.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric
+import org.apache.flink.ml.nn.util.QuadTree
+import org.apache.flink.test.util.FlinkTestBase
+import org.apache.flink.ml.math.{Breeze, Vector, DenseVector}
+
+import org.scalatest.{Matchers, FlatSpec}
+
+/** Test of Quadtree class
+  * Constructor for the Quadtree class:
+  * class QuadTree(minVec:ListBuffer[Double], maxVec:ListBuffer[Double])
+  *
+  */
+
+class QuadTreeSuite extends FlatSpec with Matchers with FlinkTestBase {
+  behavior of "The QuadTree Class"
+
+  it should "partition into equal size sub-boxes and search for nearby objects properly" in {
+
+    val minVec = DenseVector(-1.0, -0.5)
+    val maxVec = DenseVector(1.0, 0.5)
+
+    val myTree = new QuadTree(minVec, maxVec, EuclideanDistanceMetric(), 3)
+
+    myTree.insert(DenseVector(-0.25, 0.3).asInstanceOf[Vector])
+    myTree.insert(DenseVector(-0.20, 0.31).asInstanceOf[Vector])
+    myTree.insert(DenseVector(-0.21, 0.29).asInstanceOf[Vector])
+
+    var a = myTree.root.getCenterWidth()
+
+    /** Tree will partition once the 4th point is added
+      */
+
+    myTree.insert(DenseVector(0.2, 0.27).asInstanceOf[Vector])
+    myTree.insert(DenseVector(0.2, 0.26).asInstanceOf[Vector])
+
+    myTree.insert(DenseVector(-0.21, 0.289).asInstanceOf[Vector])
+    myTree.insert(DenseVector(-0.1, 0.289).asInstanceOf[Vector])
+
+    myTree.insert(DenseVector(0.7, 0.45).asInstanceOf[Vector])
+
+    /**
+     * Exact values of (centers,dimensions) of root + children nodes, to test
+     * partitionBox and makeChildren methods; exact values are given to avoid
+     * essentially copying and pasting the code to automatically generate them
+     * from minVec/maxVec
+     */
+
+    val knownCentersLengths = Set((DenseVector(0.0, 0.0), DenseVector(2.0, 1.0)),
+      (DenseVector(-0.5, -0.25), DenseVector(1.0, 0.5)),
+      (DenseVector(-0.5, 0.25), DenseVector(1.0, 0.5)),
+      (DenseVector(0.5, -0.25), DenseVector(1.0, 0.5)),
+      (DenseVector(0.5, 0.25), DenseVector(1.0, 0.5))
+    )
+
+    /**
+     * (centers,dimensions) computed from QuadTree.makeChildren
+     */
+
+    var computedCentersLength = Set((DenseVector(0.0, 0.0), DenseVector(2.0, 1.0)))
+    for (child <- myTree.root.children) {
+      computedCentersLength += child.getCenterWidth().asInstanceOf[(DenseVector, DenseVector)]
+    }
+
+
+    /**
+     * Tests search for nearby neighbors, make sure the right object is contained in neighbor search
+     * the neighbor search will contain more points
+     */
+    val neighborsComputed = myTree.searchNeighbors(DenseVector(0.7001, 0.45001), 0.001)
+    val isNeighborInSearch = neighborsComputed.contains(DenseVector(0.7, 0.45))
+
+    /**
+     * Test ability to get all objects in minimal bounding box + objects in siblings' block method
+     * In this case, drawing a picture of the QuadTree shows that
+     * (-0.2, 0.31), (-0.21, 0.29), (-0.21, 0.289)
+     * are objects near (-0.2001, 0.31001)
+     */
+
+    val siblingsObjectsComputed = myTree.searchNeighborsSiblingQueue(DenseVector(-0.2001, 0.31001))
+    val isSiblingsInSearch = siblingsObjectsComputed.contains(DenseVector(-0.2, 0.31)) &&
+      siblingsObjectsComputed.contains(DenseVector(-0.21, 0.29)) &&
+      siblingsObjectsComputed.contains(DenseVector(-0.21, 0.289))
+
+    computedCentersLength should be(knownCentersLengths)
+    isNeighborInSearch should be(true)
+    isSiblingsInSearch should be(true)
+  }
+}


Mime
View raw message