spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject [2/4] spark git commit: [SPARK-5469] restructure pyspark.sql into multiple files
Date Tue, 10 Feb 2015 04:58:34 GMT
http://git-wip-us.apache.org/repos/asf/spark/blob/f0562b42/python/pyspark/sql/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
new file mode 100644
index 0000000..0a5ba00
--- /dev/null
+++ b/python/pyspark/sql/__init__.py
@@ -0,0 +1,42 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+public classes of Spark SQL:
+
+    - L{SQLContext}
+      Main entry point for SQL functionality.
+    - L{DataFrame}
+      A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In
+      addition to normal RDD operations, DataFrames also support SQL.
+    - L{GroupedData}
+    - L{Column}
+      Column is a DataFrame with a single column.
+    - L{Row}
+      A Row of data returned by a Spark SQL query.
+    - L{HiveContext}
+      Main entry point for accessing data stored in Apache Hive..
+"""
+
+from pyspark.sql.context import SQLContext, HiveContext
+from pyspark.sql.types import Row
+from pyspark.sql.dataframe import DataFrame, GroupedData, Column, Dsl, SchemaRDD
+
+__all__ = [
+    'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
+    'Dsl',
+]

http://git-wip-us.apache.org/repos/asf/spark/blob/f0562b42/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
new file mode 100644
index 0000000..49f016a
--- /dev/null
+++ b/python/pyspark/sql/context.py
@@ -0,0 +1,642 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import warnings
+import json
+from array import array
+from itertools import imap
+
+from py4j.protocol import Py4JError
+
+from pyspark.rdd import _prepare_for_python_RDD
+from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
+from pyspark.sql.types import StringType, StructType, _verify_type, \
+    _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
+from pyspark.sql.dataframe import DataFrame
+
+__all__ = ["SQLContext", "HiveContext"]
+
+
+class SQLContext(object):
+
+    """Main entry point for Spark SQL functionality.
+
+    A SQLContext can be used create L{DataFrame}, register L{DataFrame} as
+    tables, execute SQL over tables, cache tables, and read parquet files.
+    """
+
+    def __init__(self, sparkContext, sqlContext=None):
+        """Create a new SQLContext.
+
+        :param sparkContext: The SparkContext to wrap.
+        :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
+        SQLContext in the JVM, instead we make all calls to this object.
+
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL
+        Traceback (most recent call last):
+            ...
+        TypeError:...
+
+        >>> bad_rdd = sc.parallelize([1,2,3])
+        >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
+        Traceback (most recent call last):
+            ...
+        ValueError:...
+
+        >>> from datetime import datetime
+        >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
+        ...     b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
+        ...     time=datetime(2014, 8, 1, 14, 1, 5))])
+        >>> df = sqlCtx.inferSchema(allTypes)
+        >>> df.registerTempTable("allTypes")
+        >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
+        ...            'from allTypes where b and i > 0').collect()
+        [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
+        >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
+        ...                     x.row.a, x.list)).collect()
+        [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
+        """
+        self._sc = sparkContext
+        self._jsc = self._sc._jsc
+        self._jvm = self._sc._jvm
+        self._scala_SQLContext = sqlContext
+
+    @property
+    def _ssql_ctx(self):
+        """Accessor for the JVM Spark SQL context.
+
+        Subclasses can override this property to provide their own
+        JVM Contexts.
+        """
+        if self._scala_SQLContext is None:
+            self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
+        return self._scala_SQLContext
+
+    def registerFunction(self, name, f, returnType=StringType()):
+        """Registers a lambda function as a UDF so it can be used in SQL statements.
+
+        In addition to a name and the function itself, the return type can be optionally specified.
+        When the return type is not given it default to a string and conversion will automatically
+        be done.  For any other return type, the produced object must match the specified type.
+
+        >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x))
+        >>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
+        [Row(c0=u'4')]
+        >>> from pyspark.sql.types import IntegerType
+        >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
+        >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
+        [Row(c0=4)]
+        """
+        func = lambda _, it: imap(lambda x: f(*x), it)
+        ser = AutoBatchedSerializer(PickleSerializer())
+        command = (func, None, ser, ser)
+        pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
+        self._ssql_ctx.udf().registerPython(name,
+                                            bytearray(pickled_cmd),
+                                            env,
+                                            includes,
+                                            self._sc.pythonExec,
+                                            bvars,
+                                            self._sc._javaAccumulator,
+                                            returnType.json())
+
+    def inferSchema(self, rdd, samplingRatio=None):
+        """Infer and apply a schema to an RDD of L{Row}.
+
+        When samplingRatio is specified, the schema is inferred by looking
+        at the types of each row in the sampled dataset. Otherwise, the
+        first 100 rows of the RDD are inspected. Nested collections are
+        supported, which can include array, dict, list, Row, tuple,
+        namedtuple, or object.
+
+        Each row could be L{pyspark.sql.Row} object or namedtuple or objects.
+        Using top level dicts is deprecated, as dict is used to represent Maps.
+
+        If a single column has multiple distinct inferred types, it may cause
+        runtime exceptions.
+
+        >>> rdd = sc.parallelize(
+        ...     [Row(field1=1, field2="row1"),
+        ...      Row(field1=2, field2="row2"),
+        ...      Row(field1=3, field2="row3")])
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> df.collect()[0]
+        Row(field1=1, field2=u'row1')
+
+        >>> NestedRow = Row("f1", "f2")
+        >>> nestedRdd1 = sc.parallelize([
+        ...     NestedRow(array('i', [1, 2]), {"row1": 1.0}),
+        ...     NestedRow(array('i', [2, 3]), {"row2": 2.0})])
+        >>> df = sqlCtx.inferSchema(nestedRdd1)
+        >>> df.collect()
+        [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
+
+        >>> nestedRdd2 = sc.parallelize([
+        ...     NestedRow([[1, 2], [2, 3]], [1, 2]),
+        ...     NestedRow([[2, 3], [3, 4]], [2, 3])])
+        >>> df = sqlCtx.inferSchema(nestedRdd2)
+        >>> df.collect()
+        [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
+
+        >>> from collections import namedtuple
+        >>> CustomRow = namedtuple('CustomRow', 'field1 field2')
+        >>> rdd = sc.parallelize(
+        ...     [CustomRow(field1=1, field2="row1"),
+        ...      CustomRow(field1=2, field2="row2"),
+        ...      CustomRow(field1=3, field2="row3")])
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> df.collect()[0]
+        Row(field1=1, field2=u'row1')
+        """
+
+        if isinstance(rdd, DataFrame):
+            raise TypeError("Cannot apply schema to DataFrame")
+
+        first = rdd.first()
+        if not first:
+            raise ValueError("The first row in RDD is empty, "
+                             "can not infer schema")
+        if type(first) is dict:
+            warnings.warn("Using RDD of dict to inferSchema is deprecated,"
+                          "please use pyspark.sql.Row instead")
+
+        if samplingRatio is None:
+            schema = _infer_schema(first)
+            if _has_nulltype(schema):
+                for row in rdd.take(100)[1:]:
+                    schema = _merge_type(schema, _infer_schema(row))
+                    if not _has_nulltype(schema):
+                        break
+                else:
+                    warnings.warn("Some of types cannot be determined by the "
+                                  "first 100 rows, please try again with sampling")
+        else:
+            if samplingRatio > 0.99:
+                rdd = rdd.sample(False, float(samplingRatio))
+            schema = rdd.map(_infer_schema).reduce(_merge_type)
+
+        converter = _create_converter(schema)
+        rdd = rdd.map(converter)
+        return self.applySchema(rdd, schema)
+
+    def applySchema(self, rdd, schema):
+        """
+        Applies the given schema to the given RDD of L{tuple} or L{list}.
+
+        These tuples or lists can contain complex nested structures like
+        lists, maps or nested rows.
+
+        The schema should be a StructType.
+
+        It is important that the schema matches the types of the objects
+        in each row or exceptions could be thrown at runtime.
+
+        >>> from pyspark.sql.types import *
+        >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")])
+        >>> schema = StructType([StructField("field1", IntegerType(), False),
+        ...     StructField("field2", StringType(), False)])
+        >>> df = sqlCtx.applySchema(rdd2, schema)
+        >>> sqlCtx.registerRDDAsTable(df, "table1")
+        >>> df2 = sqlCtx.sql("SELECT * from table1")
+        >>> df2.collect()
+        [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
+
+        >>> from datetime import date, datetime
+        >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
+        ...     date(2010, 1, 1),
+        ...     datetime(2010, 1, 1, 1, 1, 1),
+        ...     {"a": 1}, (2,), [1, 2, 3], None)])
+        >>> schema = StructType([
+        ...     StructField("byte1", ByteType(), False),
+        ...     StructField("byte2", ByteType(), False),
+        ...     StructField("short1", ShortType(), False),
+        ...     StructField("short2", ShortType(), False),
+        ...     StructField("int", IntegerType(), False),
+        ...     StructField("float", FloatType(), False),
+        ...     StructField("date", DateType(), False),
+        ...     StructField("time", TimestampType(), False),
+        ...     StructField("map",
+        ...         MapType(StringType(), IntegerType(), False), False),
+        ...     StructField("struct",
+        ...         StructType([StructField("b", ShortType(), False)]), False),
+        ...     StructField("list", ArrayType(ByteType(), False), False),
+        ...     StructField("null", DoubleType(), True)])
+        >>> df = sqlCtx.applySchema(rdd, schema)
+        >>> results = df.map(
+        ...     lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
+        ...         x.time, x.map["a"], x.struct.b, x.list, x.null))
+        >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
+        (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
+             datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+
+        >>> df.registerTempTable("table2")
+        >>> sqlCtx.sql(
+        ...   "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
+        ...     "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
+        ...     "float + 1.5 as float FROM table2").collect()
+        [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)]
+
+        >>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
+        >>> rdd = sc.parallelize([(127, -32768, 1.0,
+        ...     datetime(2010, 1, 1, 1, 1, 1),
+        ...     {"a": 1}, (2,), [1, 2, 3])])
+        >>> abstract = "byte short float time map{} struct(b) list[]"
+        >>> schema = _parse_schema_abstract(abstract)
+        >>> typedSchema = _infer_schema_type(rdd.first(), schema)
+        >>> df = sqlCtx.applySchema(rdd, typedSchema)
+        >>> df.collect()
+        [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
+        """
+
+        if isinstance(rdd, DataFrame):
+            raise TypeError("Cannot apply schema to DataFrame")
+
+        if not isinstance(schema, StructType):
+            raise TypeError("schema should be StructType")
+
+        # take the first few rows to verify schema
+        rows = rdd.take(10)
+        # Row() cannot been deserialized by Pyrolite
+        if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
+            rdd = rdd.map(tuple)
+            rows = rdd.take(10)
+
+        for row in rows:
+            _verify_type(row, schema)
+
+        # convert python objects to sql data
+        converter = _python_to_sql_converter(schema)
+        rdd = rdd.map(converter)
+
+        jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
+        df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
+        return DataFrame(df, self)
+
+    def registerRDDAsTable(self, rdd, tableName):
+        """Registers the given RDD as a temporary table in the catalog.
+
+        Temporary tables exist only during the lifetime of this instance of
+        SQLContext.
+
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> sqlCtx.registerRDDAsTable(df, "table1")
+        """
+        if (rdd.__class__ is DataFrame):
+            df = rdd._jdf
+            self._ssql_ctx.registerRDDAsTable(df, tableName)
+        else:
+            raise ValueError("Can only register DataFrame as table")
+
+    def parquetFile(self, *paths):
+        """Loads a Parquet file, returning the result as a L{DataFrame}.
+
+        >>> import tempfile, shutil
+        >>> parquetFile = tempfile.mkdtemp()
+        >>> shutil.rmtree(parquetFile)
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> df.saveAsParquetFile(parquetFile)
+        >>> df2 = sqlCtx.parquetFile(parquetFile)
+        >>> sorted(df.collect()) == sorted(df2.collect())
+        True
+        """
+        gateway = self._sc._gateway
+        jpath = paths[0]
+        jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths) - 1)
+        for i in range(1, len(paths)):
+            jpaths[i] = paths[i]
+        jdf = self._ssql_ctx.parquetFile(jpath, jpaths)
+        return DataFrame(jdf, self)
+
+    def jsonFile(self, path, schema=None, samplingRatio=1.0):
+        """
+        Loads a text file storing one JSON object per line as a
+        L{DataFrame}.
+
+        If the schema is provided, applies the given schema to this
+        JSON dataset.
+
+        Otherwise, it samples the dataset with ratio `samplingRatio` to
+        determine the schema.
+
+        >>> import tempfile, shutil
+        >>> jsonFile = tempfile.mkdtemp()
+        >>> shutil.rmtree(jsonFile)
+        >>> ofn = open(jsonFile, 'w')
+        >>> for json in jsonStrings:
+        ...   print>>ofn, json
+        >>> ofn.close()
+        >>> df1 = sqlCtx.jsonFile(jsonFile)
+        >>> sqlCtx.registerRDDAsTable(df1, "table1")
+        >>> df2 = sqlCtx.sql(
+        ...   "SELECT field1 AS f1, field2 as f2, field3 as f3, "
+        ...   "field6 as f4 from table1")
+        >>> for r in df2.collect():
+        ...     print r
+        Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
+        Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
+        Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
+
+        >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema())
+        >>> sqlCtx.registerRDDAsTable(df3, "table2")
+        >>> df4 = sqlCtx.sql(
+        ...   "SELECT field1 AS f1, field2 as f2, field3 as f3, "
+        ...   "field6 as f4 from table2")
+        >>> for r in df4.collect():
+        ...    print r
+        Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
+        Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
+        Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
+
+        >>> from pyspark.sql.types import *
+        >>> schema = StructType([
+        ...     StructField("field2", StringType(), True),
+        ...     StructField("field3",
+        ...         StructType([
+        ...             StructField("field5",
+        ...                 ArrayType(IntegerType(), False), True)]), False)])
+        >>> df5 = sqlCtx.jsonFile(jsonFile, schema)
+        >>> sqlCtx.registerRDDAsTable(df5, "table3")
+        >>> df6 = sqlCtx.sql(
+        ...   "SELECT field2 AS f1, field3.field5 as f2, "
+        ...   "field3.field5[0] as f3 from table3")
+        >>> df6.collect()
+        [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
+        """
+        if schema is None:
+            df = self._ssql_ctx.jsonFile(path, samplingRatio)
+        else:
+            scala_datatype = self._ssql_ctx.parseDataType(schema.json())
+            df = self._ssql_ctx.jsonFile(path, scala_datatype)
+        return DataFrame(df, self)
+
+    def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
+        """Loads an RDD storing one JSON object per string as a L{DataFrame}.
+
+        If the schema is provided, applies the given schema to this
+        JSON dataset.
+
+        Otherwise, it samples the dataset with ratio `samplingRatio` to
+        determine the schema.
+
+        >>> df1 = sqlCtx.jsonRDD(json)
+        >>> sqlCtx.registerRDDAsTable(df1, "table1")
+        >>> df2 = sqlCtx.sql(
+        ...   "SELECT field1 AS f1, field2 as f2, field3 as f3, "
+        ...   "field6 as f4 from table1")
+        >>> for r in df2.collect():
+        ...     print r
+        Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
+        Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
+        Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
+
+        >>> df3 = sqlCtx.jsonRDD(json, df1.schema())
+        >>> sqlCtx.registerRDDAsTable(df3, "table2")
+        >>> df4 = sqlCtx.sql(
+        ...   "SELECT field1 AS f1, field2 as f2, field3 as f3, "
+        ...   "field6 as f4 from table2")
+        >>> for r in df4.collect():
+        ...     print r
+        Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
+        Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
+        Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
+
+        >>> from pyspark.sql.types import *
+        >>> schema = StructType([
+        ...     StructField("field2", StringType(), True),
+        ...     StructField("field3",
+        ...         StructType([
+        ...             StructField("field5",
+        ...                 ArrayType(IntegerType(), False), True)]), False)])
+        >>> df5 = sqlCtx.jsonRDD(json, schema)
+        >>> sqlCtx.registerRDDAsTable(df5, "table3")
+        >>> df6 = sqlCtx.sql(
+        ...   "SELECT field2 AS f1, field3.field5 as f2, "
+        ...   "field3.field5[0] as f3 from table3")
+        >>> df6.collect()
+        [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
+
+        >>> sqlCtx.jsonRDD(sc.parallelize(['{}',
+        ...         '{"key0": {"key1": "value1"}}'])).collect()
+        [Row(key0=None), Row(key0=Row(key1=u'value1'))]
+        >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}',
+        ...         '{"key0": {"key1": "value1"}}'])).collect()
+        [Row(key0=None), Row(key0=Row(key1=u'value1'))]
+        """
+
+        def func(iterator):
+            for x in iterator:
+                if not isinstance(x, basestring):
+                    x = unicode(x)
+                if isinstance(x, unicode):
+                    x = x.encode("utf-8")
+                yield x
+        keyed = rdd.mapPartitions(func)
+        keyed._bypass_serializer = True
+        jrdd = keyed._jrdd.map(self._jvm.BytesToString())
+        if schema is None:
+            df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
+        else:
+            scala_datatype = self._ssql_ctx.parseDataType(schema.json())
+            df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
+        return DataFrame(df, self)
+
+    def sql(self, sqlQuery):
+        """Return a L{DataFrame} representing the result of the given query.
+
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> sqlCtx.registerRDDAsTable(df, "table1")
+        >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
+        >>> df2.collect()
+        [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
+        """
+        return DataFrame(self._ssql_ctx.sql(sqlQuery), self)
+
+    def table(self, tableName):
+        """Returns the specified table as a L{DataFrame}.
+
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> sqlCtx.registerRDDAsTable(df, "table1")
+        >>> df2 = sqlCtx.table("table1")
+        >>> sorted(df.collect()) == sorted(df2.collect())
+        True
+        """
+        return DataFrame(self._ssql_ctx.table(tableName), self)
+
+    def cacheTable(self, tableName):
+        """Caches the specified table in-memory."""
+        self._ssql_ctx.cacheTable(tableName)
+
+    def uncacheTable(self, tableName):
+        """Removes the specified table from the in-memory cache."""
+        self._ssql_ctx.uncacheTable(tableName)
+
+
+class HiveContext(SQLContext):
+
+    """A variant of Spark SQL that integrates with data stored in Hive.
+
+    Configuration for Hive is read from hive-site.xml on the classpath.
+    It supports running both SQL and HiveQL commands.
+    """
+
+    def __init__(self, sparkContext, hiveContext=None):
+        """Create a new HiveContext.
+
+        :param sparkContext: The SparkContext to wrap.
+        :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new
+        HiveContext in the JVM, instead we make all calls to this object.
+        """
+        SQLContext.__init__(self, sparkContext)
+
+        if hiveContext:
+            self._scala_HiveContext = hiveContext
+
+    @property
+    def _ssql_ctx(self):
+        try:
+            if not hasattr(self, '_scala_HiveContext'):
+                self._scala_HiveContext = self._get_hive_ctx()
+            return self._scala_HiveContext
+        except Py4JError as e:
+            raise Exception("You must build Spark with Hive. "
+                            "Export 'SPARK_HIVE=true' and run "
+                            "build/sbt assembly", e)
+
+    def _get_hive_ctx(self):
+        return self._jvm.HiveContext(self._jsc.sc())
+
+
+def _create_row(fields, values):
+    row = Row(*values)
+    row.__FIELDS__ = fields
+    return row
+
+
+class Row(tuple):
+
+    """
+    A row in L{DataFrame}. The fields in it can be accessed like attributes.
+
+    Row can be used to create a row object by using named arguments,
+    the fields will be sorted by names.
+
+    >>> row = Row(name="Alice", age=11)
+    >>> row
+    Row(age=11, name='Alice')
+    >>> row.name, row.age
+    ('Alice', 11)
+
+    Row also can be used to create another Row like class, then it
+    could be used to create Row objects, such as
+
+    >>> Person = Row("name", "age")
+    >>> Person
+    <Row(name, age)>
+    >>> Person("Alice", 11)
+    Row(name='Alice', age=11)
+    """
+
+    def __new__(self, *args, **kwargs):
+        if args and kwargs:
+            raise ValueError("Can not use both args "
+                             "and kwargs to create Row")
+        if args:
+            # create row class or objects
+            return tuple.__new__(self, args)
+
+        elif kwargs:
+            # create row objects
+            names = sorted(kwargs.keys())
+            values = tuple(kwargs[n] for n in names)
+            row = tuple.__new__(self, values)
+            row.__FIELDS__ = names
+            return row
+
+        else:
+            raise ValueError("No args or kwargs")
+
+    def asDict(self):
+        """
+        Return as an dict
+        """
+        if not hasattr(self, "__FIELDS__"):
+            raise TypeError("Cannot convert a Row class into dict")
+        return dict(zip(self.__FIELDS__, self))
+
+    # let obect acs like class
+    def __call__(self, *args):
+        """create new Row object"""
+        return _create_row(self, args)
+
+    def __getattr__(self, item):
+        if item.startswith("__"):
+            raise AttributeError(item)
+        try:
+            # it will be slow when it has many fields,
+            # but this will not be used in normal cases
+            idx = self.__FIELDS__.index(item)
+            return self[idx]
+        except IndexError:
+            raise AttributeError(item)
+
+    def __reduce__(self):
+        if hasattr(self, "__FIELDS__"):
+            return (_create_row, (self.__FIELDS__, tuple(self)))
+        else:
+            return tuple.__reduce__(self)
+
+    def __repr__(self):
+        if hasattr(self, "__FIELDS__"):
+            return "Row(%s)" % ", ".join("%s=%r" % (k, v)
+                                         for k, v in zip(self.__FIELDS__, self))
+        else:
+            return "<Row(%s)>" % ", ".join(self)
+
+
+def _test():
+    import doctest
+    from pyspark.context import SparkContext
+    from pyspark.sql import Row, SQLContext
+    import pyspark.sql.context
+    globs = pyspark.sql.context.__dict__.copy()
+    sc = SparkContext('local[4]', 'PythonTest')
+    globs['sc'] = sc
+    globs['sqlCtx'] = sqlCtx = SQLContext(sc)
+    globs['rdd'] = sc.parallelize(
+        [Row(field1=1, field2="row1"),
+         Row(field1=2, field2="row2"),
+         Row(field1=3, field2="row3")]
+    )
+    jsonStrings = [
+        '{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
+        '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
+        '"field6":[{"field7": "row2"}]}',
+        '{"field1" : null, "field2": "row3", '
+        '"field3":{"field4":33, "field5": []}}'
+    ]
+    globs['jsonStrings'] = jsonStrings
+    globs['json'] = sc.parallelize(jsonStrings)
+    (failure_count, test_count) = doctest.testmod(
+        pyspark.sql.context, globs=globs, optionflags=doctest.ELLIPSIS)
+    globs['sc'].stop()
+    if failure_count:
+        exit(-1)
+
+
+if __name__ == "__main__":
+    _test()

http://git-wip-us.apache.org/repos/asf/spark/blob/f0562b42/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
new file mode 100644
index 0000000..cda704e
--- /dev/null
+++ b/python/pyspark/sql/dataframe.py
@@ -0,0 +1,974 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import itertools
+import warnings
+import random
+import os
+from tempfile import NamedTemporaryFile
+from itertools import imap
+
+from py4j.java_collections import ListConverter, MapConverter
+
+from pyspark.context import SparkContext
+from pyspark.rdd import RDD, _prepare_for_python_RDD
+from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
+    UTF8Deserializer
+from pyspark.storagelevel import StorageLevel
+from pyspark.traceback_utils import SCCallSiteSync
+from pyspark.sql.types import *
+from pyspark.sql.types import _create_cls, _parse_datatype_json_string
+
+
+__all__ = ["DataFrame", "GroupedData", "Column", "Dsl", "SchemaRDD"]
+
+
+class DataFrame(object):
+
+    """A collection of rows that have the same columns.
+
+    A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
+    and can be created using various functions in :class:`SQLContext`::
+
+        people = sqlContext.parquetFile("...")
+
+    Once created, it can be manipulated using the various domain-specific-language
+    (DSL) functions defined in: :class:`DataFrame`, :class:`Column`.
+
+    To select a column from the data frame, use the apply method::
+
+        ageCol = people.age
+
+    Note that the :class:`Column` type can also be manipulated
+    through its various functions::
+
+        # The following creates a new column that increases everybody's age by 10.
+        people.age + 10
+
+
+    A more concrete example::
+
+        # To create DataFrame using SQLContext
+        people = sqlContext.parquetFile("...")
+        department = sqlContext.parquetFile("...")
+
+        people.filter(people.age > 30).join(department, people.deptId == department.id)) \
+          .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
+    """
+
+    def __init__(self, jdf, sql_ctx):
+        self._jdf = jdf
+        self.sql_ctx = sql_ctx
+        self._sc = sql_ctx and sql_ctx._sc
+        self.is_cached = False
+
+    @property
+    def rdd(self):
+        """
+        Return the content of the :class:`DataFrame` as an :class:`RDD`
+        of :class:`Row` s.
+        """
+        if not hasattr(self, '_lazy_rdd'):
+            jrdd = self._jdf.javaToPython()
+            rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
+            schema = self.schema()
+
+            def applySchema(it):
+                cls = _create_cls(schema)
+                return itertools.imap(cls, it)
+
+            self._lazy_rdd = rdd.mapPartitions(applySchema)
+
+        return self._lazy_rdd
+
+    def toJSON(self, use_unicode=False):
+        """Convert a DataFrame into a MappedRDD of JSON documents; one document per row.
+
+        >>> df.toJSON().first()
+        '{"age":2,"name":"Alice"}'
+        """
+        rdd = self._jdf.toJSON()
+        return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
+
+    def saveAsParquetFile(self, path):
+        """Save the contents as a Parquet file, preserving the schema.
+
+        Files that are written out using this method can be read back in as
+        a DataFrame using the L{SQLContext.parquetFile} method.
+
+        >>> import tempfile, shutil
+        >>> parquetFile = tempfile.mkdtemp()
+        >>> shutil.rmtree(parquetFile)
+        >>> df.saveAsParquetFile(parquetFile)
+        >>> df2 = sqlCtx.parquetFile(parquetFile)
+        >>> sorted(df2.collect()) == sorted(df.collect())
+        True
+        """
+        self._jdf.saveAsParquetFile(path)
+
+    def registerTempTable(self, name):
+        """Registers this RDD as a temporary table using the given name.
+
+        The lifetime of this temporary table is tied to the L{SQLContext}
+        that was used to create this DataFrame.
+
+        >>> df.registerTempTable("people")
+        >>> df2 = sqlCtx.sql("select * from people")
+        >>> sorted(df.collect()) == sorted(df2.collect())
+        True
+        """
+        self._jdf.registerTempTable(name)
+
+    def registerAsTable(self, name):
+        """DEPRECATED: use registerTempTable() instead"""
+        warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning)
+        self.registerTempTable(name)
+
+    def insertInto(self, tableName, overwrite=False):
+        """Inserts the contents of this DataFrame into the specified table.
+
+        Optionally overwriting any existing data.
+        """
+        self._jdf.insertInto(tableName, overwrite)
+
+    def saveAsTable(self, tableName):
+        """Creates a new table with the contents of this DataFrame."""
+        self._jdf.saveAsTable(tableName)
+
+    def schema(self):
+        """Returns the schema of this DataFrame (represented by
+        a L{StructType}).
+
+        >>> df.schema()
+        StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
+        """
+        return _parse_datatype_json_string(self._jdf.schema().json())
+
+    def printSchema(self):
+        """Prints out the schema in the tree format.
+
+        >>> df.printSchema()
+        root
+         |-- age: integer (nullable = true)
+         |-- name: string (nullable = true)
+        <BLANKLINE>
+        """
+        print (self._jdf.schema().treeString())
+
+    def count(self):
+        """Return the number of elements in this RDD.
+
+        Unlike the base RDD implementation of count, this implementation
+        leverages the query optimizer to compute the count on the DataFrame,
+        which supports features such as filter pushdown.
+
+        >>> df.count()
+        2L
+        """
+        return self._jdf.count()
+
+    def collect(self):
+        """Return a list that contains all of the rows.
+
+        Each object in the list is a Row, the fields can be accessed as
+        attributes.
+
+        >>> df.collect()
+        [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+        """
+        with SCCallSiteSync(self._sc) as css:
+            bytesInJava = self._jdf.javaToPython().collect().iterator()
+        tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
+        tempFile.close()
+        self._sc._writeToFile(bytesInJava, tempFile.name)
+        # Read the data into Python and deserialize it:
+        with open(tempFile.name, 'rb') as tempFile:
+            rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
+        os.unlink(tempFile.name)
+        cls = _create_cls(self.schema())
+        return [cls(r) for r in rs]
+
+    def limit(self, num):
+        """Limit the result count to the number specified.
+
+        >>> df.limit(1).collect()
+        [Row(age=2, name=u'Alice')]
+        >>> df.limit(0).collect()
+        []
+        """
+        jdf = self._jdf.limit(num)
+        return DataFrame(jdf, self.sql_ctx)
+
+    def take(self, num):
+        """Take the first num rows of the RDD.
+
+        Each object in the list is a Row, the fields can be accessed as
+        attributes.
+
+        >>> df.take(2)
+        [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+        """
+        return self.limit(num).collect()
+
+    def map(self, f):
+        """ Return a new RDD by applying a function to each Row, it's a
+        shorthand for df.rdd.map()
+
+        >>> df.map(lambda p: p.name).collect()
+        [u'Alice', u'Bob']
+        """
+        return self.rdd.map(f)
+
+    def mapPartitions(self, f, preservesPartitioning=False):
+        """
+        Return a new RDD by applying a function to each partition.
+
+        >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
+        >>> def f(iterator): yield 1
+        >>> rdd.mapPartitions(f).sum()
+        4
+        """
+        return self.rdd.mapPartitions(f, preservesPartitioning)
+
+    def cache(self):
+        """ Persist with the default storage level (C{MEMORY_ONLY_SER}).
+        """
+        self.is_cached = True
+        self._jdf.cache()
+        return self
+
+    def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
+        """ Set the storage level to persist its values across operations
+        after the first time it is computed. This can only be used to assign
+        a new storage level if the RDD does not have a storage level set yet.
+        If no storage level is specified defaults to (C{MEMORY_ONLY_SER}).
+        """
+        self.is_cached = True
+        javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
+        self._jdf.persist(javaStorageLevel)
+        return self
+
+    def unpersist(self, blocking=True):
+        """ Mark it as non-persistent, and remove all blocks for it from
+        memory and disk.
+        """
+        self.is_cached = False
+        self._jdf.unpersist(blocking)
+        return self
+
+    # def coalesce(self, numPartitions, shuffle=False):
+    #     rdd = self._jdf.coalesce(numPartitions, shuffle, None)
+    #     return DataFrame(rdd, self.sql_ctx)
+
+    def repartition(self, numPartitions):
+        """ Return a new :class:`DataFrame` that has exactly `numPartitions`
+        partitions.
+        """
+        rdd = self._jdf.repartition(numPartitions, None)
+        return DataFrame(rdd, self.sql_ctx)
+
+    def sample(self, withReplacement, fraction, seed=None):
+        """
+        Return a sampled subset of this DataFrame.
+
+        >>> df.sample(False, 0.5, 97).count()
+        1L
+        """
+        assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+        seed = seed if seed is not None else random.randint(0, sys.maxint)
+        rdd = self._jdf.sample(withReplacement, fraction, long(seed))
+        return DataFrame(rdd, self.sql_ctx)
+
+    # def takeSample(self, withReplacement, num, seed=None):
+    #     """Return a fixed-size sampled subset of this DataFrame.
+    #
+    #     >>> df = sqlCtx.inferSchema(rdd)
+    #     >>> df.takeSample(False, 2, 97)
+    #     [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
+    #     """
+    #     seed = seed if seed is not None else random.randint(0, sys.maxint)
+    #     with SCCallSiteSync(self.context) as css:
+    #         bytesInJava = self._jdf \
+    #             .takeSampleToPython(withReplacement, num, long(seed)) \
+    #             .iterator()
+    #     cls = _create_cls(self.schema())
+    #     return map(cls, self._collect_iterator_through_file(bytesInJava))
+
+    @property
+    def dtypes(self):
+        """Return all column names and their data types as a list.
+
+        >>> df.dtypes
+        [('age', 'integer'), ('name', 'string')]
+        """
+        return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields]
+
+    @property
+    def columns(self):
+        """ Return all column names as a list.
+
+        >>> df.columns
+        [u'age', u'name']
+        """
+        return [f.name for f in self.schema().fields]
+
+    def join(self, other, joinExprs=None, joinType=None):
+        """
+        Join with another DataFrame, using the given join expression.
+        The following performs a full outer join between `df1` and `df2`::
+
+        :param other: Right side of the join
+        :param joinExprs: Join expression
+        :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
+
+        >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
+        [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
+        """
+
+        if joinExprs is None:
+            jdf = self._jdf.join(other._jdf)
+        else:
+            assert isinstance(joinExprs, Column), "joinExprs should be Column"
+            if joinType is None:
+                jdf = self._jdf.join(other._jdf, joinExprs._jc)
+            else:
+                assert isinstance(joinType, basestring), "joinType should be basestring"
+                jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
+        return DataFrame(jdf, self.sql_ctx)
+
+    def sort(self, *cols):
+        """ Return a new :class:`DataFrame` sorted by the specified column.
+
+        :param cols: The columns or expressions used for sorting
+
+        >>> df.sort(df.age.desc()).collect()
+        [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
+        >>> df.sortBy(df.age.desc()).collect()
+        [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
+        """
+        if not cols:
+            raise ValueError("should sort by at least one column")
+        jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+                                        self._sc._gateway._gateway_client)
+        jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
+        return DataFrame(jdf, self.sql_ctx)
+
+    sortBy = sort
+
+    def head(self, n=None):
+        """ Return the first `n` rows or the first row if n is None.
+
+        >>> df.head()
+        Row(age=2, name=u'Alice')
+        >>> df.head(1)
+        [Row(age=2, name=u'Alice')]
+        """
+        if n is None:
+            rs = self.head(1)
+            return rs[0] if rs else None
+        return self.take(n)
+
+    def first(self):
+        """ Return the first row.
+
+        >>> df.first()
+        Row(age=2, name=u'Alice')
+        """
+        return self.head()
+
+    def __getitem__(self, item):
+        """ Return the column by given name
+
+        >>> df['age'].collect()
+        [Row(age=2), Row(age=5)]
+        >>> df[ ["name", "age"]].collect()
+        [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
+        >>> df[ df.age > 3 ].collect()
+        [Row(age=5, name=u'Bob')]
+        """
+        if isinstance(item, basestring):
+            jc = self._jdf.apply(item)
+            return Column(jc, self.sql_ctx)
+        elif isinstance(item, Column):
+            return self.filter(item)
+        elif isinstance(item, list):
+            return self.select(*item)
+        else:
+            raise IndexError("unexpected index: %s" % item)
+
+    def __getattr__(self, name):
+        """ Return the column by given name
+
+        >>> df.age.collect()
+        [Row(age=2), Row(age=5)]
+        """
+        if name.startswith("__"):
+            raise AttributeError(name)
+        jc = self._jdf.apply(name)
+        return Column(jc, self.sql_ctx)
+
+    def select(self, *cols):
+        """ Selecting a set of expressions.
+
+        >>> df.select().collect()
+        [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+        >>> df.select('*').collect()
+        [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+        >>> df.select('name', 'age').collect()
+        [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
+        >>> df.select(df.name, (df.age + 10).alias('age')).collect()
+        [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
+        """
+        if not cols:
+            cols = ["*"]
+        jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+                                        self._sc._gateway._gateway_client)
+        jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+        return DataFrame(jdf, self.sql_ctx)
+
+    def selectExpr(self, *expr):
+        """
+        Selects a set of SQL expressions. This is a variant of
+        `select` that accepts SQL expressions.
+
+        >>> df.selectExpr("age * 2", "abs(age)").collect()
+        [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)]
+        """
+        jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client)
+        jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
+        return DataFrame(jdf, self.sql_ctx)
+
+    def filter(self, condition):
+        """ Filtering rows using the given condition, which could be
+        Column expression or string of SQL expression.
+
+        where() is an alias for filter().
+
+        >>> df.filter(df.age > 3).collect()
+        [Row(age=5, name=u'Bob')]
+        >>> df.where(df.age == 2).collect()
+        [Row(age=2, name=u'Alice')]
+
+        >>> df.filter("age > 3").collect()
+        [Row(age=5, name=u'Bob')]
+        >>> df.where("age = 2").collect()
+        [Row(age=2, name=u'Alice')]
+        """
+        if isinstance(condition, basestring):
+            jdf = self._jdf.filter(condition)
+        elif isinstance(condition, Column):
+            jdf = self._jdf.filter(condition._jc)
+        else:
+            raise TypeError("condition should be string or Column")
+        return DataFrame(jdf, self.sql_ctx)
+
+    where = filter
+
+    def groupBy(self, *cols):
+        """ Group the :class:`DataFrame` using the specified columns,
+        so we can run aggregation on them. See :class:`GroupedData`
+        for all the available aggregate functions.
+
+        >>> df.groupBy().avg().collect()
+        [Row(AVG(age#0)=3.5)]
+        >>> df.groupBy('name').agg({'age': 'mean'}).collect()
+        [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
+        >>> df.groupBy(df.name).avg().collect()
+        [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
+        """
+        jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+                                        self._sc._gateway._gateway_client)
+        jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+        return GroupedData(jdf, self.sql_ctx)
+
+    def agg(self, *exprs):
+        """ Aggregate on the entire :class:`DataFrame` without groups
+        (shorthand for df.groupBy.agg()).
+
+        >>> df.agg({"age": "max"}).collect()
+        [Row(MAX(age#0)=5)]
+        >>> from pyspark.sql import Dsl
+        >>> df.agg(Dsl.min(df.age)).collect()
+        [Row(MIN(age#0)=2)]
+        """
+        return self.groupBy().agg(*exprs)
+
+    def unionAll(self, other):
+        """ Return a new DataFrame containing union of rows in this
+        frame and another frame.
+
+        This is equivalent to `UNION ALL` in SQL.
+        """
+        return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
+
+    def intersect(self, other):
+        """ Return a new :class:`DataFrame` containing rows only in
+        both this frame and another frame.
+
+        This is equivalent to `INTERSECT` in SQL.
+        """
+        return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
+
+    def subtract(self, other):
+        """ Return a new :class:`DataFrame` containing rows in this frame
+        but not in another frame.
+
+        This is equivalent to `EXCEPT` in SQL.
+        """
+        return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
+
+    def addColumn(self, colName, col):
+        """ Return a new :class:`DataFrame` by adding a column.
+
+        >>> df.addColumn('age2', df.age + 2).collect()
+        [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
+        """
+        return self.select('*', col.alias(colName))
+
+    def to_pandas(self):
+        """
+        Collect all the rows and return a `pandas.DataFrame`.
+
+        >>> df.to_pandas()  # doctest: +SKIP
+           age   name
+        0    2  Alice
+        1    5    Bob
+        """
+        import pandas as pd
+        return pd.DataFrame.from_records(self.collect(), columns=self.columns)
+
+
+# Having SchemaRDD for backward compatibility (for docs)
+class SchemaRDD(DataFrame):
+    """
+    SchemaRDD is deprecated, please use DataFrame
+    """
+
+
+def dfapi(f):
+    def _api(self):
+        name = f.__name__
+        jdf = getattr(self._jdf, name)()
+        return DataFrame(jdf, self.sql_ctx)
+    _api.__name__ = f.__name__
+    _api.__doc__ = f.__doc__
+    return _api
+
+
+class GroupedData(object):
+
+    """
+    A set of methods for aggregations on a :class:`DataFrame`,
+    created by DataFrame.groupBy().
+    """
+
+    def __init__(self, jdf, sql_ctx):
+        self._jdf = jdf
+        self.sql_ctx = sql_ctx
+
+    def agg(self, *exprs):
+        """ Compute aggregates by specifying a map from column name
+        to aggregate methods.
+
+        The available aggregate methods are `avg`, `max`, `min`,
+        `sum`, `count`.
+
+        :param exprs: list or aggregate columns or a map from column
+                      name to aggregate methods.
+
+        >>> gdf = df.groupBy(df.name)
+        >>> gdf.agg({"age": "max"}).collect()
+        [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)]
+        >>> from pyspark.sql import Dsl
+        >>> gdf.agg(Dsl.min(df.age)).collect()
+        [Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
+        """
+        assert exprs, "exprs should not be empty"
+        if len(exprs) == 1 and isinstance(exprs[0], dict):
+            jmap = MapConverter().convert(exprs[0],
+                                          self.sql_ctx._sc._gateway._gateway_client)
+            jdf = self._jdf.agg(jmap)
+        else:
+            # Columns
+            assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
+            jcols = ListConverter().convert([c._jc for c in exprs[1:]],
+                                            self.sql_ctx._sc._gateway._gateway_client)
+            jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+        return DataFrame(jdf, self.sql_ctx)
+
+    @dfapi
+    def count(self):
+        """ Count the number of rows for each group.
+
+        >>> df.groupBy(df.age).count().collect()
+        [Row(age=2, count=1), Row(age=5, count=1)]
+        """
+
+    @dfapi
+    def mean(self):
+        """Compute the average value for each numeric columns
+        for each group. This is an alias for `avg`."""
+
+    @dfapi
+    def avg(self):
+        """Compute the average value for each numeric columns
+        for each group."""
+
+    @dfapi
+    def max(self):
+        """Compute the max value for each numeric columns for
+        each group. """
+
+    @dfapi
+    def min(self):
+        """Compute the min value for each numeric column for
+        each group."""
+
+    @dfapi
+    def sum(self):
+        """Compute the sum for each numeric columns for each
+        group."""
+
+
+def _create_column_from_literal(literal):
+    sc = SparkContext._active_spark_context
+    return sc._jvm.Dsl.lit(literal)
+
+
+def _create_column_from_name(name):
+    sc = SparkContext._active_spark_context
+    return sc._jvm.Dsl.col(name)
+
+
+def _to_java_column(col):
+    if isinstance(col, Column):
+        jcol = col._jc
+    else:
+        jcol = _create_column_from_name(col)
+    return jcol
+
+
+def _unary_op(name, doc="unary operator"):
+    """ Create a method for given unary operator """
+    def _(self):
+        jc = getattr(self._jc, name)()
+        return Column(jc, self.sql_ctx)
+    _.__doc__ = doc
+    return _
+
+
+def _dsl_op(name, doc=''):
+    def _(self):
+        jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
+        return Column(jc, self.sql_ctx)
+    _.__doc__ = doc
+    return _
+
+
+def _bin_op(name, doc="binary operator"):
+    """ Create a method for given binary operator
+    """
+    def _(self, other):
+        jc = other._jc if isinstance(other, Column) else other
+        njc = getattr(self._jc, name)(jc)
+        return Column(njc, self.sql_ctx)
+    _.__doc__ = doc
+    return _
+
+
+def _reverse_op(name, doc="binary operator"):
+    """ Create a method for binary operator (this object is on right side)
+    """
+    def _(self, other):
+        jother = _create_column_from_literal(other)
+        jc = getattr(jother, name)(self._jc)
+        return Column(jc, self.sql_ctx)
+    _.__doc__ = doc
+    return _
+
+
+class Column(DataFrame):
+
+    """
+    A column in a DataFrame.
+
+    `Column` instances can be created by::
+
+        # 1. Select a column out of a DataFrame
+        df.colName
+        df["colName"]
+
+        # 2. Create from an expression
+        df.colName + 1
+        1 / df.colName
+    """
+
+    def __init__(self, jc, sql_ctx=None):
+        self._jc = jc
+        super(Column, self).__init__(jc, sql_ctx)
+
+    # arithmetic operators
+    __neg__ = _dsl_op("negate")
+    __add__ = _bin_op("plus")
+    __sub__ = _bin_op("minus")
+    __mul__ = _bin_op("multiply")
+    __div__ = _bin_op("divide")
+    __mod__ = _bin_op("mod")
+    __radd__ = _bin_op("plus")
+    __rsub__ = _reverse_op("minus")
+    __rmul__ = _bin_op("multiply")
+    __rdiv__ = _reverse_op("divide")
+    __rmod__ = _reverse_op("mod")
+
+    # logistic operators
+    __eq__ = _bin_op("equalTo")
+    __ne__ = _bin_op("notEqual")
+    __lt__ = _bin_op("lt")
+    __le__ = _bin_op("leq")
+    __ge__ = _bin_op("geq")
+    __gt__ = _bin_op("gt")
+
+    # `and`, `or`, `not` cannot be overloaded in Python,
+    # so use bitwise operators as boolean operators
+    __and__ = _bin_op('and')
+    __or__ = _bin_op('or')
+    __invert__ = _dsl_op('not')
+    __rand__ = _bin_op("and")
+    __ror__ = _bin_op("or")
+
+    # container operators
+    __contains__ = _bin_op("contains")
+    __getitem__ = _bin_op("getItem")
+    getField = _bin_op("getField", "An expression that gets a field by name in a StructField.")
+
+    # string methods
+    rlike = _bin_op("rlike")
+    like = _bin_op("like")
+    startswith = _bin_op("startsWith")
+    endswith = _bin_op("endsWith")
+
+    def substr(self, startPos, length):
+        """
+        Return a Column which is a substring of the column
+
+        :param startPos: start position (int or Column)
+        :param length:  length of the substring (int or Column)
+
+        >>> df.name.substr(1, 3).collect()
+        [Row(col=u'Ali'), Row(col=u'Bob')]
+        """
+        if type(startPos) != type(length):
+            raise TypeError("Can not mix the type")
+        if isinstance(startPos, (int, long)):
+            jc = self._jc.substr(startPos, length)
+        elif isinstance(startPos, Column):
+            jc = self._jc.substr(startPos._jc, length._jc)
+        else:
+            raise TypeError("Unexpected type: %s" % type(startPos))
+        return Column(jc, self.sql_ctx)
+
+    __getslice__ = substr
+
+    # order
+    asc = _unary_op("asc")
+    desc = _unary_op("desc")
+
+    isNull = _unary_op("isNull", "True if the current expression is null.")
+    isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
+
+    def alias(self, alias):
+        """Return a alias for this column
+
+        >>> df.age.alias("age2").collect()
+        [Row(age2=2), Row(age2=5)]
+        """
+        return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
+
+    def cast(self, dataType):
+        """ Convert the column into type `dataType`
+
+        >>> df.select(df.age.cast("string").alias('ages')).collect()
+        [Row(ages=u'2'), Row(ages=u'5')]
+        >>> df.select(df.age.cast(StringType()).alias('ages')).collect()
+        [Row(ages=u'2'), Row(ages=u'5')]
+        """
+        if self.sql_ctx is None:
+            sc = SparkContext._active_spark_context
+            ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+        else:
+            ssql_ctx = self.sql_ctx._ssql_ctx
+        if isinstance(dataType, basestring):
+            jc = self._jc.cast(dataType)
+        elif isinstance(dataType, DataType):
+            jdt = ssql_ctx.parseDataType(dataType.json())
+            jc = self._jc.cast(jdt)
+        return Column(jc, self.sql_ctx)
+
+    def to_pandas(self):
+        """
+        Return a pandas.Series from the column
+
+        >>> df.age.to_pandas()  # doctest: +SKIP
+        0    2
+        1    5
+        dtype: int64
+        """
+        import pandas as pd
+        data = [c for c, in self.collect()]
+        return pd.Series(data)
+
+
+def _aggregate_func(name, doc=""):
+    """ Create a function for aggregator by name"""
+    def _(col):
+        sc = SparkContext._active_spark_context
+        jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
+        return Column(jc)
+    _.__name__ = name
+    _.__doc__ = doc
+    return staticmethod(_)
+
+
+class UserDefinedFunction(object):
+    def __init__(self, func, returnType):
+        self.func = func
+        self.returnType = returnType
+        self._broadcast = None
+        self._judf = self._create_judf()
+
+    def _create_judf(self):
+        f = self.func  # put it in closure `func`
+        func = lambda _, it: imap(lambda x: f(*x), it)
+        ser = AutoBatchedSerializer(PickleSerializer())
+        command = (func, None, ser, ser)
+        sc = SparkContext._active_spark_context
+        pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
+        ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+        jdt = ssql_ctx.parseDataType(self.returnType.json())
+        judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
+                                                 includes, sc.pythonExec, broadcast_vars,
+                                                 sc._javaAccumulator, jdt)
+        return judf
+
+    def __del__(self):
+        if self._broadcast is not None:
+            self._broadcast.unpersist()
+            self._broadcast = None
+
+    def __call__(self, *cols):
+        sc = SparkContext._active_spark_context
+        jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+                                        sc._gateway._gateway_client)
+        jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
+        return Column(jc)
+
+
+class Dsl(object):
+    """
+    A collections of builtin aggregators
+    """
+    DSLS = {
+        'lit': 'Creates a :class:`Column` of literal value.',
+        'col': 'Returns a :class:`Column` based on the given column name.',
+        'column': 'Returns a :class:`Column` based on the given column name.',
+        'upper': 'Converts a string expression to upper case.',
+        'lower': 'Converts a string expression to upper case.',
+        'sqrt': 'Computes the square root of the specified float value.',
+        'abs': 'Computes the absolutle value.',
+
+        'max': 'Aggregate function: returns the maximum value of the expression in a group.',
+        'min': 'Aggregate function: returns the minimum value of the expression in a group.',
+        'first': 'Aggregate function: returns the first value in a group.',
+        'last': 'Aggregate function: returns the last value in a group.',
+        'count': 'Aggregate function: returns the number of items in a group.',
+        'sum': 'Aggregate function: returns the sum of all values in the expression.',
+        'avg': 'Aggregate function: returns the average of the values in a group.',
+        'mean': 'Aggregate function: returns the average of the values in a group.',
+        'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
+    }
+
+    for _name, _doc in DSLS.items():
+        locals()[_name] = _aggregate_func(_name, _doc)
+    del _name, _doc
+
+    @staticmethod
+    def countDistinct(col, *cols):
+        """ Return a new Column for distinct count of (col, *cols)
+
+        >>> from pyspark.sql import Dsl
+        >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
+        [Row(c=2)]
+
+        >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect()
+        [Row(c=2)]
+        """
+        sc = SparkContext._active_spark_context
+        jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+                                        sc._gateway._gateway_client)
+        jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
+                                       sc._jvm.PythonUtils.toSeq(jcols))
+        return Column(jc)
+
+    @staticmethod
+    def approxCountDistinct(col, rsd=None):
+        """ Return a new Column for approxiate distinct count of (col, *cols)
+
+        >>> from pyspark.sql import Dsl
+        >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
+        [Row(c=2)]
+        """
+        sc = SparkContext._active_spark_context
+        if rsd is None:
+            jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
+        else:
+            jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
+        return Column(jc)
+
+    @staticmethod
+    def udf(f, returnType=StringType()):
+        """Create a user defined function (UDF)
+
+        >>> slen = Dsl.udf(lambda s: len(s), IntegerType())
+        >>> df.select(slen(df.name).alias('slen')).collect()
+        [Row(slen=5), Row(slen=3)]
+        """
+        return UserDefinedFunction(f, returnType)
+
+
+def _test():
+    import doctest
+    from pyspark.context import SparkContext
+    from pyspark.sql import Row, SQLContext
+    import pyspark.sql.dataframe
+    globs = pyspark.sql.dataframe.__dict__.copy()
+    sc = SparkContext('local[4]', 'PythonTest')
+    globs['sc'] = sc
+    globs['sqlCtx'] = sqlCtx = SQLContext(sc)
+    rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)])
+    rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)])
+    globs['df'] = sqlCtx.inferSchema(rdd2)
+    globs['df2'] = sqlCtx.inferSchema(rdd3)
+    (failure_count, test_count) = doctest.testmod(
+        pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS)
+    globs['sc'].stop()
+    if failure_count:
+        exit(-1)
+
+
+if __name__ == "__main__":
+    _test()

http://git-wip-us.apache.org/repos/asf/spark/blob/f0562b42/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
new file mode 100644
index 0000000..d25c636
--- /dev/null
+++ b/python/pyspark/sql/tests.py
@@ -0,0 +1,300 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Unit tests for pyspark.sql; additional tests are implemented as doctests in
+individual modules.
+"""
+import os
+import sys
+import pydoc
+import shutil
+import tempfile
+
+if sys.version_info[:2] <= (2, 6):
+    try:
+        import unittest2 as unittest
+    except ImportError:
+        sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
+        sys.exit(1)
+else:
+    import unittest
+
+
+from pyspark.sql import SQLContext, Column
+from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
+    UserDefinedType, DoubleType, LongType
+from pyspark.tests import ReusedPySparkTestCase
+
+
+class ExamplePointUDT(UserDefinedType):
+    """
+    User-defined type (UDT) for ExamplePoint.
+    """
+
+    @classmethod
+    def sqlType(self):
+        return ArrayType(DoubleType(), False)
+
+    @classmethod
+    def module(cls):
+        return 'pyspark.tests'
+
+    @classmethod
+    def scalaUDT(cls):
+        return 'org.apache.spark.sql.test.ExamplePointUDT'
+
+    def serialize(self, obj):
+        return [obj.x, obj.y]
+
+    def deserialize(self, datum):
+        return ExamplePoint(datum[0], datum[1])
+
+
+class ExamplePoint:
+    """
+    An example class to demonstrate UDT in Scala, Java, and Python.
+    """
+
+    __UDT__ = ExamplePointUDT()
+
+    def __init__(self, x, y):
+        self.x = x
+        self.y = y
+
+    def __repr__(self):
+        return "ExamplePoint(%s,%s)" % (self.x, self.y)
+
+    def __str__(self):
+        return "(%s,%s)" % (self.x, self.y)
+
+    def __eq__(self, other):
+        return isinstance(other, ExamplePoint) and \
+            other.x == self.x and other.y == self.y
+
+
+class SQLTests(ReusedPySparkTestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        ReusedPySparkTestCase.setUpClass()
+        cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(cls.tempdir.name)
+        cls.sqlCtx = SQLContext(cls.sc)
+        cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+        rdd = cls.sc.parallelize(cls.testData)
+        cls.df = cls.sqlCtx.inferSchema(rdd)
+
+    @classmethod
+    def tearDownClass(cls):
+        ReusedPySparkTestCase.tearDownClass()
+        shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+
+    def test_udf(self):
+        self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
+        [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
+        self.assertEqual(row[0], 5)
+
+    def test_udf2(self):
+        self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
+        self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
+        [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
+        self.assertEqual(4, res[0])
+
+    def test_udf_with_array_type(self):
+        d = [Row(l=range(3), d={"key": range(5)})]
+        rdd = self.sc.parallelize(d)
+        self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+        self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
+        self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
+        [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
+        self.assertEqual(range(3), l1)
+        self.assertEqual(1, l2)
+
+    def test_broadcast_in_udf(self):
+        bar = {"a": "aa", "b": "bb", "c": "abc"}
+        foo = self.sc.broadcast(bar)
+        self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
+        [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
+        self.assertEqual("abc", res[0])
+        [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
+        self.assertEqual("", res[0])
+
+    def test_basic_functions(self):
+        rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
+        df = self.sqlCtx.jsonRDD(rdd)
+        df.count()
+        df.collect()
+        df.schema()
+
+        # cache and checkpoint
+        self.assertFalse(df.is_cached)
+        df.persist()
+        df.unpersist()
+        df.cache()
+        self.assertTrue(df.is_cached)
+        self.assertEqual(2, df.count())
+
+        df.registerTempTable("temp")
+        df = self.sqlCtx.sql("select foo from temp")
+        df.count()
+        df.collect()
+
+    def test_apply_schema_to_row(self):
+        df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
+        df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
+        self.assertEqual(df.collect(), df2.collect())
+
+        rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
+        df3 = self.sqlCtx.applySchema(rdd, df.schema())
+        self.assertEqual(10, df3.count())
+
+    def test_serialize_nested_array_and_map(self):
+        d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
+        rdd = self.sc.parallelize(d)
+        df = self.sqlCtx.inferSchema(rdd)
+        row = df.head()
+        self.assertEqual(1, len(row.l))
+        self.assertEqual(1, row.l[0].a)
+        self.assertEqual("2", row.d["key"].d)
+
+        l = df.map(lambda x: x.l).first()
+        self.assertEqual(1, len(l))
+        self.assertEqual('s', l[0].b)
+
+        d = df.map(lambda x: x.d).first()
+        self.assertEqual(1, len(d))
+        self.assertEqual(1.0, d["key"].c)
+
+        row = df.map(lambda x: x.d["key"]).first()
+        self.assertEqual(1.0, row.c)
+        self.assertEqual("2", row.d)
+
+    def test_infer_schema(self):
+        d = [Row(l=[], d={}),
+             Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
+        rdd = self.sc.parallelize(d)
+        df = self.sqlCtx.inferSchema(rdd)
+        self.assertEqual([], df.map(lambda r: r.l).first())
+        self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
+        df.registerTempTable("test")
+        result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
+        self.assertEqual(1, result.head()[0])
+
+        df2 = self.sqlCtx.inferSchema(rdd, 1.0)
+        self.assertEqual(df.schema(), df2.schema())
+        self.assertEqual({}, df2.map(lambda r: r.d).first())
+        self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
+        df2.registerTempTable("test2")
+        result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
+        self.assertEqual(1, result.head()[0])
+
+    def test_struct_in_map(self):
+        d = [Row(m={Row(i=1): Row(s="")})]
+        rdd = self.sc.parallelize(d)
+        df = self.sqlCtx.inferSchema(rdd)
+        k, v = df.head().m.items()[0]
+        self.assertEqual(1, k.i)
+        self.assertEqual("", v.s)
+
+    def test_convert_row_to_dict(self):
+        row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
+        self.assertEqual(1, row.asDict()['l'][0].a)
+        rdd = self.sc.parallelize([row])
+        df = self.sqlCtx.inferSchema(rdd)
+        df.registerTempTable("test")
+        row = self.sqlCtx.sql("select l, d from test").head()
+        self.assertEqual(1, row.asDict()["l"][0].a)
+        self.assertEqual(1.0, row.asDict()['d']['key'].c)
+
+    def test_infer_schema_with_udt(self):
+        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
+        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+        rdd = self.sc.parallelize([row])
+        df = self.sqlCtx.inferSchema(rdd)
+        schema = df.schema()
+        field = [f for f in schema.fields if f.name == "point"][0]
+        self.assertEqual(type(field.dataType), ExamplePointUDT)
+        df.registerTempTable("labeled_point")
+        point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
+        self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+    def test_apply_schema_with_udt(self):
+        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
+        row = (1.0, ExamplePoint(1.0, 2.0))
+        rdd = self.sc.parallelize([row])
+        schema = StructType([StructField("label", DoubleType(), False),
+                             StructField("point", ExamplePointUDT(), False)])
+        df = self.sqlCtx.applySchema(rdd, schema)
+        point = df.head().point
+        self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+    def test_parquet_with_udt(self):
+        from pyspark.sql.tests import ExamplePoint
+        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+        rdd = self.sc.parallelize([row])
+        df0 = self.sqlCtx.inferSchema(rdd)
+        output_dir = os.path.join(self.tempdir.name, "labeled_point")
+        df0.saveAsParquetFile(output_dir)
+        df1 = self.sqlCtx.parquetFile(output_dir)
+        point = df1.head().point
+        self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+    def test_column_operators(self):
+        ci = self.df.key
+        cs = self.df.value
+        c = ci == cs
+        self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
+        rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
+        self.assertTrue(all(isinstance(c, Column) for c in rcc))
+        cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
+        self.assertTrue(all(isinstance(c, Column) for c in cb))
+        cbool = (ci & ci), (ci | ci), (~ci)
+        self.assertTrue(all(isinstance(c, Column) for c in cbool))
+        css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
+        self.assertTrue(all(isinstance(c, Column) for c in css))
+        self.assertTrue(isinstance(ci.cast(LongType()), Column))
+
+    def test_column_select(self):
+        df = self.df
+        self.assertEqual(self.testData, df.select("*").collect())
+        self.assertEqual(self.testData, df.select(df.key, df.value).collect())
+        self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
+
+    def test_aggregator(self):
+        df = self.df
+        g = df.groupBy()
+        self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
+        self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
+
+        from pyspark.sql import Dsl
+        self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first()))
+        self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
+        self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])
+
+    def test_help_command(self):
+        # Regression test for SPARK-5464
+        rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
+        df = self.sqlCtx.jsonRDD(rdd)
+        # render_doc() reproduces the help() exception without printing output
+        pydoc.render_doc(df)
+        pydoc.render_doc(df.foo)
+        pydoc.render_doc(df.take(1))
+
+
+if __name__ == "__main__":
+    unittest.main()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message