spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From andrewo...@apache.org
Subject [07/10] spark git commit: [SPARK-15037][SQL][MLLIB] Use SparkSession instead of SQLContext in Scala/Java TestSuites
Date Tue, 10 May 2016 18:18:01 GMT
http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index 8e7e000..125ad02 100755
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{Dataset, Row}
 
 object StopWordsRemoverSuite extends SparkFunSuite {
   def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = {
@@ -42,7 +42,7 @@ class StopWordsRemoverSuite
     val remover = new StopWordsRemover()
       .setInputCol("raw")
       .setOutputCol("filtered")
-    val dataSet = sqlContext.createDataFrame(Seq(
+    val dataSet = spark.createDataFrame(Seq(
       (Seq("test", "test"), Seq("test", "test")),
       (Seq("a", "b", "c", "d"), Seq("b", "c")),
       (Seq("a", "the", "an"), Seq()),
@@ -60,7 +60,7 @@ class StopWordsRemoverSuite
       .setInputCol("raw")
       .setOutputCol("filtered")
       .setStopWords(stopWords)
-    val dataSet = sqlContext.createDataFrame(Seq(
+    val dataSet = spark.createDataFrame(Seq(
       (Seq("test", "test"), Seq()),
       (Seq("a", "b", "c", "d"), Seq("b", "c", "d")),
       (Seq("a", "the", "an"), Seq()),
@@ -77,7 +77,7 @@ class StopWordsRemoverSuite
       .setInputCol("raw")
       .setOutputCol("filtered")
       .setCaseSensitive(true)
-    val dataSet = sqlContext.createDataFrame(Seq(
+    val dataSet = spark.createDataFrame(Seq(
       (Seq("A"), Seq("A")),
       (Seq("The", "the"), Seq("The"))
     )).toDF("raw", "expected")
@@ -98,7 +98,7 @@ class StopWordsRemoverSuite
       .setInputCol("raw")
       .setOutputCol("filtered")
       .setStopWords(stopWords)
-    val dataSet = sqlContext.createDataFrame(Seq(
+    val dataSet = spark.createDataFrame(Seq(
       (Seq("acaba", "ama", "biri"), Seq()),
       (Seq("hep", "her", "scala"), Seq("scala"))
     )).toDF("raw", "expected")
@@ -112,7 +112,7 @@ class StopWordsRemoverSuite
       .setInputCol("raw")
       .setOutputCol("filtered")
       .setStopWords(stopWords.toArray)
-    val dataSet = sqlContext.createDataFrame(Seq(
+    val dataSet = spark.createDataFrame(Seq(
       (Seq("python", "scala", "a"), Seq("python", "scala", "a")),
       (Seq("Python", "Scala", "swift"), Seq("Python", "Scala", "swift"))
     )).toDF("raw", "expected")
@@ -126,7 +126,7 @@ class StopWordsRemoverSuite
       .setInputCol("raw")
       .setOutputCol("filtered")
       .setStopWords(stopWords.toArray)
-    val dataSet = sqlContext.createDataFrame(Seq(
+    val dataSet = spark.createDataFrame(Seq(
       (Seq("python", "scala", "a"), Seq()),
       (Seq("Python", "Scala", "swift"), Seq("swift"))
     )).toDF("raw", "expected")
@@ -148,7 +148,7 @@ class StopWordsRemoverSuite
     val remover = new StopWordsRemover()
       .setInputCol("raw")
       .setOutputCol(outputCol)
-    val dataSet = sqlContext.createDataFrame(Seq(
+    val dataSet = spark.createDataFrame(Seq(
       (Seq("The", "the", "swift"), Seq("swift"))
     )).toDF("raw", outputCol)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index d0f3cdc..c221d4a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -39,7 +39,7 @@ class StringIndexerSuite
 
   test("StringIndexer") {
     val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
-    val df = sqlContext.createDataFrame(data).toDF("id", "label")
+    val df = spark.createDataFrame(data).toDF("id", "label")
     val indexer = new StringIndexer()
       .setInputCol("label")
       .setOutputCol("labelIndex")
@@ -63,8 +63,8 @@ class StringIndexerSuite
   test("StringIndexerUnseen") {
     val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2)
     val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2)
-    val df = sqlContext.createDataFrame(data).toDF("id", "label")
-    val df2 = sqlContext.createDataFrame(data2).toDF("id", "label")
+    val df = spark.createDataFrame(data).toDF("id", "label")
+    val df2 = spark.createDataFrame(data2).toDF("id", "label")
     val indexer = new StringIndexer()
       .setInputCol("label")
       .setOutputCol("labelIndex")
@@ -93,7 +93,7 @@ class StringIndexerSuite
 
   test("StringIndexer with a numeric input column") {
     val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
-    val df = sqlContext.createDataFrame(data).toDF("id", "label")
+    val df = spark.createDataFrame(data).toDF("id", "label")
     val indexer = new StringIndexer()
       .setInputCol("label")
       .setOutputCol("labelIndex")
@@ -114,12 +114,12 @@ class StringIndexerSuite
     val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
       .setInputCol("label")
       .setOutputCol("labelIndex")
-    val df = sqlContext.range(0L, 10L).toDF()
+    val df = spark.range(0L, 10L).toDF()
     assert(indexerModel.transform(df).collect().toSet === df.collect().toSet)
   }
 
   test("StringIndexerModel can't overwrite output column") {
-    val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
+    val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
     val indexer = new StringIndexer()
       .setInputCol("input")
       .setOutputCol("output")
@@ -153,7 +153,7 @@ class StringIndexerSuite
 
   test("IndexToString.transform") {
     val labels = Array("a", "b", "c")
-    val df0 = sqlContext.createDataFrame(Seq(
+    val df0 = spark.createDataFrame(Seq(
       (0, "a"), (1, "b"), (2, "c"), (0, "a")
     )).toDF("index", "expected")
 
@@ -180,7 +180,7 @@ class StringIndexerSuite
 
   test("StringIndexer, IndexToString are inverses") {
     val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
-    val df = sqlContext.createDataFrame(data).toDF("id", "label")
+    val df = spark.createDataFrame(data).toDF("id", "label")
     val indexer = new StringIndexer()
       .setInputCol("label")
       .setOutputCol("labelIndex")
@@ -213,7 +213,7 @@ class StringIndexerSuite
 
   test("StringIndexer metadata") {
     val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
-    val df = sqlContext.createDataFrame(data).toDF("id", "label")
+    val df = spark.createDataFrame(data).toDF("id", "label")
     val indexer = new StringIndexer()
       .setInputCol("label")
       .setOutputCol("labelIndex")

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index 123ddfe..f30bdc3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -57,13 +57,13 @@ class RegexTokenizerSuite
       .setPattern("\\w+|\\p{Punct}")
       .setInputCol("rawText")
       .setOutputCol("tokens")
-    val dataset0 = sqlContext.createDataFrame(Seq(
+    val dataset0 = spark.createDataFrame(Seq(
       TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")),
       TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct"))
     ))
     testRegexTokenizer(tokenizer0, dataset0)
 
-    val dataset1 = sqlContext.createDataFrame(Seq(
+    val dataset1 = spark.createDataFrame(Seq(
       TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")),
       TokenizerTestData("Te,st. punct", Array("punct"))
     ))
@@ -73,7 +73,7 @@ class RegexTokenizerSuite
     val tokenizer2 = new RegexTokenizer()
       .setInputCol("rawText")
       .setOutputCol("tokens")
-    val dataset2 = sqlContext.createDataFrame(Seq(
+    val dataset2 = spark.createDataFrame(Seq(
       TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")),
       TokenizerTestData("Te,st.  punct", Array("te,st.", "punct"))
     ))
@@ -85,7 +85,7 @@ class RegexTokenizerSuite
       .setInputCol("rawText")
       .setOutputCol("tokens")
       .setToLowercase(false)
-    val dataset = sqlContext.createDataFrame(Seq(
+    val dataset = spark.createDataFrame(Seq(
       TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")),
       TokenizerTestData("java scala", Array("java", "scala"))
     ))

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index dce994f..250011c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -57,7 +57,7 @@ class VectorAssemblerSuite
   }
 
   test("VectorAssembler") {
-    val df = sqlContext.createDataFrame(Seq(
+    val df = spark.createDataFrame(Seq(
       (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)
     )).toDF("id", "x", "y", "name", "z", "n")
     val assembler = new VectorAssembler()
@@ -70,7 +70,7 @@ class VectorAssemblerSuite
   }
 
   test("transform should throw an exception in case of unsupported type") {
-    val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c")
+    val df = spark.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c")
     val assembler = new VectorAssembler()
       .setInputCols(Array("a", "b", "c"))
       .setOutputCol("features")
@@ -87,7 +87,7 @@ class VectorAssemblerSuite
       NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"),
       NumericAttribute.defaultAttr.withName("salary")))
     val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0)))
-    val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad")
+    val df = spark.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad")
       .select(
         col("browser").as("browser", browser.toMetadata()),
         col("hour").as("hour", hour.toMetadata()),

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 1ffc62b..d1c0270 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -85,11 +85,11 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
     checkPair(densePoints1Seq, sparsePoints1Seq)
     checkPair(densePoints2Seq, sparsePoints2Seq)
 
-    densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData))
-    sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData))
-    densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData))
-    sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData))
-    badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData))
+    densePoints1 = spark.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData))
+    sparsePoints1 = spark.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData))
+    densePoints2 = spark.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData))
+    sparsePoints2 = spark.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData))
+    badPoints = spark.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData))
   }
 
   private def getIndexer: VectorIndexer =
@@ -102,7 +102,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
   }
 
   test("Cannot fit an empty DataFrame") {
-    val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
+    val rdd = spark.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
     val vectorIndexer = getIndexer
     intercept[IllegalArgumentException] {
       vectorIndexer.fit(rdd)

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
index 6bb4678..88a077f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -79,7 +79,7 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De
     val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]])
 
     val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) }
-    val df = sqlContext.createDataFrame(rdd,
+    val df = spark.createDataFrame(rdd,
       StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField())))
 
     val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result")

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 80c177b..8cbe0f3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -36,8 +36,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
 
   test("Word2Vec") {
 
-    val sqlContext = this.sqlContext
-    import sqlContext.implicits._
+    val spark = this.spark
+    import spark.implicits._
 
     val sentence = "a b " * 100 + "a c " * 10
     val numOfWords = sentence.split(" ").size
@@ -78,8 +78,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
 
   test("getVectors") {
 
-    val sqlContext = this.sqlContext
-    import sqlContext.implicits._
+    val spark = this.spark
+    import spark.implicits._
 
     val sentence = "a b " * 100 + "a c " * 10
     val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
@@ -119,8 +119,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
 
   test("findSynonyms") {
 
-    val sqlContext = this.sqlContext
-    import sqlContext.implicits._
+    val spark = this.spark
+    import spark.implicits._
 
     val sentence = "a b " * 100 + "a c " * 10
     val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
@@ -146,8 +146,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
 
   test("window size") {
 
-    val sqlContext = this.sqlContext
-    import sqlContext.implicits._
+    val spark = this.spark
+    import spark.implicits._
 
     val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10
     val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 1704037..9da0c32 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -38,7 +38,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
 
@@ -305,8 +305,8 @@ class ALSSuite
       numUserBlocks: Int = 2,
       numItemBlocks: Int = 3,
       targetRMSE: Double = 0.05): Unit = {
-    val sqlContext = this.sqlContext
-    import sqlContext.implicits._
+    val spark = this.spark
+    import spark.implicits._
     val als = new ALS()
       .setRank(rank)
       .setRegParam(regParam)
@@ -460,8 +460,8 @@ class ALSSuite
     allEstimatorParamSettings.foreach { case (p, v) =>
       als.set(als.getParam(p), v)
     }
-    val sqlContext = this.sqlContext
-    import sqlContext.implicits._
+    val spark = this.spark
+    import spark.implicits._
     val model = als.fit(ratings.toDF())
 
     // Test Estimator save/load
@@ -535,8 +535,11 @@ class ALSCleanerSuite extends SparkFunSuite {
         // Generate test data
         val (training, _) = ALSSuite.genImplicitTestData(sc, 20, 5, 1, 0.2, 0)
         // Implicitly test the cleaning of parents during ALS training
-        val sqlContext = new SQLContext(sc)
-        import sqlContext.implicits._
+        val spark = SparkSession.builder
+          .master("local[2]")
+          .appName("ALSCleanerSuite")
+          .getOrCreate()
+        import spark.implicits._
         val als = new ALS()
           .setRank(1)
           .setRegParam(1e-5)
@@ -577,8 +580,8 @@ class ALSStorageSuite
   }
 
   test("default and non-default storage params set correct RDD StorageLevels") {
-    val sqlContext = this.sqlContext
-    import sqlContext.implicits._
+    val spark = this.spark
+    import spark.implicits._
     val data = Seq(
       (0, 0, 1.0),
       (0, 1, 2.0),

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index 76891ad..f8fc775 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -37,13 +37,13 @@ class AFTSurvivalRegressionSuite
 
   override def beforeAll(): Unit = {
     super.beforeAll()
-    datasetUnivariate = sqlContext.createDataFrame(
+    datasetUnivariate = spark.createDataFrame(
       sc.parallelize(generateAFTInput(
         1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)))
-    datasetMultivariate = sqlContext.createDataFrame(
+    datasetMultivariate = spark.createDataFrame(
       sc.parallelize(generateAFTInput(
         2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0)))
-    datasetUnivariateScaled = sqlContext.createDataFrame(
+    datasetUnivariateScaled = spark.createDataFrame(
       sc.parallelize(generateAFTInput(
         1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x =>
           AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor)
@@ -356,7 +356,7 @@ class AFTSurvivalRegressionSuite
   test("should support all NumericType labels") {
     val aft = new AFTSurvivalRegression().setMaxIter(1)
     MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression](
-      aft, isClassification = false, sqlContext) { (expected, actual) =>
+      aft, isClassification = false, spark) { (expected, actual) =>
         assert(expected.intercept === actual.intercept)
         assert(expected.coefficients === actual.coefficients)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index e9fb267..d9f26ad 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -120,7 +120,7 @@ class DecisionTreeRegressorSuite
   test("should support all NumericType labels and not support other types") {
     val dt = new DecisionTreeRegressor().setMaxDepth(1)
     MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor](
-      dt, isClassification = false, sqlContext) { (expected, actual) =>
+      dt, isClassification = false, spark) { (expected, actual) =>
         TreeTests.checkEqual(expected, actual)
       }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 2163779..f6ea5bb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -72,7 +72,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
   }
 
   test("GBTRegressor behaves reasonably on toy data") {
-    val df = sqlContext.createDataFrame(Seq(
+    val df = spark.createDataFrame(Seq(
       LabeledPoint(10, Vectors.dense(1, 2, 3, 4)),
       LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)),
       LabeledPoint(11, Vectors.dense(2, 2, 3, 4)),
@@ -99,7 +99,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
     val path = tempDir.toURI.toString
     sc.setCheckpointDir(path)
 
-    val df = sqlContext.createDataFrame(data)
+    val df = spark.createDataFrame(data)
     val gbt = new GBTRegressor()
       .setMaxDepth(2)
       .setMaxIter(5)
@@ -115,7 +115,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
   test("should support all NumericType labels and not support other types") {
     val gbt = new GBTRegressor().setMaxDepth(1)
     MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor](
-      gbt, isClassification = false, sqlContext) { (expected, actual) =>
+      gbt, isClassification = false, spark) { (expected, actual) =>
         TreeTests.checkEqual(expected, actual)
       }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index b854be2..161f8c8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -52,19 +52,19 @@ class GeneralizedLinearRegressionSuite
 
     import GeneralizedLinearRegressionSuite._
 
-    datasetGaussianIdentity = sqlContext.createDataFrame(
+    datasetGaussianIdentity = spark.createDataFrame(
       sc.parallelize(generateGeneralizedLinearRegressionInput(
         intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
         xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
         family = "gaussian", link = "identity"), 2))
 
-    datasetGaussianLog = sqlContext.createDataFrame(
+    datasetGaussianLog = spark.createDataFrame(
       sc.parallelize(generateGeneralizedLinearRegressionInput(
         intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
         xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
         family = "gaussian", link = "log"), 2))
 
-    datasetGaussianInverse = sqlContext.createDataFrame(
+    datasetGaussianInverse = spark.createDataFrame(
       sc.parallelize(generateGeneralizedLinearRegressionInput(
         intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
         xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
@@ -80,40 +80,40 @@ class GeneralizedLinearRegressionSuite
         generateMultinomialLogisticInput(coefficients, xMean, xVariance,
           addIntercept = true, nPoints, seed)
 
-      sqlContext.createDataFrame(sc.parallelize(testData, 2))
+      spark.createDataFrame(sc.parallelize(testData, 2))
     }
 
-    datasetPoissonLog = sqlContext.createDataFrame(
+    datasetPoissonLog = spark.createDataFrame(
       sc.parallelize(generateGeneralizedLinearRegressionInput(
         intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
         xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
         family = "poisson", link = "log"), 2))
 
-    datasetPoissonIdentity = sqlContext.createDataFrame(
+    datasetPoissonIdentity = spark.createDataFrame(
       sc.parallelize(generateGeneralizedLinearRegressionInput(
         intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
         xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
         family = "poisson", link = "identity"), 2))
 
-    datasetPoissonSqrt = sqlContext.createDataFrame(
+    datasetPoissonSqrt = spark.createDataFrame(
       sc.parallelize(generateGeneralizedLinearRegressionInput(
         intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
         xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
         family = "poisson", link = "sqrt"), 2))
 
-    datasetGammaInverse = sqlContext.createDataFrame(
+    datasetGammaInverse = spark.createDataFrame(
       sc.parallelize(generateGeneralizedLinearRegressionInput(
         intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
         xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
         family = "gamma", link = "inverse"), 2))
 
-    datasetGammaIdentity = sqlContext.createDataFrame(
+    datasetGammaIdentity = spark.createDataFrame(
       sc.parallelize(generateGeneralizedLinearRegressionInput(
         intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
         xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
         family = "gamma", link = "identity"), 2))
 
-    datasetGammaLog = sqlContext.createDataFrame(
+    datasetGammaLog = spark.createDataFrame(
       sc.parallelize(generateGeneralizedLinearRegressionInput(
         intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
         xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
@@ -540,7 +540,7 @@ class GeneralizedLinearRegressionSuite
        w <- c(1, 2, 3, 4)
        df <- as.data.frame(cbind(A, b))
      */
-    val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq(
+    val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq(
       Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
       Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)),
       Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
@@ -668,7 +668,7 @@ class GeneralizedLinearRegressionSuite
        w <- c(1, 2, 3, 4)
        df <- as.data.frame(cbind(A, b))
      */
-    val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq(
+    val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq(
       Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
       Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)),
       Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)),
@@ -782,7 +782,7 @@ class GeneralizedLinearRegressionSuite
        w <- c(1, 2, 3, 4)
        df <- as.data.frame(cbind(A, b))
      */
-    val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq(
+    val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq(
       Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
       Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)),
       Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)),
@@ -899,7 +899,7 @@ class GeneralizedLinearRegressionSuite
        w <- c(1, 2, 3, 4)
        df <- as.data.frame(cbind(A, b))
      */
-    val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq(
+    val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq(
       Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
       Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)),
       Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)),
@@ -1021,14 +1021,14 @@ class GeneralizedLinearRegressionSuite
     val glr = new GeneralizedLinearRegression().setMaxIter(1)
     MLTestingUtils.checkNumericTypes[
         GeneralizedLinearRegressionModel, GeneralizedLinearRegression](
-      glr, isClassification = false, sqlContext) { (expected, actual) =>
+      glr, isClassification = false, spark) { (expected, actual) =>
         assert(expected.intercept === actual.intercept)
         assert(expected.coefficients === actual.coefficients)
       }
   }
 
   test("glm accepts Dataset[LabeledPoint]") {
-    val context = sqlContext
+    val context = spark
     import context.implicits._
     new GeneralizedLinearRegression()
       .setFamily("gaussian")

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index 3a10ad7..9bf7542 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -28,13 +28,13 @@ class IsotonicRegressionSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   private def generateIsotonicInput(labels: Seq[Double]): DataFrame = {
-    sqlContext.createDataFrame(
+    spark.createDataFrame(
       labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) }
     ).toDF("label", "features", "weight")
   }
 
   private def generatePredictionInput(features: Seq[Double]): DataFrame = {
-    sqlContext.createDataFrame(features.map(Tuple1.apply))
+    spark.createDataFrame(features.map(Tuple1.apply))
       .toDF("features")
   }
 
@@ -145,7 +145,7 @@ class IsotonicRegressionSuite
   }
 
   test("vector features column with feature index") {
-    val dataset = sqlContext.createDataFrame(Seq(
+    val dataset = spark.createDataFrame(Seq(
       (4.0, Vectors.dense(0.0, 1.0)),
       (3.0, Vectors.dense(0.0, 2.0)),
       (5.0, Vectors.sparse(2, Array(1), Array(3.0))))
@@ -184,7 +184,7 @@ class IsotonicRegressionSuite
   test("should support all NumericType labels and not support other types") {
     val ir = new IsotonicRegression()
     MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression](
-      ir, isClassification = false, sqlContext) { (expected, actual) =>
+      ir, isClassification = false, spark) { (expected, actual) =>
         assert(expected.boundaries === actual.boundaries)
         assert(expected.predictions === actual.predictions)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index eb19d13..10f547b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -42,7 +42,7 @@ class LinearRegressionSuite
 
   override def beforeAll(): Unit = {
     super.beforeAll()
-    datasetWithDenseFeature = sqlContext.createDataFrame(
+    datasetWithDenseFeature = spark.createDataFrame(
       sc.parallelize(LinearDataGenerator.generateLinearInput(
         intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
         xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2))
@@ -50,7 +50,7 @@ class LinearRegressionSuite
        datasetWithDenseFeatureWithoutIntercept is not needed for correctness testing
        but is useful for illustrating training model without intercept
      */
-    datasetWithDenseFeatureWithoutIntercept = sqlContext.createDataFrame(
+    datasetWithDenseFeatureWithoutIntercept = spark.createDataFrame(
       sc.parallelize(LinearDataGenerator.generateLinearInput(
         intercept = 0.0, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
         xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2))
@@ -59,7 +59,7 @@ class LinearRegressionSuite
     // When feature size is larger than 4096, normal optimizer is choosed
     // as the solver of linear regression in the case of "auto" mode.
     val featureSize = 4100
-    datasetWithSparseFeature = sqlContext.createDataFrame(
+    datasetWithSparseFeature = spark.createDataFrame(
       sc.parallelize(LinearDataGenerator.generateLinearInput(
         intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble()).toArray,
         xMean = Seq.fill(featureSize)(r.nextDouble()).toArray,
@@ -74,7 +74,7 @@ class LinearRegressionSuite
        w <- c(1, 2, 3, 4)
        df <- as.data.frame(cbind(A, b))
      */
-    datasetWithWeight = sqlContext.createDataFrame(
+    datasetWithWeight = spark.createDataFrame(
       sc.parallelize(Seq(
         Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
         Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)),
@@ -90,14 +90,14 @@ class LinearRegressionSuite
        w <- c(1, 2, 3, 4)
        df.const.label <- as.data.frame(cbind(A, b.const))
      */
-    datasetWithWeightConstantLabel = sqlContext.createDataFrame(
+    datasetWithWeightConstantLabel = spark.createDataFrame(
       sc.parallelize(Seq(
         Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
         Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)),
         Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)),
         Instance(17.0, 4.0, Vectors.dense(3.0, 13.0))
       ), 2))
-    datasetWithWeightZeroLabel = sqlContext.createDataFrame(
+    datasetWithWeightZeroLabel = spark.createDataFrame(
       sc.parallelize(Seq(
         Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
         Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)),
@@ -828,8 +828,8 @@ class LinearRegressionSuite
         }
         val data2 = weightedSignedData ++ weightedNoiseData
 
-        (sqlContext.createDataFrame(sc.parallelize(data1, 4)),
-          sqlContext.createDataFrame(sc.parallelize(data2, 4)))
+        (spark.createDataFrame(sc.parallelize(data1, 4)),
+          spark.createDataFrame(sc.parallelize(data2, 4)))
       }
 
       val trainer1a = (new LinearRegression).setFitIntercept(true)
@@ -1010,7 +1010,7 @@ class LinearRegressionSuite
   test("should support all NumericType labels and not support other types") {
     val lr = new LinearRegression().setMaxIter(1)
     MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
-      lr, isClassification = false, sqlContext) { (expected, actual) =>
+      lr, isClassification = false, spark) { (expected, actual) =>
         assert(expected.intercept === actual.intercept)
         assert(expected.coefficients === actual.coefficients)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index ca400e1..72f3c65 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -98,7 +98,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
   test("should support all NumericType labels and not support other types") {
     val rf = new RandomForestRegressor().setMaxDepth(1)
     MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor](
-      rf, isClassification = false, sqlContext) { (expected, actual) =>
+      rf, isClassification = false, spark) { (expected, actual) =>
         TreeTests.checkEqual(expected, actual)
       }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 1d7144f..7d0e01f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -56,7 +56,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
   }
 
   test("select as sparse vector") {
-    val df = sqlContext.read.format("libsvm").load(path)
+    val df = spark.read.format("libsvm").load(path)
     assert(df.columns(0) == "label")
     assert(df.columns(1) == "features")
     val row1 = df.first()
@@ -66,7 +66,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
   }
 
   test("select as dense vector") {
-    val df = sqlContext.read.format("libsvm").options(Map("vectorType" -> "dense"))
+    val df = spark.read.format("libsvm").options(Map("vectorType" -> "dense"))
       .load(path)
     assert(df.columns(0) == "label")
     assert(df.columns(1) == "features")
@@ -78,7 +78,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
   }
 
   test("select a vector with specifying the longer dimension") {
-    val df = sqlContext.read.option("numFeatures", "100").format("libsvm")
+    val df = spark.read.option("numFeatures", "100").format("libsvm")
       .load(path)
     val row1 = df.first()
     val v = row1.getAs[SparseVector](1)
@@ -86,27 +86,27 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
   }
 
   test("write libsvm data and read it again") {
-    val df = sqlContext.read.format("libsvm").load(path)
+    val df = spark.read.format("libsvm").load(path)
     val tempDir2 = new File(tempDir, "read_write_test")
     val writepath = tempDir2.toURI.toString
     // TODO: Remove requirement to coalesce by supporting multiple reads.
     df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)
 
-    val df2 = sqlContext.read.format("libsvm").load(writepath)
+    val df2 = spark.read.format("libsvm").load(writepath)
     val row1 = df2.first()
     val v = row1.getAs[SparseVector](1)
     assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
   }
 
   test("write libsvm data failed due to invalid schema") {
-    val df = sqlContext.read.format("text").load(path)
+    val df = spark.read.format("text").load(path)
     intercept[SparkException] {
       df.write.format("libsvm").save(path + "_2")
     }
   }
 
   test("select features from libsvm relation") {
-    val df = sqlContext.read.format("libsvm").load(path)
+    val df = spark.read.format("libsvm").load(path)
     df.select("features").rdd.map { case Row(d: Vector) => d }.first
     df.select("features").collect
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
index fecf372..de92b51 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
@@ -37,8 +37,8 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
     val numIterations = 20
     val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2)
     val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2)
-    val trainDF = sqlContext.createDataFrame(trainRdd)
-    val validateDF = sqlContext.createDataFrame(validateRdd)
+    val trainDF = spark.createDataFrame(trainRdd)
+    val validateDF = spark.createDataFrame(validateRdd)
 
     val algos = Array(Regression, Regression, Classification)
     val losses = Array(SquaredError, AbsoluteError, LogLoss)

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index e3f0989..12ade4c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.tree._
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, SparkSession}
 
 private[ml] object TreeTests extends SparkFunSuite {
 
@@ -42,8 +42,12 @@ private[ml] object TreeTests extends SparkFunSuite {
       data: RDD[LabeledPoint],
       categoricalFeatures: Map[Int, Int],
       numClasses: Int): DataFrame = {
-    val sqlContext = SQLContext.getOrCreate(data.sparkContext)
-    import sqlContext.implicits._
+    val spark = SparkSession.builder
+      .master("local[2]")
+      .appName("TreeTests")
+      .getOrCreate()
+    import spark.implicits._
+
     val df = data.toDF()
     val numFeatures = data.first().features.size
     val featuresAttributes = Range(0, numFeatures).map { feature =>

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 061d04c..85df6da 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -39,7 +39,7 @@ class CrossValidatorSuite
 
   override def beforeAll(): Unit = {
     super.beforeAll()
-    dataset = sqlContext.createDataFrame(
+    dataset = spark.createDataFrame(
       sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
   }
 
@@ -67,7 +67,7 @@ class CrossValidatorSuite
   }
 
   test("cross validation with linear regression") {
-    val dataset = sqlContext.createDataFrame(
+    val dataset = spark.createDataFrame(
       sc.parallelize(LinearDataGenerator.generateLinearInput(
         6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index df9ba41..f8d3de1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.types.StructType
 class TrainValidationSplitSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
   test("train validation with logistic regression") {
-    val dataset = sqlContext.createDataFrame(
+    val dataset = spark.createDataFrame(
       sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
 
     val lr = new LogisticRegression
@@ -58,7 +58,7 @@ class TrainValidationSplitSuite
   }
 
   test("train validation with linear regression") {
-    val dataset = sqlContext.createDataFrame(
+    val dataset = spark.createDataFrame(
         sc.parallelize(LinearDataGenerator.generateLinearInput(
             6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index d9e6fd5a..4fe473b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, SparkSession}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 
@@ -38,17 +38,17 @@ object MLTestingUtils extends SparkFunSuite {
   def checkNumericTypes[M <: Model[M], T <: Estimator[M]](
       estimator: T,
       isClassification: Boolean,
-      sqlContext: SQLContext)(check: (M, M) => Unit): Unit = {
+      spark: SparkSession)(check: (M, M) => Unit): Unit = {
     val dfs = if (isClassification) {
-      genClassifDFWithNumericLabelCol(sqlContext)
+      genClassifDFWithNumericLabelCol(spark)
     } else {
-      genRegressionDFWithNumericLabelCol(sqlContext)
+      genRegressionDFWithNumericLabelCol(spark)
     }
     val expected = estimator.fit(dfs(DoubleType))
     val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t)))
     actuals.foreach(actual => check(expected, actual))
 
-    val dfWithStringLabels = sqlContext.createDataFrame(Seq(
+    val dfWithStringLabels = spark.createDataFrame(Seq(
       ("0", Vectors.dense(0, 2, 3), 0.0)
     )).toDF("label", "features", "censor")
     val thrown = intercept[IllegalArgumentException] {
@@ -58,13 +58,13 @@ object MLTestingUtils extends SparkFunSuite {
       "Column label must be of type NumericType but was actually of type StringType"))
   }
 
-  def checkNumericTypes[T <: Evaluator](evaluator: T, sqlContext: SQLContext): Unit = {
-    val dfs = genEvaluatorDFWithNumericLabelCol(sqlContext, "label", "prediction")
+  def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = {
+    val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction")
     val expected = evaluator.evaluate(dfs(DoubleType))
     val actuals = dfs.keys.filter(_ != DoubleType).map(t => evaluator.evaluate(dfs(t)))
     actuals.foreach(actual => assert(expected === actual))
 
-    val dfWithStringLabels = sqlContext.createDataFrame(Seq(
+    val dfWithStringLabels = spark.createDataFrame(Seq(
       ("0", 0d)
     )).toDF("label", "prediction")
     val thrown = intercept[IllegalArgumentException] {
@@ -75,10 +75,10 @@ object MLTestingUtils extends SparkFunSuite {
   }
 
   def genClassifDFWithNumericLabelCol(
-      sqlContext: SQLContext,
+      spark: SparkSession,
       labelColName: String = "label",
       featuresColName: String = "features"): Map[NumericType, DataFrame] = {
-    val df = sqlContext.createDataFrame(Seq(
+    val df = spark.createDataFrame(Seq(
       (0, Vectors.dense(0, 2, 3)),
       (1, Vectors.dense(0, 3, 1)),
       (0, Vectors.dense(0, 2, 2)),
@@ -95,11 +95,11 @@ object MLTestingUtils extends SparkFunSuite {
   }
 
   def genRegressionDFWithNumericLabelCol(
-      sqlContext: SQLContext,
+      spark: SparkSession,
       labelColName: String = "label",
       featuresColName: String = "features",
       censorColName: String = "censor"): Map[NumericType, DataFrame] = {
-    val df = sqlContext.createDataFrame(Seq(
+    val df = spark.createDataFrame(Seq(
       (0, Vectors.dense(0)),
       (1, Vectors.dense(1)),
       (2, Vectors.dense(2)),
@@ -117,10 +117,10 @@ object MLTestingUtils extends SparkFunSuite {
   }
 
   def genEvaluatorDFWithNumericLabelCol(
-      sqlContext: SQLContext,
+      spark: SparkSession,
       labelColName: String = "label",
       predictionColName: String = "prediction"): Map[NumericType, DataFrame] = {
-    val df = sqlContext.createDataFrame(Seq(
+    val df = spark.createDataFrame(Seq(
       (0, 0d),
       (1, 1d),
       (2, 2d),

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
index 7f9e340..ba8d36f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
@@ -23,23 +23,22 @@ import org.scalatest.Suite
 
 import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.ml.util.TempDirectory
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{SparkSession, SQLContext}
 import org.apache.spark.util.Utils
 
 trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
+  @transient var spark: SparkSession = _
   @transient var sc: SparkContext = _
-  @transient var sqlContext: SQLContext = _
   @transient var checkpointDir: String = _
 
   override def beforeAll() {
     super.beforeAll()
-    val conf = new SparkConf()
-      .setMaster("local[2]")
-      .setAppName("MLlibUnitTest")
-    sc = new SparkContext(conf)
-    SQLContext.clearActive()
-    sqlContext = new SQLContext(sc)
-    SQLContext.setActive(sqlContext)
+    spark = SparkSession.builder
+      .master("local[2]")
+      .appName("MLlibUnitTest")
+      .getOrCreate()
+    sc = spark.sparkContext
+
     checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString
     sc.setCheckpointDir(checkpointDir)
   }
@@ -47,12 +46,11 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
   override def afterAll() {
     try {
       Utils.deleteRecursively(new File(checkpointDir))
-      sqlContext = null
       SQLContext.clearActive()
-      if (sc != null) {
-        sc.stop()
+      if (spark != null) {
+        spark.stop()
       }
-      sc = null
+      spark = null
     } finally {
       super.afterAll()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
index 189cc39..f2ae40e 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
@@ -28,14 +28,13 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
@@ -44,21 +43,22 @@ import org.apache.spark.sql.types.StructType;
 // serialized, as an alternative to converting these anonymous classes to static inner classes;
 // see http://stackoverflow.com/questions/758570/.
 public class JavaApplySchemaSuite implements Serializable {
-  private transient JavaSparkContext javaCtx;
-  private transient SQLContext sqlContext;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    SparkContext context = new SparkContext("local[*]", "testing");
-    javaCtx = new JavaSparkContext(context);
-    sqlContext = new SQLContext(context);
+    spark = SparkSession.builder()
+      .master("local[*]")
+      .appName("testing")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sqlContext.sparkContext().stop();
-    sqlContext = null;
-    javaCtx = null;
+    spark.stop();
+    spark = null;
   }
 
   public static class Person implements Serializable {
@@ -94,7 +94,7 @@ public class JavaApplySchemaSuite implements Serializable {
     person2.setAge(28);
     personList.add(person2);
 
-    JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map(
+    JavaRDD<Row> rowRDD = jsc.parallelize(personList).map(
       new Function<Person, Row>() {
         @Override
         public Row call(Person person) throws Exception {
@@ -107,9 +107,9 @@ public class JavaApplySchemaSuite implements Serializable {
     fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
     StructType schema = DataTypes.createStructType(fields);
 
-    Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
+    Dataset<Row> df = spark.createDataFrame(rowRDD, schema);
     df.registerTempTable("people");
-    List<Row> actual = sqlContext.sql("SELECT * FROM people").collectAsList();
+    List<Row> actual = spark.sql("SELECT * FROM people").collectAsList();
 
     List<Row> expected = new ArrayList<>(2);
     expected.add(RowFactory.create("Michael", 29));
@@ -130,7 +130,7 @@ public class JavaApplySchemaSuite implements Serializable {
     person2.setAge(28);
     personList.add(person2);
 
-    JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map(
+    JavaRDD<Row> rowRDD = jsc.parallelize(personList).map(
         new Function<Person, Row>() {
           @Override
           public Row call(Person person) {
@@ -143,9 +143,9 @@ public class JavaApplySchemaSuite implements Serializable {
     fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
     StructType schema = DataTypes.createStructType(fields);
 
-    Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
+    Dataset<Row> df = spark.createDataFrame(rowRDD, schema);
     df.registerTempTable("people");
-    List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD()
+    List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD()
       .map(new Function<Row, String>() {
         @Override
         public String call(Row row) {
@@ -162,7 +162,7 @@ public class JavaApplySchemaSuite implements Serializable {
 
   @Test
   public void applySchemaToJSON() {
-    JavaRDD<String> jsonRDD = javaCtx.parallelize(Arrays.asList(
+    JavaRDD<String> jsonRDD = jsc.parallelize(Arrays.asList(
       "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " +
         "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " +
         "\"boolean\":true, \"null\":null}",
@@ -199,18 +199,18 @@ public class JavaApplySchemaSuite implements Serializable {
         null,
         "this is another simple string."));
 
-    Dataset<Row> df1 = sqlContext.read().json(jsonRDD);
+    Dataset<Row> df1 = spark.read().json(jsonRDD);
     StructType actualSchema1 = df1.schema();
     Assert.assertEquals(expectedSchema, actualSchema1);
     df1.registerTempTable("jsonTable1");
-    List<Row> actual1 = sqlContext.sql("select * from jsonTable1").collectAsList();
+    List<Row> actual1 = spark.sql("select * from jsonTable1").collectAsList();
     Assert.assertEquals(expectedResult, actual1);
 
-    Dataset<Row> df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD);
+    Dataset<Row> df2 = spark.read().schema(expectedSchema).json(jsonRDD);
     StructType actualSchema2 = df2.schema();
     Assert.assertEquals(expectedSchema, actualSchema2);
     df2.registerTempTable("jsonTable2");
-    List<Row> actual2 = sqlContext.sql("select * from jsonTable2").collectAsList();
+    List<Row> actual2 = spark.sql("select * from jsonTable2").collectAsList();
     Assert.assertEquals(expectedResult, actual2);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 1eb680d..324ebba 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -20,12 +20,7 @@ package test.org.apache.spark.sql;
 import java.io.Serializable;
 import java.net.URISyntaxException;
 import java.net.URL;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.List;
-import java.util.Map;
-import java.util.ArrayList;
+import java.util.*;
 
 import scala.collection.JavaConverters;
 import scala.collection.Seq;
@@ -34,46 +29,45 @@ import com.google.common.collect.ImmutableMap;
 import com.google.common.primitives.Ints;
 import org.junit.*;
 
-import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.*;
-import org.apache.spark.sql.test.TestSQLContext;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.test.TestSparkSession;
 import org.apache.spark.sql.types.*;
+import org.apache.spark.util.sketch.BloomFilter;
 import org.apache.spark.util.sketch.CountMinSketch;
 import static org.apache.spark.sql.functions.*;
 import static org.apache.spark.sql.types.DataTypes.*;
-import org.apache.spark.util.sketch.BloomFilter;
 
 public class JavaDataFrameSuite {
+  private transient TestSparkSession spark;
   private transient JavaSparkContext jsc;
-  private transient TestSQLContext context;
 
   @Before
   public void setUp() {
     // Trigger static initializer of TestData
-    SparkContext sc = new SparkContext("local[*]", "testing");
-    jsc = new JavaSparkContext(sc);
-    context = new TestSQLContext(sc);
-    context.loadTestData();
+    spark = new TestSparkSession();
+    jsc = new JavaSparkContext(spark.sparkContext());
+    spark.loadTestData();
   }
 
   @After
   public void tearDown() {
-    context.sparkContext().stop();
-    context = null;
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
   public void testExecution() {
-    Dataset<Row> df = context.table("testData").filter("key = 1");
+    Dataset<Row> df = spark.table("testData").filter("key = 1");
     Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0));
   }
 
   @Test
   public void testCollectAndTake() {
-    Dataset<Row> df = context.table("testData").filter("key = 1 or key = 2 or key = 3");
+    Dataset<Row> df = spark.table("testData").filter("key = 1 or key = 2 or key = 3");
     Assert.assertEquals(3, df.select("key").collectAsList().size());
     Assert.assertEquals(2, df.select("key").takeAsList(2).size());
   }
@@ -83,7 +77,7 @@ public class JavaDataFrameSuite {
    */
   @Test
   public void testVarargMethods() {
-    Dataset<Row> df = context.table("testData");
+    Dataset<Row> df = spark.table("testData");
 
     df.toDF("key1", "value1");
 
@@ -112,7 +106,7 @@ public class JavaDataFrameSuite {
     df.select(coalesce(col("key")));
 
     // Varargs with mathfunctions
-    Dataset<Row> df2 = context.table("testData2");
+    Dataset<Row> df2 = spark.table("testData2");
     df2.select(exp("a"), exp("b"));
     df2.select(exp(log("a")));
     df2.select(pow("a", "a"), pow("b", 2.0));
@@ -126,7 +120,7 @@ public class JavaDataFrameSuite {
   @Ignore
   public void testShow() {
     // This test case is intended ignored, but to make sure it compiles correctly
-    Dataset<Row> df = context.table("testData");
+    Dataset<Row> df = spark.table("testData");
     df.show();
     df.show(1000);
   }
@@ -194,7 +188,7 @@ public class JavaDataFrameSuite {
   public void testCreateDataFrameFromLocalJavaBeans() {
     Bean bean = new Bean();
     List<Bean> data = Arrays.asList(bean);
-    Dataset<Row> df = context.createDataFrame(data, Bean.class);
+    Dataset<Row> df = spark.createDataFrame(data, Bean.class);
     validateDataFrameWithBeans(bean, df);
   }
 
@@ -202,7 +196,7 @@ public class JavaDataFrameSuite {
   public void testCreateDataFrameFromJavaBeans() {
     Bean bean = new Bean();
     JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean));
-    Dataset<Row> df = context.createDataFrame(rdd, Bean.class);
+    Dataset<Row> df = spark.createDataFrame(rdd, Bean.class);
     validateDataFrameWithBeans(bean, df);
   }
 
@@ -210,7 +204,7 @@ public class JavaDataFrameSuite {
   public void testCreateDataFromFromList() {
     StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true)));
     List<Row> rows = Arrays.asList(RowFactory.create(0));
-    Dataset<Row> df = context.createDataFrame(rows, schema);
+    Dataset<Row> df = spark.createDataFrame(rows, schema);
     List<Row> result = df.collectAsList();
     Assert.assertEquals(1, result.size());
   }
@@ -239,7 +233,7 @@ public class JavaDataFrameSuite {
 
   @Test
   public void testCrosstab() {
-    Dataset<Row> df = context.table("testData2");
+    Dataset<Row> df = spark.table("testData2");
     Dataset<Row> crosstab = df.stat().crosstab("a", "b");
     String[] columnNames = crosstab.schema().fieldNames();
     Assert.assertEquals("a_b", columnNames[0]);
@@ -258,7 +252,7 @@ public class JavaDataFrameSuite {
 
   @Test
   public void testFrequentItems() {
-    Dataset<Row> df = context.table("testData2");
+    Dataset<Row> df = spark.table("testData2");
     String[] cols = {"a"};
     Dataset<Row> results = df.stat().freqItems(cols, 0.2);
     Assert.assertTrue(results.collectAsList().get(0).getSeq(0).contains(1));
@@ -266,21 +260,21 @@ public class JavaDataFrameSuite {
 
   @Test
   public void testCorrelation() {
-    Dataset<Row> df = context.table("testData2");
+    Dataset<Row> df = spark.table("testData2");
     Double pearsonCorr = df.stat().corr("a", "b", "pearson");
     Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6);
   }
 
   @Test
   public void testCovariance() {
-    Dataset<Row> df = context.table("testData2");
+    Dataset<Row> df = spark.table("testData2");
     Double result = df.stat().cov("a", "b");
     Assert.assertTrue(Math.abs(result) < 1.0e-6);
   }
 
   @Test
   public void testSampleBy() {
-    Dataset<Row> df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
+    Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
     Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
     List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList();
     Assert.assertEquals(0, actual.get(0).getLong(0));
@@ -291,7 +285,7 @@ public class JavaDataFrameSuite {
 
   @Test
   public void pivot() {
-    Dataset<Row> df = context.table("courseSales");
+    Dataset<Row> df = spark.table("courseSales");
     List<Row> actual = df.groupBy("year")
       .pivot("course", Arrays.<Object>asList("dotNET", "Java"))
       .agg(sum("earnings")).orderBy("year").collectAsList();
@@ -324,10 +318,10 @@ public class JavaDataFrameSuite {
 
   @Test
   public void testGenericLoad() {
-    Dataset<Row> df1 = context.read().format("text").load(getResource("text-suite.txt"));
+    Dataset<Row> df1 = spark.read().format("text").load(getResource("text-suite.txt"));
     Assert.assertEquals(4L, df1.count());
 
-    Dataset<Row> df2 = context.read().format("text").load(
+    Dataset<Row> df2 = spark.read().format("text").load(
       getResource("text-suite.txt"),
       getResource("text-suite2.txt"));
     Assert.assertEquals(5L, df2.count());
@@ -335,10 +329,10 @@ public class JavaDataFrameSuite {
 
   @Test
   public void testTextLoad() {
-    Dataset<String> ds1 = context.read().text(getResource("text-suite.txt"));
+    Dataset<String> ds1 = spark.read().text(getResource("text-suite.txt"));
     Assert.assertEquals(4L, ds1.count());
 
-    Dataset<String> ds2 = context.read().text(
+    Dataset<String> ds2 = spark.read().text(
       getResource("text-suite.txt"),
       getResource("text-suite2.txt"));
     Assert.assertEquals(5L, ds2.count());
@@ -346,7 +340,7 @@ public class JavaDataFrameSuite {
 
   @Test
   public void testCountMinSketch() {
-    Dataset<Long> df = context.range(1000);
+    Dataset<Long> df = spark.range(1000);
 
     CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42);
     Assert.assertEquals(sketch1.totalCount(), 1000);
@@ -371,7 +365,7 @@ public class JavaDataFrameSuite {
 
   @Test
   public void testBloomFilter() {
-    Dataset<Long> df = context.range(1000);
+    Dataset<Long> df = spark.range(1000);
 
     BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03);
     Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3);

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index f1b1c22..8354a5b 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -23,46 +23,43 @@ import java.sql.Date;
 import java.sql.Timestamp;
 import java.util.*;
 
-import com.google.common.base.Objects;
-import org.junit.rules.ExpectedException;
 import scala.Tuple2;
 import scala.Tuple3;
 import scala.Tuple4;
 import scala.Tuple5;
 
+import com.google.common.base.Objects;
 import org.junit.*;
+import org.junit.rules.ExpectedException;
 
 import org.apache.spark.Accumulator;
-import org.apache.spark.SparkContext;
-import org.apache.spark.api.java.function.*;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.*;
 import org.apache.spark.sql.*;
-import org.apache.spark.sql.test.TestSQLContext;
 import org.apache.spark.sql.catalyst.encoders.OuterScopes;
 import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.test.TestSparkSession;
 import org.apache.spark.sql.types.StructType;
-
-import static org.apache.spark.sql.functions.*;
+import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.sql.functions.expr;
 import static org.apache.spark.sql.types.DataTypes.*;
 
 public class JavaDatasetSuite implements Serializable {
+  private transient TestSparkSession spark;
   private transient JavaSparkContext jsc;
-  private transient TestSQLContext context;
 
   @Before
   public void setUp() {
     // Trigger static initializer of TestData
-    SparkContext sc = new SparkContext("local[*]", "testing");
-    jsc = new JavaSparkContext(sc);
-    context = new TestSQLContext(sc);
-    context.loadTestData();
+    spark = new TestSparkSession();
+    jsc = new JavaSparkContext(spark.sparkContext());
+    spark.loadTestData();
   }
 
   @After
   public void tearDown() {
-    context.sparkContext().stop();
-    context = null;
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   private <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
@@ -72,7 +69,7 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testCollect() {
     List<String> data = Arrays.asList("hello", "world");
-    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+    Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
     List<String> collected = ds.collectAsList();
     Assert.assertEquals(Arrays.asList("hello", "world"), collected);
   }
@@ -80,7 +77,7 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testTake() {
     List<String> data = Arrays.asList("hello", "world");
-    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+    Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
     List<String> collected = ds.takeAsList(1);
     Assert.assertEquals(Arrays.asList("hello"), collected);
   }
@@ -88,7 +85,7 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testToLocalIterator() {
     List<String> data = Arrays.asList("hello", "world");
-    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+    Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
     Iterator<String> iter = ds.toLocalIterator();
     Assert.assertEquals("hello", iter.next());
     Assert.assertEquals("world", iter.next());
@@ -98,7 +95,7 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testCommonOperation() {
     List<String> data = Arrays.asList("hello", "world");
-    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+    Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
     Assert.assertEquals("hello", ds.first());
 
     Dataset<String> filtered = ds.filter(new FilterFunction<String>() {
@@ -149,7 +146,7 @@ public class JavaDatasetSuite implements Serializable {
   public void testForeach() {
     final Accumulator<Integer> accum = jsc.accumulator(0);
     List<String> data = Arrays.asList("a", "b", "c");
-    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+    Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
 
     ds.foreach(new ForeachFunction<String>() {
       @Override
@@ -163,7 +160,7 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testReduce() {
     List<Integer> data = Arrays.asList(1, 2, 3);
-    Dataset<Integer> ds = context.createDataset(data, Encoders.INT());
+    Dataset<Integer> ds = spark.createDataset(data, Encoders.INT());
 
     int reduced = ds.reduce(new ReduceFunction<Integer>() {
       @Override
@@ -177,7 +174,7 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testGroupBy() {
     List<String> data = Arrays.asList("a", "foo", "bar");
-    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+    Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
     KeyValueGroupedDataset<Integer, String> grouped = ds.groupByKey(
       new MapFunction<String, Integer>() {
         @Override
@@ -227,7 +224,7 @@ public class JavaDatasetSuite implements Serializable {
       toSet(reduced.collectAsList()));
 
     List<Integer> data2 = Arrays.asList(2, 6, 10);
-    Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT());
+    Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT());
     KeyValueGroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(
       new MapFunction<Integer, Integer>() {
         @Override
@@ -261,7 +258,7 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testSelect() {
     List<Integer> data = Arrays.asList(2, 6);
-    Dataset<Integer> ds = context.createDataset(data, Encoders.INT());
+    Dataset<Integer> ds = spark.createDataset(data, Encoders.INT());
 
     Dataset<Tuple2<Integer, String>> selected = ds.select(
       expr("value + 1"),
@@ -275,12 +272,12 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testSetOperation() {
     List<String> data = Arrays.asList("abc", "abc", "xyz");
-    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+    Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
 
     Assert.assertEquals(asSet("abc", "xyz"), toSet(ds.distinct().collectAsList()));
 
     List<String> data2 = Arrays.asList("xyz", "foo", "foo");
-    Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING());
+    Dataset<String> ds2 = spark.createDataset(data2, Encoders.STRING());
 
     Dataset<String> intersected = ds.intersect(ds2);
     Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList());
@@ -307,9 +304,9 @@ public class JavaDatasetSuite implements Serializable {
   @Test
   public void testJoin() {
     List<Integer> data = Arrays.asList(1, 2, 3);
-    Dataset<Integer> ds = context.createDataset(data, Encoders.INT()).as("a");
+    Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()).as("a");
     List<Integer> data2 = Arrays.asList(2, 3, 4);
-    Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()).as("b");
+    Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT()).as("b");
 
     Dataset<Tuple2<Integer, Integer>> joined =
       ds.joinWith(ds2, col("a.value").equalTo(col("b.value")));
@@ -322,21 +319,21 @@ public class JavaDatasetSuite implements Serializable {
   public void testTupleEncoder() {
     Encoder<Tuple2<Integer, String>> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING());
     List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b"));
-    Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2);
+    Dataset<Tuple2<Integer, String>> ds2 = spark.createDataset(data2, encoder2);
     Assert.assertEquals(data2, ds2.collectAsList());
 
     Encoder<Tuple3<Integer, Long, String>> encoder3 =
       Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING());
     List<Tuple3<Integer, Long, String>> data3 =
       Arrays.asList(new Tuple3<>(1, 2L, "a"));
-    Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3);
+    Dataset<Tuple3<Integer, Long, String>> ds3 = spark.createDataset(data3, encoder3);
     Assert.assertEquals(data3, ds3.collectAsList());
 
     Encoder<Tuple4<Integer, String, Long, String>> encoder4 =
       Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING());
     List<Tuple4<Integer, String, Long, String>> data4 =
       Arrays.asList(new Tuple4<>(1, "b", 2L, "a"));
-    Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4);
+    Dataset<Tuple4<Integer, String, Long, String>> ds4 = spark.createDataset(data4, encoder4);
     Assert.assertEquals(data4, ds4.collectAsList());
 
     Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 =
@@ -345,7 +342,7 @@ public class JavaDatasetSuite implements Serializable {
     List<Tuple5<Integer, String, Long, String, Boolean>> data5 =
       Arrays.asList(new Tuple5<>(1, "b", 2L, "a", true));
     Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 =
-      context.createDataset(data5, encoder5);
+      spark.createDataset(data5, encoder5);
     Assert.assertEquals(data5, ds5.collectAsList());
   }
 
@@ -356,7 +353,7 @@ public class JavaDatasetSuite implements Serializable {
       Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING());
     List<Tuple2<Tuple2<Integer, String>, String>> data =
       Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b"));
-    Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder);
+    Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = spark.createDataset(data, encoder);
     Assert.assertEquals(data, ds.collectAsList());
 
     // test (int, (string, string, long))
@@ -366,7 +363,7 @@ public class JavaDatasetSuite implements Serializable {
     List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 =
       Arrays.asList(tuple2(1, new Tuple3<>("a", "b", 3L)));
     Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 =
-      context.createDataset(data2, encoder2);
+      spark.createDataset(data2, encoder2);
     Assert.assertEquals(data2, ds2.collectAsList());
 
     // test (int, ((string, long), string))
@@ -376,7 +373,7 @@ public class JavaDatasetSuite implements Serializable {
     List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 =
       Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b")));
     Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 =
-      context.createDataset(data3, encoder3);
+      spark.createDataset(data3, encoder3);
     Assert.assertEquals(data3, ds3.collectAsList());
   }
 
@@ -390,7 +387,7 @@ public class JavaDatasetSuite implements Serializable {
         1.7976931348623157E308, new BigDecimal("0.922337203685477589"),
           Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE));
     Dataset<Tuple5<Double, BigDecimal, Date, Timestamp, Float>> ds =
-      context.createDataset(data, encoder);
+      spark.createDataset(data, encoder);
     Assert.assertEquals(data, ds.collectAsList());
   }
 
@@ -441,7 +438,7 @@ public class JavaDatasetSuite implements Serializable {
     Encoder<KryoSerializable> encoder = Encoders.kryo(KryoSerializable.class);
     List<KryoSerializable> data = Arrays.asList(
       new KryoSerializable("hello"), new KryoSerializable("world"));
-    Dataset<KryoSerializable> ds = context.createDataset(data, encoder);
+    Dataset<KryoSerializable> ds = spark.createDataset(data, encoder);
     Assert.assertEquals(data, ds.collectAsList());
   }
 
@@ -450,14 +447,14 @@ public class JavaDatasetSuite implements Serializable {
     Encoder<JavaSerializable> encoder = Encoders.javaSerialization(JavaSerializable.class);
     List<JavaSerializable> data = Arrays.asList(
       new JavaSerializable("hello"), new JavaSerializable("world"));
-    Dataset<JavaSerializable> ds = context.createDataset(data, encoder);
+    Dataset<JavaSerializable> ds = spark.createDataset(data, encoder);
     Assert.assertEquals(data, ds.collectAsList());
   }
 
   @Test
   public void testRandomSplit() {
     List<String> data = Arrays.asList("hello", "world", "from", "spark");
-    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+    Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
     double[] arraySplit = {1, 2, 3};
 
     List<Dataset<String>> randomSplit =  ds.randomSplitAsList(arraySplit, 1);
@@ -647,14 +644,14 @@ public class JavaDatasetSuite implements Serializable {
     obj2.setF(Arrays.asList(300L, null, 400L));
 
     List<SimpleJavaBean> data = Arrays.asList(obj1, obj2);
-    Dataset<SimpleJavaBean> ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class));
+    Dataset<SimpleJavaBean> ds = spark.createDataset(data, Encoders.bean(SimpleJavaBean.class));
     Assert.assertEquals(data, ds.collectAsList());
 
     NestedJavaBean obj3 = new NestedJavaBean();
     obj3.setA(obj1);
 
     List<NestedJavaBean> data2 = Arrays.asList(obj3);
-    Dataset<NestedJavaBean> ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class));
+    Dataset<NestedJavaBean> ds2 = spark.createDataset(data2, Encoders.bean(NestedJavaBean.class));
     Assert.assertEquals(data2, ds2.collectAsList());
 
     Row row1 = new GenericRow(new Object[]{
@@ -678,7 +675,7 @@ public class JavaDatasetSuite implements Serializable {
       .add("d", createArrayType(StringType))
       .add("e", createArrayType(StringType))
       .add("f", createArrayType(LongType));
-    Dataset<SimpleJavaBean> ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema)
+    Dataset<SimpleJavaBean> ds3 = spark.createDataFrame(Arrays.asList(row1, row2), schema)
       .as(Encoders.bean(SimpleJavaBean.class));
     Assert.assertEquals(data, ds3.collectAsList());
   }
@@ -692,7 +689,7 @@ public class JavaDatasetSuite implements Serializable {
     obj.setB(new Date(0));
     obj.setC(java.math.BigDecimal.valueOf(1));
     Dataset<SimpleJavaBean2> ds =
-      context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class));
+      spark.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class));
     ds.collect();
   }
 
@@ -776,7 +773,7 @@ public class JavaDatasetSuite implements Serializable {
           })
       });
 
-      Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
+      Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema);
       Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
 
       SmallBean smallBean = new SmallBean();
@@ -793,7 +790,7 @@ public class JavaDatasetSuite implements Serializable {
     {
       Row row = new GenericRow(new Object[] { null });
 
-      Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
+      Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema);
       Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
 
       NestedSmallBean nestedSmallBean = new NestedSmallBean();
@@ -810,7 +807,7 @@ public class JavaDatasetSuite implements Serializable {
           })
       });
 
-      Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
+      Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema);
       Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
 
       ds.collect();

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
index 4a78dca..2274912 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
@@ -24,33 +24,30 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.SparkContext;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.api.java.UDF1;
 import org.apache.spark.sql.api.java.UDF2;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.types.DataTypes;
 
 // The test suite itself is Serializable so that anonymous Function implementations can be
 // serialized, as an alternative to converting these anonymous classes to static inner classes;
 // see http://stackoverflow.com/questions/758570/.
 public class JavaUDFSuite implements Serializable {
-  private transient JavaSparkContext sc;
-  private transient SQLContext sqlContext;
+  private transient SparkSession spark;
 
   @Before
   public void setUp() {
-    SparkContext _sc = new SparkContext("local[*]", "testing");
-    sqlContext = new SQLContext(_sc);
-    sc = new JavaSparkContext(_sc);
+    spark = SparkSession.builder()
+      .master("local[*]")
+      .appName("testing")
+      .getOrCreate();
   }
 
   @After
   public void tearDown() {
-    sqlContext.sparkContext().stop();
-    sqlContext = null;
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @SuppressWarnings("unchecked")
@@ -60,14 +57,14 @@ public class JavaUDFSuite implements Serializable {
     // sqlContext.registerFunction(
     //   "stringLengthTest", (String str) -> str.length(), DataType.IntegerType);
 
-    sqlContext.udf().register("stringLengthTest", new UDF1<String, Integer>() {
+    spark.udf().register("stringLengthTest", new UDF1<String, Integer>() {
       @Override
       public Integer call(String str) {
         return str.length();
       }
     }, DataTypes.IntegerType);
 
-    Row result = sqlContext.sql("SELECT stringLengthTest('test')").head();
+    Row result = spark.sql("SELECT stringLengthTest('test')").head();
     Assert.assertEquals(4, result.getInt(0));
   }
 
@@ -80,14 +77,14 @@ public class JavaUDFSuite implements Serializable {
     //   (String str1, String str2) -> str1.length() + str2.length,
     //   DataType.IntegerType);
 
-    sqlContext.udf().register("stringLengthTest", new UDF2<String, String, Integer>() {
+    spark.udf().register("stringLengthTest", new UDF2<String, String, Integer>() {
       @Override
       public Integer call(String str1, String str2) {
         return str1.length() + str2.length();
       }
     }, DataTypes.IntegerType);
 
-    Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head();
+    Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
     Assert.assertEquals(9, result.getInt(0));
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message