spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From hol...@apache.org
Subject spark git commit: [SPARK-19454][PYTHON][SQL] DataFrame.replace improvements
Date Wed, 05 Apr 2017 18:47:54 GMT
Repository: spark
Updated Branches:
  refs/heads/master a2d8d767d -> e2773996b


[SPARK-19454][PYTHON][SQL] DataFrame.replace improvements

## What changes were proposed in this pull request?

- Allows skipping `value` argument if `to_replace` is a `dict`:
	```python
	df = sc.parallelize([("Alice", 1, 3.0)]).toDF()
	df.replace({"Alice": "Bob"}).show()
	````
- Adds validation step to ensure homogeneous values / replacements.
- Simplifies internal control flow.
- Improves unit tests coverage.

## How was this patch tested?

Existing unit tests, additional unit tests, manual testing.

Author: zero323 <zero323@users.noreply.github.com>

Closes #16793 from zero323/SPARK-19454.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e2773996
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e2773996
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e2773996

Branch: refs/heads/master
Commit: e2773996b8d1c0214d9ffac634a059b4923caf7b
Parents: a2d8d76
Author: zero323 <zero323@users.noreply.github.com>
Authored: Wed Apr 5 11:47:40 2017 -0700
Committer: Holden Karau <holden@us.ibm.com>
Committed: Wed Apr 5 11:47:40 2017 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py | 81 +++++++++++++++++++++++++-----------
 python/pyspark/sql/tests.py     | 72 ++++++++++++++++++++++++++++++++
 2 files changed, 128 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e2773996/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index a24512f..774caf5 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -25,6 +25,8 @@ if sys.version >= '3':
 else:
     from itertools import imap as map
 
+import warnings
+
 from pyspark import copy_func, since
 from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
 from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
@@ -1281,7 +1283,7 @@ class DataFrame(object):
             return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
 
     @since(1.4)
-    def replace(self, to_replace, value, subset=None):
+    def replace(self, to_replace, value=None, subset=None):
         """Returns a new :class:`DataFrame` replacing a value with another value.
         :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
         aliases of each other.
@@ -1326,43 +1328,72 @@ class DataFrame(object):
         |null|  null|null|
         +----+------+----+
         """
-        if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)):
+        # Helper functions
+        def all_of(types):
+            """Given a type or tuple of types and a sequence of xs
+            check if each x is instance of type(s)
+
+            >>> all_of(bool)([True, False])
+            True
+            >>> all_of(basestring)(["a", 1])
+            False
+            """
+            def all_of_(xs):
+                return all(isinstance(x, types) for x in xs)
+            return all_of_
+
+        all_of_bool = all_of(bool)
+        all_of_str = all_of(basestring)
+        all_of_numeric = all_of((float, int, long))
+
+        # Validate input types
+        valid_types = (bool, float, int, long, basestring, list, tuple)
+        if not isinstance(to_replace, valid_types + (dict, )):
             raise ValueError(
-                "to_replace should be a float, int, long, string, list, tuple, or dict")
+                "to_replace should be a float, int, long, string, list, tuple, or dict. "
+                "Got {0}".format(type(to_replace)))
 
-        if not isinstance(value, (float, int, long, basestring, list, tuple)):
-            raise ValueError("value should be a float, int, long, string, list, or tuple")
+        if not isinstance(value, valid_types) and not isinstance(to_replace, dict):
+            raise ValueError("If to_replace is not a dict, value should be "
+                             "a float, int, long, string, list, or tuple. "
+                             "Got {0}".format(type(value)))
+
+        if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)):
+            if len(to_replace) != len(value):
+                raise ValueError("to_replace and value lists should be of the same length.
"
+                                 "Got {0} and {1}".format(len(to_replace), len(value)))
 
-        rep_dict = dict()
+        if not (subset is None or isinstance(subset, (list, tuple, basestring))):
+            raise ValueError("subset should be a list or tuple of column names, "
+                             "column name or None. Got {0}".format(type(subset)))
 
+        # Reshape input arguments if necessary
         if isinstance(to_replace, (float, int, long, basestring)):
             to_replace = [to_replace]
 
-        if isinstance(to_replace, tuple):
-            to_replace = list(to_replace)
+        if isinstance(value, (float, int, long, basestring)):
+            value = [value for _ in range(len(to_replace))]
 
-        if isinstance(value, tuple):
-            value = list(value)
-
-        if isinstance(to_replace, list) and isinstance(value, list):
-            if len(to_replace) != len(value):
-                raise ValueError("to_replace and value lists should be of the same length")
-            rep_dict = dict(zip(to_replace, value))
-        elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)):
-            rep_dict = dict([(tr, value) for tr in to_replace])
-        elif isinstance(to_replace, dict):
+        if isinstance(to_replace, dict):
             rep_dict = to_replace
+            if value is not None:
+                warnings.warn("to_replace is a dict and value is not None. value will be
ignored.")
+        else:
+            rep_dict = dict(zip(to_replace, value))
 
-        if subset is None:
-            return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx)
-        elif isinstance(subset, basestring):
+        if isinstance(subset, basestring):
             subset = [subset]
 
-        if not isinstance(subset, (list, tuple)):
-            raise ValueError("subset should be a list or tuple of column names")
+        # Verify we were not passed in mixed type generics."
+        if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values())
+                   for all_of_type in [all_of_bool, all_of_str, all_of_numeric]):
+            raise ValueError("Mixed type replacements are not supported")
 
-        return DataFrame(
-            self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx)
+        if subset is None:
+            return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx)
+        else:
+            return DataFrame(
+                self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx)
 
     @since(2.0)
     def approxQuantile(self, col, probabilities, relativeError):

http://git-wip-us.apache.org/repos/asf/spark/blob/e2773996/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index db41b4e..2b24443 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1779,6 +1779,78 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(row.age, 10)
         self.assertEqual(row.height, None)
 
+        # replace with lists
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first()
+        self.assertTupleEqual(row, (u'Ann', 10, 80.1))
+
+        # replace with dict
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first()
+        self.assertTupleEqual(row, (u'Alice', 11, 80.1))
+
+        # test backward compatibility with dummy value
+        dummy_value = 1
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first()
+        self.assertTupleEqual(row, (u'Bob', 10, 80.1))
+
+        # test dict with mixed numerics
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first()
+        self.assertTupleEqual(row, (u'Alice', -10, 90.5))
+
+        # replace with tuples
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first()
+        self.assertTupleEqual(row, (u'Bob', 10, 80.1))
+
+        # replace multiple columns
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first()
+        self.assertTupleEqual(row, (u'Alice', 20, 90.0))
+
+        # test for mixed numerics
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first()
+        self.assertTupleEqual(row, (u'Alice', 20, 90.5))
+
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first()
+        self.assertTupleEqual(row, (u'Alice', 20, 90.5))
+
+        # replace with boolean
+        row = (self
+               .spark.createDataFrame([(u'Alice', 10, 80.0)], schema)
+               .selectExpr("name = 'Bob'", 'age <= 15')
+               .replace(False, True).first())
+        self.assertTupleEqual(row, (True, True))
+
+        # should fail if subset is not list, tuple or None
+        with self.assertRaises(ValueError):
+            self.spark.createDataFrame(
+                [(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first()
+
+        # should fail if to_replace and value have different length
+        with self.assertRaises(ValueError):
+            self.spark.createDataFrame(
+                [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first()
+
+        # should fail if when received unexpected type
+        with self.assertRaises(ValueError):
+            from datetime import datetime
+            self.spark.createDataFrame(
+                [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first()
+
+        # should fail if provided mixed type replacements
+        with self.assertRaises(ValueError):
+            self.spark.createDataFrame(
+                [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first()
+
+        with self.assertRaises(ValueError):
+            self.spark.createDataFrame(
+                [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first()
+
     def test_capture_analysis_exception(self):
         self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc"))
         self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message