spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-7738] [SQL] [PySpark] add reader and writer API in Python
Date Tue, 19 May 2015 21:23:32 GMT
Repository: spark
Updated Branches:
  refs/heads/master c12dff9b8 -> 4de74d260


[SPARK-7738] [SQL] [PySpark] add reader and writer API in Python

cc rxin, please take a quick look, I'm working on tests.

Author: Davies Liu <davies@databricks.com>

Closes #6238 from davies/readwrite and squashes the following commits:

c7200eb [Davies Liu] update tests
9cbf01b [Davies Liu] Merge branch 'master' of github.com:apache/spark into readwrite
f0c5a04 [Davies Liu] use sqlContext.read.load
5f68bc8 [Davies Liu] update tests
6437e9a [Davies Liu] Merge branch 'master' of github.com:apache/spark into readwrite
bcc6668 [Davies Liu] add reader amd writer API in Python


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4de74d26
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4de74d26
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4de74d26

Branch: refs/heads/master
Commit: 4de74d2602f6577c3c8458aa85377e89c19724ca
Parents: c12dff9
Author: Davies Liu <davies@databricks.com>
Authored: Tue May 19 14:23:28 2015 -0700
Committer: Reynold Xin <rxin@databricks.com>
Committed: Tue May 19 14:23:28 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/api/python/PythonUtils.scala   |  11 +-
 python/pyspark/sql/__init__.py                  |   1 +
 python/pyspark/sql/context.py                   |  28 +-
 python/pyspark/sql/dataframe.py                 |  67 ++--
 python/pyspark/sql/readwriter.py                | 338 +++++++++++++++++++
 python/pyspark/sql/tests.py                     |  77 ++---
 6 files changed, 430 insertions(+), 92 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4de74d26/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
index efb6b93..90dacae 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
@@ -50,8 +50,15 @@ private[spark] object PythonUtils {
   /**
    * Convert list of T into seq of T (for calling API with varargs)
    */
-  def toSeq[T](cols: JList[T]): Seq[T] = {
-    cols.toList.toSeq
+  def toSeq[T](vs: JList[T]): Seq[T] = {
+    vs.toList.toSeq
+  }
+
+  /**
+   * Convert list of T into array of T (for calling API with array)
+   */
+  def toArray[T](vs: JList[T]): Array[T] = {
+    vs.toArray().asInstanceOf[Array[T]]
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/4de74d26/python/pyspark/sql/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index 19805e2..634c575 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -58,6 +58,7 @@ from pyspark.sql.context import SQLContext, HiveContext
 from pyspark.sql.column import Column
 from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions
 from pyspark.sql.group import GroupedData
+from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
 
 __all__ = [
     'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',

http://git-wip-us.apache.org/repos/asf/spark/blob/4de74d26/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 9f26d13..7543475 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -31,6 +31,7 @@ from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
 from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
     _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
 from pyspark.sql.dataframe import DataFrame
+from pyspark.sql.readwriter import DataFrameReader
 
 try:
     import pandas
@@ -457,19 +458,7 @@ class SQLContext(object):
 
         Optionally, a schema can be provided as the schema of the returned DataFrame.
         """
-        if path is not None:
-            options["path"] = path
-        if source is None:
-            source = self.getConf("spark.sql.sources.default",
-                                  "org.apache.spark.sql.parquet")
-        if schema is None:
-            df = self._ssql_ctx.load(source, options)
-        else:
-            if not isinstance(schema, StructType):
-                raise TypeError("schema should be StructType")
-            scala_datatype = self._ssql_ctx.parseDataType(schema.json())
-            df = self._ssql_ctx.load(source, scala_datatype, options)
-        return DataFrame(df, self)
+        return self.read.load(path, source, schema, **options)
 
     def createExternalTable(self, tableName, path=None, source=None,
                             schema=None, **options):
@@ -567,6 +556,19 @@ class SQLContext(object):
         """Removes all cached tables from the in-memory cache. """
         self._ssql_ctx.clearCache()
 
+    @property
+    def read(self):
+        """
+        Returns a :class:`DataFrameReader` that can be used to read data
+        in as a :class:`DataFrame`.
+
+        ::note: Experimental
+
+        >>> sqlContext.read
+        <pyspark.sql.readwriter.DataFrameReader object at ...>
+        """
+        return DataFrameReader(self)
+
 
 class HiveContext(SQLContext):
     """A variant of Spark SQL that integrates with data stored in Hive.

http://git-wip-us.apache.org/repos/asf/spark/blob/4de74d26/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index e4a191a..f2280b5 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -29,9 +29,10 @@ from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
 from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.traceback_utils import SCCallSiteSync
-from pyspark.sql.types import *
 from pyspark.sql.types import _create_cls, _parse_datatype_json_string
 from pyspark.sql.column import Column, _to_seq, _to_java_column
+from pyspark.sql.readwriter import DataFrameWriter
+from pyspark.sql.types import *
 
 __all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"]
 
@@ -151,25 +152,6 @@ class DataFrame(object):
         """
         self._jdf.insertInto(tableName, overwrite)
 
-    def _java_save_mode(self, mode):
-        """Returns the Java save mode based on the Python save mode represented by a string.
-        """
-        jSaveMode = self._sc._jvm.org.apache.spark.sql.SaveMode
-        jmode = jSaveMode.ErrorIfExists
-        mode = mode.lower()
-        if mode == "append":
-            jmode = jSaveMode.Append
-        elif mode == "overwrite":
-            jmode = jSaveMode.Overwrite
-        elif mode == "ignore":
-            jmode = jSaveMode.Ignore
-        elif mode == "error":
-            pass
-        else:
-            raise ValueError(
-                "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.")
-        return jmode
-
     def saveAsTable(self, tableName, source=None, mode="error", **options):
         """Saves the contents of this :class:`DataFrame` to a data source as a table.
 
@@ -185,11 +167,7 @@ class DataFrame(object):
         * `error`: Throw an exception if data already exists.
         * `ignore`: Silently ignore this operation if data already exists.
         """
-        if source is None:
-            source = self.sql_ctx.getConf("spark.sql.sources.default",
-                                          "org.apache.spark.sql.parquet")
-        jmode = self._java_save_mode(mode)
-        self._jdf.saveAsTable(tableName, source, jmode, options)
+        self.write.saveAsTable(tableName, source, mode, **options)
 
     def save(self, path=None, source=None, mode="error", **options):
         """Saves the contents of the :class:`DataFrame` to a data source.
@@ -206,13 +184,22 @@ class DataFrame(object):
         * `error`: Throw an exception if data already exists.
         * `ignore`: Silently ignore this operation if data already exists.
         """
-        if path is not None:
-            options["path"] = path
-        if source is None:
-            source = self.sql_ctx.getConf("spark.sql.sources.default",
-                                          "org.apache.spark.sql.parquet")
-        jmode = self._java_save_mode(mode)
-        self._jdf.save(source, jmode, options)
+        return self.write.save(path, source, mode, **options)
+
+    @property
+    def write(self):
+        """
+        Interface for saving the content of the :class:`DataFrame` out
+        into external storage.
+
+        :return :class:`DataFrameWriter`
+
+        ::note: Experimental
+
+        >>> df.write
+        <pyspark.sql.readwriter.DataFrameWriter object at ...>
+        """
+        return DataFrameWriter(self)
 
     @property
     def schema(self):
@@ -411,9 +398,19 @@ class DataFrame(object):
         self._jdf.unpersist(blocking)
         return self
 
-    # def coalesce(self, numPartitions, shuffle=False):
-    #     rdd = self._jdf.coalesce(numPartitions, shuffle, None)
-    #     return DataFrame(rdd, self.sql_ctx)
+    def coalesce(self, numPartitions):
+        """
+        Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions.
+
+        Similar to coalesce defined on an :class:`RDD`, this operation results in a
+        narrow dependency, e.g. if you go from 1000 partitions to 100 partitions,
+        there will not be a shuffle, instead each of the 100 new partitions will
+        claim 10 of the current partitions.
+
+        >>> df.coalesce(1).rdd.getNumPartitions()
+        1
+        """
+        return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx)
 
     def repartition(self, numPartitions):
         """Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions.

http://git-wip-us.apache.org/repos/asf/spark/blob/4de74d26/python/pyspark/sql/readwriter.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
new file mode 100644
index 0000000..e2b27fb
--- /dev/null
+++ b/python/pyspark/sql/readwriter.py
@@ -0,0 +1,338 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from py4j.java_gateway import JavaClass
+
+from pyspark.sql.column import _to_seq
+from pyspark.sql.types import *
+
+__all__ = ["DataFrameReader", "DataFrameWriter"]
+
+
+class DataFrameReader(object):
+    """
+    Interface used to load a :class:`DataFrame` from external storage systems
+    (e.g. file systems, key-value stores, etc). Use :func:`SQLContext.read`
+    to access this.
+
+    ::Note: Experimental
+    """
+
+    def __init__(self, sqlContext):
+        self._jreader = sqlContext._ssql_ctx.read()
+        self._sqlContext = sqlContext
+
+    def _df(self, jdf):
+        from pyspark.sql.dataframe import DataFrame
+        return DataFrame(jdf, self._sqlContext)
+
+    def load(self, path=None, format=None, schema=None, **options):
+        """Loads data from a data source and returns it as a :class`DataFrame`.
+
+        :param path: optional string for file-system backed data sources.
+        :param format: optional string for format of the data source. Default to 'parquet'.
+        :param schema: optional :class:`StructType` for the input schema.
+        :param options: all other string options
+        """
+        jreader = self._jreader
+        if format is not None:
+            jreader = jreader.format(format)
+        if schema is not None:
+            if not isinstance(schema, StructType):
+                raise TypeError("schema should be StructType")
+            jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
+            jreader = jreader.schema(jschema)
+        for k in options:
+            jreader = jreader.option(k, options[k])
+        if path is not None:
+            return self._df(jreader.load(path))
+        else:
+            return self._df(jreader.load())
+
+    def json(self, path, schema=None):
+        """
+        Loads a JSON file (one object per line) and returns the result as
+        a :class`DataFrame`.
+
+        If the ``schema`` parameter is not specified, this function goes
+        through the input once to determine the input schema.
+
+        :param path: string, path to the JSON dataset.
+        :param schema: an optional :class:`StructType` for the input schema.
+
+        >>> import tempfile, shutil
+        >>> jsonFile = tempfile.mkdtemp()
+        >>> shutil.rmtree(jsonFile)
+        >>> with open(jsonFile, 'w') as f:
+        ...     f.writelines(jsonStrings)
+        >>> df1 = sqlContext.read.json(jsonFile)
+        >>> df1.printSchema()
+        root
+         |-- field1: long (nullable = true)
+         |-- field2: string (nullable = true)
+         |-- field3: struct (nullable = true)
+         |    |-- field4: long (nullable = true)
+
+        >>> from pyspark.sql.types import *
+        >>> schema = StructType([
+        ...     StructField("field2", StringType()),
+        ...     StructField("field3",
+        ...         StructType([StructField("field5", ArrayType(IntegerType()))]))])
+        >>> df2 = sqlContext.read.json(jsonFile, schema)
+        >>> df2.printSchema()
+        root
+         |-- field2: string (nullable = true)
+         |-- field3: struct (nullable = true)
+         |    |-- field5: array (nullable = true)
+         |    |    |-- element: integer (containsNull = true)
+        """
+        if schema is None:
+            jdf = self._jreader.json(path)
+        else:
+            jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
+            jdf = self._jreader.schema(jschema).json(path)
+        return self._df(jdf)
+
+    def table(self, tableName):
+        """Returns the specified table as a :class:`DataFrame`.
+
+        >>> sqlContext.registerDataFrameAsTable(df, "table1")
+        >>> df2 = sqlContext.read.table("table1")
+        >>> sorted(df.collect()) == sorted(df2.collect())
+        True
+        """
+        return self._df(self._jreader.table(tableName))
+
+    def parquet(self, *path):
+        """Loads a Parquet file, returning the result as a :class:`DataFrame`.
+
+        >>> import tempfile, shutil
+        >>> parquetFile = tempfile.mkdtemp()
+        >>> shutil.rmtree(parquetFile)
+        >>> df.saveAsParquetFile(parquetFile)
+        >>> df2 = sqlContext.read.parquet(parquetFile)
+        >>> sorted(df.collect()) == sorted(df2.collect())
+        True
+        """
+        return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, path)))
+
+    def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None,
+             predicates=None, properties={}):
+        """
+        Construct a :class:`DataFrame` representing the database table accessible
+        via JDBC URL `url` named `table` and connection `properties`.
+
+        The `column` parameter could be used to partition the table, then it will
+        be retrieved in parallel based on the parameters passed to this function.
+
+        The `predicates` parameter gives a list expressions suitable for inclusion
+        in WHERE clauses; each one defines one partition of the :class:`DataFrame`.
+
+        ::Note: Don't create too many partitions in parallel on a large cluster;
+        otherwise Spark might crash your external database systems.
+
+        :param url: a JDBC URL
+        :param table: name of table
+        :param column: the column used to partition
+        :param lowerBound: the lower bound of partition column
+        :param upperBound: the upper bound of the partition column
+        :param numPartitions: the number of partitions
+        :param predicates: a list of expressions
+        :param properties: JDBC database connection arguments, a list of arbitrary string
+                           tag/value. Normally at least a "user" and "password" property
+                           should be included.
+        :return: a DataFrame
+        """
+        jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
+        for k in properties:
+            jprop.setProperty(k, properties[k])
+        if column is not None:
+            if numPartitions is None:
+                numPartitions = self._sqlContext._sc.defaultParallelism
+            return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound),
+                                               int(numPartitions), jprop))
+        if predicates is not None:
+            arr = self._sqlContext._sc._jvm.PythonUtils.toArray(predicates)
+            return self._df(self._jreader.jdbc(url, table, arr, jprop))
+        return self._df(self._jreader.jdbc(url, table, jprop))
+
+
+class DataFrameWriter(object):
+    """
+    Interface used to write a [[DataFrame]] to external storage systems
+    (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.write`
+    to access this.
+
+    ::Note: Experimental
+    """
+    def __init__(self, df):
+        self._df = df
+        self._sqlContext = df.sql_ctx
+        self._jwrite = df._jdf.write()
+
+    def save(self, path=None, format=None, mode="error", **options):
+        """
+        Saves the contents of the :class:`DataFrame` to a data source.
+
+        The data source is specified by the ``format`` and a set of ``options``.
+        If ``format`` is not specified, the default data source configured by
+        ``spark.sql.sources.default`` will be used.
+
+        Additionally, mode is used to specify the behavior of the save operation when
+        data already exists in the data source. There are four modes:
+
+        * `append`: Append contents of this :class:`DataFrame` to existing data.
+        * `overwrite`: Overwrite existing data.
+        * `error`: Throw an exception if data already exists.
+        * `ignore`: Silently ignore this operation if data already exists.
+
+        :param path: the path in a Hadoop supported file system
+        :param format: the format used to save
+        :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+        :param options: all other string options
+        """
+        jwrite = self._jwrite.mode(mode)
+        if format is not None:
+            jwrite = jwrite.format(format)
+        for k in options:
+            jwrite = jwrite.option(k, options[k])
+        if path is None:
+            jwrite.save()
+        else:
+            jwrite.save(path)
+
+    def saveAsTable(self, name, format=None, mode="error", **options):
+        """
+        Saves the contents of this :class:`DataFrame` to a data source as a table.
+
+        The data source is specified by the ``source`` and a set of ``options``.
+        If ``source`` is not specified, the default data source configured by
+        ``spark.sql.sources.default`` will be used.
+
+        Additionally, mode is used to specify the behavior of the saveAsTable operation when
+        table already exists in the data source. There are four modes:
+
+        * `append`: Append contents of this :class:`DataFrame` to existing data.
+        * `overwrite`: Overwrite existing data.
+        * `error`: Throw an exception if data already exists.
+        * `ignore`: Silently ignore this operation if data already exists.
+
+        :param name: the table name
+        :param format: the format used to save
+        :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+        :param options: all other string options
+        """
+        jwrite = self._jwrite.mode(mode)
+        if format is not None:
+            jwrite = jwrite.format(format)
+        for k in options:
+            jwrite = jwrite.option(k, options[k])
+        return jwrite.saveAsTable(name)
+
+    def json(self, path, mode="error"):
+        """
+        Saves the content of the :class:`DataFrame` in JSON format at the
+        specified path.
+
+        Additionally, mode is used to specify the behavior of the save operation when
+        data already exists in the data source. There are four modes:
+
+        * `append`: Append contents of this :class:`DataFrame` to existing data.
+        * `overwrite`: Overwrite existing data.
+        * `error`: Throw an exception if data already exists.
+        * `ignore`: Silently ignore this operation if data already exists.
+
+        :param path: the path in any Hadoop supported file system
+        :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+        """
+        return self._jwrite.mode(mode).json(path)
+
+    def parquet(self, path, mode="error"):
+        """
+        Saves the content of the :class:`DataFrame` in Parquet format at the
+        specified path.
+
+        Additionally, mode is used to specify the behavior of the save operation when
+        data already exists in the data source. There are four modes:
+
+        * `append`: Append contents of this :class:`DataFrame` to existing data.
+        * `overwrite`: Overwrite existing data.
+        * `error`: Throw an exception if data already exists.
+        * `ignore`: Silently ignore this operation if data already exists.
+
+        :param path: the path in any Hadoop supported file system
+        :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+        """
+        return self._jwrite.mode(mode).parquet(path)
+
+    def jdbc(self, url, table, mode="error", properties={}):
+        """
+        Saves the content of the :class:`DataFrame` to a external database table
+        via JDBC.
+
+        In the case the table already exists in the external database,
+        behavior of this function depends on the save mode, specified by the `mode`
+        function (default to throwing an exception). There are four modes:
+
+        * `append`: Append contents of this :class:`DataFrame` to existing data.
+        * `overwrite`: Overwrite existing data.
+        * `error`: Throw an exception if data already exists.
+        * `ignore`: Silently ignore this operation if data already exists.
+
+        :param url: a JDBC URL of the form `jdbc:subprotocol:subname`
+        :param table: Name of the table in the external database.
+        :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+        :param properties: JDBC database connection arguments, a list of
+                                    arbitrary string tag/value. Normally at least a
+                                    "user" and "password" property should be included.
+        """
+        jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
+        for k in properties:
+            jprop.setProperty(k, properties[k])
+        self._jwrite.mode(mode).jdbc(url, table, jprop)
+
+
+def _test():
+    import doctest
+    from pyspark.context import SparkContext
+    from pyspark.sql import Row, SQLContext
+    import pyspark.sql.readwriter
+    globs = pyspark.sql.readwriter.__dict__.copy()
+    sc = SparkContext('local[4]', 'PythonTest')
+    globs['sc'] = sc
+    globs['sqlContext'] = SQLContext(sc)
+    globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
+        .toDF(StructType([StructField('age', IntegerType()),
+                          StructField('name', StringType())]))
+    jsonStrings = [
+        '{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
+        '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
+        '"field6":[{"field7": "row2"}]}',
+        '{"field1" : null, "field2": "row3", '
+        '"field3":{"field4":33, "field5": []}}'
+    ]
+    globs['jsonStrings'] = jsonStrings
+    (failure_count, test_count) = doctest.testmod(
+        pyspark.sql.readwriter, globs=globs,
+        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
+    globs['sc'].stop()
+    if failure_count:
+        exit(-1)
+
+
+if __name__ == "__main__":
+    _test()

http://git-wip-us.apache.org/repos/asf/spark/blob/4de74d26/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 84ae36f..7e34996 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -485,29 +485,29 @@ class SQLTests(ReusedPySparkTestCase):
         df = self.df
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
-        df.save(tmpPath, "org.apache.spark.sql.json", "error")
-        actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
-        self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+        df.write.json(tmpPath)
+        actual = self.sqlCtx.read.json(tmpPath)
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
         schema = StructType([StructField("value", StringType(), True)])
-        actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema)
-        self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
+        actual = self.sqlCtx.read.json(tmpPath, schema)
+        self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
 
-        df.save(tmpPath, "org.apache.spark.sql.json", "overwrite")
-        actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
-        self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+        df.write.json(tmpPath, "overwrite")
+        actual = self.sqlCtx.read.json(tmpPath)
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
-        df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath,
-                noUse="this options will not be used in save.")
-        actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath,
-                                  noUse="this options will not be used in load.")
-        self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+        df.write.save(format="json", mode="overwrite", path=tmpPath,
+                      noUse="this options will not be used in save.")
+        actual = self.sqlCtx.read.load(format="json", path=tmpPath,
+                                       noUse="this options will not be used in load.")
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
         defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
                                                     "org.apache.spark.sql.parquet")
         self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
         actual = self.sqlCtx.load(path=tmpPath)
-        self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
         self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
 
         shutil.rmtree(tmpPath)
@@ -767,51 +767,44 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
         df = self.df
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
-        df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath)
-        actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath,
-                                                 "org.apache.spark.sql.json")
-        self.assertTrue(
-            sorted(df.collect()) ==
-            sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
-        self.assertTrue(
-            sorted(df.collect()) ==
-            sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
-        self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+        df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath)
+        actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, "json")
+        self.assertEqual(sorted(df.collect()),
+                         sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+        self.assertEqual(sorted(df.collect()),
+                         sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
         self.sqlCtx.sql("DROP TABLE externalJsonTable")
 
-        df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath)
+        df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath)
         schema = StructType([StructField("value", StringType(), True)])
-        actual = self.sqlCtx.createExternalTable("externalJsonTable",
-                                                 source="org.apache.spark.sql.json",
+        actual = self.sqlCtx.createExternalTable("externalJsonTable", source="json",
                                                  schema=schema, path=tmpPath,
                                                  noUse="this options will not be used")
-        self.assertTrue(
-            sorted(df.collect()) ==
-            sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
-        self.assertTrue(
-            sorted(df.select("value").collect()) ==
-            sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
-        self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
+        self.assertEqual(sorted(df.collect()),
+                         sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+        self.assertEqual(sorted(df.select("value").collect()),
+                         sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+        self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
         self.sqlCtx.sql("DROP TABLE savedJsonTable")
         self.sqlCtx.sql("DROP TABLE externalJsonTable")
 
         defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
                                                     "org.apache.spark.sql.parquet")
         self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
-        df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
+        df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
         actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
-        self.assertTrue(
-            sorted(df.collect()) ==
-            sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
-        self.assertTrue(
-            sorted(df.collect()) ==
-            sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
-        self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+        self.assertEqual(sorted(df.collect()),
+                         sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+        self.assertEqual(sorted(df.collect()),
+                         sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
         self.sqlCtx.sql("DROP TABLE savedJsonTable")
         self.sqlCtx.sql("DROP TABLE externalJsonTable")
         self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
 
         shutil.rmtree(tmpPath)
 
+
 if __name__ == "__main__":
     unittest.main()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message