spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-5677] [SPARK-5734] [SQL] [PySpark] Python DataFrame API remaining tasks
Date Wed, 11 Feb 2015 20:13:18 GMT
Repository: spark
Updated Branches:
  refs/heads/master 1ac099e3e -> b694eb9c2


[SPARK-5677] [SPARK-5734] [SQL] [PySpark] Python DataFrame API remaining tasks

1. DataFrame.renameColumn

2. DataFrame.show() and _repr_

3. Use simpleString() rather than jsonValue in DataFrame.dtypes

4. createDataFrame from local Python data, including pandas.DataFrame

Author: Davies Liu <davies@databricks.com>

Closes #4528 from davies/df3 and squashes the following commits:

014acea [Davies Liu] fix typo
6ba526e [Davies Liu] fix tests
46f5f95 [Davies Liu] address comments
6cbc154 [Davies Liu] dataframe.show() and improve dtypes
6f94f25 [Davies Liu] create DataFrame from local Python data


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b694eb9c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b694eb9c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b694eb9c

Branch: refs/heads/master
Commit: b694eb9c2fefeaa33891d3e61f9bea369bc09984
Parents: 1ac099e
Author: Davies Liu <davies@databricks.com>
Authored: Wed Feb 11 12:13:16 2015 -0800
Committer: Reynold Xin <rxin@databricks.com>
Committed: Wed Feb 11 12:13:16 2015 -0800

----------------------------------------------------------------------
 python/pyspark/sql/context.py                   | 114 ++++++++++++-------
 python/pyspark/sql/dataframe.py                 |  42 ++++++-
 python/pyspark/sql/tests.py                     |   2 +-
 python/pyspark/sql/types.py                     |  32 ++++++
 .../org/apache/spark/sql/DataFrameImpl.scala    |  15 ++-
 5 files changed, 155 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b694eb9c/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 9d29ef4..db4bcbe 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -23,12 +23,18 @@ from itertools import imap
 from py4j.protocol import Py4JError
 from py4j.java_collections import MapConverter
 
-from pyspark.rdd import _prepare_for_python_RDD
+from pyspark.rdd import RDD, _prepare_for_python_RDD
 from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
-from pyspark.sql.types import StringType, StructType, _infer_type, _verify_type, \
+from pyspark.sql.types import StringType, StructType, _verify_type, \
     _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
 from pyspark.sql.dataframe import DataFrame
 
+try:
+    import pandas
+    has_pandas = True
+except ImportError:
+    has_pandas = False
+
 __all__ = ["SQLContext", "HiveContext"]
 
 
@@ -116,6 +122,31 @@ class SQLContext(object):
                                             self._sc._javaAccumulator,
                                             returnType.json())
 
+    def _inferSchema(self, rdd, samplingRatio=None):
+        first = rdd.first()
+        if not first:
+            raise ValueError("The first row in RDD is empty, "
+                             "can not infer schema")
+        if type(first) is dict:
+            warnings.warn("Using RDD of dict to inferSchema is deprecated,"
+                          "please use pyspark.sql.Row instead")
+
+        if samplingRatio is None:
+            schema = _infer_schema(first)
+            if _has_nulltype(schema):
+                for row in rdd.take(100)[1:]:
+                    schema = _merge_type(schema, _infer_schema(row))
+                    if not _has_nulltype(schema):
+                        break
+                else:
+                    raise ValueError("Some of types cannot be determined by the "
+                                     "first 100 rows, please try again with sampling")
+        else:
+            if samplingRatio < 0.99:
+                rdd = rdd.sample(False, float(samplingRatio))
+            schema = rdd.map(_infer_schema).reduce(_merge_type)
+        return schema
+
     def inferSchema(self, rdd, samplingRatio=None):
         """Infer and apply a schema to an RDD of L{Row}.
 
@@ -171,29 +202,7 @@ class SQLContext(object):
         if isinstance(rdd, DataFrame):
             raise TypeError("Cannot apply schema to DataFrame")
 
-        first = rdd.first()
-        if not first:
-            raise ValueError("The first row in RDD is empty, "
-                             "can not infer schema")
-        if type(first) is dict:
-            warnings.warn("Using RDD of dict to inferSchema is deprecated,"
-                          "please use pyspark.sql.Row instead")
-
-        if samplingRatio is None:
-            schema = _infer_schema(first)
-            if _has_nulltype(schema):
-                for row in rdd.take(100)[1:]:
-                    schema = _merge_type(schema, _infer_schema(row))
-                    if not _has_nulltype(schema):
-                        break
-                else:
-                    warnings.warn("Some of types cannot be determined by the "
-                                  "first 100 rows, please try again with sampling")
-        else:
-            if samplingRatio < 0.99:
-                rdd = rdd.sample(False, float(samplingRatio))
-            schema = rdd.map(_infer_schema).reduce(_merge_type)
-
+        schema = self._inferSchema(rdd, samplingRatio)
         converter = _create_converter(schema)
         rdd = rdd.map(converter)
         return self.applySchema(rdd, schema)
@@ -274,7 +283,7 @@ class SQLContext(object):
             raise TypeError("Cannot apply schema to DataFrame")
 
         if not isinstance(schema, StructType):
-            raise TypeError("schema should be StructType")
+            raise TypeError("schema should be StructType, but got %s" % schema)
 
         # take the first few rows to verify schema
         rows = rdd.take(10)
@@ -294,9 +303,9 @@ class SQLContext(object):
         df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
         return DataFrame(df, self)
 
-    def createDataFrame(self, rdd, schema=None, samplingRatio=None):
+    def createDataFrame(self, data, schema=None, samplingRatio=None):
         """
-        Create a DataFrame from an RDD of tuple/list and an optional `schema`.
+        Create a DataFrame from an RDD of tuple/list, list or pandas.DataFrame.
 
         `schema` could be :class:`StructType` or a list of column names.
 
@@ -311,12 +320,20 @@ class SQLContext(object):
         rows will be used to do referring. The first row will be used if
         `samplingRatio` is None.
 
-        :param rdd: an RDD of Row or tuple or list or dict
+        :param data: an RDD of Row/tuple/list/dict, list, or pandas.DataFrame
         :param schema: a StructType or list of names of columns
         :param samplingRatio: the sample ratio of rows used for inferring
         :return: a DataFrame
 
-        >>> rdd = sc.parallelize([('Alice', 1)])
+        >>> l = [('Alice', 1)]
+        >>> sqlCtx.createDataFrame(l, ['name', 'age']).collect()
+        [Row(name=u'Alice', age=1)]
+
+        >>> d = [{'name': 'Alice', 'age': 1}]
+        >>> sqlCtx.createDataFrame(d).collect()
+        [Row(age=1, name=u'Alice')]
+
+        >>> rdd = sc.parallelize(l)
         >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age'])
         >>> df.collect()
         [Row(name=u'Alice', age=1)]
@@ -336,19 +353,32 @@ class SQLContext(object):
         >>> df3.collect()
         [Row(name=u'Alice', age=1)]
         """
-        if isinstance(rdd, DataFrame):
-            raise TypeError("rdd is already a DataFrame")
+        if isinstance(data, DataFrame):
+            raise TypeError("data is already a DataFrame")
 
-        if isinstance(schema, StructType):
-            return self.applySchema(rdd, schema)
-        else:
-            if isinstance(schema, (list, tuple)):
-                first = rdd.first()
-                if not isinstance(first, (list, tuple)):
-                    raise ValueError("each row in `rdd` should be list or tuple")
-                row_cls = Row(*schema)
-                rdd = rdd.map(lambda r: row_cls(*r))
-            return self.inferSchema(rdd, samplingRatio)
+        if has_pandas and isinstance(data, pandas.DataFrame):
+            data = self._sc.parallelize(data.to_records(index=False))
+            if schema is None:
+                schema = list(data.columns)
+
+        if not isinstance(data, RDD):
+            try:
+                # data could be list, tuple, generator ...
+                data = self._sc.parallelize(data)
+            except Exception:
+                raise ValueError("cannot create an RDD from type: %s" % type(data))
+
+        if schema is None:
+            return self.inferSchema(data, samplingRatio)
+
+        if isinstance(schema, (list, tuple)):
+            first = data.first()
+            if not isinstance(first, (list, tuple)):
+                raise ValueError("each row in `rdd` should be list or tuple")
+            row_cls = Row(*schema)
+            schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio)
+
+        return self.applySchema(data, schema)
 
     def registerRDDAsTable(self, rdd, tableName):
         """Registers the given RDD as a temporary table in the catalog.

http://git-wip-us.apache.org/repos/asf/spark/blob/b694eb9c/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 3eef0cc..3eb56ed 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -236,6 +236,24 @@ class DataFrame(object):
         """
         print (self._jdf.schema().treeString())
 
+    def show(self):
+        """
+        Print the first 20 rows.
+
+        >>> df.show()
+        age name
+        2   Alice
+        5   Bob
+        >>> df
+        age name
+        2   Alice
+        5   Bob
+        """
+        print (self)
+
+    def __repr__(self):
+        return self._jdf.showString()
+
     def count(self):
         """Return the number of elements in this RDD.
 
@@ -380,9 +398,9 @@ class DataFrame(object):
         """Return all column names and their data types as a list.
 
         >>> df.dtypes
-        [('age', 'integer'), ('name', 'string')]
+        [('age', 'int'), ('name', 'string')]
         """
-        return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields]
+        return [(str(f.name), f.dataType.simpleString()) for f in self.schema().fields]
 
     @property
     def columns(self):
@@ -606,6 +624,17 @@ class DataFrame(object):
         """
         return self.select('*', col.alias(colName))
 
+    def renameColumn(self, existing, new):
+        """ Rename an existing column to a new name
+
+        >>> df.renameColumn('age', 'age2').collect()
+        [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
+        """
+        cols = [Column(_to_java_column(c), self.sql_ctx).alias(new)
+                if c == existing else c
+                for c in self.columns]
+        return self.select(*cols)
+
     def to_pandas(self):
         """
         Collect all the rows and return a `pandas.DataFrame`.
@@ -885,6 +914,12 @@ class Column(DataFrame):
             jc = self._jc.cast(jdt)
         return Column(jc, self.sql_ctx)
 
+    def __repr__(self):
+        if self._jdf.isComputable():
+            return self._jdf.samples()
+        else:
+            return 'Column<%s>' % self._jdf.toString()
+
     def to_pandas(self):
         """
         Return a pandas.Series from the column
@@ -1030,7 +1065,8 @@ def _test():
     globs['df'] = sqlCtx.inferSchema(rdd2)
     globs['df2'] = sqlCtx.inferSchema(rdd3)
     (failure_count, test_count) = doctest.testmod(
-        pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS)
+        pyspark.sql.dataframe, globs=globs,
+        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
     globs['sc'].stop()
     if failure_count:
         exit(-1)

http://git-wip-us.apache.org/repos/asf/spark/blob/b694eb9c/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 5e41e36..43e5c3a 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -194,7 +194,7 @@ class SQLTests(ReusedPySparkTestCase):
         result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
         self.assertEqual(1, result.head()[0])
 
-        df2 = self.sqlCtx.createDataFrame(rdd, 1.0)
+        df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
         self.assertEqual(df.schema(), df2.schema())
         self.assertEqual({}, df2.map(lambda r: r.d).first())
         self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())

http://git-wip-us.apache.org/repos/asf/spark/blob/b694eb9c/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 41afefe..40bd7e5 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -52,6 +52,9 @@ class DataType(object):
     def typeName(cls):
         return cls.__name__[:-4].lower()
 
+    def simpleString(self):
+        return self.typeName()
+
     def jsonValue(self):
         return self.typeName()
 
@@ -145,6 +148,12 @@ class DecimalType(DataType):
         self.scale = scale
         self.hasPrecisionInfo = precision is not None
 
+    def simpleString(self):
+        if self.hasPrecisionInfo:
+            return "decimal(%d,%d)" % (self.precision, self.scale)
+        else:
+            return "decimal(10,0)"
+
     def jsonValue(self):
         if self.hasPrecisionInfo:
             return "decimal(%d,%d)" % (self.precision, self.scale)
@@ -180,6 +189,8 @@ class ByteType(PrimitiveType):
 
     The data type representing int values with 1 singed byte.
     """
+    def simpleString(self):
+        return 'tinyint'
 
 
 class IntegerType(PrimitiveType):
@@ -188,6 +199,8 @@ class IntegerType(PrimitiveType):
 
     The data type representing int values.
     """
+    def simpleString(self):
+        return 'int'
 
 
 class LongType(PrimitiveType):
@@ -198,6 +211,8 @@ class LongType(PrimitiveType):
     beyond the range of [-9223372036854775808, 9223372036854775807],
     please use DecimalType.
     """
+    def simpleString(self):
+        return 'bigint'
 
 
 class ShortType(PrimitiveType):
@@ -206,6 +221,8 @@ class ShortType(PrimitiveType):
 
     The data type representing int values with 2 signed bytes.
     """
+    def simpleString(self):
+        return 'smallint'
 
 
 class ArrayType(DataType):
@@ -233,6 +250,9 @@ class ArrayType(DataType):
         self.elementType = elementType
         self.containsNull = containsNull
 
+    def simpleString(self):
+        return 'array<%s>' % self.elementType.simpleString()
+
     def __repr__(self):
         return "ArrayType(%s,%s)" % (self.elementType,
                                      str(self.containsNull).lower())
@@ -283,6 +303,9 @@ class MapType(DataType):
         self.valueType = valueType
         self.valueContainsNull = valueContainsNull
 
+    def simpleString(self):
+        return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString())
+
     def __repr__(self):
         return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
                                       str(self.valueContainsNull).lower())
@@ -337,6 +360,9 @@ class StructField(DataType):
         self.nullable = nullable
         self.metadata = metadata or {}
 
+    def simpleString(self):
+        return '%s:%s' % (self.name, self.dataType.simpleString())
+
     def __repr__(self):
         return "StructField(%s,%s,%s)" % (self.name, self.dataType,
                                           str(self.nullable).lower())
@@ -379,6 +405,9 @@ class StructType(DataType):
         """
         self.fields = fields
 
+    def simpleString(self):
+        return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields))
+
     def __repr__(self):
         return ("StructType(List(%s))" %
                 ",".join(str(field) for field in self.fields))
@@ -435,6 +464,9 @@ class UserDefinedType(DataType):
         """
         raise NotImplementedError("UDT must implement deserialize().")
 
+    def simpleString(self):
+        return 'null'
+
     def json(self):
         return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b694eb9c/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 9638ce0..41da442 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -126,7 +126,10 @@ private[sql] class DataFrameImpl protected[sql](
     logicalPlan.isInstanceOf[LocalRelation]
   }
 
-  override def show(): Unit = {
+  /**
+   * Internal API for Python
+   */
+  private[sql] def showString(): String = {
     val data = take(20)
     val numCols = schema.fieldNames.length
 
@@ -146,12 +149,16 @@ private[sql] class DataFrameImpl protected[sql](
       }
     }
 
-    // Pad the cells and print them
-    println(rows.map { row =>
+    // Pad the cells
+    rows.map { row =>
       row.zipWithIndex.map { case (cell, i) =>
         String.format(s"%-${colWidths(i)}s", cell)
       }.mkString(" ")
-    }.mkString("\n"))
+    }.mkString("\n")
+  }
+
+  override def show(): Unit = {
+    println(showString)
   }
 
   override def join(right: DataFrame): DataFrame = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message