spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From zsxw...@apache.org
Subject [1/2] spark git commit: [SPARK-12244][SPARK-12245][STREAMING] Rename trackStateByKey to mapWithState and change tracking function signature
Date Thu, 10 Dec 2015 04:59:36 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 699f497cf -> f6d866173


http://git-wip-us.apache.org/repos/asf/spark/blob/f6d86617/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
----------------------------------------------------------------------
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
deleted file mode 100644
index eac4cdd..0000000
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
+++ /dev/null
@@ -1,210 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming;
-
-import java.io.Serializable;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-import java.util.Set;
-
-import scala.Tuple2;
-
-import com.google.common.base.Optional;
-import com.google.common.collect.Lists;
-import com.google.common.collect.Sets;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.streaming.api.java.JavaDStream;
-import org.apache.spark.util.ManualClock;
-import org.junit.Assert;
-import org.junit.Test;
-
-import org.apache.spark.HashPartitioner;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.function.Function2;
-import org.apache.spark.api.java.function.Function4;
-import org.apache.spark.streaming.api.java.JavaPairDStream;
-import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
-
-public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implements Serializable {
-
-  /**
-   * This test is only for testing the APIs. It's not necessary to run it.
-   */
-  public void testAPI() {
-    JavaPairRDD<String, Boolean> initialRDD = null;
-    JavaPairDStream<String, Integer> wordsDstream = null;
-
-    final Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>
-        trackStateFunc =
-        new Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>() {
-
-          @Override
-          public Optional<Double> call(
-              Time time, String word, Optional<Integer> one, State<Boolean> state) {
-            // Use all State's methods here
-            state.exists();
-            state.get();
-            state.isTimingOut();
-            state.remove();
-            state.update(true);
-            return Optional.of(2.0);
-          }
-        };
-
-    JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
-        wordsDstream.trackStateByKey(
-            StateSpec.function(trackStateFunc)
-                .initialState(initialRDD)
-                .numPartitions(10)
-                .partitioner(new HashPartitioner(10))
-                .timeout(Durations.seconds(10)));
-
-    JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots();
-
-    final Function2<Optional<Integer>, State<Boolean>, Double> trackStateFunc2 =
-        new Function2<Optional<Integer>, State<Boolean>, Double>() {
-
-          @Override
-          public Double call(Optional<Integer> one, State<Boolean> state) {
-            // Use all State's methods here
-            state.exists();
-            state.get();
-            state.isTimingOut();
-            state.remove();
-            state.update(true);
-            return 2.0;
-          }
-        };
-
-    JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
-        wordsDstream.trackStateByKey(
-            StateSpec.<String, Integer, Boolean, Double> function(trackStateFunc2)
-                .initialState(initialRDD)
-                .numPartitions(10)
-                .partitioner(new HashPartitioner(10))
-                .timeout(Durations.seconds(10)));
-
-    JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots();
-  }
-
-  @Test
-  public void testBasicFunction() {
-    List<List<String>> inputData = Arrays.asList(
-        Collections.<String>emptyList(),
-        Arrays.asList("a"),
-        Arrays.asList("a", "b"),
-        Arrays.asList("a", "b", "c"),
-        Arrays.asList("a", "b"),
-        Arrays.asList("a"),
-        Collections.<String>emptyList()
-    );
-
-    List<Set<Integer>> outputData = Arrays.asList(
-        Collections.<Integer>emptySet(),
-        Sets.newHashSet(1),
-        Sets.newHashSet(2, 1),
-        Sets.newHashSet(3, 2, 1),
-        Sets.newHashSet(4, 3),
-        Sets.newHashSet(5),
-        Collections.<Integer>emptySet()
-    );
-
-    List<Set<Tuple2<String, Integer>>> stateData = Arrays.asList(
-        Collections.<Tuple2<String, Integer>>emptySet(),
-        Sets.newHashSet(new Tuple2<String, Integer>("a", 1)),
-        Sets.newHashSet(new Tuple2<String, Integer>("a", 2), new Tuple2<String, Integer>("b", 1)),
-        Sets.newHashSet(
-            new Tuple2<String, Integer>("a", 3),
-            new Tuple2<String, Integer>("b", 2),
-            new Tuple2<String, Integer>("c", 1)),
-        Sets.newHashSet(
-            new Tuple2<String, Integer>("a", 4),
-            new Tuple2<String, Integer>("b", 3),
-            new Tuple2<String, Integer>("c", 1)),
-        Sets.newHashSet(
-            new Tuple2<String, Integer>("a", 5),
-            new Tuple2<String, Integer>("b", 3),
-            new Tuple2<String, Integer>("c", 1)),
-        Sets.newHashSet(
-            new Tuple2<String, Integer>("a", 5),
-            new Tuple2<String, Integer>("b", 3),
-            new Tuple2<String, Integer>("c", 1))
-    );
-
-    Function2<Optional<Integer>, State<Integer>, Integer> trackStateFunc =
-        new Function2<Optional<Integer>, State<Integer>, Integer>() {
-
-          @Override
-          public Integer call(Optional<Integer> value, State<Integer> state) throws Exception {
-            int sum = value.or(0) + (state.exists() ? state.get() : 0);
-            state.update(sum);
-            return sum;
-          }
-        };
-    testOperation(
-        inputData,
-        StateSpec.<String, Integer, Integer, Integer>function(trackStateFunc),
-        outputData,
-        stateData);
-  }
-
-  private <K, S, T> void testOperation(
-      List<List<K>> input,
-      StateSpec<K, Integer, S, T> trackStateSpec,
-      List<Set<T>> expectedOutputs,
-      List<Set<Tuple2<K, S>>> expectedStateSnapshots) {
-    int numBatches = expectedOutputs.size();
-    JavaDStream<K> inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2);
-    JavaTrackStateDStream<K, Integer, S, T> trackeStateStream =
-        JavaPairDStream.fromJavaDStream(inputStream.map(new Function<K, Tuple2<K, Integer>>() {
-          @Override
-          public Tuple2<K, Integer> call(K x) throws Exception {
-            return new Tuple2<K, Integer>(x, 1);
-          }
-        })).trackStateByKey(trackStateSpec);
-
-    final List<Set<T>> collectedOutputs =
-        Collections.synchronizedList(Lists.<Set<T>>newArrayList());
-    trackeStateStream.foreachRDD(new Function<JavaRDD<T>, Void>() {
-      @Override
-      public Void call(JavaRDD<T> rdd) throws Exception {
-        collectedOutputs.add(Sets.newHashSet(rdd.collect()));
-        return null;
-      }
-    });
-    final List<Set<Tuple2<K, S>>> collectedStateSnapshots =
-        Collections.synchronizedList(Lists.<Set<Tuple2<K, S>>>newArrayList());
-    trackeStateStream.stateSnapshots().foreachRDD(new Function<JavaPairRDD<K, S>, Void>() {
-      @Override
-      public Void call(JavaPairRDD<K, S> rdd) throws Exception {
-        collectedStateSnapshots.add(Sets.newHashSet(rdd.collect()));
-        return null;
-      }
-    });
-    BatchCounter batchCounter = new BatchCounter(ssc.ssc());
-    ssc.start();
-    ((ManualClock) ssc.ssc().scheduler().clock())
-        .advance(ssc.ssc().progressListener().batchDuration() * numBatches + 1);
-    batchCounter.waitUntilBatchesCompleted(numBatches, 10000);
-
-    Assert.assertEquals(expectedOutputs, collectedOutputs);
-    Assert.assertEquals(expectedStateSnapshots, collectedStateSnapshots);
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d86617/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala
new file mode 100644
index 0000000..4b08085
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala
@@ -0,0 +1,581 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import java.io.File
+
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+import scala.reflect.ClassTag
+
+import org.scalatest.PrivateMethodTester._
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+
+import org.apache.spark.streaming.dstream.{DStream, InternalMapWithStateDStream, MapWithStateDStream, MapWithStateDStreamImpl}
+import org.apache.spark.util.{ManualClock, Utils}
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+
+class MapWithStateSuite extends SparkFunSuite
+  with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter {
+
+  private var sc: SparkContext = null
+  protected var checkpointDir: File = null
+  protected val batchDuration = Seconds(1)
+
+  before {
+    StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
+    checkpointDir = Utils.createTempDir("checkpoint")
+  }
+
+  after {
+    if (checkpointDir != null) {
+      Utils.deleteRecursively(checkpointDir)
+    }
+    StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
+  }
+
+  override def beforeAll(): Unit = {
+    val conf = new SparkConf().setMaster("local").setAppName("MapWithStateSuite")
+    conf.set("spark.streaming.clock", classOf[ManualClock].getName())
+    sc = new SparkContext(conf)
+  }
+
+  override def afterAll(): Unit = {
+    if (sc != null) {
+      sc.stop()
+    }
+  }
+
+  test("state - get, exists, update, remove, ") {
+    var state: StateImpl[Int] = null
+
+    def testState(
+        expectedData: Option[Int],
+        shouldBeUpdated: Boolean = false,
+        shouldBeRemoved: Boolean = false,
+        shouldBeTimingOut: Boolean = false
+      ): Unit = {
+      if (expectedData.isDefined) {
+        assert(state.exists)
+        assert(state.get() === expectedData.get)
+        assert(state.getOption() === expectedData)
+        assert(state.getOption.getOrElse(-1) === expectedData.get)
+      } else {
+        assert(!state.exists)
+        intercept[NoSuchElementException] {
+          state.get()
+        }
+        assert(state.getOption() === None)
+        assert(state.getOption.getOrElse(-1) === -1)
+      }
+
+      assert(state.isTimingOut() === shouldBeTimingOut)
+      if (shouldBeTimingOut) {
+        intercept[IllegalArgumentException] {
+          state.remove()
+        }
+        intercept[IllegalArgumentException] {
+          state.update(-1)
+        }
+      }
+
+      assert(state.isUpdated() === shouldBeUpdated)
+
+      assert(state.isRemoved() === shouldBeRemoved)
+      if (shouldBeRemoved) {
+        intercept[IllegalArgumentException] {
+          state.remove()
+        }
+        intercept[IllegalArgumentException] {
+          state.update(-1)
+        }
+      }
+    }
+
+    state = new StateImpl[Int]()
+    testState(None)
+
+    state.wrap(None)
+    testState(None)
+
+    state.wrap(Some(1))
+    testState(Some(1))
+
+    state.update(2)
+    testState(Some(2), shouldBeUpdated = true)
+
+    state = new StateImpl[Int]()
+    state.update(2)
+    testState(Some(2), shouldBeUpdated = true)
+
+    state.remove()
+    testState(None, shouldBeRemoved = true)
+
+    state.wrapTiminoutState(3)
+    testState(Some(3), shouldBeTimingOut = true)
+  }
+
+  test("mapWithState - basic operations with simple API") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData =
+      Seq(
+        Seq(),
+        Seq(1),
+        Seq(2, 1),
+        Seq(3, 2, 1),
+        Seq(4, 3),
+        Seq(5),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1))
+      )
+
+    // state maintains running count, and updated count is returned
+    val mappingFunc = (key: String, value: Option[Int], state: State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      state.update(sum)
+      sum
+    }
+
+    testOperation[String, Int, Int](
+      inputData, StateSpec.function(mappingFunc), outputData, stateData)
+  }
+
+  test("mapWithState - basic operations with advanced API") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData =
+      Seq(
+        Seq(),
+        Seq("aa"),
+        Seq("aa", "bb"),
+        Seq("aa", "bb", "cc"),
+        Seq("aa", "bb"),
+        Seq("aa"),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1))
+      )
+
+    // state maintains running count, key string doubled and returned
+    val mappingFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      state.update(sum)
+      Some(key * 2)
+    }
+
+    testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData)
+  }
+
+  test("mapWithState - type inferencing and class tags") {
+
+    // Simple track state function with value as Int, state as Double and mapped type as Double
+    val simpleFunc = (key: String, value: Option[Int], state: State[Double]) => {
+      0L
+    }
+
+    // Advanced track state function with key as String, value as Int, state as Double and
+    // mapped type as Double
+    val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => {
+      Some(0L)
+    }
+
+    def testTypes(dstream: MapWithStateDStream[_, _, _, _]): Unit = {
+      val dstreamImpl = dstream.asInstanceOf[MapWithStateDStreamImpl[_, _, _, _]]
+      assert(dstreamImpl.keyClass === classOf[String])
+      assert(dstreamImpl.valueClass === classOf[Int])
+      assert(dstreamImpl.stateClass === classOf[Double])
+      assert(dstreamImpl.mappedClass === classOf[Long])
+    }
+    val ssc = new StreamingContext(sc, batchDuration)
+    val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2)
+
+    // Defining StateSpec inline with mapWithState and simple function implicitly gets the types
+    val simpleFunctionStateStream1 = inputStream.mapWithState(
+      StateSpec.function(simpleFunc).numPartitions(1))
+    testTypes(simpleFunctionStateStream1)
+
+    // Separately defining StateSpec with simple function requires explicitly specifying types
+    val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc)
+    val simpleFunctionStateStream2 = inputStream.mapWithState(simpleFuncSpec)
+    testTypes(simpleFunctionStateStream2)
+
+    // Separately defining StateSpec with advanced function implicitly gets the types
+    val advFuncSpec1 = StateSpec.function(advancedFunc)
+    val advFunctionStateStream1 = inputStream.mapWithState(advFuncSpec1)
+    testTypes(advFunctionStateStream1)
+
+    // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types
+    val advFunctionStateStream2 = inputStream.mapWithState(
+      StateSpec.function(simpleFunc).numPartitions(1))
+    testTypes(advFunctionStateStream2)
+
+    // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types
+    val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc)
+    val advFunctionStateStream3 = inputStream.mapWithState[Double, Long](advFuncSpec2)
+    testTypes(advFunctionStateStream3)
+  }
+
+  test("mapWithState - states as mapped data") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3)),
+        Seq(("a", 5)),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1))
+      )
+
+    val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      val output = (key, sum)
+      state.update(sum)
+      Some(output)
+    }
+
+    testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData)
+  }
+
+  test("mapWithState - initial states, with nothing returned as from mapping function") {
+
+    val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0))
+
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData = Seq.fill(inputData.size)(Seq.empty[Int])
+
+    val stateData =
+      Seq(
+        Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)),
+        Seq(("a", 6), ("b", 10), ("c", -20), ("d", 0)),
+        Seq(("a", 7), ("b", 11), ("c", -20), ("d", 0)),
+        Seq(("a", 8), ("b", 12), ("c", -19), ("d", 0)),
+        Seq(("a", 9), ("b", 13), ("c", -19), ("d", 0)),
+        Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)),
+        Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0))
+      )
+
+    val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      val output = (key, sum)
+      state.update(sum)
+      None.asInstanceOf[Option[Int]]
+    }
+
+    val mapWithStateSpec = StateSpec.function(mappingFunc).initialState(sc.makeRDD(initialState))
+    testOperation(inputData, mapWithStateSpec, outputData, stateData)
+  }
+
+  test("mapWithState - state removing") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"), // a will be removed
+        Seq("a", "b", "c"), // b will be removed
+        Seq("a", "b", "c"), // a and c will be removed
+        Seq("a", "b"), // b will be removed
+        Seq("a"), // a will be removed
+        Seq()
+      )
+
+    // States that were removed
+    val outputData =
+      Seq(
+        Seq(),
+        Seq(),
+        Seq("a"),
+        Seq("b"),
+        Seq("a", "c"),
+        Seq("b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("b", 1)),
+        Seq(("a", 1), ("c", 1)),
+        Seq(("b", 1)),
+        Seq(("a", 1)),
+        Seq(),
+        Seq()
+      )
+
+    val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+      if (state.exists) {
+        state.remove()
+        Some(key)
+      } else {
+        state.update(value.get)
+        None
+      }
+    }
+
+    testOperation(
+      inputData, StateSpec.function(mappingFunc).numPartitions(1), outputData, stateData)
+  }
+
+  test("mapWithState - state timing out") {
+    val inputData =
+      Seq(
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq(), // c will time out
+        Seq(), // b will time out
+        Seq("a") // a will not time out
+      ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active
+
+    val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+      if (value.isDefined) {
+        state.update(1)
+      }
+      if (state.isTimingOut) {
+        Some(key)
+      } else {
+        None
+      }
+    }
+
+    val (collectedOutputs, collectedStateSnapshots) = getOperationOutput(
+      inputData, StateSpec.function(mappingFunc).timeout(Seconds(3)), 20)
+
+    // b and c should be returned once each, when they were marked as expired
+    assert(collectedOutputs.flatten.sorted === Seq("b", "c"))
+
+    // States for a, b, c should be defined at one point of time
+    assert(collectedStateSnapshots.exists {
+      _.toSet == Set(("a", 1), ("b", 1), ("c", 1))
+    })
+
+    // Finally state should be defined only for a
+    assert(collectedStateSnapshots.last.toSet === Set(("a", 1)))
+  }
+
+  test("mapWithState - checkpoint durations") {
+    val privateMethod = PrivateMethod[InternalMapWithStateDStream[_, _, _, _]]('internalStream)
+
+    def testCheckpointDuration(
+        batchDuration: Duration,
+        expectedCheckpointDuration: Duration,
+        explicitCheckpointDuration: Option[Duration] = None
+      ): Unit = {
+      val ssc = new StreamingContext(sc, batchDuration)
+
+      try {
+        val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1)
+        val dummyFunc = (key: Int, value: Option[Int], state: State[Int]) => 0
+        val mapWithStateStream = inputStream.mapWithState(StateSpec.function(dummyFunc))
+        val internalmapWithStateStream = mapWithStateStream invokePrivate privateMethod()
+
+        explicitCheckpointDuration.foreach { d =>
+          mapWithStateStream.checkpoint(d)
+        }
+        mapWithStateStream.register()
+        ssc.checkpoint(checkpointDir.toString)
+        ssc.start()  // should initialize all the checkpoint durations
+        assert(mapWithStateStream.checkpointDuration === null)
+        assert(internalmapWithStateStream.checkpointDuration === expectedCheckpointDuration)
+      } finally {
+        ssc.stop(stopSparkContext = false)
+      }
+    }
+
+    testCheckpointDuration(Milliseconds(100), Seconds(1))
+    testCheckpointDuration(Seconds(1), Seconds(10))
+    testCheckpointDuration(Seconds(10), Seconds(100))
+
+    testCheckpointDuration(Milliseconds(100), Seconds(2), Some(Seconds(2)))
+    testCheckpointDuration(Seconds(1), Seconds(2), Some(Seconds(2)))
+    testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20)))
+  }
+
+
+  test("mapWithState - driver failure recovery") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1))
+      )
+
+    def operation(dstream: DStream[String]): DStream[(String, Int)] = {
+
+      val checkpointDuration = batchDuration * (stateData.size / 2)
+
+      val runningCount = (key: String, value: Option[Int], state: State[Int]) => {
+        state.update(state.getOption().getOrElse(0) + value.getOrElse(0))
+        state.get()
+      }
+
+      val mapWithStateStream = dstream.map { _ -> 1 }.mapWithState(
+        StateSpec.function(runningCount))
+      // Set internval make sure there is one RDD checkpointing
+      mapWithStateStream.checkpoint(checkpointDuration)
+      mapWithStateStream.stateSnapshots()
+    }
+
+    testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2,
+      batchDuration = batchDuration, stopSparkContextAfterTest = false)
+  }
+
+  private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag](
+      input: Seq[Seq[K]],
+      mapWithStateSpec: StateSpec[K, Int, S, T],
+      expectedOutputs: Seq[Seq[T]],
+      expectedStateSnapshots: Seq[Seq[(K, S)]]
+    ): Unit = {
+    require(expectedOutputs.size == expectedStateSnapshots.size)
+
+    val (collectedOutputs, collectedStateSnapshots) =
+      getOperationOutput(input, mapWithStateSpec, expectedOutputs.size)
+    assert(expectedOutputs, collectedOutputs, "outputs")
+    assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots")
+  }
+
+  private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag](
+      input: Seq[Seq[K]],
+      mapWithStateSpec: StateSpec[K, Int, S, T],
+      numBatches: Int
+    ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = {
+
+    // Setup the stream computation
+    val ssc = new StreamingContext(sc, Seconds(1))
+    val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
+    val trackeStateStream = inputStream.map(x => (x, 1)).mapWithState(mapWithStateSpec)
+    val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]
+    val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs)
+    val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]]
+    val stateSnapshotStream = new TestOutputStream(
+      trackeStateStream.stateSnapshots(), collectedStateSnapshots)
+    outputStream.register()
+    stateSnapshotStream.register()
+
+    val batchCounter = new BatchCounter(ssc)
+    ssc.checkpoint(checkpointDir.toString)
+    ssc.start()
+
+    val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+    clock.advance(batchDuration.milliseconds * numBatches)
+
+    batchCounter.waitUntilBatchesCompleted(numBatches, 10000)
+    ssc.stop(stopSparkContext = false)
+    (collectedOutputs, collectedStateSnapshots)
+  }
+
+  private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) {
+    val debugString = "\nExpected:\n" + expected.mkString("\n") +
+      "\nCollected:\n" + collected.mkString("\n")
+    assert(expected.size === collected.size,
+      s"number of collected $typ (${collected.size}) different from expected (${expected.size})" +
+        debugString)
+    expected.zip(collected).foreach { case (c, e) =>
+      assert(c.toSet === e.toSet,
+        s"collected $typ is different from expected $debugString"
+      )
+    }
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d86617/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
deleted file mode 100644
index 1fc320d..0000000
--- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
+++ /dev/null
@@ -1,581 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming
-
-import java.io.File
-
-import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
-import scala.reflect.ClassTag
-
-import org.scalatest.PrivateMethodTester._
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
-
-import org.apache.spark.streaming.dstream.{DStream, InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl}
-import org.apache.spark.util.{ManualClock, Utils}
-import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
-
-class TrackStateByKeySuite extends SparkFunSuite
-  with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter {
-
-  private var sc: SparkContext = null
-  protected var checkpointDir: File = null
-  protected val batchDuration = Seconds(1)
-
-  before {
-    StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
-    checkpointDir = Utils.createTempDir("checkpoint")
-  }
-
-  after {
-    if (checkpointDir != null) {
-      Utils.deleteRecursively(checkpointDir)
-    }
-    StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
-  }
-
-  override def beforeAll(): Unit = {
-    val conf = new SparkConf().setMaster("local").setAppName("TrackStateByKeySuite")
-    conf.set("spark.streaming.clock", classOf[ManualClock].getName())
-    sc = new SparkContext(conf)
-  }
-
-  override def afterAll(): Unit = {
-    if (sc != null) {
-      sc.stop()
-    }
-  }
-
-  test("state - get, exists, update, remove, ") {
-    var state: StateImpl[Int] = null
-
-    def testState(
-        expectedData: Option[Int],
-        shouldBeUpdated: Boolean = false,
-        shouldBeRemoved: Boolean = false,
-        shouldBeTimingOut: Boolean = false
-      ): Unit = {
-      if (expectedData.isDefined) {
-        assert(state.exists)
-        assert(state.get() === expectedData.get)
-        assert(state.getOption() === expectedData)
-        assert(state.getOption.getOrElse(-1) === expectedData.get)
-      } else {
-        assert(!state.exists)
-        intercept[NoSuchElementException] {
-          state.get()
-        }
-        assert(state.getOption() === None)
-        assert(state.getOption.getOrElse(-1) === -1)
-      }
-
-      assert(state.isTimingOut() === shouldBeTimingOut)
-      if (shouldBeTimingOut) {
-        intercept[IllegalArgumentException] {
-          state.remove()
-        }
-        intercept[IllegalArgumentException] {
-          state.update(-1)
-        }
-      }
-
-      assert(state.isUpdated() === shouldBeUpdated)
-
-      assert(state.isRemoved() === shouldBeRemoved)
-      if (shouldBeRemoved) {
-        intercept[IllegalArgumentException] {
-          state.remove()
-        }
-        intercept[IllegalArgumentException] {
-          state.update(-1)
-        }
-      }
-    }
-
-    state = new StateImpl[Int]()
-    testState(None)
-
-    state.wrap(None)
-    testState(None)
-
-    state.wrap(Some(1))
-    testState(Some(1))
-
-    state.update(2)
-    testState(Some(2), shouldBeUpdated = true)
-
-    state = new StateImpl[Int]()
-    state.update(2)
-    testState(Some(2), shouldBeUpdated = true)
-
-    state.remove()
-    testState(None, shouldBeRemoved = true)
-
-    state.wrapTiminoutState(3)
-    testState(Some(3), shouldBeTimingOut = true)
-  }
-
-  test("trackStateByKey - basic operations with simple API") {
-    val inputData =
-      Seq(
-        Seq(),
-        Seq("a"),
-        Seq("a", "b"),
-        Seq("a", "b", "c"),
-        Seq("a", "b"),
-        Seq("a"),
-        Seq()
-      )
-
-    val outputData =
-      Seq(
-        Seq(),
-        Seq(1),
-        Seq(2, 1),
-        Seq(3, 2, 1),
-        Seq(4, 3),
-        Seq(5),
-        Seq()
-      )
-
-    val stateData =
-      Seq(
-        Seq(),
-        Seq(("a", 1)),
-        Seq(("a", 2), ("b", 1)),
-        Seq(("a", 3), ("b", 2), ("c", 1)),
-        Seq(("a", 4), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1))
-      )
-
-    // state maintains running count, and updated count is returned
-    val trackStateFunc = (value: Option[Int], state: State[Int]) => {
-      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
-      state.update(sum)
-      sum
-    }
-
-    testOperation[String, Int, Int](
-      inputData, StateSpec.function(trackStateFunc), outputData, stateData)
-  }
-
-  test("trackStateByKey - basic operations with advanced API") {
-    val inputData =
-      Seq(
-        Seq(),
-        Seq("a"),
-        Seq("a", "b"),
-        Seq("a", "b", "c"),
-        Seq("a", "b"),
-        Seq("a"),
-        Seq()
-      )
-
-    val outputData =
-      Seq(
-        Seq(),
-        Seq("aa"),
-        Seq("aa", "bb"),
-        Seq("aa", "bb", "cc"),
-        Seq("aa", "bb"),
-        Seq("aa"),
-        Seq()
-      )
-
-    val stateData =
-      Seq(
-        Seq(),
-        Seq(("a", 1)),
-        Seq(("a", 2), ("b", 1)),
-        Seq(("a", 3), ("b", 2), ("c", 1)),
-        Seq(("a", 4), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1))
-      )
-
-    // state maintains running count, key string doubled and returned
-    val trackStateFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => {
-      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
-      state.update(sum)
-      Some(key * 2)
-    }
-
-    testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData)
-  }
-
-  test("trackStateByKey - type inferencing and class tags") {
-
-    // Simple track state function with value as Int, state as Double and emitted type as Double
-    val simpleFunc = (value: Option[Int], state: State[Double]) => {
-      0L
-    }
-
-    // Advanced track state function with key as String, value as Int, state as Double and
-    // emitted type as Double
-    val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => {
-      Some(0L)
-    }
-
-    def testTypes(dstream: TrackStateDStream[_, _, _, _]): Unit = {
-      val dstreamImpl = dstream.asInstanceOf[TrackStateDStreamImpl[_, _, _, _]]
-      assert(dstreamImpl.keyClass === classOf[String])
-      assert(dstreamImpl.valueClass === classOf[Int])
-      assert(dstreamImpl.stateClass === classOf[Double])
-      assert(dstreamImpl.emittedClass === classOf[Long])
-    }
-    val ssc = new StreamingContext(sc, batchDuration)
-    val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2)
-
-    // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types
-    val simpleFunctionStateStream1 = inputStream.trackStateByKey(
-      StateSpec.function(simpleFunc).numPartitions(1))
-    testTypes(simpleFunctionStateStream1)
-
-    // Separately defining StateSpec with simple function requires explicitly specifying types
-    val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc)
-    val simpleFunctionStateStream2 = inputStream.trackStateByKey(simpleFuncSpec)
-    testTypes(simpleFunctionStateStream2)
-
-    // Separately defining StateSpec with advanced function implicitly gets the types
-    val advFuncSpec1 = StateSpec.function(advancedFunc)
-    val advFunctionStateStream1 = inputStream.trackStateByKey(advFuncSpec1)
-    testTypes(advFunctionStateStream1)
-
-    // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
-    val advFunctionStateStream2 = inputStream.trackStateByKey(
-      StateSpec.function(simpleFunc).numPartitions(1))
-    testTypes(advFunctionStateStream2)
-
-    // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
-    val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc)
-    val advFunctionStateStream3 = inputStream.trackStateByKey[Double, Long](advFuncSpec2)
-    testTypes(advFunctionStateStream3)
-  }
-
-  test("trackStateByKey - states as emitted records") {
-    val inputData =
-      Seq(
-        Seq(),
-        Seq("a"),
-        Seq("a", "b"),
-        Seq("a", "b", "c"),
-        Seq("a", "b"),
-        Seq("a"),
-        Seq()
-      )
-
-    val outputData =
-      Seq(
-        Seq(),
-        Seq(("a", 1)),
-        Seq(("a", 2), ("b", 1)),
-        Seq(("a", 3), ("b", 2), ("c", 1)),
-        Seq(("a", 4), ("b", 3)),
-        Seq(("a", 5)),
-        Seq()
-      )
-
-    val stateData =
-      Seq(
-        Seq(),
-        Seq(("a", 1)),
-        Seq(("a", 2), ("b", 1)),
-        Seq(("a", 3), ("b", 2), ("c", 1)),
-        Seq(("a", 4), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1))
-      )
-
-    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
-      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
-      val output = (key, sum)
-      state.update(sum)
-      Some(output)
-    }
-
-    testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData)
-  }
-
-  test("trackStateByKey - initial states, with nothing emitted") {
-
-    val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0))
-
-    val inputData =
-      Seq(
-        Seq(),
-        Seq("a"),
-        Seq("a", "b"),
-        Seq("a", "b", "c"),
-        Seq("a", "b"),
-        Seq("a"),
-        Seq()
-      )
-
-    val outputData = Seq.fill(inputData.size)(Seq.empty[Int])
-
-    val stateData =
-      Seq(
-        Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)),
-        Seq(("a", 6), ("b", 10), ("c", -20), ("d", 0)),
-        Seq(("a", 7), ("b", 11), ("c", -20), ("d", 0)),
-        Seq(("a", 8), ("b", 12), ("c", -19), ("d", 0)),
-        Seq(("a", 9), ("b", 13), ("c", -19), ("d", 0)),
-        Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)),
-        Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0))
-      )
-
-    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
-      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
-      val output = (key, sum)
-      state.update(sum)
-      None.asInstanceOf[Option[Int]]
-    }
-
-    val trackStateSpec = StateSpec.function(trackStateFunc).initialState(sc.makeRDD(initialState))
-    testOperation(inputData, trackStateSpec, outputData, stateData)
-  }
-
-  test("trackStateByKey - state removing") {
-    val inputData =
-      Seq(
-        Seq(),
-        Seq("a"),
-        Seq("a", "b"), // a will be removed
-        Seq("a", "b", "c"), // b will be removed
-        Seq("a", "b", "c"), // a and c will be removed
-        Seq("a", "b"), // b will be removed
-        Seq("a"), // a will be removed
-        Seq()
-      )
-
-    // States that were removed
-    val outputData =
-      Seq(
-        Seq(),
-        Seq(),
-        Seq("a"),
-        Seq("b"),
-        Seq("a", "c"),
-        Seq("b"),
-        Seq("a"),
-        Seq()
-      )
-
-    val stateData =
-      Seq(
-        Seq(),
-        Seq(("a", 1)),
-        Seq(("b", 1)),
-        Seq(("a", 1), ("c", 1)),
-        Seq(("b", 1)),
-        Seq(("a", 1)),
-        Seq(),
-        Seq()
-      )
-
-    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
-      if (state.exists) {
-        state.remove()
-        Some(key)
-      } else {
-        state.update(value.get)
-        None
-      }
-    }
-
-    testOperation(
-      inputData, StateSpec.function(trackStateFunc).numPartitions(1), outputData, stateData)
-  }
-
-  test("trackStateByKey - state timing out") {
-    val inputData =
-      Seq(
-        Seq("a", "b", "c"),
-        Seq("a", "b"),
-        Seq("a"),
-        Seq(), // c will time out
-        Seq(), // b will time out
-        Seq("a") // a will not time out
-      ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active
-
-    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
-      if (value.isDefined) {
-        state.update(1)
-      }
-      if (state.isTimingOut) {
-        Some(key)
-      } else {
-        None
-      }
-    }
-
-    val (collectedOutputs, collectedStateSnapshots) = getOperationOutput(
-      inputData, StateSpec.function(trackStateFunc).timeout(Seconds(3)), 20)
-
-    // b and c should be emitted once each, when they were marked as expired
-    assert(collectedOutputs.flatten.sorted === Seq("b", "c"))
-
-    // States for a, b, c should be defined at one point of time
-    assert(collectedStateSnapshots.exists {
-      _.toSet == Set(("a", 1), ("b", 1), ("c", 1))
-    })
-
-    // Finally state should be defined only for a
-    assert(collectedStateSnapshots.last.toSet === Set(("a", 1)))
-  }
-
-  test("trackStateByKey - checkpoint durations") {
-    val privateMethod = PrivateMethod[InternalTrackStateDStream[_, _, _, _]]('internalStream)
-
-    def testCheckpointDuration(
-        batchDuration: Duration,
-        expectedCheckpointDuration: Duration,
-        explicitCheckpointDuration: Option[Duration] = None
-      ): Unit = {
-      val ssc = new StreamingContext(sc, batchDuration)
-
-      try {
-        val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1)
-        val dummyFunc = (value: Option[Int], state: State[Int]) => 0
-        val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc))
-        val internalTrackStateStream = trackStateStream invokePrivate privateMethod()
-
-        explicitCheckpointDuration.foreach { d =>
-          trackStateStream.checkpoint(d)
-        }
-        trackStateStream.register()
-        ssc.checkpoint(checkpointDir.toString)
-        ssc.start()  // should initialize all the checkpoint durations
-        assert(trackStateStream.checkpointDuration === null)
-        assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration)
-      } finally {
-        ssc.stop(stopSparkContext = false)
-      }
-    }
-
-    testCheckpointDuration(Milliseconds(100), Seconds(1))
-    testCheckpointDuration(Seconds(1), Seconds(10))
-    testCheckpointDuration(Seconds(10), Seconds(100))
-
-    testCheckpointDuration(Milliseconds(100), Seconds(2), Some(Seconds(2)))
-    testCheckpointDuration(Seconds(1), Seconds(2), Some(Seconds(2)))
-    testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20)))
-  }
-
-
-  test("trackStateByKey - driver failure recovery") {
-    val inputData =
-      Seq(
-        Seq(),
-        Seq("a"),
-        Seq("a", "b"),
-        Seq("a", "b", "c"),
-        Seq("a", "b"),
-        Seq("a"),
-        Seq()
-      )
-
-    val stateData =
-      Seq(
-        Seq(),
-        Seq(("a", 1)),
-        Seq(("a", 2), ("b", 1)),
-        Seq(("a", 3), ("b", 2), ("c", 1)),
-        Seq(("a", 4), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1))
-      )
-
-    def operation(dstream: DStream[String]): DStream[(String, Int)] = {
-
-      val checkpointDuration = batchDuration * (stateData.size / 2)
-
-      val runningCount = (value: Option[Int], state: State[Int]) => {
-        state.update(state.getOption().getOrElse(0) + value.getOrElse(0))
-        state.get()
-      }
-
-      val trackStateStream = dstream.map { _ -> 1 }.trackStateByKey(
-        StateSpec.function(runningCount))
-      // Set internval make sure there is one RDD checkpointing
-      trackStateStream.checkpoint(checkpointDuration)
-      trackStateStream.stateSnapshots()
-    }
-
-    testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2,
-      batchDuration = batchDuration, stopSparkContextAfterTest = false)
-  }
-
-  private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag](
-      input: Seq[Seq[K]],
-      trackStateSpec: StateSpec[K, Int, S, T],
-      expectedOutputs: Seq[Seq[T]],
-      expectedStateSnapshots: Seq[Seq[(K, S)]]
-    ): Unit = {
-    require(expectedOutputs.size == expectedStateSnapshots.size)
-
-    val (collectedOutputs, collectedStateSnapshots) =
-      getOperationOutput(input, trackStateSpec, expectedOutputs.size)
-    assert(expectedOutputs, collectedOutputs, "outputs")
-    assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots")
-  }
-
-  private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag](
-      input: Seq[Seq[K]],
-      trackStateSpec: StateSpec[K, Int, S, T],
-      numBatches: Int
-    ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = {
-
-    // Setup the stream computation
-    val ssc = new StreamingContext(sc, Seconds(1))
-    val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
-    val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec)
-    val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]
-    val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs)
-    val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]]
-    val stateSnapshotStream = new TestOutputStream(
-      trackeStateStream.stateSnapshots(), collectedStateSnapshots)
-    outputStream.register()
-    stateSnapshotStream.register()
-
-    val batchCounter = new BatchCounter(ssc)
-    ssc.checkpoint(checkpointDir.toString)
-    ssc.start()
-
-    val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
-    clock.advance(batchDuration.milliseconds * numBatches)
-
-    batchCounter.waitUntilBatchesCompleted(numBatches, 10000)
-    ssc.stop(stopSparkContext = false)
-    (collectedOutputs, collectedStateSnapshots)
-  }
-
-  private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) {
-    val debugString = "\nExpected:\n" + expected.mkString("\n") +
-      "\nCollected:\n" + collected.mkString("\n")
-    assert(expected.size === collected.size,
-      s"number of collected $typ (${collected.size}) different from expected (${expected.size})" +
-        debugString)
-    expected.zip(collected).foreach { case (c, e) =>
-      assert(c.toSet === e.toSet,
-        s"collected $typ is different from expected $debugString"
-      )
-    }
-  }
-}
-

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d86617/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala
new file mode 100644
index 0000000..aa95bd3
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala
@@ -0,0 +1,389 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.rdd
+
+import java.io.File
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
+import org.apache.spark.streaming.{State, Time}
+import org.apache.spark.util.Utils
+
+class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll {
+
+  private var sc: SparkContext = null
+  private var checkpointDir: File = _
+
+  override def beforeAll(): Unit = {
+    sc = new SparkContext(
+      new SparkConf().setMaster("local").setAppName("MapWithStateRDDSuite"))
+    checkpointDir = Utils.createTempDir()
+    sc.setCheckpointDir(checkpointDir.toString)
+  }
+
+  override def afterAll(): Unit = {
+    if (sc != null) {
+      sc.stop()
+    }
+    Utils.deleteRecursively(checkpointDir)
+  }
+
+  override def sparkContext: SparkContext = sc
+
+  test("creation from pair RDD") {
+    val data = Seq((1, "1"), (2, "2"), (3, "3"))
+    val partitioner = new HashPartitioner(10)
+    val rdd = MapWithStateRDD.createFromPairRDD[Int, Int, String, Int](
+      sc.parallelize(data), partitioner, Time(123))
+    assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty)
+    assert(rdd.partitions.size === partitioner.numPartitions)
+
+    assert(rdd.partitioner === Some(partitioner))
+  }
+
+  test("updating state and generating mapped data in MapWithStateRDDRecord") {
+
+    val initialTime = 1000L
+    val updatedTime = 2000L
+    val thresholdTime = 1500L
+    @volatile var functionCalled = false
+
+    /**
+     * Assert that applying given data on a prior record generates correct updated record, with
+     * correct state map and mapped data
+     */
+    def assertRecordUpdate(
+        initStates: Iterable[Int],
+        data: Iterable[String],
+        expectedStates: Iterable[(Int, Long)],
+        timeoutThreshold: Option[Long] = None,
+        removeTimedoutData: Boolean = false,
+        expectedOutput: Iterable[Int] = None,
+        expectedTimingOutStates: Iterable[Int] = None,
+        expectedRemovedStates: Iterable[Int] = None
+      ): Unit = {
+      val initialStateMap = new OpenHashMapBasedStateMap[String, Int]()
+      initStates.foreach { s => initialStateMap.put("key", s, initialTime) }
+      functionCalled = false
+      val record = MapWithStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty)
+      val dataIterator = data.map { v => ("key", v) }.iterator
+      val removedStates = new ArrayBuffer[Int]
+      val timingOutStates = new ArrayBuffer[Int]
+      /**
+       * Mapping function that updates/removes state based on instructions in the data, and
+       * return state (when instructed or when state is timing out).
+       */
+      def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = {
+        functionCalled = true
+
+        assert(t.milliseconds === updatedTime, "mapping func called with wrong time")
+
+        data match {
+          case Some("noop") =>
+            None
+          case Some("get-state") =>
+            Some(state.getOption().getOrElse(-1))
+          case Some("update-state") =>
+            if (state.exists) state.update(state.get + 1) else state.update(0)
+            None
+          case Some("remove-state") =>
+            removedStates += state.get()
+            state.remove()
+            None
+          case None =>
+            assert(state.isTimingOut() === true, "State is not timing out when data = None")
+            timingOutStates += state.get()
+            None
+          case _ =>
+            fail("Unexpected test data")
+        }
+      }
+
+      val updatedRecord = MapWithStateRDDRecord.updateRecordWithData[String, String, Int, Int](
+        Some(record), dataIterator, testFunc,
+        Time(updatedTime), timeoutThreshold, removeTimedoutData)
+
+      val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) }
+      assert(updatedStateData.toSet === expectedStates.toSet,
+        "states do not match after updating the MapWithStateRDDRecord")
+
+      assert(updatedRecord.mappedData.toSet === expectedOutput.toSet,
+        "mapped data do not match after updating the MapWithStateRDDRecord")
+
+      assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " +
+        "match those that were expected to do so while updating the MapWithStateRDDRecord")
+
+      assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " +
+        "match those that were expected to do so while updating the MapWithStateRDDRecord")
+
+    }
+
+    // No data, no state should be changed, function should not be called,
+    assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil)
+    assert(functionCalled === false)
+    assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime)))
+    assert(functionCalled === false)
+
+    // Data present, function should be called irrespective of whether state exists
+    assertRecordUpdate(initStates = Seq(0), data = Seq("noop"),
+      expectedStates = Seq((0, initialTime)))
+    assert(functionCalled === true)
+    assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None)
+    assert(functionCalled === true)
+
+    // Function called with right state data
+    assertRecordUpdate(initStates = None, data = Seq("get-state"),
+      expectedStates = None, expectedOutput = Seq(-1))
+    assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"),
+      expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123))
+
+    // Update state and timestamp, when timeout not present
+    assertRecordUpdate(initStates = Nil, data = Seq("update-state"),
+      expectedStates = Seq((0, updatedTime)))
+    assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"),
+      expectedStates = Seq((1, updatedTime)))
+
+    // Remove state
+    assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"),
+      expectedStates = Nil, expectedRemovedStates = Seq(345))
+
+    // State strictly older than timeout threshold should be timed out
+    assertRecordUpdate(initStates = Seq(123), data = Nil,
+      timeoutThreshold = Some(initialTime), removeTimedoutData = true,
+      expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil)
+
+    assertRecordUpdate(initStates = Seq(123), data = Nil,
+      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+      expectedStates = Nil, expectedTimingOutStates = Seq(123))
+
+    // State should not be timed out after it has received data
+    assertRecordUpdate(initStates = Seq(123), data = Seq("noop"),
+      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+      expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil)
+    assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"),
+      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+      expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123))
+
+  }
+
+  test("states generated by MapWithStateRDD") {
+    val initStates = Seq(("k1", 0), ("k2", 0))
+    val initTime = 123
+    val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet
+    val partitioner = new HashPartitioner(2)
+    val initStateRDD = MapWithStateRDD.createFromPairRDD[String, Int, Int, Int](
+      sc.parallelize(initStates), partitioner, Time(initTime)).persist()
+    assertRDD(initStateRDD, initStateWthTime, Set.empty)
+
+    val updateTime = 345
+
+    /**
+     * Test that the test state RDD, when operated with new data,
+     * creates a new state RDD with expected states
+     */
+    def testStateUpdates(
+        testStateRDD: MapWithStateRDD[String, Int, Int, Int],
+        testData: Seq[(String, Int)],
+        expectedStates: Set[(String, Int, Int)]): MapWithStateRDD[String, Int, Int, Int] = {
+
+      // Persist the test MapWithStateRDD so that its not recomputed while doing the next operation.
+      // This is to make sure that we only touch which state keys are being touched in the next op.
+      testStateRDD.persist().count()
+
+      // To track which keys are being touched
+      MapWithStateRDDSuite.touchedStateKeys.clear()
+
+      val mappingFunction = (time: Time, key: String, data: Option[Int], state: State[Int]) => {
+
+        // Track the key that has been touched
+        MapWithStateRDDSuite.touchedStateKeys += key
+
+        // If the data is 0, do not do anything with the state
+        // else if the data is 1, increment the state if it exists, or set new state to 0
+        // else if the data is 2, remove the state if it exists
+        data match {
+          case Some(1) =>
+            if (state.exists()) { state.update(state.get + 1) }
+            else state.update(0)
+          case Some(2) =>
+            state.remove()
+          case _ =>
+        }
+        None.asInstanceOf[Option[Int]]  // Do not return anything, not being tested
+      }
+      val newDataRDD = sc.makeRDD(testData).partitionBy(testStateRDD.partitioner.get)
+
+      // Assert that the new state RDD has expected state data
+      val newStateRDD = assertOperation(
+        testStateRDD, newDataRDD, mappingFunction, updateTime, expectedStates, Set.empty)
+
+      // Assert that the function was called only for the keys present in the data
+      assert(MapWithStateRDDSuite.touchedStateKeys.size === testData.size,
+        "More number of keys are being touched than that is expected")
+      assert(MapWithStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys,
+        "Keys not in the data are being touched unexpectedly")
+
+      // Assert that the test RDD's data has not changed
+      assertRDD(initStateRDD, initStateWthTime, Set.empty)
+      newStateRDD
+    }
+
+    // Test no-op, no state should change
+    testStateUpdates(initStateRDD, Seq(), initStateWthTime)   // should not scan any state
+    testStateUpdates(
+      initStateRDD, Seq(("k1", 0)), initStateWthTime)         // should not update existing state
+    testStateUpdates(
+      initStateRDD, Seq(("k3", 0)), initStateWthTime)         // should not create new state
+
+    // Test creation of new state
+    val rdd1 = testStateUpdates(initStateRDD, Seq(("k3", 1)), // should create k3's state as 0
+      Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime)))
+
+    val rdd2 = testStateUpdates(rdd1, Seq(("k4", 1)),         // should create k4's state as 0
+      Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime), ("k4", 0, updateTime)))
+
+    // Test updating of state
+    val rdd3 = testStateUpdates(
+      initStateRDD, Seq(("k1", 1)),                   // should increment k1's state 0 -> 1
+      Set(("k1", 1, updateTime), ("k2", 0, initTime)))
+
+    val rdd4 = testStateUpdates(rdd3,
+      Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)),  // should update k2, 0 -> 2 and create k3, 0
+      Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 0, updateTime)))
+
+    val rdd5 = testStateUpdates(
+      rdd4, Seq(("k3", 1)),                           // should update k3's state 0 -> 2
+      Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 1, updateTime)))
+
+    // Test removing of state
+    val rdd6 = testStateUpdates(                      // should remove k1's state
+      initStateRDD, Seq(("k1", 2)), Set(("k2", 0, initTime)))
+
+    val rdd7 = testStateUpdates(                      // should remove k2's state
+      rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime)))
+
+    val rdd8 = testStateUpdates(                      // should remove k3's state
+      rdd7, Seq(("k3", 2)), Set())
+  }
+
+  test("checkpointing") {
+    /**
+     * This tests whether the MapWithStateRDD correctly truncates any references to its parent RDDs
+     * - the data RDD and the parent MapWithStateRDD.
+     */
+    def rddCollectFunc(rdd: RDD[MapWithStateRDDRecord[Int, Int, Int]])
+      : Set[(List[(Int, Int, Long)], List[Int])] = {
+      rdd.map { record => (record.stateMap.getAll().toList, record.mappedData.toList) }
+         .collect.toSet
+    }
+
+    /** Generate MapWithStateRDD with data RDD having a long lineage */
+    def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int])
+      : MapWithStateRDD[Int, Int, Int, Int] = {
+      MapWithStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0))
+    }
+
+    testRDD(
+      makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _)
+    testRDDPartitions(
+      makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _)
+
+    /** Generate MapWithStateRDD with parent state RDD having a long lineage */
+    def makeStateRDDWithLongLineageParenttateRDD(
+        longLineageRDD: RDD[Int]): MapWithStateRDD[Int, Int, Int, Int] = {
+
+      // Create a MapWithStateRDD that has a long lineage using the data RDD with a long lineage
+      val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD)
+
+      // Create a new MapWithStateRDD, with the lineage lineage MapWithStateRDD as the parent
+      new MapWithStateRDD[Int, Int, Int, Int](
+        stateRDDWithLongLineage,
+        stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner),
+        (time: Time, key: Int, value: Option[Int], state: State[Int]) => None,
+        Time(10),
+        None
+      )
+    }
+
+    testRDD(
+      makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _)
+    testRDDPartitions(
+      makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _)
+  }
+
+  test("checkpointing empty state RDD") {
+    val emptyStateRDD = MapWithStateRDD.createFromPairRDD[Int, Int, Int, Int](
+      sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0))
+    emptyStateRDD.checkpoint()
+    assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
+    val cpRDD = sc.checkpointFile[MapWithStateRDDRecord[Int, Int, Int]](
+      emptyStateRDD.getCheckpointFile.get)
+    assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
+  }
+
+  /** Assert whether the `mapWithState` operation generates expected results */
+  private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+      testStateRDD: MapWithStateRDD[K, V, S, T],
+      newDataRDD: RDD[(K, V)],
+      mappingFunction: (Time, K, Option[V], State[S]) => Option[T],
+      currentTime: Long,
+      expectedStates: Set[(K, S, Int)],
+      expectedMappedData: Set[T],
+      doFullScan: Boolean = false
+    ): MapWithStateRDD[K, V, S, T] = {
+
+    val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) {
+      newDataRDD.partitionBy(testStateRDD.partitioner.get)
+    } else {
+      newDataRDD
+    }
+
+    val newStateRDD = new MapWithStateRDD[K, V, S, T](
+      testStateRDD, newDataRDD, mappingFunction, Time(currentTime), None)
+    if (doFullScan) newStateRDD.setFullScan()
+
+    // Persist to make sure that it gets computed only once and we can track precisely how many
+    // state keys the computing touched
+    newStateRDD.persist().count()
+    assertRDD(newStateRDD, expectedStates, expectedMappedData)
+    newStateRDD
+  }
+
+  /** Assert whether the [[MapWithStateRDD]] has the expected state and mapped data */
+  private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+      stateRDD: MapWithStateRDD[K, V, S, T],
+      expectedStates: Set[(K, S, Int)],
+      expectedMappedData: Set[T]): Unit = {
+    val states = stateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
+    val mappedData = stateRDD.flatMap { _.mappedData }.collect().toSet
+    assert(states === expectedStates,
+      "states after mapWithState operation were not as expected")
+    assert(mappedData === expectedMappedData,
+      "mapped data after mapWithState operation were not as expected")
+  }
+}
+
+object MapWithStateRDDSuite {
+  private val touchedStateKeys = new ArrayBuffer[String]()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d86617/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
deleted file mode 100644
index 3b2d43f..0000000
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
+++ /dev/null
@@ -1,389 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.rdd
-
-import java.io.File
-
-import scala.collection.mutable.ArrayBuffer
-import scala.reflect.ClassTag
-
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
-import org.apache.spark.streaming.{State, Time}
-import org.apache.spark.util.Utils
-
-class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll {
-
-  private var sc: SparkContext = null
-  private var checkpointDir: File = _
-
-  override def beforeAll(): Unit = {
-    sc = new SparkContext(
-      new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite"))
-    checkpointDir = Utils.createTempDir()
-    sc.setCheckpointDir(checkpointDir.toString)
-  }
-
-  override def afterAll(): Unit = {
-    if (sc != null) {
-      sc.stop()
-    }
-    Utils.deleteRecursively(checkpointDir)
-  }
-
-  override def sparkContext: SparkContext = sc
-
-  test("creation from pair RDD") {
-    val data = Seq((1, "1"), (2, "2"), (3, "3"))
-    val partitioner = new HashPartitioner(10)
-    val rdd = TrackStateRDD.createFromPairRDD[Int, Int, String, Int](
-      sc.parallelize(data), partitioner, Time(123))
-    assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty)
-    assert(rdd.partitions.size === partitioner.numPartitions)
-
-    assert(rdd.partitioner === Some(partitioner))
-  }
-
-  test("updating state and generating emitted data in TrackStateRecord") {
-
-    val initialTime = 1000L
-    val updatedTime = 2000L
-    val thresholdTime = 1500L
-    @volatile var functionCalled = false
-
-    /**
-     * Assert that applying given data on a prior record generates correct updated record, with
-     * correct state map and emitted data
-     */
-    def assertRecordUpdate(
-        initStates: Iterable[Int],
-        data: Iterable[String],
-        expectedStates: Iterable[(Int, Long)],
-        timeoutThreshold: Option[Long] = None,
-        removeTimedoutData: Boolean = false,
-        expectedOutput: Iterable[Int] = None,
-        expectedTimingOutStates: Iterable[Int] = None,
-        expectedRemovedStates: Iterable[Int] = None
-      ): Unit = {
-      val initialStateMap = new OpenHashMapBasedStateMap[String, Int]()
-      initStates.foreach { s => initialStateMap.put("key", s, initialTime) }
-      functionCalled = false
-      val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty)
-      val dataIterator = data.map { v => ("key", v) }.iterator
-      val removedStates = new ArrayBuffer[Int]
-      val timingOutStates = new ArrayBuffer[Int]
-      /**
-       * Tracking function that updates/removes state based on instructions in the data, and
-       * return state (when instructed or when state is timing out).
-       */
-      def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = {
-        functionCalled = true
-
-        assert(t.milliseconds === updatedTime, "tracking func called with wrong time")
-
-        data match {
-          case Some("noop") =>
-            None
-          case Some("get-state") =>
-            Some(state.getOption().getOrElse(-1))
-          case Some("update-state") =>
-            if (state.exists) state.update(state.get + 1) else state.update(0)
-            None
-          case Some("remove-state") =>
-            removedStates += state.get()
-            state.remove()
-            None
-          case None =>
-            assert(state.isTimingOut() === true, "State is not timing out when data = None")
-            timingOutStates += state.get()
-            None
-          case _ =>
-            fail("Unexpected test data")
-        }
-      }
-
-      val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, String, Int, Int](
-        Some(record), dataIterator, testFunc,
-        Time(updatedTime), timeoutThreshold, removeTimedoutData)
-
-      val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) }
-      assert(updatedStateData.toSet === expectedStates.toSet,
-        "states do not match after updating the TrackStateRecord")
-
-      assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet,
-        "emitted data do not match after updating the TrackStateRecord")
-
-      assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " +
-        "match those that were expected to do so while updating the TrackStateRecord")
-
-      assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " +
-        "match those that were expected to do so while updating the TrackStateRecord")
-
-    }
-
-    // No data, no state should be changed, function should not be called,
-    assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil)
-    assert(functionCalled === false)
-    assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime)))
-    assert(functionCalled === false)
-
-    // Data present, function should be called irrespective of whether state exists
-    assertRecordUpdate(initStates = Seq(0), data = Seq("noop"),
-      expectedStates = Seq((0, initialTime)))
-    assert(functionCalled === true)
-    assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None)
-    assert(functionCalled === true)
-
-    // Function called with right state data
-    assertRecordUpdate(initStates = None, data = Seq("get-state"),
-      expectedStates = None, expectedOutput = Seq(-1))
-    assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"),
-      expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123))
-
-    // Update state and timestamp, when timeout not present
-    assertRecordUpdate(initStates = Nil, data = Seq("update-state"),
-      expectedStates = Seq((0, updatedTime)))
-    assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"),
-      expectedStates = Seq((1, updatedTime)))
-
-    // Remove state
-    assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"),
-      expectedStates = Nil, expectedRemovedStates = Seq(345))
-
-    // State strictly older than timeout threshold should be timed out
-    assertRecordUpdate(initStates = Seq(123), data = Nil,
-      timeoutThreshold = Some(initialTime), removeTimedoutData = true,
-      expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil)
-
-    assertRecordUpdate(initStates = Seq(123), data = Nil,
-      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
-      expectedStates = Nil, expectedTimingOutStates = Seq(123))
-
-    // State should not be timed out after it has received data
-    assertRecordUpdate(initStates = Seq(123), data = Seq("noop"),
-      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
-      expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil)
-    assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"),
-      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
-      expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123))
-
-  }
-
-  test("states generated by TrackStateRDD") {
-    val initStates = Seq(("k1", 0), ("k2", 0))
-    val initTime = 123
-    val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet
-    val partitioner = new HashPartitioner(2)
-    val initStateRDD = TrackStateRDD.createFromPairRDD[String, Int, Int, Int](
-      sc.parallelize(initStates), partitioner, Time(initTime)).persist()
-    assertRDD(initStateRDD, initStateWthTime, Set.empty)
-
-    val updateTime = 345
-
-    /**
-     * Test that the test state RDD, when operated with new data,
-     * creates a new state RDD with expected states
-     */
-    def testStateUpdates(
-        testStateRDD: TrackStateRDD[String, Int, Int, Int],
-        testData: Seq[(String, Int)],
-        expectedStates: Set[(String, Int, Int)]): TrackStateRDD[String, Int, Int, Int] = {
-
-      // Persist the test TrackStateRDD so that its not recomputed while doing the next operation.
-      // This is to make sure that we only track which state keys are being touched in the next op.
-      testStateRDD.persist().count()
-
-      // To track which keys are being touched
-      TrackStateRDDSuite.touchedStateKeys.clear()
-
-      val trackingFunc = (time: Time, key: String, data: Option[Int], state: State[Int]) => {
-
-        // Track the key that has been touched
-        TrackStateRDDSuite.touchedStateKeys += key
-
-        // If the data is 0, do not do anything with the state
-        // else if the data is 1, increment the state if it exists, or set new state to 0
-        // else if the data is 2, remove the state if it exists
-        data match {
-          case Some(1) =>
-            if (state.exists()) { state.update(state.get + 1) }
-            else state.update(0)
-          case Some(2) =>
-            state.remove()
-          case _ =>
-        }
-        None.asInstanceOf[Option[Int]]  // Do not return anything, not being tested
-      }
-      val newDataRDD = sc.makeRDD(testData).partitionBy(testStateRDD.partitioner.get)
-
-      // Assert that the new state RDD has expected state data
-      val newStateRDD = assertOperation(
-        testStateRDD, newDataRDD, trackingFunc, updateTime, expectedStates, Set.empty)
-
-      // Assert that the function was called only for the keys present in the data
-      assert(TrackStateRDDSuite.touchedStateKeys.size === testData.size,
-        "More number of keys are being touched than that is expected")
-      assert(TrackStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys,
-        "Keys not in the data are being touched unexpectedly")
-
-      // Assert that the test RDD's data has not changed
-      assertRDD(initStateRDD, initStateWthTime, Set.empty)
-      newStateRDD
-    }
-
-    // Test no-op, no state should change
-    testStateUpdates(initStateRDD, Seq(), initStateWthTime)   // should not scan any state
-    testStateUpdates(
-      initStateRDD, Seq(("k1", 0)), initStateWthTime)         // should not update existing state
-    testStateUpdates(
-      initStateRDD, Seq(("k3", 0)), initStateWthTime)         // should not create new state
-
-    // Test creation of new state
-    val rdd1 = testStateUpdates(initStateRDD, Seq(("k3", 1)), // should create k3's state as 0
-      Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime)))
-
-    val rdd2 = testStateUpdates(rdd1, Seq(("k4", 1)),         // should create k4's state as 0
-      Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime), ("k4", 0, updateTime)))
-
-    // Test updating of state
-    val rdd3 = testStateUpdates(
-      initStateRDD, Seq(("k1", 1)),                   // should increment k1's state 0 -> 1
-      Set(("k1", 1, updateTime), ("k2", 0, initTime)))
-
-    val rdd4 = testStateUpdates(rdd3,
-      Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)),  // should update k2, 0 -> 2 and create k3, 0
-      Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 0, updateTime)))
-
-    val rdd5 = testStateUpdates(
-      rdd4, Seq(("k3", 1)),                           // should update k3's state 0 -> 2
-      Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 1, updateTime)))
-
-    // Test removing of state
-    val rdd6 = testStateUpdates(                      // should remove k1's state
-      initStateRDD, Seq(("k1", 2)), Set(("k2", 0, initTime)))
-
-    val rdd7 = testStateUpdates(                      // should remove k2's state
-      rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime)))
-
-    val rdd8 = testStateUpdates(                      // should remove k3's state
-      rdd7, Seq(("k3", 2)), Set())
-  }
-
-  test("checkpointing") {
-    /**
-     * This tests whether the TrackStateRDD correctly truncates any references to its parent RDDs -
-     * the data RDD and the parent TrackStateRDD.
-     */
-    def rddCollectFunc(rdd: RDD[TrackStateRDDRecord[Int, Int, Int]])
-      : Set[(List[(Int, Int, Long)], List[Int])] = {
-      rdd.map { record => (record.stateMap.getAll().toList, record.emittedRecords.toList) }
-         .collect.toSet
-    }
-
-    /** Generate TrackStateRDD with data RDD having a long lineage */
-    def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int])
-      : TrackStateRDD[Int, Int, Int, Int] = {
-      TrackStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0))
-    }
-
-    testRDD(
-      makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _)
-    testRDDPartitions(
-      makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _)
-
-    /** Generate TrackStateRDD with parent state RDD having a long lineage */
-    def makeStateRDDWithLongLineageParenttateRDD(
-        longLineageRDD: RDD[Int]): TrackStateRDD[Int, Int, Int, Int] = {
-
-      // Create a TrackStateRDD that has a long lineage using the data RDD with a long lineage
-      val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD)
-
-      // Create a new TrackStateRDD, with the lineage lineage TrackStateRDD as the parent
-      new TrackStateRDD[Int, Int, Int, Int](
-        stateRDDWithLongLineage,
-        stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner),
-        (time: Time, key: Int, value: Option[Int], state: State[Int]) => None,
-        Time(10),
-        None
-      )
-    }
-
-    testRDD(
-      makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _)
-    testRDDPartitions(
-      makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _)
-  }
-
-  test("checkpointing empty state RDD") {
-    val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int](
-      sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0))
-    emptyStateRDD.checkpoint()
-    assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
-    val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]](
-      emptyStateRDD.getCheckpointFile.get)
-    assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
-  }
-
-  /** Assert whether the `trackStateByKey` operation generates expected results */
-  private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
-      testStateRDD: TrackStateRDD[K, V, S, T],
-      newDataRDD: RDD[(K, V)],
-      trackStateFunc: (Time, K, Option[V], State[S]) => Option[T],
-      currentTime: Long,
-      expectedStates: Set[(K, S, Int)],
-      expectedEmittedRecords: Set[T],
-      doFullScan: Boolean = false
-    ): TrackStateRDD[K, V, S, T] = {
-
-    val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) {
-      newDataRDD.partitionBy(testStateRDD.partitioner.get)
-    } else {
-      newDataRDD
-    }
-
-    val newStateRDD = new TrackStateRDD[K, V, S, T](
-      testStateRDD, newDataRDD, trackStateFunc, Time(currentTime), None)
-    if (doFullScan) newStateRDD.setFullScan()
-
-    // Persist to make sure that it gets computed only once and we can track precisely how many
-    // state keys the computing touched
-    newStateRDD.persist().count()
-    assertRDD(newStateRDD, expectedStates, expectedEmittedRecords)
-    newStateRDD
-  }
-
-  /** Assert whether the [[TrackStateRDD]] has the expected state ad emitted records */
-  private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
-      trackStateRDD: TrackStateRDD[K, V, S, T],
-      expectedStates: Set[(K, S, Int)],
-      expectedEmittedRecords: Set[T]): Unit = {
-    val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
-    val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet
-    assert(states === expectedStates,
-      "states after track state operation were not as expected")
-    assert(emittedRecords === expectedEmittedRecords,
-      "emitted records after track state operation were not as expected")
-  }
-}
-
-object TrackStateRDDSuite {
-  private val touchedStateKeys = new ArrayBuffer[String]()
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message