spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-5554] [SQL] [PySpark] add more tests for DataFrame Python API
Date Wed, 04 Feb 2015 00:02:15 GMT
Repository: spark
Updated Branches:
  refs/heads/master 1e8b5394b -> 068c0e2ee


[SPARK-5554] [SQL] [PySpark] add more tests for DataFrame Python API

Add more tests and docs for DataFrame Python API, improve test coverage, fix bugs.

Author: Davies Liu <davies@databricks.com>

Closes #4331 from davies/fix_df and squashes the following commits:

dd9919f [Davies Liu] fix tests
467332c [Davies Liu] support string in cast()
83c92fe [Davies Liu] address comments
c052f6f [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_df
8dd19a9 [Davies Liu] fix tests in python 2.6
35ccb9f [Davies Liu] fix build
78ebcfa [Davies Liu] add sql_test.py in run_tests
9ab78b4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_df
6040ba7 [Davies Liu] fix docs
3ab2661 [Davies Liu] add more tests for DataFrame


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/068c0e2e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/068c0e2e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/068c0e2e

Branch: refs/heads/master
Commit: 068c0e2ee05ee8b133c2dc26b8fa094ab2712d45
Parents: 1e8b539
Author: Davies Liu <davies@databricks.com>
Authored: Tue Feb 3 16:01:56 2015 -0800
Committer: Reynold Xin <rxin@databricks.com>
Committed: Tue Feb 3 16:01:56 2015 -0800

----------------------------------------------------------------------
 python/pyspark/sql.py                           | 467 +++++++++++--------
 python/pyspark/sql_tests.py                     | 299 ++++++++++++
 python/pyspark/tests.py                         | 261 -----------
 python/run-tests                                |   1 +
 .../scala/org/apache/spark/sql/Column.scala     |  38 +-
 .../apache/spark/sql/test/ExamplePointUDT.scala |   2 +-
 6 files changed, 586 insertions(+), 482 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/068c0e2e/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 32bff0c..268c7ef 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -62,7 +62,7 @@ __all__ = [
     "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType",
     "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
     "ShortType", "ArrayType", "MapType", "StructField", "StructType",
-    "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row",
+    "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", "Dsl",
     "SchemaRDD"]
 
 
@@ -1804,7 +1804,7 @@ class DataFrame(object):
         people = sqlContext.parquetFile("...")
 
     Once created, it can be manipulated using the various domain-specific-language
-    (DSL) functions defined in: [[DataFrame]], [[Column]].
+    (DSL) functions defined in: :class:`DataFrame`, :class:`Column`.
 
     To select a column from the data frame, use the apply method::
 
@@ -1835,8 +1835,10 @@ class DataFrame(object):
 
     @property
     def rdd(self):
-        """Return the content of the :class:`DataFrame` as an :class:`RDD`
-        of :class:`Row`s. """
+        """
+        Return the content of the :class:`DataFrame` as an :class:`RDD`
+        of :class:`Row` s.
+        """
         if not hasattr(self, '_lazy_rdd'):
             jrdd = self._jdf.javaToPython()
             rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
@@ -1850,18 +1852,6 @@ class DataFrame(object):
 
         return self._lazy_rdd
 
-    def limit(self, num):
-        """Limit the result count to the number specified.
-
-        >>> df = sqlCtx.inferSchema(rdd)
-        >>> df.limit(2).collect()
-        [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
-        >>> df.limit(0).collect()
-        []
-        """
-        jdf = self._jdf.limit(num)
-        return DataFrame(jdf, self.sql_ctx)
-
     def toJSON(self, use_unicode=False):
         """Convert a DataFrame into a MappedRDD of JSON documents; one document per row.
 
@@ -1886,7 +1876,6 @@ class DataFrame(object):
         >>> import tempfile, shutil
         >>> parquetFile = tempfile.mkdtemp()
         >>> shutil.rmtree(parquetFile)
-        >>> df = sqlCtx.inferSchema(rdd)
         >>> df.saveAsParquetFile(parquetFile)
         >>> df2 = sqlCtx.parquetFile(parquetFile)
         >>> sorted(df2.collect()) == sorted(df.collect())
@@ -1900,9 +1889,8 @@ class DataFrame(object):
         The lifetime of this temporary table is tied to the L{SQLContext}
         that was used to create this DataFrame.
 
-        >>> df = sqlCtx.inferSchema(rdd)
-        >>> df.registerTempTable("test")
-        >>> df2 = sqlCtx.sql("select * from test")
+        >>> df.registerTempTable("people")
+        >>> df2 = sqlCtx.sql("select * from people")
         >>> sorted(df.collect()) == sorted(df2.collect())
         True
         """
@@ -1926,11 +1914,22 @@ class DataFrame(object):
 
     def schema(self):
         """Returns the schema of this DataFrame (represented by
-        a L{StructType})."""
+        a L{StructType}).
+
+        >>> df.schema()
+        StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
+        """
         return _parse_datatype_json_string(self._jdf.schema().json())
 
     def printSchema(self):
-        """Prints out the schema in the tree format."""
+        """Prints out the schema in the tree format.
+
+        >>> df.printSchema()
+        root
+         |-- age: integer (nullable = true)
+         |-- name: string (nullable = true)
+        <BLANKLINE>
+        """
         print (self._jdf.schema().treeString())
 
     def count(self):
@@ -1940,11 +1939,8 @@ class DataFrame(object):
         leverages the query optimizer to compute the count on the DataFrame,
         which supports features such as filter pushdown.
 
-        >>> df = sqlCtx.inferSchema(rdd)
         >>> df.count()
-        3L
-        >>> df.count() == df.map(lambda x: x).count()
-        True
+        2L
         """
         return self._jdf.count()
 
@@ -1954,13 +1950,11 @@ class DataFrame(object):
         Each object in the list is a Row, the fields can be accessed as
         attributes.
 
-        >>> df = sqlCtx.inferSchema(rdd)
         >>> df.collect()
-        [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
+        [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
         """
         with SCCallSiteSync(self._sc) as css:
             bytesInJava = self._jdf.javaToPython().collect().iterator()
-        cls = _create_cls(self.schema())
         tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
         tempFile.close()
         self._sc._writeToFile(bytesInJava, tempFile.name)
@@ -1968,23 +1962,37 @@ class DataFrame(object):
         with open(tempFile.name, 'rb') as tempFile:
             rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
         os.unlink(tempFile.name)
+        cls = _create_cls(self.schema())
         return [cls(r) for r in rs]
 
+    def limit(self, num):
+        """Limit the result count to the number specified.
+
+        >>> df.limit(1).collect()
+        [Row(age=2, name=u'Alice')]
+        >>> df.limit(0).collect()
+        []
+        """
+        jdf = self._jdf.limit(num)
+        return DataFrame(jdf, self.sql_ctx)
+
     def take(self, num):
         """Take the first num rows of the RDD.
 
         Each object in the list is a Row, the fields can be accessed as
         attributes.
 
-        >>> df = sqlCtx.inferSchema(rdd)
         >>> df.take(2)
-        [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
+        [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
         """
         return self.limit(num).collect()
 
     def map(self, f):
         """ Return a new RDD by applying a function to each Row, it's a
         shorthand for df.rdd.map()
+
+        >>> df.map(lambda p: p.name).collect()
+        [u'Alice', u'Bob']
         """
         return self.rdd.map(f)
 
@@ -2067,140 +2075,167 @@ class DataFrame(object):
     @property
     def dtypes(self):
         """Return all column names and their data types as a list.
+
+        >>> df.dtypes
+        [(u'age', 'IntegerType'), (u'name', 'StringType')]
         """
         return [(f.name, str(f.dataType)) for f in self.schema().fields]
 
     @property
     def columns(self):
         """ Return all column names as a list.
+
+        >>> df.columns
+        [u'age', u'name']
         """
         return [f.name for f in self.schema().fields]
 
-    def show(self):
-        raise NotImplemented
-
     def join(self, other, joinExprs=None, joinType=None):
         """
         Join with another DataFrame, using the given join expression.
         The following performs a full outer join between `df1` and `df2`::
 
-            df1.join(df2, df1.key == df2.key, "outer")
-
         :param other: Right side of the join
         :param joinExprs: Join expression
-        :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`,
-                         `semijoin`.
+        :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
+
+        >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
+        [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
         """
-        if joinType is None:
-            if joinExprs is None:
-                jdf = self._jdf.join(other._jdf)
-            else:
-                jdf = self._jdf.join(other._jdf, joinExprs)
+
+        if joinExprs is None:
+            jdf = self._jdf.join(other._jdf)
         else:
-            jdf = self._jdf.join(other._jdf, joinExprs, joinType)
+            assert isinstance(joinExprs, Column), "joinExprs should be Column"
+            if joinType is None:
+                jdf = self._jdf.join(other._jdf, joinExprs._jc)
+            else:
+                assert isinstance(joinType, basestring), "joinType should be basestring"
+                jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
         return DataFrame(jdf, self.sql_ctx)
 
     def sort(self, *cols):
-        """ Return a new [[DataFrame]] sorted by the specified column,
-        in ascending column.
+        """ Return a new :class:`DataFrame` sorted by the specified column.
 
         :param cols: The columns or expressions used for sorting
+
+        >>> df.sort(df.age.desc()).collect()
+        [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
+        >>> df.sortBy(df.age.desc()).collect()
+        [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
         """
         if not cols:
             raise ValueError("should sort by at least one column")
-        for i, c in enumerate(cols):
-            if isinstance(c, basestring):
-                cols[i] = Column(c)
-        jcols = [c._jc for c in cols]
-        jdf = self._jdf.join(*jcols)
+        jcols = ListConverter().convert([_to_java_column(c) for c in cols[1:]],
+                                        self._sc._gateway._gateway_client)
+        jdf = self._jdf.sort(_to_java_column(cols[0]),
+                             self._sc._jvm.Dsl.toColumns(jcols))
         return DataFrame(jdf, self.sql_ctx)
 
     sortBy = sort
 
     def head(self, n=None):
-        """ Return the first `n` rows or the first row if n is None. """
+        """ Return the first `n` rows or the first row if n is None.
+
+        >>> df.head()
+        Row(age=2, name=u'Alice')
+        >>> df.head(1)
+        [Row(age=2, name=u'Alice')]
+        """
         if n is None:
             rs = self.head(1)
             return rs[0] if rs else None
         return self.take(n)
 
     def first(self):
-        """ Return the first row. """
-        return self.head()
+        """ Return the first row.
 
-    def tail(self):
-        raise NotImplemented
+        >>> df.first()
+        Row(age=2, name=u'Alice')
+        """
+        return self.head()
 
     def __getitem__(self, item):
+        """ Return the column by given name
+
+        >>> df['age'].collect()
+        [Row(age=2), Row(age=5)]
+        """
         if isinstance(item, basestring):
-            return Column(self._jdf.apply(item))
+            jc = self._jdf.apply(item)
+            return Column(jc, self.sql_ctx)
 
         # TODO projection
         raise IndexError
 
     def __getattr__(self, name):
-        """ Return the column by given name """
+        """ Return the column by given name
+
+        >>> df.age.collect()
+        [Row(age=2), Row(age=5)]
+        """
         if name.startswith("__"):
             raise AttributeError(name)
-        return Column(self._jdf.apply(name))
-
-    def alias(self, name):
-        """ Alias the current DataFrame """
-        return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx)
+        jc = self._jdf.apply(name)
+        return Column(jc, self.sql_ctx)
 
     def select(self, *cols):
-        """ Selecting a set of expressions.::
-
-            df.select()
-            df.select('colA', 'colB')
-            df.select(df.colA, df.colB + 1)
-
+        """ Selecting a set of expressions.
+
+        >>> df.select().collect()
+        [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+        >>> df.select('*').collect()
+        [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+        >>> df.select('name', 'age').collect()
+        [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
+        >>> df.select(df.name, (df.age + 10).As('age')).collect()
+        [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
         """
         if not cols:
             cols = ["*"]
-        if isinstance(cols[0], basestring):
-            cols = [_create_column_from_name(n) for n in cols]
-        else:
-            cols = [c._jc for c in cols]
-        jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
+        jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+                                        self._sc._gateway._gateway_client)
         jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
         return DataFrame(jdf, self.sql_ctx)
 
     def filter(self, condition):
-        """ Filtering rows using the given condition::
-
-            df.filter(df.age > 15)
-            df.where(df.age > 15)
+        """ Filtering rows using the given condition.
 
+        >>> df.filter(df.age > 3).collect()
+        [Row(age=5, name=u'Bob')]
+        >>> df.where(df.age == 2).collect()
+        [Row(age=2, name=u'Alice')]
         """
         return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx)
 
     where = filter
 
     def groupBy(self, *cols):
-        """ Group the [[DataFrame]] using the specified columns,
+        """ Group the :class:`DataFrame` using the specified columns,
         so we can run aggregation on them. See :class:`GroupedDataFrame`
-        for all the available aggregate functions::
-
-            df.groupBy(df.department).avg()
-            df.groupBy("department", "gender").agg({
-                "salary": "avg",
-                "age":    "max",
-            })
+        for all the available aggregate functions.
+
+        >>> df.groupBy().avg().collect()
+        [Row(AVG(age#0)=3.5)]
+        >>> df.groupBy('name').agg({'age': 'mean'}).collect()
+        [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
+        >>> df.groupBy(df.name).avg().collect()
+        [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
         """
-        if cols and isinstance(cols[0], basestring):
-            cols = [_create_column_from_name(n) for n in cols]
-        else:
-            cols = [c._jc for c in cols]
-        jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
+        jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+                                        self._sc._gateway._gateway_client)
         jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
         return GroupedDataFrame(jdf, self.sql_ctx)
 
     def agg(self, *exprs):
-        """ Aggregate on the entire [[DataFrame]] without groups
-        (shorthand for df.groupBy.agg())::
-
-            df.agg({"age": "max", "salary": "avg"})
+        """ Aggregate on the entire :class:`DataFrame` without groups
+        (shorthand for df.groupBy.agg()).
+
+        >>> df.agg({"age": "max"}).collect()
+        [Row(MAX(age#0)=5)]
+        >>> from pyspark.sql import Dsl
+        >>> df.agg(Dsl.min(df.age)).collect()
+        [Row(MIN(age#0)=2)]
         """
         return self.groupBy().agg(*exprs)
 
@@ -2213,7 +2248,7 @@ class DataFrame(object):
         return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
 
     def intersect(self, other):
-        """ Return a new [[DataFrame]] containing rows only in
+        """ Return a new :class:`DataFrame` containing rows only in
         both this frame and another frame.
 
         This is equivalent to `INTERSECT` in SQL.
@@ -2221,7 +2256,7 @@ class DataFrame(object):
         return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
 
     def subtract(self, other):
-        """ Return a new [[DataFrame]] containing rows in this frame
+        """ Return a new :class:`DataFrame` containing rows in this frame
         but not in another frame.
 
         This is equivalent to `EXCEPT` in SQL.
@@ -2229,7 +2264,11 @@ class DataFrame(object):
         return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
 
     def sample(self, withReplacement, fraction, seed=None):
-        """ Return a new DataFrame by sampling a fraction of rows. """
+        """ Return a new DataFrame by sampling a fraction of rows.
+
+        >>> df.sample(False, 0.5, 10).collect()
+        [Row(age=2, name=u'Alice')]
+        """
         if seed is None:
             jdf = self._jdf.sample(withReplacement, fraction)
         else:
@@ -2237,11 +2276,12 @@ class DataFrame(object):
         return DataFrame(jdf, self.sql_ctx)
 
     def addColumn(self, colName, col):
-        """ Return a new [[DataFrame]] by adding a column. """
-        return self.select('*', col.alias(colName))
+        """ Return a new :class:`DataFrame` by adding a column.
 
-    def removeColumn(self, colName):
-        raise NotImplemented
+        >>> df.addColumn('age2', df.age + 2).collect()
+        [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
+        """
+        return self.select('*', col.As(colName))
 
 
 # Having SchemaRDD for backward compatibility (for docs)
@@ -2280,7 +2320,14 @@ class GroupedDataFrame(object):
         `sum`, `count`.
 
         :param exprs: list or aggregate columns or a map from column
-                      name to agregate methods.
+                      name to aggregate methods.
+
+        >>> gdf = df.groupBy(df.name)
+        >>> gdf.agg({"age": "max"}).collect()
+        [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)]
+        >>> from pyspark.sql import Dsl
+        >>> gdf.agg(Dsl.min(df.age)).collect()
+        [Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
         """
         assert exprs, "exprs should not be empty"
         if len(exprs) == 1 and isinstance(exprs[0], dict):
@@ -2297,7 +2344,11 @@ class GroupedDataFrame(object):
 
     @dfapi
     def count(self):
-        """ Count the number of rows for each group. """
+        """ Count the number of rows for each group.
+
+        >>> df.groupBy(df.age).count().collect()
+        [Row(age=2, count=1), Row(age=5, count=1)]
+        """
 
     @dfapi
     def mean(self):
@@ -2349,18 +2400,25 @@ SCALA_METHOD_MAPPINGS = {
 
 def _create_column_from_literal(literal):
     sc = SparkContext._active_spark_context
-    return sc._jvm.org.apache.spark.sql.Dsl.lit(literal)
+    return sc._jvm.Dsl.lit(literal)
 
 
 def _create_column_from_name(name):
     sc = SparkContext._active_spark_context
-    return sc._jvm.IncomputableColumn(name)
+    return sc._jvm.Dsl.col(name)
+
+
+def _to_java_column(col):
+    if isinstance(col, Column):
+        jcol = col._jc
+    else:
+        jcol = _create_column_from_name(col)
+    return jcol
 
 
 def _scalaMethod(name):
     """ Translate operators into methodName in Scala
 
-    For example:
     >>> _scalaMethod('+')
     '$plus'
     >>> _scalaMethod('>=')
@@ -2371,37 +2429,34 @@ def _scalaMethod(name):
     return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)
 
 
-def _unary_op(name):
+def _unary_op(name, doc="unary operator"):
     """ Create a method for given unary operator """
     def _(self):
-        return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx)
+        jc = getattr(self._jc, _scalaMethod(name))()
+        return Column(jc, self.sql_ctx)
+    _.__doc__ = doc
     return _
 
 
-def _bin_op(name, pass_literal_through=True):
+def _bin_op(name, doc="binary operator"):
     """ Create a method for given binary operator
-
-    Keyword arguments:
-    pass_literal_through -- whether to pass literal value directly through to the JVM.
     """
     def _(self, other):
-        if isinstance(other, Column):
-            jc = other._jc
-        else:
-            if pass_literal_through:
-                jc = other
-            else:
-                jc = _create_column_from_literal(other)
-        return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx)
+        jc = other._jc if isinstance(other, Column) else other
+        njc = getattr(self._jc, _scalaMethod(name))(jc)
+        return Column(njc, self.sql_ctx)
+    _.__doc__ = doc
     return _
 
 
-def _reverse_op(name):
+def _reverse_op(name, doc="binary operator"):
     """ Create a method for binary operator (this object is on right side)
     """
     def _(self, other):
-        return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc),
-                      self._jdf, self.sql_ctx)
+        jother = _create_column_from_literal(other)
+        jc = getattr(jother, _scalaMethod(name))(self._jc)
+        return Column(jc, self.sql_ctx)
+    _.__doc__ = doc
     return _
 
 
@@ -2410,20 +2465,20 @@ class Column(DataFrame):
     """
     A column in a DataFrame.
 
-    `Column` instances can be created by:
-    {{{
-    // 1. Select a column out of a DataFrame
-    df.colName
-    df["colName"]
+    `Column` instances can be created by::
+
+        # 1. Select a column out of a DataFrame
+        df.colName
+        df["colName"]
 
-    // 2. Create from an expression
-    df["colName"] + 1
-    }}}
+        # 2. Create from an expression
+        df.colName + 1
+        1 / df.colName
     """
 
-    def __init__(self, jc, jdf=None, sql_ctx=None):
+    def __init__(self, jc, sql_ctx=None):
         self._jc = jc
-        super(Column, self).__init__(jdf, sql_ctx)
+        super(Column, self).__init__(jc, sql_ctx)
 
     # arithmetic operators
     __neg__ = _unary_op("unary_-")
@@ -2438,8 +2493,6 @@ class Column(DataFrame):
     __rdiv__ = _reverse_op("/")
     __rmod__ = _reverse_op("%")
     __abs__ = _unary_op("abs")
-    abs = _unary_op("abs")
-    sqrt = _unary_op("sqrt")
 
     # logistic operators
     __eq__ = _bin_op("===")
@@ -2448,47 +2501,45 @@ class Column(DataFrame):
     __le__ = _bin_op("<=")
     __ge__ = _bin_op(">=")
     __gt__ = _bin_op(">")
-    # `and`, `or`, `not` cannot be overloaded in Python
-    And = _bin_op('&&')
-    Or = _bin_op('||')
-    Not = _unary_op('unary_!')
-
-    # bitwise operators
-    __and__ = _bin_op("&")
-    __or__ = _bin_op("|")
-    __invert__ = _unary_op("unary_~")
-    __xor__ = _bin_op("^")
-    # __lshift__ = _bin_op("<<")
-    # __rshift__ = _bin_op(">>")
-    __rand__ = _bin_op("&")
-    __ror__ = _bin_op("|")
-    __rxor__ = _bin_op("^")
-    # __rlshift__ = _reverse_op("<<")
-    # __rrshift__ = _reverse_op(">>")
+
+    # `and`, `or`, `not` cannot be overloaded in Python,
+    # so use bitwise operators as boolean operators
+    __and__ = _bin_op('&&')
+    __or__ = _bin_op('||')
+    __invert__ = _unary_op('unary_!')
+    __rand__ = _bin_op("&&")
+    __ror__ = _bin_op("||")
 
     # container operators
     __contains__ = _bin_op("contains")
     __getitem__ = _bin_op("getItem")
-    # __getattr__ = _bin_op("getField")
+    getField = _bin_op("getField", "An expression that gets a field by name in a StructField.")
 
     # string methods
     rlike = _bin_op("rlike")
     like = _bin_op("like")
     startswith = _bin_op("startsWith")
     endswith = _bin_op("endsWith")
-    upper = _unary_op("upper")
-    lower = _unary_op("lower")
 
-    def substr(self, startPos, pos):
-        if type(startPos) != type(pos):
+    def substr(self, startPos, length):
+        """
+        Return a Column which is a substring of the column
+
+        :param startPos: start position (int or Column)
+        :param length:  length of the substring (int or Column)
+
+        >>> df.name.substr(1, 3).collect()
+        [Row(col=u'Ali'), Row(col=u'Bob')]
+        """
+        if type(startPos) != type(length):
             raise TypeError("Can not mix the type")
         if isinstance(startPos, (int, long)):
-            jc = self._jc.substr(startPos, pos)
+            jc = self._jc.substr(startPos, length)
         elif isinstance(startPos, Column):
-            jc = self._jc.substr(startPos._jc, pos._jc)
+            jc = self._jc.substr(startPos._jc, length._jc)
         else:
             raise TypeError("Unexpected type: %s" % type(startPos))
-        return Column(jc, self._jdf, self.sql_ctx)
+        return Column(jc, self.sql_ctx)
 
     __getslice__ = substr
 
@@ -2496,55 +2547,89 @@ class Column(DataFrame):
     asc = _unary_op("asc")
     desc = _unary_op("desc")
 
-    isNull = _unary_op("isNull")
-    isNotNull = _unary_op("isNotNull")
+    isNull = _unary_op("isNull", "True if the current expression is null.")
+    isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
 
     # `as` is keyword
     def alias(self, alias):
-        return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx)
+        """Return a alias for this column
+
+        >>> df.age.As("age2").collect()
+        [Row(age2=2), Row(age2=5)]
+        >>> df.age.alias("age2").collect()
+        [Row(age2=2), Row(age2=5)]
+        """
+        return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
+    As = alias
 
     def cast(self, dataType):
+        """ Convert the column into type `dataType`
+
+        >>> df.select(df.age.cast("string").As('ages')).collect()
+        [Row(ages=u'2'), Row(ages=u'5')]
+        >>> df.select(df.age.cast(StringType()).As('ages')).collect()
+        [Row(ages=u'2'), Row(ages=u'5')]
+        """
         if self.sql_ctx is None:
             sc = SparkContext._active_spark_context
             ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
         else:
             ssql_ctx = self.sql_ctx._ssql_ctx
-        jdt = ssql_ctx.parseDataType(dataType.json())
-        return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx)
+        if isinstance(dataType, basestring):
+            jc = self._jc.cast(dataType)
+        elif isinstance(dataType, DataType):
+            jdt = ssql_ctx.parseDataType(dataType.json())
+            jc = self._jc.cast(jdt)
+        return Column(jc, self.sql_ctx)
 
 
-def _to_java_column(col):
-    if isinstance(col, Column):
-        jcol = col._jc
-    else:
-        jcol = _create_column_from_name(col)
-    return jcol
-
-
-def _aggregate_func(name):
+def _aggregate_func(name, doc=""):
     """ Create a function for aggregator by name"""
     def _(col):
         sc = SparkContext._active_spark_context
         jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
         return Column(jc)
-
+    _.__name__ = name
+    _.__doc__ = doc
     return staticmethod(_)
 
 
-class Aggregator(object):
+class Dsl(object):
     """
     A collections of builtin aggregators
     """
-    AGGS = [
-        'lit', 'col', 'column', 'upper', 'lower', 'sqrt', 'abs',
-        'min', 'max', 'first', 'last', 'count', 'avg', 'mean', 'sum', 'sumDistinct',
-    ]
-    for _name in AGGS:
-        locals()[_name] = _aggregate_func(_name)
-    del _name
+    DSLS = {
+        'lit': 'Creates a :class:`Column` of literal value.',
+        'col': 'Returns a :class:`Column` based on the given column name.',
+        'column': 'Returns a :class:`Column` based on the given column name.',
+        'upper': 'Converts a string expression to upper case.',
+        'lower': 'Converts a string expression to upper case.',
+        'sqrt': 'Computes the square root of the specified float value.',
+        'abs': 'Computes the absolutle value.',
+
+        'max': 'Aggregate function: returns the maximum value of the expression in a group.',
+        'min': 'Aggregate function: returns the minimum value of the expression in a group.',
+        'first': 'Aggregate function: returns the first value in a group.',
+        'last': 'Aggregate function: returns the last value in a group.',
+        'count': 'Aggregate function: returns the number of items in a group.',
+        'sum': 'Aggregate function: returns the sum of all values in the expression.',
+        'avg': 'Aggregate function: returns the average of the values in a group.',
+        'mean': 'Aggregate function: returns the average of the values in a group.',
+        'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
+    }
+
+    for _name, _doc in DSLS.items():
+        locals()[_name] = _aggregate_func(_name, _doc)
+    del _name, _doc
 
     @staticmethod
     def countDistinct(col, *cols):
+        """ Return a new Column for distinct count of (col, *cols)
+
+        >>> from pyspark.sql import Dsl
+        >>> df.agg(Dsl.countDistinct(df.age, df.name).As('c')).collect()
+        [Row(c=2)]
+        """
         sc = SparkContext._active_spark_context
         jcols = ListConverter().convert([_to_java_column(c) for c in cols],
                                         sc._gateway._gateway_client)
@@ -2554,6 +2639,12 @@ class Aggregator(object):
 
     @staticmethod
     def approxCountDistinct(col, rsd=None):
+        """ Return a new Column for approxiate distinct count of (col, *cols)
+
+        >>> from pyspark.sql import Dsl
+        >>> df.agg(Dsl.approxCountDistinct(df.age).As('c')).collect()
+        [Row(c=2)]
+        """
         sc = SparkContext._active_spark_context
         if rsd is None:
             jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
@@ -2568,16 +2659,20 @@ def _test():
     # let doctest run in pyspark.sql, so DataTypes can be picklable
     import pyspark.sql
     from pyspark.sql import Row, SQLContext
-    from pyspark.tests import ExamplePoint, ExamplePointUDT
+    from pyspark.sql_tests import ExamplePoint, ExamplePointUDT
     globs = pyspark.sql.__dict__.copy()
     sc = SparkContext('local[4]', 'PythonTest')
     globs['sc'] = sc
-    globs['sqlCtx'] = SQLContext(sc)
+    globs['sqlCtx'] = sqlCtx = SQLContext(sc)
     globs['rdd'] = sc.parallelize(
         [Row(field1=1, field2="row1"),
          Row(field1=2, field2="row2"),
          Row(field1=3, field2="row3")]
     )
+    rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)])
+    rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)])
+    globs['df'] = sqlCtx.inferSchema(rdd2)
+    globs['df2'] = sqlCtx.inferSchema(rdd3)
     globs['ExamplePoint'] = ExamplePoint
     globs['ExamplePointUDT'] = ExamplePointUDT
     jsonStrings = [

http://git-wip-us.apache.org/repos/asf/spark/blob/068c0e2e/python/pyspark/sql_tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql_tests.py b/python/pyspark/sql_tests.py
new file mode 100644
index 0000000..d314f46
--- /dev/null
+++ b/python/pyspark/sql_tests.py
@@ -0,0 +1,299 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Unit tests for pyspark.sql; additional tests are implemented as doctests in
+individual modules.
+"""
+import os
+import sys
+import pydoc
+import shutil
+import tempfile
+
+if sys.version_info[:2] <= (2, 6):
+    try:
+        import unittest2 as unittest
+    except ImportError:
+        sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
+        sys.exit(1)
+else:
+    import unittest
+
+from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
+    UserDefinedType, DoubleType
+from pyspark.tests import ReusedPySparkTestCase
+
+
+class ExamplePointUDT(UserDefinedType):
+    """
+    User-defined type (UDT) for ExamplePoint.
+    """
+
+    @classmethod
+    def sqlType(self):
+        return ArrayType(DoubleType(), False)
+
+    @classmethod
+    def module(cls):
+        return 'pyspark.tests'
+
+    @classmethod
+    def scalaUDT(cls):
+        return 'org.apache.spark.sql.test.ExamplePointUDT'
+
+    def serialize(self, obj):
+        return [obj.x, obj.y]
+
+    def deserialize(self, datum):
+        return ExamplePoint(datum[0], datum[1])
+
+
+class ExamplePoint:
+    """
+    An example class to demonstrate UDT in Scala, Java, and Python.
+    """
+
+    __UDT__ = ExamplePointUDT()
+
+    def __init__(self, x, y):
+        self.x = x
+        self.y = y
+
+    def __repr__(self):
+        return "ExamplePoint(%s,%s)" % (self.x, self.y)
+
+    def __str__(self):
+        return "(%s,%s)" % (self.x, self.y)
+
+    def __eq__(self, other):
+        return isinstance(other, ExamplePoint) and \
+            other.x == self.x and other.y == self.y
+
+
+class SQLTests(ReusedPySparkTestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        ReusedPySparkTestCase.setUpClass()
+        cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(cls.tempdir.name)
+        cls.sqlCtx = SQLContext(cls.sc)
+        cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+        rdd = cls.sc.parallelize(cls.testData)
+        cls.df = cls.sqlCtx.inferSchema(rdd)
+
+    @classmethod
+    def tearDownClass(cls):
+        ReusedPySparkTestCase.tearDownClass()
+        shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+
+    def test_udf(self):
+        self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
+        [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
+        self.assertEqual(row[0], 5)
+
+    def test_udf2(self):
+        self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
+        self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
+        [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
+        self.assertEqual(4, res[0])
+
+    def test_udf_with_array_type(self):
+        d = [Row(l=range(3), d={"key": range(5)})]
+        rdd = self.sc.parallelize(d)
+        self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+        self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
+        self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
+        [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
+        self.assertEqual(range(3), l1)
+        self.assertEqual(1, l2)
+
+    def test_broadcast_in_udf(self):
+        bar = {"a": "aa", "b": "bb", "c": "abc"}
+        foo = self.sc.broadcast(bar)
+        self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
+        [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
+        self.assertEqual("abc", res[0])
+        [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
+        self.assertEqual("", res[0])
+
+    def test_basic_functions(self):
+        rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
+        df = self.sqlCtx.jsonRDD(rdd)
+        df.count()
+        df.collect()
+        df.schema()
+
+        # cache and checkpoint
+        self.assertFalse(df.is_cached)
+        df.persist()
+        df.unpersist()
+        df.cache()
+        self.assertTrue(df.is_cached)
+        self.assertEqual(2, df.count())
+
+        df.registerTempTable("temp")
+        df = self.sqlCtx.sql("select foo from temp")
+        df.count()
+        df.collect()
+
+    def test_apply_schema_to_row(self):
+        df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
+        df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
+        self.assertEqual(df.collect(), df2.collect())
+
+        rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
+        df3 = self.sqlCtx.applySchema(rdd, df.schema())
+        self.assertEqual(10, df3.count())
+
+    def test_serialize_nested_array_and_map(self):
+        d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
+        rdd = self.sc.parallelize(d)
+        df = self.sqlCtx.inferSchema(rdd)
+        row = df.head()
+        self.assertEqual(1, len(row.l))
+        self.assertEqual(1, row.l[0].a)
+        self.assertEqual("2", row.d["key"].d)
+
+        l = df.map(lambda x: x.l).first()
+        self.assertEqual(1, len(l))
+        self.assertEqual('s', l[0].b)
+
+        d = df.map(lambda x: x.d).first()
+        self.assertEqual(1, len(d))
+        self.assertEqual(1.0, d["key"].c)
+
+        row = df.map(lambda x: x.d["key"]).first()
+        self.assertEqual(1.0, row.c)
+        self.assertEqual("2", row.d)
+
+    def test_infer_schema(self):
+        d = [Row(l=[], d={}),
+             Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
+        rdd = self.sc.parallelize(d)
+        df = self.sqlCtx.inferSchema(rdd)
+        self.assertEqual([], df.map(lambda r: r.l).first())
+        self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
+        df.registerTempTable("test")
+        result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
+        self.assertEqual(1, result.head()[0])
+
+        df2 = self.sqlCtx.inferSchema(rdd, 1.0)
+        self.assertEqual(df.schema(), df2.schema())
+        self.assertEqual({}, df2.map(lambda r: r.d).first())
+        self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
+        df2.registerTempTable("test2")
+        result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
+        self.assertEqual(1, result.head()[0])
+
+    def test_struct_in_map(self):
+        d = [Row(m={Row(i=1): Row(s="")})]
+        rdd = self.sc.parallelize(d)
+        df = self.sqlCtx.inferSchema(rdd)
+        k, v = df.head().m.items()[0]
+        self.assertEqual(1, k.i)
+        self.assertEqual("", v.s)
+
+    def test_convert_row_to_dict(self):
+        row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
+        self.assertEqual(1, row.asDict()['l'][0].a)
+        rdd = self.sc.parallelize([row])
+        df = self.sqlCtx.inferSchema(rdd)
+        df.registerTempTable("test")
+        row = self.sqlCtx.sql("select l, d from test").head()
+        self.assertEqual(1, row.asDict()["l"][0].a)
+        self.assertEqual(1.0, row.asDict()['d']['key'].c)
+
+    def test_infer_schema_with_udt(self):
+        from pyspark.sql_tests import ExamplePoint, ExamplePointUDT
+        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+        rdd = self.sc.parallelize([row])
+        df = self.sqlCtx.inferSchema(rdd)
+        schema = df.schema()
+        field = [f for f in schema.fields if f.name == "point"][0]
+        self.assertEqual(type(field.dataType), ExamplePointUDT)
+        df.registerTempTable("labeled_point")
+        point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
+        self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+    def test_apply_schema_with_udt(self):
+        from pyspark.sql_tests import ExamplePoint, ExamplePointUDT
+        row = (1.0, ExamplePoint(1.0, 2.0))
+        rdd = self.sc.parallelize([row])
+        schema = StructType([StructField("label", DoubleType(), False),
+                             StructField("point", ExamplePointUDT(), False)])
+        df = self.sqlCtx.applySchema(rdd, schema)
+        point = df.head().point
+        self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+    def test_parquet_with_udt(self):
+        from pyspark.sql_tests import ExamplePoint
+        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+        rdd = self.sc.parallelize([row])
+        df0 = self.sqlCtx.inferSchema(rdd)
+        output_dir = os.path.join(self.tempdir.name, "labeled_point")
+        df0.saveAsParquetFile(output_dir)
+        df1 = self.sqlCtx.parquetFile(output_dir)
+        point = df1.head().point
+        self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+    def test_column_operators(self):
+        from pyspark.sql import Column, LongType
+        ci = self.df.key
+        cs = self.df.value
+        c = ci == cs
+        self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
+        rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
+        self.assertTrue(all(isinstance(c, Column) for c in rcc))
+        cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
+        self.assertTrue(all(isinstance(c, Column) for c in cb))
+        cbool = (ci & ci), (ci | ci), (~ci)
+        self.assertTrue(all(isinstance(c, Column) for c in cbool))
+        css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
+        self.assertTrue(all(isinstance(c, Column) for c in css))
+        self.assertTrue(isinstance(ci.cast(LongType()), Column))
+
+    def test_column_select(self):
+        df = self.df
+        self.assertEqual(self.testData, df.select("*").collect())
+        self.assertEqual(self.testData, df.select(df.key, df.value).collect())
+        self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
+
+    def test_aggregator(self):
+        df = self.df
+        g = df.groupBy()
+        self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
+        self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
+
+        from pyspark.sql import Dsl
+        self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first()))
+        self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
+        self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])
+
+    def test_help_command(self):
+        # Regression test for SPARK-5464
+        rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
+        df = self.sqlCtx.jsonRDD(rdd)
+        # render_doc() reproduces the help() exception without printing output
+        pydoc.render_doc(df)
+        pydoc.render_doc(df.foo)
+        pydoc.render_doc(df.take(1))
+
+
+if __name__ == "__main__":
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/spark/blob/068c0e2e/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index c7d0622..b5e28c4 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -23,7 +23,6 @@ from array import array
 from fileinput import input
 from glob import glob
 import os
-import pydoc
 import re
 import shutil
 import subprocess
@@ -52,8 +51,6 @@ from pyspark.files import SparkFiles
 from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
     CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer
 from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
-from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
-    UserDefinedType, DoubleType
 from pyspark import shuffle
 from pyspark.profiler import BasicProfiler
 
@@ -795,264 +792,6 @@ class ProfilerTests(PySparkTestCase):
         rdd.foreach(heavy_foo)
 
 
-class ExamplePointUDT(UserDefinedType):
-    """
-    User-defined type (UDT) for ExamplePoint.
-    """
-
-    @classmethod
-    def sqlType(self):
-        return ArrayType(DoubleType(), False)
-
-    @classmethod
-    def module(cls):
-        return 'pyspark.tests'
-
-    @classmethod
-    def scalaUDT(cls):
-        return 'org.apache.spark.sql.test.ExamplePointUDT'
-
-    def serialize(self, obj):
-        return [obj.x, obj.y]
-
-    def deserialize(self, datum):
-        return ExamplePoint(datum[0], datum[1])
-
-
-class ExamplePoint:
-    """
-    An example class to demonstrate UDT in Scala, Java, and Python.
-    """
-
-    __UDT__ = ExamplePointUDT()
-
-    def __init__(self, x, y):
-        self.x = x
-        self.y = y
-
-    def __repr__(self):
-        return "ExamplePoint(%s,%s)" % (self.x, self.y)
-
-    def __str__(self):
-        return "(%s,%s)" % (self.x, self.y)
-
-    def __eq__(self, other):
-        return isinstance(other, ExamplePoint) and \
-            other.x == self.x and other.y == self.y
-
-
-class SQLTests(ReusedPySparkTestCase):
-
-    @classmethod
-    def setUpClass(cls):
-        ReusedPySparkTestCase.setUpClass()
-        cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
-        os.unlink(cls.tempdir.name)
-
-    @classmethod
-    def tearDownClass(cls):
-        ReusedPySparkTestCase.tearDownClass()
-        shutil.rmtree(cls.tempdir.name, ignore_errors=True)
-
-    def setUp(self):
-        self.sqlCtx = SQLContext(self.sc)
-        self.testData = [Row(key=i, value=str(i)) for i in range(100)]
-        rdd = self.sc.parallelize(self.testData)
-        self.df = self.sqlCtx.inferSchema(rdd)
-
-    def test_udf(self):
-        self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
-        [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
-        self.assertEqual(row[0], 5)
-
-    def test_udf2(self):
-        self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
-        self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
-        [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
-        self.assertEqual(4, res[0])
-
-    def test_udf_with_array_type(self):
-        d = [Row(l=range(3), d={"key": range(5)})]
-        rdd = self.sc.parallelize(d)
-        self.sqlCtx.inferSchema(rdd).registerTempTable("test")
-        self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
-        self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
-        [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
-        self.assertEqual(range(3), l1)
-        self.assertEqual(1, l2)
-
-    def test_broadcast_in_udf(self):
-        bar = {"a": "aa", "b": "bb", "c": "abc"}
-        foo = self.sc.broadcast(bar)
-        self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
-        [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
-        self.assertEqual("abc", res[0])
-        [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
-        self.assertEqual("", res[0])
-
-    def test_basic_functions(self):
-        rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
-        df = self.sqlCtx.jsonRDD(rdd)
-        df.count()
-        df.collect()
-        df.schema()
-
-        # cache and checkpoint
-        self.assertFalse(df.is_cached)
-        df.persist()
-        df.unpersist()
-        df.cache()
-        self.assertTrue(df.is_cached)
-        self.assertEqual(2, df.count())
-
-        df.registerTempTable("temp")
-        df = self.sqlCtx.sql("select foo from temp")
-        df.count()
-        df.collect()
-
-    def test_apply_schema_to_row(self):
-        df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
-        df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
-        self.assertEqual(df.collect(), df2.collect())
-
-        rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
-        df3 = self.sqlCtx.applySchema(rdd, df.schema())
-        self.assertEqual(10, df3.count())
-
-    def test_serialize_nested_array_and_map(self):
-        d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
-        rdd = self.sc.parallelize(d)
-        df = self.sqlCtx.inferSchema(rdd)
-        row = df.head()
-        self.assertEqual(1, len(row.l))
-        self.assertEqual(1, row.l[0].a)
-        self.assertEqual("2", row.d["key"].d)
-
-        l = df.map(lambda x: x.l).first()
-        self.assertEqual(1, len(l))
-        self.assertEqual('s', l[0].b)
-
-        d = df.map(lambda x: x.d).first()
-        self.assertEqual(1, len(d))
-        self.assertEqual(1.0, d["key"].c)
-
-        row = df.map(lambda x: x.d["key"]).first()
-        self.assertEqual(1.0, row.c)
-        self.assertEqual("2", row.d)
-
-    def test_infer_schema(self):
-        d = [Row(l=[], d={}),
-             Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
-        rdd = self.sc.parallelize(d)
-        df = self.sqlCtx.inferSchema(rdd)
-        self.assertEqual([], df.map(lambda r: r.l).first())
-        self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
-        df.registerTempTable("test")
-        result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
-        self.assertEqual(1, result.head()[0])
-
-        df2 = self.sqlCtx.inferSchema(rdd, 1.0)
-        self.assertEqual(df.schema(), df2.schema())
-        self.assertEqual({}, df2.map(lambda r: r.d).first())
-        self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
-        df2.registerTempTable("test2")
-        result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
-        self.assertEqual(1, result.head()[0])
-
-    def test_struct_in_map(self):
-        d = [Row(m={Row(i=1): Row(s="")})]
-        rdd = self.sc.parallelize(d)
-        df = self.sqlCtx.inferSchema(rdd)
-        k, v = df.head().m.items()[0]
-        self.assertEqual(1, k.i)
-        self.assertEqual("", v.s)
-
-    def test_convert_row_to_dict(self):
-        row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
-        self.assertEqual(1, row.asDict()['l'][0].a)
-        rdd = self.sc.parallelize([row])
-        df = self.sqlCtx.inferSchema(rdd)
-        df.registerTempTable("test")
-        row = self.sqlCtx.sql("select l, d from test").head()
-        self.assertEqual(1, row.asDict()["l"][0].a)
-        self.assertEqual(1.0, row.asDict()['d']['key'].c)
-
-    def test_infer_schema_with_udt(self):
-        from pyspark.tests import ExamplePoint, ExamplePointUDT
-        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
-        rdd = self.sc.parallelize([row])
-        df = self.sqlCtx.inferSchema(rdd)
-        schema = df.schema()
-        field = [f for f in schema.fields if f.name == "point"][0]
-        self.assertEqual(type(field.dataType), ExamplePointUDT)
-        df.registerTempTable("labeled_point")
-        point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
-        self.assertEqual(point, ExamplePoint(1.0, 2.0))
-
-    def test_apply_schema_with_udt(self):
-        from pyspark.tests import ExamplePoint, ExamplePointUDT
-        row = (1.0, ExamplePoint(1.0, 2.0))
-        rdd = self.sc.parallelize([row])
-        schema = StructType([StructField("label", DoubleType(), False),
-                             StructField("point", ExamplePointUDT(), False)])
-        df = self.sqlCtx.applySchema(rdd, schema)
-        point = df.head().point
-        self.assertEquals(point, ExamplePoint(1.0, 2.0))
-
-    def test_parquet_with_udt(self):
-        from pyspark.tests import ExamplePoint
-        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
-        rdd = self.sc.parallelize([row])
-        df0 = self.sqlCtx.inferSchema(rdd)
-        output_dir = os.path.join(self.tempdir.name, "labeled_point")
-        df0.saveAsParquetFile(output_dir)
-        df1 = self.sqlCtx.parquetFile(output_dir)
-        point = df1.head().point
-        self.assertEquals(point, ExamplePoint(1.0, 2.0))
-
-    def test_column_operators(self):
-        from pyspark.sql import Column, LongType
-        ci = self.df.key
-        cs = self.df.value
-        c = ci == cs
-        self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
-        rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
-        self.assertTrue(all(isinstance(c, Column) for c in rcc))
-        cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
-        self.assertTrue(all(isinstance(c, Column) for c in cb))
-        cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci)
-        self.assertTrue(all(isinstance(c, Column) for c in cbit))
-        css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
-        self.assertTrue(all(isinstance(c, Column) for c in css))
-        self.assertTrue(isinstance(ci.cast(LongType()), Column))
-
-    def test_column_select(self):
-        df = self.df
-        self.assertEqual(self.testData, df.select("*").collect())
-        self.assertEqual(self.testData, df.select(df.key, df.value).collect())
-        self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
-
-    def test_aggregator(self):
-        df = self.df
-        g = df.groupBy()
-        self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
-        self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
-
-        from pyspark.sql import Aggregator as Agg
-        self.assertEqual((0, u'99'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
-        self.assertTrue(95 < g.agg(Agg.approxCountDistinct(df.key)).first()[0])
-        self.assertEqual(100, g.agg(Agg.countDistinct(df.value)).first()[0])
-
-    def test_help_command(self):
-        # Regression test for SPARK-5464
-        rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
-        df = self.sqlCtx.jsonRDD(rdd)
-        # render_doc() reproduces the help() exception without printing output
-        pydoc.render_doc(df)
-        pydoc.render_doc(df.foo)
-        pydoc.render_doc(df.take(1))
-
-
 class InputFormatTests(ReusedPySparkTestCase):
 
     @classmethod

http://git-wip-us.apache.org/repos/asf/spark/blob/068c0e2e/python/run-tests
----------------------------------------------------------------------
diff --git a/python/run-tests b/python/run-tests
index e91f1a8..649a2c4 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -65,6 +65,7 @@ function run_core_tests() {
 function run_sql_tests() {
     echo "Run sql tests ..."
     run_test "pyspark/sql.py"
+    run_test "pyspark/sql_tests.py"
 }
 
 function run_mllib_tests() {

http://git-wip-us.apache.org/repos/asf/spark/blob/068c0e2e/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 4aa3721..ddce77d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -116,15 +116,6 @@ trait Column extends DataFrame {
   def unary_- : Column = exprToColumn(UnaryMinus(expr))
 
   /**
-   * Bitwise NOT.
-   * {{{
-   *   // Scala: select the flags column and negate every bit.
-   *   df.select( ~df("flags") )
-   * }}}
-   */
-  def unary_~ : Column = exprToColumn(BitwiseNot(expr))
-
-  /**
    * Inversion of boolean expression, i.e. NOT.
    * {{
    *   // Scala: select rows that are not active (isActive === false)
@@ -363,27 +354,6 @@ trait Column extends DataFrame {
   def and(other: Column): Column = this && other
 
   /**
-   * Bitwise AND.
-   */
-  def & (other: Any): Column = constructColumn(other) { o =>
-    BitwiseAnd(expr, o.expr)
-  }
-
-  /**
-   * Bitwise OR with an expression.
-   */
-  def | (other: Any): Column = constructColumn(other) { o =>
-    BitwiseOr(expr, o.expr)
-  }
-
-  /**
-   * Bitwise XOR with an expression.
-   */
-  def ^ (other: Any): Column = constructColumn(other) { o =>
-    BitwiseXor(expr, o.expr)
-  }
-
-  /**
    * Sum of this expression and another expression.
    * {{{
    *   // Scala: The following selects the sum of a person's height and weight.
@@ -527,16 +497,16 @@ trait Column extends DataFrame {
    * @param startPos expression for the starting position.
    * @param len expression for the length of the substring.
    */
-  def substr(startPos: Column, len: Column): Column = {
-    new IncomputableColumn(Substring(expr, startPos.expr, len.expr))
-  }
+  def substr(startPos: Column, len: Column): Column =
+    exprToColumn(Substring(expr, startPos.expr, len.expr), computable = false)
 
   /**
    * An expression that returns a substring.
    * @param startPos starting position.
    * @param len length of the substring.
    */
-  def substr(startPos: Int, len: Int): Column = this.substr(lit(startPos), lit(len))
+  def substr(startPos: Int, len: Int): Column =
+    exprToColumn(Substring(expr, lit(startPos).expr, lit(len).expr))
 
   def contains(other: Any): Column = constructColumn(other) { o =>
     Contains(expr, o.expr)

http://git-wip-us.apache.org/repos/asf/spark/blob/068c0e2e/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index 006b16f..e6f622e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -37,7 +37,7 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
 
   override def sqlType: DataType = ArrayType(DoubleType, false)
 
-  override def pyUDT: String = "pyspark.tests.ExamplePointUDT"
+  override def pyUDT: String = "pyspark.sql_tests.ExamplePointUDT"
 
   override def serialize(obj: Any): Seq[Double] = {
     obj match {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message