spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From andrewo...@apache.org
Subject spark git commit: [SPARK-14988][PYTHON] SparkSession API follow-ups
Date Fri, 29 Apr 2016 23:41:22 GMT
Repository: spark
Updated Branches:
  refs/heads/master 4ae9fe091 -> d33e3d572


[SPARK-14988][PYTHON] SparkSession API follow-ups

## What changes were proposed in this pull request?

Addresses comments in #12765.

## How was this patch tested?

Python tests.

Author: Andrew Or <andrew@databricks.com>

Closes #12784 from andrewor14/python-followup.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d33e3d57
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d33e3d57
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d33e3d57

Branch: refs/heads/master
Commit: d33e3d572ed7143f151f9c96fd08407f8de340f4
Parents: 4ae9fe0
Author: Andrew Or <andrew@databricks.com>
Authored: Fri Apr 29 16:41:13 2016 -0700
Committer: Andrew Or <andrew@databricks.com>
Committed: Fri Apr 29 16:41:13 2016 -0700

----------------------------------------------------------------------
 python/pyspark/sql/catalog.py                   | 168 +---------------
 python/pyspark/sql/conf.py                      |  58 ++----
 python/pyspark/sql/context.py                   |   8 +-
 python/pyspark/sql/session.py                   |   4 +-
 python/pyspark/sql/tests.py                     | 199 ++++++++++++++++++-
 .../scala/org/apache/spark/sql/Dataset.scala    |   2 +-
 .../org/apache/spark/sql/RuntimeConfig.scala    |  17 ++
 .../scala/org/apache/spark/sql/SQLContext.scala |   2 +-
 .../org/apache/spark/sql/SparkSession.scala     |   2 +-
 .../spark/sql/execution/command/cache.scala     |   2 +-
 .../org/apache/spark/sql/internal/SQLConf.scala |   7 +
 11 files changed, 256 insertions(+), 213 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/python/pyspark/sql/catalog.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 4f92383..9cfdd0a 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -45,45 +45,19 @@ class Catalog(object):
     @ignore_unicode_prefix
     @since(2.0)
     def currentDatabase(self):
-        """Returns the current default database in this session.
-
-        >>> spark.catalog._reset()
-        >>> spark.catalog.currentDatabase()
-        u'default'
-        """
+        """Returns the current default database in this session."""
         return self._jcatalog.currentDatabase()
 
     @ignore_unicode_prefix
     @since(2.0)
     def setCurrentDatabase(self, dbName):
-        """Sets the current default database in this session.
-
-        >>> spark.catalog._reset()
-        >>> spark.sql("CREATE DATABASE some_db")
-        DataFrame[]
-        >>> spark.catalog.setCurrentDatabase("some_db")
-        >>> spark.catalog.currentDatabase()
-        u'some_db'
-        >>> spark.catalog.setCurrentDatabase("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
-        Traceback (most recent call last):
-            ...
-        AnalysisException: ...
-        """
+        """Sets the current default database in this session."""
         return self._jcatalog.setCurrentDatabase(dbName)
 
     @ignore_unicode_prefix
     @since(2.0)
     def listDatabases(self):
-        """Returns a list of databases available across all sessions.
-
-        >>> spark.catalog._reset()
-        >>> [db.name for db in spark.catalog.listDatabases()]
-        [u'default']
-        >>> spark.sql("CREATE DATABASE some_db")
-        DataFrame[]
-        >>> [db.name for db in spark.catalog.listDatabases()]
-        [u'default', u'some_db']
-        """
+        """Returns a list of databases available across all sessions."""
         iter = self._jcatalog.listDatabases().toLocalIterator()
         databases = []
         while iter.hasNext():
@@ -101,31 +75,6 @@ class Catalog(object):
 
         If no database is specified, the current database is used.
         This includes all temporary tables.
-
-        >>> spark.catalog._reset()
-        >>> spark.sql("CREATE DATABASE some_db")
-        DataFrame[]
-        >>> spark.catalog.listTables()
-        []
-        >>> spark.catalog.listTables("some_db")
-        []
-        >>> spark.createDataFrame([(1, 1)]).registerTempTable("my_temp_tab")
-        >>> spark.sql("CREATE TABLE my_tab1 (name STRING, age INT)")
-        DataFrame[]
-        >>> spark.sql("CREATE TABLE some_db.my_tab2 (name STRING, age INT)")
-        DataFrame[]
-        >>> spark.catalog.listTables()
-        [Table(name=u'my_tab1', database=u'default', description=None, tableType=u'MANAGED',
-        isTemporary=False), Table(name=u'my_temp_tab', database=None, description=None,
-        tableType=u'TEMPORARY', isTemporary=True)]
-        >>> spark.catalog.listTables("some_db")
-        [Table(name=u'my_tab2', database=u'some_db', description=None, tableType=u'MANAGED',
-        isTemporary=False), Table(name=u'my_temp_tab', database=None, description=None,
-        tableType=u'TEMPORARY', isTemporary=True)]
-        >>> spark.catalog.listTables("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
-        Traceback (most recent call last):
-            ...
-        AnalysisException: ...
         """
         if dbName is None:
             dbName = self.currentDatabase()
@@ -148,28 +97,6 @@ class Catalog(object):
 
         If no database is specified, the current database is used.
         This includes all temporary functions.
-
-        >>> spark.catalog._reset()
-        >>> spark.sql("CREATE DATABASE my_db")
-        DataFrame[]
-        >>> funcNames = set(f.name for f in spark.catalog.listFunctions())
-        >>> set(["+", "floor", "to_unix_timestamp", "current_database"]).issubset(funcNames)
-        True
-        >>> spark.sql("CREATE FUNCTION my_func1 AS 'org.apache.spark.whatever'")
-        DataFrame[]
-        >>> spark.sql("CREATE FUNCTION my_db.my_func2 AS 'org.apache.spark.whatever'")
-        DataFrame[]
-        >>> spark.catalog.registerFunction("temp_func", lambda x: str(x))
-        >>> newFuncNames = set(f.name for f in spark.catalog.listFunctions()) -
funcNames
-        >>> newFuncNamesDb = set(f.name for f in spark.catalog.listFunctions("my_db"))
- funcNames
-        >>> sorted(list(newFuncNames))
-        [u'my_func1', u'temp_func']
-        >>> sorted(list(newFuncNamesDb))
-        [u'my_func2', u'temp_func']
-        >>> spark.catalog.listFunctions("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
-        Traceback (most recent call last):
-            ...
-        AnalysisException: ...
         """
         if dbName is None:
             dbName = self.currentDatabase()
@@ -193,26 +120,6 @@ class Catalog(object):
 
         Note: the order of arguments here is different from that of its JVM counterpart
         because Python does not support method overloading.
-
-        >>> spark.catalog._reset()
-        >>> spark.sql("CREATE DATABASE some_db")
-        DataFrame[]
-        >>> spark.sql("CREATE TABLE my_tab1 (name STRING, age INT)")
-        DataFrame[]
-        >>> spark.sql("CREATE TABLE some_db.my_tab2 (nickname STRING, tolerance
FLOAT)")
-        DataFrame[]
-        >>> spark.catalog.listColumns("my_tab1")
-        [Column(name=u'name', description=None, dataType=u'string', nullable=True,
-        isPartition=False, isBucket=False), Column(name=u'age', description=None,
-        dataType=u'int', nullable=True, isPartition=False, isBucket=False)]
-        >>> spark.catalog.listColumns("my_tab2", "some_db")
-        [Column(name=u'nickname', description=None, dataType=u'string', nullable=True,
-        isPartition=False, isBucket=False), Column(name=u'tolerance', description=None,
-        dataType=u'float', nullable=True, isPartition=False, isBucket=False)]
-        >>> spark.catalog.listColumns("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
-        Traceback (most recent call last):
-            ...
-        AnalysisException: ...
         """
         if dbName is None:
             dbName = self.currentDatabase()
@@ -247,7 +154,7 @@ class Catalog(object):
         if path is not None:
             options["path"] = path
         if source is None:
-            source = self._sparkSession.getConf(
+            source = self._sparkSession.conf.get(
                 "spark.sql.sources.default", "org.apache.spark.sql.parquet")
         if schema is None:
             df = self._jcatalog.createExternalTable(tableName, source, options)
@@ -275,16 +182,16 @@ class Catalog(object):
         self._jcatalog.dropTempTable(tableName)
 
     @since(2.0)
-    def registerDataFrameAsTable(self, df, tableName):
+    def registerTable(self, df, tableName):
         """Registers the given :class:`DataFrame` as a temporary table in the catalog.
 
         >>> df = spark.createDataFrame([(2, 1), (3, 1)])
-        >>> spark.catalog.registerDataFrameAsTable(df, "my_cool_table")
+        >>> spark.catalog.registerTable(df, "my_cool_table")
         >>> spark.table("my_cool_table").collect()
         [Row(_1=2, _2=1), Row(_1=3, _2=1)]
         """
         if isinstance(df, DataFrame):
-            self._jsparkSession.registerDataFrameAsTable(df._jdf, tableName)
+            self._jsparkSession.registerTable(df._jdf, tableName)
         else:
             raise ValueError("Can only register DataFrame as table")
 
@@ -321,75 +228,22 @@ class Catalog(object):
 
     @since(2.0)
     def isCached(self, tableName):
-        """Returns true if the table is currently cached in-memory.
-
-        >>> spark.catalog._reset()
-        >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab")
-        >>> spark.catalog.isCached("my_tab")
-        False
-        >>> spark.catalog.cacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
-        Traceback (most recent call last):
-            ...
-        AnalysisException: ...
-        """
+        """Returns true if the table is currently cached in-memory."""
         return self._jcatalog.isCached(tableName)
 
     @since(2.0)
     def cacheTable(self, tableName):
-        """Caches the specified table in-memory.
-
-        >>> spark.catalog._reset()
-        >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab")
-        >>> spark.catalog.isCached("my_tab")
-        False
-        >>> spark.catalog.cacheTable("my_tab")
-        >>> spark.catalog.isCached("my_tab")
-        True
-        >>> spark.catalog.cacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
-        Traceback (most recent call last):
-            ...
-        AnalysisException: ...
-        """
+        """Caches the specified table in-memory."""
         self._jcatalog.cacheTable(tableName)
 
     @since(2.0)
     def uncacheTable(self, tableName):
-        """Removes the specified table from the in-memory cache.
-
-        >>> spark.catalog._reset()
-        >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab")
-        >>> spark.catalog.cacheTable("my_tab")
-        >>> spark.catalog.isCached("my_tab")
-        True
-        >>> spark.catalog.uncacheTable("my_tab")
-        >>> spark.catalog.isCached("my_tab")
-        False
-        >>> spark.catalog.uncacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
-        Traceback (most recent call last):
-            ...
-        AnalysisException: ...
-        """
+        """Removes the specified table from the in-memory cache."""
         self._jcatalog.uncacheTable(tableName)
 
     @since(2.0)
     def clearCache(self):
-        """Removes all cached tables from the in-memory cache.
-
-        >>> spark.catalog._reset()
-        >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab1")
-        >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab2")
-        >>> spark.catalog.cacheTable("my_tab1")
-        >>> spark.catalog.cacheTable("my_tab2")
-        >>> spark.catalog.isCached("my_tab1")
-        True
-        >>> spark.catalog.isCached("my_tab2")
-        True
-        >>> spark.catalog.clearCache()
-        >>> spark.catalog.isCached("my_tab1")
-        False
-        >>> spark.catalog.isCached("my_tab2")
-        False
-        """
+        """Removes all cached tables from the in-memory cache."""
         self._jcatalog.clearCache()
 
     def _reset(self):

http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/python/pyspark/sql/conf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py
index 1d9f052..7428c91 100644
--- a/python/pyspark/sql/conf.py
+++ b/python/pyspark/sql/conf.py
@@ -33,64 +33,34 @@ class RuntimeConfig(object):
     @ignore_unicode_prefix
     @since(2.0)
     def set(self, key, value):
-        """Sets the given Spark runtime configuration property.
-
-        >>> spark.conf.set("garble", "marble")
-        >>> spark.getConf("garble")
-        u'marble'
-        """
+        """Sets the given Spark runtime configuration property."""
         self._jconf.set(key, value)
 
     @ignore_unicode_prefix
     @since(2.0)
-    def get(self, key):
+    def get(self, key, default=None):
         """Returns the value of Spark runtime configuration property for the given key,
         assuming it is set.
-
-        >>> spark.setConf("bogo", "sipeo")
-        >>> spark.conf.get("bogo")
-        u'sipeo'
-        >>> spark.conf.get("definitely.not.set") # doctest: +IGNORE_EXCEPTION_DETAIL
-        Traceback (most recent call last):
-            ...
-        Py4JJavaError: ...
-        """
-        return self._jconf.get(key)
-
-    @ignore_unicode_prefix
-    @since(2.0)
-    def getOption(self, key):
-        """Returns the value of Spark runtime configuration property for the given key,
-        or None if it is not set.
-
-        >>> spark.setConf("bogo", "sipeo")
-        >>> spark.conf.getOption("bogo")
-        u'sipeo'
-        >>> spark.conf.getOption("definitely.not.set") is None
-        True
         """
-        iter = self._jconf.getOption(key).iterator()
-        if iter.hasNext():
-            return iter.next()
+        self._checkType(key, "key")
+        if default is None:
+            return self._jconf.get(key)
         else:
-            return None
+            self._checkType(default, "default")
+            return self._jconf.get(key, default)
 
     @ignore_unicode_prefix
     @since(2.0)
     def unset(self, key):
-        """Resets the configuration property for the given key.
-
-        >>> spark.setConf("armado", "larmado")
-        >>> spark.getConf("armado")
-        u'larmado'
-        >>> spark.conf.unset("armado")
-        >>> spark.getConf("armado") # doctest: +IGNORE_EXCEPTION_DETAIL
-        Traceback (most recent call last):
-            ...
-        Py4JJavaError: ...
-        """
+        """Resets the configuration property for the given key."""
         self._jconf.unset(key)
 
+    def _checkType(self, obj, identifier):
+        """Assert that an object is of type str."""
+        if not isinstance(obj, str) and not isinstance(obj, unicode):
+            raise TypeError("expected %s '%s' to be a string (was '%s')" %
+                            (identifier, obj, type(obj).__name__))
+
 
 def _test():
     import os

http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 94856c2..417d719 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -127,10 +127,10 @@ class SQLContext(object):
 
         >>> sqlContext.getConf("spark.sql.shuffle.partitions")
         u'200'
-        >>> sqlContext.getConf("spark.sql.shuffle.partitions", "10")
+        >>> sqlContext.getConf("spark.sql.shuffle.partitions", u"10")
         u'10'
-        >>> sqlContext.setConf("spark.sql.shuffle.partitions", "50")
-        >>> sqlContext.getConf("spark.sql.shuffle.partitions", "10")
+        >>> sqlContext.setConf("spark.sql.shuffle.partitions", u"50")
+        >>> sqlContext.getConf("spark.sql.shuffle.partitions", u"10")
         u'50'
         """
         return self.sparkSession.getConf(key, defaultValue)
@@ -301,7 +301,7 @@ class SQLContext(object):
 
         >>> sqlContext.registerDataFrameAsTable(df, "table1")
         """
-        self.sparkSession.catalog.registerDataFrameAsTable(df, tableName)
+        self.sparkSession.catalog.registerTable(df, tableName)
 
     @since(1.6)
     def dropTempTable(self, tableName):

http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/python/pyspark/sql/session.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index b3bc896..c245261 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -443,7 +443,7 @@ class SparkSession(object):
 
         :return: :class:`DataFrame`
 
-        >>> spark.catalog.registerDataFrameAsTable(df, "table1")
+        >>> spark.catalog.registerTable(df, "table1")
         >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1")
         >>> df2.collect()
         [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
@@ -456,7 +456,7 @@ class SparkSession(object):
 
         :return: :class:`DataFrame`
 
-        >>> spark.catalog.registerDataFrameAsTable(df, "table1")
+        >>> spark.catalog.registerTable(df, "table1")
         >>> df2 = spark.table("table1")
         >>> sorted(df.collect()) == sorted(df2.collect())
         True

http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1d3dc15..ea98206 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -45,7 +45,7 @@ if sys.version_info[:2] <= (2, 6):
 else:
     import unittest
 
-from pyspark.sql import SQLContext, HiveContext, Column, Row
+from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
 from pyspark.sql.types import *
 from pyspark.sql.types import UserDefinedType, _infer_type
 from pyspark.tests import ReusedPySparkTestCase
@@ -199,7 +199,8 @@ class SQLTests(ReusedPySparkTestCase):
         ReusedPySparkTestCase.setUpClass()
         cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
         os.unlink(cls.tempdir.name)
-        cls.sqlCtx = SQLContext(cls.sc)
+        cls.sparkSession = SparkSession(cls.sc)
+        cls.sqlCtx = cls.sparkSession._wrapped
         cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
         rdd = cls.sc.parallelize(cls.testData, 2)
         cls.df = rdd.toDF()
@@ -1394,6 +1395,200 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(df.schema.simpleString(), "struct<value:int>")
         self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
 
+    def test_conf(self):
+        spark = self.sparkSession
+        spark.setConf("bogo", "sipeo")
+        self.assertEqual(self.sparkSession.conf.get("bogo"), "sipeo")
+        spark.setConf("bogo", "ta")
+        self.assertEqual(spark.conf.get("bogo"), "ta")
+        self.assertEqual(spark.conf.get("bogo", "not.read"), "ta")
+        self.assertEqual(spark.conf.get("not.set", "ta"), "ta")
+        self.assertRaisesRegexp(Exception, "not.set", lambda: spark.conf.get("not.set"))
+        spark.conf.unset("bogo")
+        self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia")
+
+    def test_current_database(self):
+        spark = self.sparkSession
+        spark.catalog._reset()
+        self.assertEquals(spark.catalog.currentDatabase(), "default")
+        spark.sql("CREATE DATABASE some_db")
+        spark.catalog.setCurrentDatabase("some_db")
+        self.assertEquals(spark.catalog.currentDatabase(), "some_db")
+        self.assertRaisesRegexp(
+            AnalysisException,
+            "does_not_exist",
+            lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
+
+    def test_list_databases(self):
+        spark = self.sparkSession
+        spark.catalog._reset()
+        databases = [db.name for db in spark.catalog.listDatabases()]
+        self.assertEquals(databases, ["default"])
+        spark.sql("CREATE DATABASE some_db")
+        databases = [db.name for db in spark.catalog.listDatabases()]
+        self.assertEquals(sorted(databases), ["default", "some_db"])
+
+    def test_list_tables(self):
+        from pyspark.sql.catalog import Table
+        spark = self.sparkSession
+        spark.catalog._reset()
+        spark.sql("CREATE DATABASE some_db")
+        self.assertEquals(spark.catalog.listTables(), [])
+        self.assertEquals(spark.catalog.listTables("some_db"), [])
+        spark.createDataFrame([(1, 1)]).registerTempTable("temp_tab")
+        spark.sql("CREATE TABLE tab1 (name STRING, age INT)")
+        spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT)")
+        tables = sorted(spark.catalog.listTables(), key=lambda t: t.name)
+        tablesDefault = sorted(spark.catalog.listTables("default"), key=lambda t: t.name)
+        tablesSomeDb = sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name)
+        self.assertEquals(tables, tablesDefault)
+        self.assertEquals(len(tables), 2)
+        self.assertEquals(len(tablesSomeDb), 2)
+        self.assertEquals(tables[0], Table(
+            name="tab1",
+            database="default",
+            description=None,
+            tableType="MANAGED",
+            isTemporary=False))
+        self.assertEquals(tables[1], Table(
+            name="temp_tab",
+            database=None,
+            description=None,
+            tableType="TEMPORARY",
+            isTemporary=True))
+        self.assertEquals(tablesSomeDb[0], Table(
+            name="tab2",
+            database="some_db",
+            description=None,
+            tableType="MANAGED",
+            isTemporary=False))
+        self.assertEquals(tablesSomeDb[1], Table(
+            name="temp_tab",
+            database=None,
+            description=None,
+            tableType="TEMPORARY",
+            isTemporary=True))
+        self.assertRaisesRegexp(
+            AnalysisException,
+            "does_not_exist",
+            lambda: spark.catalog.listTables("does_not_exist"))
+
+    def test_list_functions(self):
+        from pyspark.sql.catalog import Function
+        spark = self.sparkSession
+        spark.catalog._reset()
+        spark.sql("CREATE DATABASE some_db")
+        functions = dict((f.name, f) for f in spark.catalog.listFunctions())
+        functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default"))
+        self.assertTrue(len(functions) > 200)
+        self.assertTrue("+" in functions)
+        self.assertTrue("like" in functions)
+        self.assertTrue("month" in functions)
+        self.assertTrue("to_unix_timestamp" in functions)
+        self.assertTrue("current_database" in functions)
+        self.assertEquals(functions["+"], Function(
+            name="+",
+            description=None,
+            className="org.apache.spark.sql.catalyst.expressions.Add",
+            isTemporary=True))
+        self.assertEquals(functions, functionsDefault)
+        spark.catalog.registerFunction("temp_func", lambda x: str(x))
+        spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
+        spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'")
+        newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions())
+        newFunctionsSomeDb = dict((f.name, f) for f in spark.catalog.listFunctions("some_db"))
+        self.assertTrue(set(functions).issubset(set(newFunctions)))
+        self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb)))
+        self.assertTrue("temp_func" in newFunctions)
+        self.assertTrue("func1" in newFunctions)
+        self.assertTrue("func2" not in newFunctions)
+        self.assertTrue("temp_func" in newFunctionsSomeDb)
+        self.assertTrue("func1" not in newFunctionsSomeDb)
+        self.assertTrue("func2" in newFunctionsSomeDb)
+        self.assertRaisesRegexp(
+            AnalysisException,
+            "does_not_exist",
+            lambda: spark.catalog.listFunctions("does_not_exist"))
+
+    def test_list_columns(self):
+        from pyspark.sql.catalog import Column
+        spark = self.sparkSession
+        spark.catalog._reset()
+        spark.sql("CREATE DATABASE some_db")
+        spark.sql("CREATE TABLE tab1 (name STRING, age INT)")
+        spark.sql("CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT)")
+        columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name)
+        columnsDefault = sorted(spark.catalog.listColumns("tab1", "default"), key=lambda
c: c.name)
+        self.assertEquals(columns, columnsDefault)
+        self.assertEquals(len(columns), 2)
+        self.assertEquals(columns[0], Column(
+            name="age",
+            description=None,
+            dataType="int",
+            nullable=True,
+            isPartition=False,
+            isBucket=False))
+        self.assertEquals(columns[1], Column(
+            name="name",
+            description=None,
+            dataType="string",
+            nullable=True,
+            isPartition=False,
+            isBucket=False))
+        columns2 = sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name)
+        self.assertEquals(len(columns2), 2)
+        self.assertEquals(columns2[0], Column(
+            name="nickname",
+            description=None,
+            dataType="string",
+            nullable=True,
+            isPartition=False,
+            isBucket=False))
+        self.assertEquals(columns2[1], Column(
+            name="tolerance",
+            description=None,
+            dataType="float",
+            nullable=True,
+            isPartition=False,
+            isBucket=False))
+        self.assertRaisesRegexp(
+            AnalysisException,
+            "tab2",
+            lambda: spark.catalog.listColumns("tab2"))
+        self.assertRaisesRegexp(
+            AnalysisException,
+            "does_not_exist",
+            lambda: spark.catalog.listColumns("does_not_exist"))
+
+    def test_cache(self):
+        spark = self.sparkSession
+        spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab1")
+        spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab2")
+        self.assertFalse(spark.catalog.isCached("tab1"))
+        self.assertFalse(spark.catalog.isCached("tab2"))
+        spark.catalog.cacheTable("tab1")
+        self.assertTrue(spark.catalog.isCached("tab1"))
+        self.assertFalse(spark.catalog.isCached("tab2"))
+        spark.catalog.cacheTable("tab2")
+        spark.catalog.uncacheTable("tab1")
+        self.assertFalse(spark.catalog.isCached("tab1"))
+        self.assertTrue(spark.catalog.isCached("tab2"))
+        spark.catalog.clearCache()
+        self.assertFalse(spark.catalog.isCached("tab1"))
+        self.assertFalse(spark.catalog.isCached("tab2"))
+        self.assertRaisesRegexp(
+            AnalysisException,
+            "does_not_exist",
+            lambda: spark.catalog.isCached("does_not_exist"))
+        self.assertRaisesRegexp(
+            AnalysisException,
+            "does_not_exist",
+            lambda: spark.catalog.cacheTable("does_not_exist"))
+        self.assertRaisesRegexp(
+            AnalysisException,
+            "does_not_exist",
+            lambda: spark.catalog.uncacheTable("does_not_exist"))
+
 
 class HiveContextSQLTests(ReusedPySparkTestCase):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 1439d14..08be94e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2308,7 +2308,7 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def registerTempTable(tableName: String): Unit = {
-    sparkSession.registerDataFrameAsTable(toDF(), tableName)
+    sparkSession.registerTable(toDF(), tableName)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
index bf97d72..f2e8515 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
@@ -72,6 +72,15 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) {
    *
    * @since 2.0.0
    */
+  def get(key: String, default: String): String = {
+    sqlConf.getConfString(key, default)
+  }
+
+  /**
+   * Returns the value of Spark runtime configuration property for the given key.
+   *
+   * @since 2.0.0
+   */
   def getOption(key: String): Option[String] = {
     try Option(get(key)) catch {
       case _: NoSuchElementException => None
@@ -86,4 +95,12 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) {
   def unset(key: String): Unit = {
     sqlConf.unsetConf(key)
   }
+
+  /**
+   * Returns whether a particular key is set.
+   */
+  protected[sql] def contains(key: String): Boolean = {
+    sqlConf.contains(key)
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 1f08a61..6dfac3d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -600,7 +600,7 @@ class SQLContext private[sql](
    * only during the lifetime of this instance of SQLContext.
    */
   private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = {
-    sparkSession.registerDataFrameAsTable(df, tableName)
+    sparkSession.registerTable(df, tableName)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 2814b70..11c0aaa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -577,7 +577,7 @@ class SparkSession private(
    * Registers the given [[DataFrame]] as a temporary table in the catalog.
    * Temporary tables exist only during the lifetime of this instance of [[SparkSession]].
    */
-  protected[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = {
+  protected[sql] def registerTable(df: DataFrame, tableName: String): Unit = {
     sessionState.catalog.createTempTable(
       sessionState.sqlParser.parseTableIdentifier(tableName).table,
       df.logicalPlan,

http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
index ec3fada..f05401b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
@@ -30,7 +30,7 @@ case class CacheTableCommand(
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     plan.foreach { logicalPlan =>
-      sparkSession.registerDataFrameAsTable(Dataset.ofRows(sparkSession, logicalPlan), tableName)
+      sparkSession.registerTable(Dataset.ofRows(sparkSession, logicalPlan), tableName)
     }
     sparkSession.catalog.cacheTable(tableName)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 2bfc895..7de7748 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -755,6 +755,13 @@ private[sql] class SQLConf extends Serializable with CatalystConf with
Logging {
     }.toSeq
   }
 
+  /**
+   * Return whether a given key is set in this [[SQLConf]].
+   */
+  def contains(key: String): Boolean = {
+    settings.containsKey(key)
+  }
+
   private def setConfWithCheck(key: String, value: String): Unit = {
     if (key.startsWith("spark.") && !key.startsWith("spark.sql.")) {
       logWarning(s"Attempt to set non-Spark SQL config in SQLConf: key = $key, value = $value")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message