spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From andrewo...@apache.org
Subject [06/10] spark git commit: [SPARK-15037][SQL][MLLIB] Use SparkSession instead of SQLContext in Scala/Java TestSuites
Date Tue, 10 May 2016 18:18:00 GMT
http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java
index 7863177..059c2d9 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java
@@ -26,36 +26,30 @@ import scala.Tuple2;
 import org.junit.After;
 import org.junit.Before;
 
-import org.apache.spark.SparkContext;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.MapFunction;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Encoder;
 import org.apache.spark.sql.Encoders;
 import org.apache.spark.sql.KeyValueGroupedDataset;
-import org.apache.spark.sql.test.TestSQLContext;
+import org.apache.spark.sql.test.TestSparkSession;
 
 /**
  * Common test base shared across this and Java8DatasetAggregatorSuite.
  */
 public class JavaDatasetAggregatorSuiteBase implements Serializable {
-  protected transient JavaSparkContext jsc;
-  protected transient TestSQLContext context;
+  private transient TestSparkSession spark;
 
   @Before
   public void setUp() {
     // Trigger static initializer of TestData
-    SparkContext sc = new SparkContext("local[*]", "testing");
-    jsc = new JavaSparkContext(sc);
-    context = new TestSQLContext(sc);
-    context.loadTestData();
+    spark = new TestSparkSession();
+    spark.loadTestData();
   }
 
   @After
   public void tearDown() {
-    context.sparkContext().stop();
-    context = null;
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   protected <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
@@ -66,7 +60,7 @@ public class JavaDatasetAggregatorSuiteBase implements Serializable {
     Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
     List<Tuple2<String, Integer>> data =
       Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
-    Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
+    Dataset<Tuple2<String, Integer>> ds = spark.createDataset(data, encoder);
 
     return ds.groupByKey(
       new MapFunction<Tuple2<String, Integer>, String>() {

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
index 9e65158..d0435e4 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
@@ -19,14 +19,16 @@ package test.org.apache.spark.sql.sources;
 
 import java.io.File;
 import java.io.IOException;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
 
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.*;
@@ -37,8 +39,8 @@ import org.apache.spark.util.Utils;
 
 public class JavaSaveLoadSuite {
 
-  private transient JavaSparkContext sc;
-  private transient SQLContext sqlContext;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   File path;
   Dataset<Row> df;
@@ -52,9 +54,11 @@ public class JavaSaveLoadSuite {
 
   @Before
   public void setUp() throws IOException {
-    SparkContext _sc = new SparkContext("local[*]", "testing");
-    sqlContext = new SQLContext(_sc);
-    sc = new JavaSparkContext(_sc);
+    spark = SparkSession.builder()
+      .master("local[*]")
+      .appName("testing")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
 
     path =
       Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile();
@@ -66,16 +70,15 @@ public class JavaSaveLoadSuite {
     for (int i = 0; i < 10; i++) {
       jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}");
     }
-    JavaRDD<String> rdd = sc.parallelize(jsonObjects);
-    df = sqlContext.read().json(rdd);
+    JavaRDD<String> rdd = jsc.parallelize(jsonObjects);
+    df = spark.read().json(rdd);
     df.registerTempTable("jsonTable");
   }
 
   @After
   public void tearDown() {
-    sqlContext.sparkContext().stop();
-    sqlContext = null;
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -83,7 +86,7 @@ public class JavaSaveLoadSuite {
     Map<String, String> options = new HashMap<>();
     options.put("path", path.toString());
     df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save();
-    Dataset<Row> loadedDF = sqlContext.read().format("json").options(options).load();
+    Dataset<Row> loadedDF = spark.read().format("json").options(options).load();
     checkAnswer(loadedDF, df.collectAsList());
   }
 
@@ -96,8 +99,8 @@ public class JavaSaveLoadSuite {
     List<StructField> fields = new ArrayList<>();
     fields.add(DataTypes.createStructField("b", DataTypes.StringType, true));
     StructType schema = DataTypes.createStructType(fields);
-    Dataset<Row> loadedDF = sqlContext.read().format("json").schema(schema).options(options).load();
+    Dataset<Row> loadedDF = spark.read().format("json").schema(schema).options(options).load();
 
-    checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList());
+    checkAnswer(loadedDF, spark.sql("SELECT b FROM jsonTable").collectAsList());
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 5ef2026..800316c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -36,7 +36,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
   import testImplicits._
 
   def rddIdOf(tableName: String): Int = {
-    val plan = sqlContext.table(tableName).queryExecution.sparkPlan
+    val plan = spark.table(tableName).queryExecution.sparkPlan
     plan.collect {
       case InMemoryTableScanExec(_, _, relation) =>
         relation.cachedColumnBuffers.id
@@ -73,41 +73,41 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
   test("cache temp table") {
     testData.select('key).registerTempTable("tempTable")
     assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0)
-    sqlContext.cacheTable("tempTable")
+    spark.catalog.cacheTable("tempTable")
     assertCached(sql("SELECT COUNT(*) FROM tempTable"))
-    sqlContext.uncacheTable("tempTable")
+    spark.catalog.uncacheTable("tempTable")
   }
 
   test("unpersist an uncached table will not raise exception") {
-    assert(None == sqlContext.cacheManager.lookupCachedData(testData))
+    assert(None == spark.cacheManager.lookupCachedData(testData))
     testData.unpersist(blocking = true)
-    assert(None == sqlContext.cacheManager.lookupCachedData(testData))
+    assert(None == spark.cacheManager.lookupCachedData(testData))
     testData.unpersist(blocking = false)
-    assert(None == sqlContext.cacheManager.lookupCachedData(testData))
+    assert(None == spark.cacheManager.lookupCachedData(testData))
     testData.persist()
-    assert(None != sqlContext.cacheManager.lookupCachedData(testData))
+    assert(None != spark.cacheManager.lookupCachedData(testData))
     testData.unpersist(blocking = true)
-    assert(None == sqlContext.cacheManager.lookupCachedData(testData))
+    assert(None == spark.cacheManager.lookupCachedData(testData))
     testData.unpersist(blocking = false)
-    assert(None == sqlContext.cacheManager.lookupCachedData(testData))
+    assert(None == spark.cacheManager.lookupCachedData(testData))
   }
 
   test("cache table as select") {
     sql("CACHE TABLE tempTable AS SELECT key FROM testData")
     assertCached(sql("SELECT COUNT(*) FROM tempTable"))
-    sqlContext.uncacheTable("tempTable")
+    spark.catalog.uncacheTable("tempTable")
   }
 
   test("uncaching temp table") {
     testData.select('key).registerTempTable("tempTable1")
     testData.select('key).registerTempTable("tempTable2")
-    sqlContext.cacheTable("tempTable1")
+    spark.catalog.cacheTable("tempTable1")
 
     assertCached(sql("SELECT COUNT(*) FROM tempTable1"))
     assertCached(sql("SELECT COUNT(*) FROM tempTable2"))
 
     // Is this valid?
-    sqlContext.uncacheTable("tempTable2")
+    spark.catalog.uncacheTable("tempTable2")
 
     // Should this be cached?
     assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0)
@@ -117,101 +117,101 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
     val data = "*" * 1000
     sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF()
       .registerTempTable("bigData")
-    sqlContext.table("bigData").persist(StorageLevel.MEMORY_AND_DISK)
-    assert(sqlContext.table("bigData").count() === 200000L)
-    sqlContext.table("bigData").unpersist(blocking = true)
+    spark.table("bigData").persist(StorageLevel.MEMORY_AND_DISK)
+    assert(spark.table("bigData").count() === 200000L)
+    spark.table("bigData").unpersist(blocking = true)
   }
 
   test("calling .cache() should use in-memory columnar caching") {
-    sqlContext.table("testData").cache()
-    assertCached(sqlContext.table("testData"))
-    sqlContext.table("testData").unpersist(blocking = true)
+    spark.table("testData").cache()
+    assertCached(spark.table("testData"))
+    spark.table("testData").unpersist(blocking = true)
   }
 
   test("calling .unpersist() should drop in-memory columnar cache") {
-    sqlContext.table("testData").cache()
-    sqlContext.table("testData").count()
-    sqlContext.table("testData").unpersist(blocking = true)
-    assertCached(sqlContext.table("testData"), 0)
+    spark.table("testData").cache()
+    spark.table("testData").count()
+    spark.table("testData").unpersist(blocking = true)
+    assertCached(spark.table("testData"), 0)
   }
 
   test("isCached") {
-    sqlContext.cacheTable("testData")
+    spark.catalog.cacheTable("testData")
 
-    assertCached(sqlContext.table("testData"))
-    assert(sqlContext.table("testData").queryExecution.withCachedData match {
+    assertCached(spark.table("testData"))
+    assert(spark.table("testData").queryExecution.withCachedData match {
       case _: InMemoryRelation => true
       case _ => false
     })
 
-    sqlContext.uncacheTable("testData")
-    assert(!sqlContext.isCached("testData"))
-    assert(sqlContext.table("testData").queryExecution.withCachedData match {
+    spark.catalog.uncacheTable("testData")
+    assert(!spark.catalog.isCached("testData"))
+    assert(spark.table("testData").queryExecution.withCachedData match {
       case _: InMemoryRelation => false
       case _ => true
     })
   }
 
   test("SPARK-1669: cacheTable should be idempotent") {
-    assume(!sqlContext.table("testData").logicalPlan.isInstanceOf[InMemoryRelation])
+    assume(!spark.table("testData").logicalPlan.isInstanceOf[InMemoryRelation])
 
-    sqlContext.cacheTable("testData")
-    assertCached(sqlContext.table("testData"))
+    spark.catalog.cacheTable("testData")
+    assertCached(spark.table("testData"))
 
     assertResult(1, "InMemoryRelation not found, testData should have been cached") {
-      sqlContext.table("testData").queryExecution.withCachedData.collect {
+      spark.table("testData").queryExecution.withCachedData.collect {
         case r: InMemoryRelation => r
       }.size
     }
 
-    sqlContext.cacheTable("testData")
+    spark.catalog.cacheTable("testData")
     assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") {
-      sqlContext.table("testData").queryExecution.withCachedData.collect {
+      spark.table("testData").queryExecution.withCachedData.collect {
         case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r
       }.size
     }
 
-    sqlContext.uncacheTable("testData")
+    spark.catalog.uncacheTable("testData")
   }
 
   test("read from cached table and uncache") {
-    sqlContext.cacheTable("testData")
-    checkAnswer(sqlContext.table("testData"), testData.collect().toSeq)
-    assertCached(sqlContext.table("testData"))
+    spark.catalog.cacheTable("testData")
+    checkAnswer(spark.table("testData"), testData.collect().toSeq)
+    assertCached(spark.table("testData"))
 
-    sqlContext.uncacheTable("testData")
-    checkAnswer(sqlContext.table("testData"), testData.collect().toSeq)
-    assertCached(sqlContext.table("testData"), 0)
+    spark.catalog.uncacheTable("testData")
+    checkAnswer(spark.table("testData"), testData.collect().toSeq)
+    assertCached(spark.table("testData"), 0)
   }
 
   test("correct error on uncache of non-cached table") {
     intercept[IllegalArgumentException] {
-      sqlContext.uncacheTable("testData")
+      spark.catalog.uncacheTable("testData")
     }
   }
 
   test("SELECT star from cached table") {
     sql("SELECT * FROM testData").registerTempTable("selectStar")
-    sqlContext.cacheTable("selectStar")
+    spark.catalog.cacheTable("selectStar")
     checkAnswer(
       sql("SELECT * FROM selectStar WHERE key = 1"),
       Seq(Row(1, "1")))
-    sqlContext.uncacheTable("selectStar")
+    spark.catalog.uncacheTable("selectStar")
   }
 
   test("Self-join cached") {
     val unCachedAnswer =
       sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect()
-    sqlContext.cacheTable("testData")
+    spark.catalog.cacheTable("testData")
     checkAnswer(
       sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"),
       unCachedAnswer.toSeq)
-    sqlContext.uncacheTable("testData")
+    spark.catalog.uncacheTable("testData")
   }
 
   test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") {
     sql("CACHE TABLE testData")
-    assertCached(sqlContext.table("testData"))
+    assertCached(spark.table("testData"))
 
     val rddId = rddIdOf("testData")
     assert(
@@ -219,7 +219,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
       "Eagerly cached in-memory table should have already been materialized")
 
     sql("UNCACHE TABLE testData")
-    assert(!sqlContext.isCached("testData"), "Table 'testData' should not be cached")
+    assert(!spark.catalog.isCached("testData"), "Table 'testData' should not be cached")
 
     eventually(timeout(10 seconds)) {
       assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
@@ -228,14 +228,14 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
 
   test("CACHE TABLE tableName AS SELECT * FROM anotherTable") {
     sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
-    assertCached(sqlContext.table("testCacheTable"))
+    assertCached(spark.table("testCacheTable"))
 
     val rddId = rddIdOf("testCacheTable")
     assert(
       isMaterialized(rddId),
       "Eagerly cached in-memory table should have already been materialized")
 
-    sqlContext.uncacheTable("testCacheTable")
+    spark.catalog.uncacheTable("testCacheTable")
     eventually(timeout(10 seconds)) {
       assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
     }
@@ -243,14 +243,14 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
 
   test("CACHE TABLE tableName AS SELECT ...") {
     sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10")
-    assertCached(sqlContext.table("testCacheTable"))
+    assertCached(spark.table("testCacheTable"))
 
     val rddId = rddIdOf("testCacheTable")
     assert(
       isMaterialized(rddId),
       "Eagerly cached in-memory table should have already been materialized")
 
-    sqlContext.uncacheTable("testCacheTable")
+    spark.catalog.uncacheTable("testCacheTable")
     eventually(timeout(10 seconds)) {
       assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
     }
@@ -258,7 +258,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
 
   test("CACHE LAZY TABLE tableName") {
     sql("CACHE LAZY TABLE testData")
-    assertCached(sqlContext.table("testData"))
+    assertCached(spark.table("testData"))
 
     val rddId = rddIdOf("testData")
     assert(
@@ -270,7 +270,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
       isMaterialized(rddId),
       "Lazily cached in-memory table should have been materialized")
 
-    sqlContext.uncacheTable("testData")
+    spark.catalog.uncacheTable("testData")
     eventually(timeout(10 seconds)) {
       assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
     }
@@ -278,7 +278,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
 
   test("InMemoryRelation statistics") {
     sql("CACHE TABLE testData")
-    sqlContext.table("testData").queryExecution.withCachedData.collect {
+    spark.table("testData").queryExecution.withCachedData.collect {
       case cached: InMemoryRelation =>
         val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum
         assert(cached.statistics.sizeInBytes === actualSizeInBytes)
@@ -287,62 +287,62 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
 
   test("Drops temporary table") {
     testData.select('key).registerTempTable("t1")
-    sqlContext.table("t1")
-    sqlContext.dropTempTable("t1")
-    intercept[AnalysisException](sqlContext.table("t1"))
+    spark.table("t1")
+    spark.catalog.dropTempTable("t1")
+    intercept[AnalysisException](spark.table("t1"))
   }
 
   test("Drops cached temporary table") {
     testData.select('key).registerTempTable("t1")
     testData.select('key).registerTempTable("t2")
-    sqlContext.cacheTable("t1")
+    spark.catalog.cacheTable("t1")
 
-    assert(sqlContext.isCached("t1"))
-    assert(sqlContext.isCached("t2"))
+    assert(spark.catalog.isCached("t1"))
+    assert(spark.catalog.isCached("t2"))
 
-    sqlContext.dropTempTable("t1")
-    intercept[AnalysisException](sqlContext.table("t1"))
-    assert(!sqlContext.isCached("t2"))
+    spark.catalog.dropTempTable("t1")
+    intercept[AnalysisException](spark.table("t1"))
+    assert(!spark.catalog.isCached("t2"))
   }
 
   test("Clear all cache") {
     sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
     sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
-    sqlContext.cacheTable("t1")
-    sqlContext.cacheTable("t2")
-    sqlContext.clearCache()
-    assert(sqlContext.cacheManager.isEmpty)
+    spark.catalog.cacheTable("t1")
+    spark.catalog.cacheTable("t2")
+    spark.catalog.clearCache()
+    assert(spark.cacheManager.isEmpty)
 
     sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
     sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
-    sqlContext.cacheTable("t1")
-    sqlContext.cacheTable("t2")
+    spark.catalog.cacheTable("t1")
+    spark.catalog.cacheTable("t2")
     sql("Clear CACHE")
-    assert(sqlContext.cacheManager.isEmpty)
+    assert(spark.cacheManager.isEmpty)
   }
 
   test("Clear accumulators when uncacheTable to prevent memory leaking") {
     sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
     sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
 
-    sqlContext.cacheTable("t1")
-    sqlContext.cacheTable("t2")
+    spark.catalog.cacheTable("t1")
+    spark.catalog.cacheTable("t2")
 
     sql("SELECT * FROM t1").count()
     sql("SELECT * FROM t2").count()
     sql("SELECT * FROM t1").count()
     sql("SELECT * FROM t2").count()
 
-    val accId1 = sqlContext.table("t1").queryExecution.withCachedData.collect {
+    val accId1 = spark.table("t1").queryExecution.withCachedData.collect {
       case i: InMemoryRelation => i.batchStats.id
     }.head
 
-    val accId2 = sqlContext.table("t1").queryExecution.withCachedData.collect {
+    val accId2 = spark.table("t1").queryExecution.withCachedData.collect {
       case i: InMemoryRelation => i.batchStats.id
     }.head
 
-    sqlContext.uncacheTable("t1")
-    sqlContext.uncacheTable("t2")
+    spark.catalog.uncacheTable("t1")
+    spark.catalog.uncacheTable("t2")
 
     assert(AccumulatorContext.get(accId1).isEmpty)
     assert(AccumulatorContext.get(accId2).isEmpty)
@@ -351,7 +351,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
   test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") {
     sparkContext.parallelize((1, 1) :: (2, 2) :: Nil)
       .toDF("key", "value").selectExpr("key", "value", "key+1").registerTempTable("abc")
-    sqlContext.cacheTable("abc")
+    spark.catalog.cacheTable("abc")
 
     val sparkPlan = sql(
       """select a.key, b.key, c.key from
@@ -374,15 +374,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
     table3x.registerTempTable("testData3x")
 
     sql("SELECT key, value FROM testData3x ORDER BY key").registerTempTable("orderedTable")
-    sqlContext.cacheTable("orderedTable")
-    assertCached(sqlContext.table("orderedTable"))
+    spark.catalog.cacheTable("orderedTable")
+    assertCached(spark.table("orderedTable"))
     // Should not have an exchange as the query is already sorted on the group by key.
     verifyNumExchanges(sql("SELECT key, count(*) FROM orderedTable GROUP BY key"), 0)
     checkAnswer(
       sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"),
       sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect())
-    sqlContext.uncacheTable("orderedTable")
-    sqlContext.dropTempTable("orderedTable")
+    spark.catalog.uncacheTable("orderedTable")
+    spark.catalog.dropTempTable("orderedTable")
 
     // Set up two tables distributed in the same way. Try this with the data distributed into
     // different number of partitions.
@@ -390,8 +390,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
       withTempTable("t1", "t2") {
         testData.repartition(numPartitions, $"key").registerTempTable("t1")
         testData2.repartition(numPartitions, $"a").registerTempTable("t2")
-        sqlContext.cacheTable("t1")
-        sqlContext.cacheTable("t2")
+        spark.catalog.cacheTable("t1")
+        spark.catalog.cacheTable("t2")
 
         // Joining them should result in no exchanges.
         verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0)
@@ -403,8 +403,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
         checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"),
           sql("SELECT count(*) FROM testData GROUP BY key"))
 
-        sqlContext.uncacheTable("t1")
-        sqlContext.uncacheTable("t2")
+        spark.catalog.uncacheTable("t1")
+        spark.catalog.uncacheTable("t2")
       }
     }
 
@@ -412,8 +412,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
     withTempTable("t1", "t2") {
       testData.repartition(6, $"key").registerTempTable("t1")
       testData2.repartition(3, $"a").registerTempTable("t2")
-      sqlContext.cacheTable("t1")
-      sqlContext.cacheTable("t2")
+      spark.catalog.cacheTable("t1")
+      spark.catalog.cacheTable("t2")
 
       val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
       verifyNumExchanges(query, 1)
@@ -421,16 +421,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
       checkAnswer(
         query,
         testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
-      sqlContext.uncacheTable("t1")
-      sqlContext.uncacheTable("t2")
+      spark.catalog.uncacheTable("t1")
+      spark.catalog.uncacheTable("t2")
     }
 
     // One side of join is not partitioned in the desired way. Need to shuffle one side.
     withTempTable("t1", "t2") {
       testData.repartition(6, $"value").registerTempTable("t1")
       testData2.repartition(6, $"a").registerTempTable("t2")
-      sqlContext.cacheTable("t1")
-      sqlContext.cacheTable("t2")
+      spark.catalog.cacheTable("t1")
+      spark.catalog.cacheTable("t2")
 
       val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
       verifyNumExchanges(query, 1)
@@ -438,15 +438,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
       checkAnswer(
         query,
         testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
-      sqlContext.uncacheTable("t1")
-      sqlContext.uncacheTable("t2")
+      spark.catalog.uncacheTable("t1")
+      spark.catalog.uncacheTable("t2")
     }
 
     withTempTable("t1", "t2") {
       testData.repartition(6, $"value").registerTempTable("t1")
       testData2.repartition(12, $"a").registerTempTable("t2")
-      sqlContext.cacheTable("t1")
-      sqlContext.cacheTable("t2")
+      spark.catalog.cacheTable("t1")
+      spark.catalog.cacheTable("t2")
 
       val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
       verifyNumExchanges(query, 1)
@@ -454,8 +454,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
       checkAnswer(
         query,
         testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
-      sqlContext.uncacheTable("t1")
-      sqlContext.uncacheTable("t2")
+      spark.catalog.uncacheTable("t1")
+      spark.catalog.uncacheTable("t2")
     }
 
     // One side of join is not partitioned in the desired way. Since the number of partitions of
@@ -464,30 +464,30 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
     withTempTable("t1", "t2") {
       testData.repartition(6, $"value").registerTempTable("t1")
       testData2.repartition(3, $"a").registerTempTable("t2")
-      sqlContext.cacheTable("t1")
-      sqlContext.cacheTable("t2")
+      spark.catalog.cacheTable("t1")
+      spark.catalog.cacheTable("t2")
 
       val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
       verifyNumExchanges(query, 2)
       checkAnswer(
         query,
         testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
-      sqlContext.uncacheTable("t1")
-      sqlContext.uncacheTable("t2")
+      spark.catalog.uncacheTable("t1")
+      spark.catalog.uncacheTable("t2")
     }
 
     // repartition's column ordering is different from group by column ordering.
     // But they use the same set of columns.
     withTempTable("t1") {
       testData.repartition(6, $"value", $"key").registerTempTable("t1")
-      sqlContext.cacheTable("t1")
+      spark.catalog.cacheTable("t1")
 
       val query = sql("SELECT value, key from t1 group by key, value")
       verifyNumExchanges(query, 0)
       checkAnswer(
         query,
         testData.distinct().select($"value", $"key"))
-      sqlContext.uncacheTable("t1")
+      spark.catalog.uncacheTable("t1")
     }
 
     // repartition's column ordering is different from join condition's column ordering.
@@ -499,8 +499,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
       df1.repartition(6, $"value", $"key").registerTempTable("t1")
       val df2 = testData2.select($"a", $"b".cast("string"))
       df2.repartition(6, $"a", $"b").registerTempTable("t2")
-      sqlContext.cacheTable("t1")
-      sqlContext.cacheTable("t2")
+      spark.catalog.cacheTable("t1")
+      spark.catalog.cacheTable("t2")
 
       val query =
         sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a and t1.value = t2.b")
@@ -509,8 +509,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
       checkAnswer(
         query,
         df1.join(df2, $"key" === $"a" && $"value" === $"b").select($"key", $"value", $"a", $"b"))
-      sqlContext.uncacheTable("t1")
-      sqlContext.uncacheTable("t2")
+      spark.catalog.uncacheTable("t1")
+      spark.catalog.uncacheTable("t2")
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 19fe29a..a5aecca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -29,7 +29,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
 
   private lazy val booleanData = {
-    sqlContext.createDataFrame(sparkContext.parallelize(
+    spark.createDataFrame(sparkContext.parallelize(
       Row(false, false) ::
       Row(false, true) ::
       Row(true, false) ::
@@ -287,7 +287,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
   }
 
   test("isNaN") {
-    val testData = sqlContext.createDataFrame(sparkContext.parallelize(
+    val testData = spark.createDataFrame(sparkContext.parallelize(
       Row(Double.NaN, Float.NaN) ::
       Row(math.log(-1), math.log(-3).toFloat) ::
       Row(null, null) ::
@@ -308,7 +308,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
   }
 
   test("nanvl") {
-    val testData = sqlContext.createDataFrame(sparkContext.parallelize(
+    val testData = spark.createDataFrame(sparkContext.parallelize(
       Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil),
       StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType),
         StructField("c", DoubleType), StructField("d", DoubleType),
@@ -351,7 +351,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
   }
 
   test("=!=") {
-    val nullData = sqlContext.createDataFrame(sparkContext.parallelize(
+    val nullData = spark.createDataFrame(sparkContext.parallelize(
       Row(1, 1) ::
       Row(1, 2) ::
       Row(1, null) ::
@@ -370,7 +370,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
       nullData.filter($"a" <=> $"b"),
       Row(1, 1) :: Row(null, null) :: Nil)
 
-    val nullData2 = sqlContext.createDataFrame(sparkContext.parallelize(
+    val nullData2 = spark.createDataFrame(sparkContext.parallelize(
         Row("abc") ::
         Row(null)  ::
         Row("xyz") :: Nil),
@@ -596,7 +596,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
     withTempPath { dir =>
       val data = sparkContext.parallelize(0 to 10).toDF("id")
       data.write.parquet(dir.getCanonicalPath)
-      val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(input_file_name())
+      val answer = spark.read.parquet(dir.getCanonicalPath).select(input_file_name())
         .head.getString(0)
       assert(answer.contains(dir.getCanonicalPath))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 63f4b75..8a99866 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -70,7 +70,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
         Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0)))
     )
 
-    val decimalDataWithNulls = sqlContext.sparkContext.parallelize(
+    val decimalDataWithNulls = spark.sparkContext.parallelize(
       DecimalData(1, 1) ::
       DecimalData(1, null) ::
       DecimalData(2, 1) ::
@@ -114,7 +114,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
         Row(null, null, 113000.0) :: Nil
     )
 
-    val df0 = sqlContext.sparkContext.parallelize(Seq(
+    val df0 = spark.sparkContext.parallelize(Seq(
       Fact(20151123, 18, 35, "room1", 18.6),
       Fact(20151123, 18, 35, "room2", 22.4),
       Fact(20151123, 18, 36, "room1", 17.4),
@@ -207,12 +207,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
       Seq(Row(1, 3), Row(2, 3), Row(3, 3))
     )
 
-    sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false)
+    spark.conf.set(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key, false)
     checkAnswer(
       testData2.groupBy("a").agg(sum($"b")),
       Seq(Row(3), Row(3), Row(3))
     )
-    sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true)
+    spark.conf.set(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key, true)
   }
 
   test("agg without groups") {
@@ -433,10 +433,10 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
 
   test("SPARK-14664: Decimal sum/avg over window should work.") {
     checkAnswer(
-      sqlContext.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),
+      spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),
       Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil)
     checkAnswer(
-      sqlContext.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"),
+      spark.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"),
       Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 0414fa1..031e66b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -154,7 +154,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
     // SPARK-12275: no physical plan for BroadcastHint in some condition
     withTempPath { path =>
       df1.write.parquet(path.getCanonicalPath)
-      val pf1 = sqlContext.read.parquet(path.getCanonicalPath)
+      val pf1 = spark.read.parquet(path.getCanonicalPath)
       assert(df1.join(broadcast(pf1)).count() === 4)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
index c6d6751..fa8fa06 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -81,11 +81,11 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
   }
 
   test("pivot max values enforced") {
-    sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
+    spark.conf.set(SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key, 1)
     intercept[AnalysisException](
       courseSales.groupBy("year").pivot("course")
     )
-    sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
+    spark.conf.set(SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key,
       SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
   }
 
@@ -104,7 +104,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
       // pivot with extra columns to trigger optimization
       .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString))
       .agg(sum($"earnings"))
-    val queryExecution = sqlContext.executePlan(df.queryExecution.logical)
+    val queryExecution = spark.executePlan(df.queryExecution.logical)
     assert(queryExecution.simpleString.contains("pivotfirst"))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 0ea7727..ab7733b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -236,7 +236,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
   }
 
   test("sampleBy") {
-    val df = sqlContext.range(0, 100).select((col("id") % 3).as("key"))
+    val df = spark.range(0, 100).select((col("id") % 3).as("key"))
     val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
     checkAnswer(
       sampled.groupBy("key").count().orderBy("key"),
@@ -247,7 +247,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
   // `CountMinSketch`es that meet required specs.  Test cases for `CountMinSketch` can be found in
   // `CountMinSketchSuite` in project spark-sketch.
   test("countMinSketch") {
-    val df = sqlContext.range(1000)
+    val df = spark.range(1000)
 
     val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42)
     assert(sketch1.totalCount() === 1000)
@@ -279,7 +279,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
   // This test only verifies some basic requirements, more correctness tests can be found in
   // `BloomFilterSuite` in project spark-sketch.
   test("Bloom filter") {
-    val df = sqlContext.range(1000)
+    val df = spark.range(1000)
 
     val filter1 = df.stat.bloomFilter("id", 1000, 0.03)
     assert(filter1.expectedFpp() - 0.03 < 1e-3)
@@ -304,7 +304,7 @@ class DataFrameStatPerfSuite extends QueryTest with SharedSQLContext with Loggin
 
   // Turn on this test if you want to test the performance of approximate quantiles.
   ignore("computing quantiles should not take much longer than describe()") {
-    val df = sqlContext.range(5000000L).toDF("col1").cache()
+    val df = spark.range(5000000L).toDF("col1").cache()
     def seconds(f: => Any): Double = {
       // Do some warmup
       logDebug("warmup...")

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 80a93ee..f77403c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -99,8 +99,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0))))
     val schema2 = StructType(Array(StructField("label", IntegerType, false),
                     StructField("point", new ExamplePointUDT(), false)))
-    val df1 = sqlContext.createDataFrame(rowRDD1, schema1)
-    val df2 = sqlContext.createDataFrame(rowRDD2, schema2)
+    val df1 = spark.createDataFrame(rowRDD1, schema1)
+    val df2 = spark.createDataFrame(rowRDD2, schema2)
 
     checkAnswer(
       df1.union(df2).orderBy("label"),
@@ -109,8 +109,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
   }
 
   test("empty data frame") {
-    assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String])
-    assert(sqlContext.emptyDataFrame.count() === 0)
+    assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String])
+    assert(spark.emptyDataFrame.count() === 0)
   }
 
   test("head and take") {
@@ -369,7 +369,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
 
     // SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake
     checkAnswer(
-      sqlContext.range(2).toDF().limit(2147483638),
+      spark.range(2).toDF().limit(2147483638),
       Row(0) :: Row(1) :: Nil
     )
   }
@@ -672,12 +672,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
 
       val parquetDir = new File(dir, "parquet").getCanonicalPath
       df.write.parquet(parquetDir)
-      val parquetDF = sqlContext.read.parquet(parquetDir)
+      val parquetDF = spark.read.parquet(parquetDir)
       assert(parquetDF.inputFiles.nonEmpty)
 
       val jsonDir = new File(dir, "json").getCanonicalPath
       df.write.json(jsonDir)
-      val jsonDF = sqlContext.read.json(jsonDir)
+      val jsonDF = spark.read.json(jsonDir)
       assert(parquetDF.inputFiles.nonEmpty)
 
       val unioned = jsonDF.union(parquetDF).inputFiles.sorted
@@ -801,7 +801,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
   test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") {
     val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
     val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false)))
-    val df = sqlContext.createDataFrame(rowRDD, schema)
+    val df = spark.createDataFrame(rowRDD, schema)
     df.rdd.collect()
   }
 
@@ -818,14 +818,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
   }
 
   test("SPARK-7551: support backticks for DataFrame attribute resolution") {
-    val df = sqlContext.read.json(sparkContext.makeRDD(
+    val df = spark.read.json(sparkContext.makeRDD(
       """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil))
     checkAnswer(
       df.select(df("`a.b`.c.`d..e`.`f`")),
       Row(1)
     )
 
-    val df2 = sqlContext.read.json(sparkContext.makeRDD(
+    val df2 = spark.read.json(sparkContext.makeRDD(
       """{"a  b": {"c": {"d  e": {"f": 1}}}}""" :: Nil))
     checkAnswer(
       df2.select(df2("`a  b`.c.d  e.f")),
@@ -881,53 +881,53 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
 
   test("SPARK-7150 range api") {
     // numSlice is greater than length
-    val res1 = sqlContext.range(0, 10, 1, 15).select("id")
+    val res1 = spark.range(0, 10, 1, 15).select("id")
     assert(res1.count == 10)
     assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
 
-    val res2 = sqlContext.range(3, 15, 3, 2).select("id")
+    val res2 = spark.range(3, 15, 3, 2).select("id")
     assert(res2.count == 4)
     assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
 
-    val res3 = sqlContext.range(1, -2).select("id")
+    val res3 = spark.range(1, -2).select("id")
     assert(res3.count == 0)
 
     // start is positive, end is negative, step is negative
-    val res4 = sqlContext.range(1, -2, -2, 6).select("id")
+    val res4 = spark.range(1, -2, -2, 6).select("id")
     assert(res4.count == 2)
     assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
 
     // start, end, step are negative
-    val res5 = sqlContext.range(-3, -8, -2, 1).select("id")
+    val res5 = spark.range(-3, -8, -2, 1).select("id")
     assert(res5.count == 3)
     assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
 
     // start, end are negative, step is positive
-    val res6 = sqlContext.range(-8, -4, 2, 1).select("id")
+    val res6 = spark.range(-8, -4, 2, 1).select("id")
     assert(res6.count == 2)
     assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
 
-    val res7 = sqlContext.range(-10, -9, -20, 1).select("id")
+    val res7 = spark.range(-10, -9, -20, 1).select("id")
     assert(res7.count == 0)
 
-    val res8 = sqlContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
+    val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
     assert(res8.count == 3)
     assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
 
-    val res9 = sqlContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
+    val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
     assert(res9.count == 2)
     assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
 
     // only end provided as argument
-    val res10 = sqlContext.range(10).select("id")
+    val res10 = spark.range(10).select("id")
     assert(res10.count == 10)
     assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
 
-    val res11 = sqlContext.range(-1).select("id")
+    val res11 = spark.range(-1).select("id")
     assert(res11.count == 0)
 
     // using the default slice number
-    val res12 = sqlContext.range(3, 15, 3).select("id")
+    val res12 = spark.range(3, 15, 3).select("id")
     assert(res12.count == 4)
     assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
   }
@@ -993,13 +993,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
 
       // pass case: parquet table (HadoopFsRelation)
       df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath)
-      val pdf = sqlContext.read.parquet(tempParquetFile.getCanonicalPath)
+      val pdf = spark.read.parquet(tempParquetFile.getCanonicalPath)
       pdf.registerTempTable("parquet_base")
       insertion.write.insertInto("parquet_base")
 
       // pass case: json table (InsertableRelation)
       df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath)
-      val jdf = sqlContext.read.json(tempJsonFile.getCanonicalPath)
+      val jdf = spark.read.json(tempJsonFile.getCanonicalPath)
       jdf.registerTempTable("json_base")
       insertion.write.mode(SaveMode.Overwrite).insertInto("json_base")
 
@@ -1019,7 +1019,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
       assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed."))
 
       // error case: insert into an OneRowRelation
-      Dataset.ofRows(sqlContext.sparkSession, OneRowRelation).registerTempTable("one_row")
+      Dataset.ofRows(spark, OneRowRelation).registerTempTable("one_row")
       val e3 = intercept[AnalysisException] {
         insertion.write.insertInto("one_row")
       }
@@ -1062,7 +1062,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
   }
 
   test("SPARK-9323: DataFrame.orderBy should support nested column name") {
-    val df = sqlContext.read.json(sparkContext.makeRDD(
+    val df = spark.read.json(sparkContext.makeRDD(
       """{"a": {"b": 1}}""" :: Nil))
     checkAnswer(df.orderBy("a.b"), Row(Row(1)))
   }
@@ -1091,10 +1091,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
       val dir2 = new File(dir, "dir2").getCanonicalPath
       df2.write.format("json").save(dir2)
 
-      checkAnswer(sqlContext.read.format("json").load(dir1, dir2),
+      checkAnswer(spark.read.format("json").load(dir1, dir2),
         Row(1, 22) :: Row(2, 23) :: Nil)
 
-      checkAnswer(sqlContext.read.format("json").load(dir1),
+      checkAnswer(spark.read.format("json").load(dir1),
         Row(1, 22) :: Nil)
     }
   }
@@ -1116,7 +1116,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
   }
 
   test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") {
-    val input = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+    val input = spark.read.json(spark.sparkContext.makeRDD(
       (1 to 10).map(i => s"""{"id": $i}""")))
 
     val df = input.select($"id", rand(0).as('r))
@@ -1185,7 +1185,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
       withTempPath { path =>
         Seq(2012 -> "a").toDF("year", "val").write.partitionBy("year").parquet(path.getAbsolutePath)
-        val df = sqlContext.read.parquet(path.getAbsolutePath)
+        val df = spark.read.parquet(path.getAbsolutePath)
         checkAnswer(df.filter($"yEAr" > 2000).select($"val"), Row("a"))
       }
     }
@@ -1244,7 +1244,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     verifyExchangingAgg(testData.repartition($"key", $"value")
       .groupBy("key").count())
 
-    val data = sqlContext.sparkContext.parallelize(
+    val data = spark.sparkContext.parallelize(
       (1 to 100).map(i => TestData2(i % 10, i))).toDF()
 
     // Distribute and order by.
@@ -1308,7 +1308,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
       withTempPath { path =>
         val p = path.getAbsolutePath
         Seq(2012 -> "a").toDF("year", "val").write.partitionBy("yEAr").parquet(p)
-        checkAnswer(sqlContext.read.parquet(p).select("YeaR"), Row(2012))
+        checkAnswer(spark.read.parquet(p).select("YeaR"), Row(2012))
       }
     }
   }
@@ -1317,7 +1317,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
   test("SPARK-11633: LogicalRDD throws TreeNode Exception: Failed to Copy Node") {
     withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
       val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1)))
-      val df = sqlContext.createDataFrame(
+      val df = spark.createDataFrame(
         rdd,
         new StructType().add("f1", IntegerType).add("f2", IntegerType),
         needsConversion = false).select($"F1", $"f2".as("f2"))
@@ -1344,7 +1344,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     }
     checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil)
 
-    sqlContext.udf.register("boxedUDF",
+    spark.udf.register("boxedUDF",
       (i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer)
     checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil)
 
@@ -1393,7 +1393,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
 
   test("reuse exchange") {
     withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "2") {
-      val df = sqlContext.range(100).toDF()
+      val df = spark.range(100).toDF()
       val join = df.join(df, "id")
       val plan = join.queryExecution.executedPlan
       checkAnswer(join, df)
@@ -1415,14 +1415,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
   }
 
   test("sameResult() on aggregate") {
-    val df = sqlContext.range(100)
+    val df = spark.range(100)
     val agg1 = df.groupBy().count()
     val agg2 = df.groupBy().count()
     // two aggregates with different ExprId within them should have same result
     assert(agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan))
     val agg3 = df.groupBy().sum()
     assert(!agg1.queryExecution.executedPlan.sameResult(agg3.queryExecution.executedPlan))
-    val df2 = sqlContext.range(101)
+    val df2 = spark.range(101)
     val agg4 = df2.groupBy().count()
     assert(!agg1.queryExecution.executedPlan.sameResult(agg4.queryExecution.executedPlan))
   }
@@ -1454,24 +1454,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
 
   test("assertAnalyzed shouldn't replace original stack trace") {
     val e = intercept[AnalysisException] {
-      sqlContext.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b)
+      spark.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b)
     }
 
     assert(e.getStackTrace.head.getClassName != classOf[QueryExecution].getName)
   }
 
   test("SPARK-13774: Check error message for non existent path without globbed paths") {
-    val e = intercept[AnalysisException] (sqlContext.read.format("csv").
+    val e = intercept[AnalysisException] (spark.read.format("csv").
       load("/xyz/file2", "/xyz/file21", "/abc/files555", "a")).getMessage()
     assert(e.startsWith("Path does not exist"))
    }
 
   test("SPARK-13774: Check error message for not existent globbed paths") {
-    val e = intercept[AnalysisException] (sqlContext.read.format("text").
+    val e = intercept[AnalysisException] (spark.read.format("text").
       load( "/xyz/*")).getMessage()
     assert(e.startsWith("Path does not exist"))
 
-    val e1 = intercept[AnalysisException] (sqlContext.read.json("/mnt/*/*-xyz.json").rdd).
+    val e1 = intercept[AnalysisException] (spark.read.json("/mnt/*/*-xyz.json").rdd).
       getMessage()
     assert(e1.startsWith("Path does not exist"))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
index 06584ec..a957d5b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
@@ -249,14 +249,14 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B
     try {
       f(tableName)
     } finally {
-      sqlContext.dropTempTable(tableName)
+      spark.catalog.dropTempTable(tableName)
     }
   }
 
   test("time window in SQL with single string expression") {
     withTempTable { table =>
       checkAnswer(
-        sqlContext.sql(s"""select window(time, "10 seconds"), value from $table""")
+        spark.sql(s"""select window(time, "10 seconds"), value from $table""")
           .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"),
         Seq(
           Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4),
@@ -270,7 +270,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B
   test("time window in SQL with with two expressions") {
     withTempTable { table =>
       checkAnswer(
-        sqlContext.sql(
+        spark.sql(
           s"""select window(time, "10 seconds", 10000000), value from $table""")
           .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"),
         Seq(
@@ -285,7 +285,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B
   test("time window in SQL with with three expressions") {
     withTempTable { table =>
       checkAnswer(
-        sqlContext.sql(
+        spark.sql(
           s"""select window(time, "10 seconds", 10000000, "5 seconds"), value from $table""")
           .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"),
         Seq(

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
index 68e99d6..fe6ba83 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
@@ -48,7 +48,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
           .add("b3", FloatType)
           .add("b4", DoubleType))
 
-    val df = sqlContext.createDataFrame(data, schema)
+    val df = spark.createDataFrame(data, schema)
     assert(df.select("b").first() === Row(struct))
   }
 
@@ -70,7 +70,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
           .add("b5b", StringType))
           .add("b6", StringType))
 
-    val df = sqlContext.createDataFrame(data, schema)
+    val df = spark.createDataFrame(data, schema)
     assert(df.select("b").first() === Row(outerStruct))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
index ae9fb80..d8e241c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.expressions.scala.typed
 import org.apache.spark.sql.functions._
@@ -31,14 +31,14 @@ object DatasetBenchmark {
 
   case class Data(l: Long, s: String)
 
-  def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = {
-    import sqlContext.implicits._
+  def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
+    import spark.implicits._
 
-    val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+    val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
     val benchmark = new Benchmark("back-to-back map", numRows)
     val func = (d: Data) => Data(d.l + 1, d.s)
 
-    val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+    val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
     benchmark.addCase("RDD") { iter =>
       var res = rdd
       var i = 0
@@ -72,17 +72,17 @@ object DatasetBenchmark {
     benchmark
   }
 
-  def backToBackFilter(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = {
-    import sqlContext.implicits._
+  def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
+    import spark.implicits._
 
-    val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+    val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
     val benchmark = new Benchmark("back-to-back filter", numRows)
     val func = (d: Data, i: Int) => d.l % (100L + i) == 0L
     val funcs = 0.until(numChains).map { i =>
       (d: Data) => func(d, i)
     }
 
-    val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+    val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
     benchmark.addCase("RDD") { iter =>
       var res = rdd
       var i = 0
@@ -130,13 +130,13 @@ object DatasetBenchmark {
     override def outputEncoder: Encoder[Long] = Encoders.scalaLong
   }
 
-  def aggregate(sqlContext: SQLContext, numRows: Long): Benchmark = {
-    import sqlContext.implicits._
+  def aggregate(spark: SparkSession, numRows: Long): Benchmark = {
+    import spark.implicits._
 
-    val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+    val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
     val benchmark = new Benchmark("aggregate", numRows)
 
-    val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+    val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
     benchmark.addCase("RDD sum") { iter =>
       rdd.aggregate(0L)(_ + _.l, _ + _)
     }
@@ -157,15 +157,17 @@ object DatasetBenchmark {
   }
 
   def main(args: Array[String]): Unit = {
-    val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
-    val sqlContext = new SQLContext(sparkContext)
+    val spark = SparkSession.builder
+      .master("local[*]")
+      .appName("Dataset benchmark")
+      .getOrCreate()
 
     val numRows = 100000000
     val numChains = 10
 
-    val benchmark = backToBackMap(sqlContext, numRows, numChains)
-    val benchmark2 = backToBackFilter(sqlContext, numRows, numChains)
-    val benchmark3 = aggregate(sqlContext, numRows)
+    val benchmark = backToBackMap(spark, numRows, numChains)
+    val benchmark2 = backToBackFilter(spark, numRows, numChains)
+    val benchmark3 = aggregate(spark, numRows)
 
     /*
     Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
index 942cc09..8c0906b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
@@ -39,7 +39,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext {
       2, 3, 4)
     // Drop the cache.
     cached.unpersist()
-    assert(!sqlContext.isCached(cached), "The Dataset should not be cached.")
+    assert(spark.cacheManager.lookupCachedData(cached).isEmpty, "The Dataset should not be cached.")
   }
 
   test("persist and then rebind right encoder when join 2 datasets") {
@@ -56,9 +56,11 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext {
     assertCached(joined, 2)
 
     ds1.unpersist()
-    assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.")
+    assert(spark.cacheManager.lookupCachedData(ds1).isEmpty,
+      "The Dataset ds1 should not be cached.")
     ds2.unpersist()
-    assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.")
+    assert(spark.cacheManager.lookupCachedData(ds2).isEmpty,
+      "The Dataset ds2 should not be cached.")
   }
 
   test("persist and then groupBy columns asKey, map") {
@@ -73,8 +75,9 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext {
     assertCached(agged.filter(_._1 == "b"))
 
     ds.unpersist()
-    assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.")
+    assert(spark.cacheManager.lookupCachedData(ds).isEmpty, "The Dataset ds should not be cached.")
     agged.unpersist()
-    assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.")
+    assert(spark.cacheManager.lookupCachedData(agged).isEmpty,
+      "The Dataset agged should not be cached.")
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 3cb4e52..3c8c862 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -46,12 +46,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
   }
 
   test("range") {
-    assert(sqlContext.range(10).map(_ + 1).reduce(_ + _) == 55)
-    assert(sqlContext.range(10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55)
-    assert(sqlContext.range(0, 10).map(_ + 1).reduce(_ + _) == 55)
-    assert(sqlContext.range(0, 10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55)
-    assert(sqlContext.range(0, 10, 1, 2).map(_ + 1).reduce(_ + _) == 55)
-    assert(sqlContext.range(0, 10, 1, 2).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55)
+    assert(spark.range(10).map(_ + 1).reduce(_ + _) == 55)
+    assert(spark.range(10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55)
+    assert(spark.range(0, 10).map(_ + 1).reduce(_ + _) == 55)
+    assert(spark.range(0, 10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55)
+    assert(spark.range(0, 10, 1, 2).map(_ + 1).reduce(_ + _) == 55)
+    assert(spark.range(0, 10, 1, 2).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55)
   }
 
   test("SPARK-12404: Datatype Helper Serializability") {
@@ -472,7 +472,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
   }
 
   test("SPARK-14696: implicit encoders for boxed types") {
-    assert(sqlContext.range(1).map { i => i : java.lang.Long }.head == 0L)
+    assert(spark.range(1).map { i => i : java.lang.Long }.head == 0L)
   }
 
   test("SPARK-11894: Incorrect results are returned when using null") {
@@ -510,8 +510,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     ))
 
     def buildDataset(rows: Row*): Dataset[NestedStruct] = {
-      val rowRDD = sqlContext.sparkContext.parallelize(rows)
-      sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct]
+      val rowRDD = spark.sparkContext.parallelize(rows)
+      spark.createDataFrame(rowRDD, schema).as[NestedStruct]
     }
 
     checkDataset(
@@ -626,7 +626,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
   }
 
   test("SPARK-14554: Dataset.map may generate wrong java code for wide table") {
-    val wideDF = sqlContext.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*)
+    val wideDF = spark.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*)
     // Make sure the generated code for this plan can compile and execute.
     checkDataset(wideDF.map(_.getLong(0)), 0L until 10 : _*)
   }
@@ -654,7 +654,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     dataset.join(actual, dataset("user") === actual("id")).collect()
   }
 
-  test("SPARK-15097: implicits on dataset's sqlContext can be imported") {
+  test("SPARK-15097: implicits on dataset's spark can be imported") {
     val dataset = Seq(1, 2, 3).toDS()
     checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4)
   }
@@ -735,10 +735,10 @@ object JavaData {
   def apply(a: Int): JavaData = new JavaData(a)
 }
 
-/** Used to test importing dataset.sqlContext.implicits._ */
+/** Used to test importing dataset.spark.implicits._ */
 object DatasetTransform {
   def addOne(ds: Dataset[Int]): Dataset[Int] = {
-    import ds.sqlContext.implicits._
+    import ds.sparkSession.implicits._
     ds.map(_ + 1)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
index b1987c6..a41b465 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
@@ -51,7 +51,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
 
   test("insert an extraStrategy") {
     try {
-      sqlContext.experimental.extraStrategies = TestStrategy :: Nil
+      spark.experimental.extraStrategies = TestStrategy :: Nil
 
       val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
       checkAnswer(
@@ -62,7 +62,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
         df.select("a", "b"),
         Row("so slow", 1))
     } finally {
-      sqlContext.experimental.extraStrategies = Nil
+      spark.experimental.extraStrategies = Nil
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 8cbad04..da567db 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -37,7 +37,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
     val x = testData2.as("x")
     val y = testData2.as("y")
     val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan
-    val planned = sqlContext.sessionState.planner.JoinSelection(join)
+    val planned = spark.sessionState.planner.JoinSelection(join)
     assert(planned.size === 1)
   }
 
@@ -60,7 +60,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
   }
 
   test("join operator selection") {
-    sqlContext.cacheManager.clearCache()
+    spark.cacheManager.clearCache()
 
     withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
       Seq(
@@ -112,7 +112,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
 //  }
 
   test("broadcasted hash join operator selection") {
-    sqlContext.cacheManager.clearCache()
+    spark.cacheManager.clearCache()
     sql("CACHE TABLE testData")
     Seq(
       ("SELECT * FROM testData join testData2 ON key = a",
@@ -126,7 +126,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
   }
 
   test("broadcasted hash outer join operator selection") {
-    sqlContext.cacheManager.clearCache()
+    spark.cacheManager.clearCache()
     sql("CACHE TABLE testData")
     sql("CACHE TABLE testData2")
     Seq(
@@ -144,7 +144,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
     val x = testData2.as("x")
     val y = testData2.as("y")
     val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan
-    val planned = sqlContext.sessionState.planner.JoinSelection(join)
+    val planned = spark.sessionState.planner.JoinSelection(join)
     assert(planned.size === 1)
   }
 
@@ -435,7 +435,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
   }
 
   test("broadcasted existence join operator selection") {
-    sqlContext.cacheManager.clearCache()
+    spark.cacheManager.clearCache()
     sql("CACHE TABLE testData")
 
     withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
@@ -461,17 +461,17 @@ class JoinSuite extends QueryTest with SharedSQLContext {
   test("cross join with broadcast") {
     sql("CACHE TABLE testData")
 
-    val sizeInByteOfTestData = statisticSizeInByte(sqlContext.table("testData"))
+    val sizeInByteOfTestData = statisticSizeInByte(spark.table("testData"))
 
     // we set the threshold is greater than statistic of the cached table testData
     withSQLConf(
       SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) {
 
-      assert(statisticSizeInByte(sqlContext.table("testData2")) >
-        sqlContext.conf.autoBroadcastJoinThreshold)
+      assert(statisticSizeInByte(spark.table("testData2")) >
+        spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD))
 
-      assert(statisticSizeInByte(sqlContext.table("testData")) <
-        sqlContext.conf.autoBroadcastJoinThreshold)
+      assert(statisticSizeInByte(spark.table("testData")) <
+        spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD))
 
       Seq(
         ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
index 9f6c86a..c88dfe5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -33,36 +33,36 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
   }
 
   after {
-    sqlContext.sessionState.catalog.dropTable(
+    spark.sessionState.catalog.dropTable(
       TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true)
   }
 
   test("get all tables") {
     checkAnswer(
-      sqlContext.tables().filter("tableName = 'listtablessuitetable'"),
+      spark.wrapped.tables().filter("tableName = 'listtablessuitetable'"),
       Row("listtablessuitetable", true))
 
     checkAnswer(
       sql("SHOW tables").filter("tableName = 'listtablessuitetable'"),
       Row("listtablessuitetable", true))
 
-    sqlContext.sessionState.catalog.dropTable(
+    spark.sessionState.catalog.dropTable(
       TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true)
-    assert(sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
+    assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
   }
 
   test("getting all tables with a database name has no impact on returned table names") {
     checkAnswer(
-      sqlContext.tables("default").filter("tableName = 'listtablessuitetable'"),
+      spark.wrapped.tables("default").filter("tableName = 'listtablessuitetable'"),
       Row("listtablessuitetable", true))
 
     checkAnswer(
       sql("show TABLES in default").filter("tableName = 'listtablessuitetable'"),
       Row("listtablessuitetable", true))
 
-    sqlContext.sessionState.catalog.dropTable(
+    spark.sessionState.catalog.dropTable(
       TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true)
-    assert(sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
+    assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
   }
 
   test("query the returned DataFrame of tables") {
@@ -70,7 +70,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
       StructField("tableName", StringType, false) ::
       StructField("isTemporary", BooleanType, false) :: Nil)
 
-    Seq(sqlContext.tables(), sql("SHOW TABLes")).foreach {
+    Seq(spark.wrapped.tables(), sql("SHOW TABLes")).foreach {
       case tableDF =>
         assert(expectedSchema === tableDF.schema)
 
@@ -81,9 +81,9 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
           Row(true, "listtablessuitetable")
         )
         checkAnswer(
-          sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
+          spark.wrapped.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
           Row("tables", true))
-        sqlContext.dropTempTable("tables")
+        spark.catalog.dropTempTable("tables")
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala
new file mode 100644
index 0000000..1732977
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import _root_.io.netty.util.internal.logging.{InternalLoggerFactory, Slf4JLoggerFactory}
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.Suite
+
+/** Manages a local `spark` {@link SparkSession} variable, correctly stopping it after each test. */
+trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite =>
+
+  @transient var spark: SparkSession = _
+
+  override def beforeAll() {
+    super.beforeAll()
+    InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory())
+  }
+
+  override def afterEach() {
+    try {
+      resetSparkContext()
+    } finally {
+      super.afterEach()
+    }
+  }
+
+  def resetSparkContext(): Unit = {
+    LocalSparkSession.stop(spark)
+    spark = null
+  }
+
+}
+
+object LocalSparkSession {
+  def stop(spark: SparkSession) {
+    if (spark != null) {
+      spark.stop()
+    }
+    // To avoid RPC rebinding to the same port, since it doesn't unbind immediately on shutdown
+    System.clearProperty("spark.driver.port")
+  }
+
+  /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */
+  def withSparkSession[T](sc: SparkSession)(f: SparkSession => T): T = {
+    try {
+      f(sc)
+    } finally {
+      stop(sc)
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index df8b3b7..a1a9b66 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.types.ObjectType
 
 abstract class QueryTest extends PlanTest {
 
-  protected def sqlContext: SQLContext
+  protected def spark: SparkSession
 
   // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
   TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
@@ -81,7 +81,7 @@ abstract class QueryTest extends PlanTest {
       expectedAnswer: T*): Unit = {
     checkAnswer(
       ds.toDF(),
-      sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq)
+      spark.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq)
 
     checkDecoding(ds, expectedAnswer: _*)
   }
@@ -267,7 +267,7 @@ abstract class QueryTest extends PlanTest {
 
 
     val jsonBackPlan = try {
-      TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext)
+      TreeNode.fromJSON[LogicalPlan](jsonString, spark.sparkContext)
     } catch {
       case NonFatal(e) =>
         fail(
@@ -282,7 +282,7 @@ abstract class QueryTest extends PlanTest {
     def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = {
       case l: LogicalRDD =>
         val origin = logicalRDDs.pop()
-        LogicalRDD(l.output, origin.rdd)(sqlContext.sparkSession)
+        LogicalRDD(l.output, origin.rdd)(spark)
       case l: LocalRelation =>
         val origin = localRelations.pop()
         l.copy(data = origin.data)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message