spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From joshro...@apache.org
Subject spark git commit: [SPARK-6055] [PySpark] fix incorrect DataType.__eq__ (for 1.2)
Date Sat, 28 Feb 2015 04:04:32 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.2 17b7cc733 -> 576fc54e5


[SPARK-6055] [PySpark] fix incorrect DataType.__eq__ (for 1.2)

The eq of DataType is not correct, class cache is not use correctly (created class can not
be find by dataType), then it will create lots of classes (saved in _cached_cls), never released.

Also, all same DataType have same hash code, there will be many object in a dict with the
same hash code, end with hash attach, it's very slow to access this dict (depends on the implementation
of CPython).

This PR also improve the performance of inferSchema (avoid the unnecessary converter of object).

Author: Davies Liu <davies@databricks.com>

Closes #4809 from davies/leak2 and squashes the following commits:

65c222f [Davies Liu] Update sql.py
9b4dadc [Davies Liu] fix __eq__ of singleton
b576107 [Davies Liu] fix tests
6c2909a [Davies Liu] fix incorrect DataType.__eq__


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/576fc54e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/576fc54e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/576fc54e

Branch: refs/heads/branch-1.2
Commit: 576fc54e5c154fc28af1a732a6bea452d0a5cabb
Parents: 17b7cc7
Author: Davies Liu <davies@databricks.com>
Authored: Fri Feb 27 20:04:16 2015 -0800
Committer: Josh Rosen <joshrosen@databricks.com>
Committed: Fri Feb 27 20:04:16 2015 -0800

----------------------------------------------------------------------
 python/pyspark/sql.py | 67 ++++++++++++++++++++++++++++++----------------
 1 file changed, 44 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/576fc54e/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index aa5af1b..4410925 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -36,6 +36,7 @@ import keyword
 import warnings
 import json
 import re
+import weakref
 from array import array
 from operator import itemgetter
 from itertools import imap
@@ -68,8 +69,7 @@ class DataType(object):
         return hash(str(self))
 
     def __eq__(self, other):
-        return (isinstance(other, self.__class__) and
-                self.__dict__ == other.__dict__)
+        return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
 
     def __ne__(self, other):
         return not self.__eq__(other)
@@ -105,10 +105,6 @@ class PrimitiveType(DataType):
 
     __metaclass__ = PrimitiveTypeSingleton
 
-    def __eq__(self, other):
-        # because they should be the same object
-        return self is other
-
 
 class NullType(PrimitiveType):
 
@@ -251,9 +247,9 @@ class ArrayType(DataType):
         :param elementType: the data type of elements.
         :param containsNull: indicates whether the list contains None values.
 
-        >>> ArrayType(StringType) == ArrayType(StringType, True)
+        >>> ArrayType(StringType()) == ArrayType(StringType(), True)
         True
-        >>> ArrayType(StringType, False) == ArrayType(StringType)
+        >>> ArrayType(StringType(), False) == ArrayType(StringType())
         False
         """
         self.elementType = elementType
@@ -298,11 +294,11 @@ class MapType(DataType):
         :param valueContainsNull: indicates whether values contains
         null values.
 
-        >>> (MapType(StringType, IntegerType)
-        ...        == MapType(StringType, IntegerType, True))
+        >>> (MapType(StringType(), IntegerType())
+        ...        == MapType(StringType(), IntegerType(), True))
         True
-        >>> (MapType(StringType, IntegerType, False)
-        ...        == MapType(StringType, FloatType))
+        >>> (MapType(StringType(), IntegerType(), False)
+        ...        == MapType(StringType(), FloatType()))
         False
         """
         self.keyType = keyType
@@ -351,11 +347,11 @@ class StructField(DataType):
                          to simple type that can be serialized to JSON
                          automatically
 
-        >>> (StructField("f1", StringType, True)
-        ...      == StructField("f1", StringType, True))
+        >>> (StructField("f1", StringType(), True)
+        ...      == StructField("f1", StringType(), True))
         True
-        >>> (StructField("f1", StringType, True)
-        ...      == StructField("f2", StringType, True))
+        >>> (StructField("f1", StringType(), True)
+        ...      == StructField("f2", StringType(), True))
         False
         """
         self.name = name
@@ -393,13 +389,13 @@ class StructType(DataType):
     def __init__(self, fields):
         """Creates a StructType
 
-        >>> struct1 = StructType([StructField("f1", StringType, True)])
-        >>> struct2 = StructType([StructField("f1", StringType, True)])
+        >>> struct1 = StructType([StructField("f1", StringType(), True)])
+        >>> struct2 = StructType([StructField("f1", StringType(), True)])
         >>> struct1 == struct2
         True
-        >>> struct1 = StructType([StructField("f1", StringType, True)])
-        >>> struct2 = StructType([StructField("f1", StringType, True),
-        ...   [StructField("f2", IntegerType, False)]])
+        >>> struct1 = StructType([StructField("f1", StringType(), True)])
+        >>> struct2 = StructType([StructField("f1", StringType(), True),
+        ...                       StructField("f2", IntegerType(), False)])
         >>> struct1 == struct2
         False
         """
@@ -499,6 +495,10 @@ _all_complex_types = dict((v.typeName(), v)
 
 def _parse_datatype_json_string(json_string):
     """Parses the given data type JSON string.
+
+    >>> import pickle
+    >>> LongType() == pickle.loads(pickle.dumps(LongType()))
+    True
     >>> def check_datatype(datatype):
     ...     scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
     ...     python_datatype = _parse_datatype_json_string(scala_datatype.json())
@@ -781,8 +781,25 @@ def _merge_type(a, b):
         return a
 
 
+def _need_converter(dataType):
+    if isinstance(dataType, StructType):
+        return True
+    elif isinstance(dataType, ArrayType):
+        return _need_converter(dataType.elementType)
+    elif isinstance(dataType, MapType):
+        return _need_converter(dataType.keyType) or _need_converter(dataType.valueType)
+    elif isinstance(dataType, NullType):
+        return True
+    else:
+        return False
+
+
 def _create_converter(dataType):
     """Create an converter to drop the names of fields in obj """
+
+    if not _need_converter(dataType):
+        return lambda x: x
+
     if isinstance(dataType, ArrayType):
         conv = _create_converter(dataType.elementType)
         return lambda row: map(conv, row)
@@ -800,6 +817,7 @@ def _create_converter(dataType):
     # dataType must be StructType
     names = [f.name for f in dataType.fields]
     converters = [_create_converter(f.dataType) for f in dataType.fields]
+    convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
 
     def convert_struct(obj):
         if obj is None:
@@ -822,7 +840,10 @@ def _create_converter(dataType):
         else:
             raise ValueError("Unexpected obj: %s" % obj)
 
-        return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+        if convert_fields:
+            return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+        else:
+            return tuple([d.get(name) for name in names])
 
     return convert_struct
 
@@ -1039,7 +1060,7 @@ def _verify_type(obj, dataType):
             _verify_type(v, f.dataType)
 
 
-_cached_cls = {}
+_cached_cls = weakref.WeakValueDictionary()
 
 
 def _restore_object(dataType, obj):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message