spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From gatorsmile <...@git.apache.org>
Subject [GitHub] spark pull request #20171: [SPARK-22978] [PySpark] Register Vectorized UDFs ...
Date Tue, 16 Jan 2018 00:21:21 GMT
Github user gatorsmile commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20171#discussion_r161637979
  
    --- Diff: python/pyspark/sql/catalog.py ---
    @@ -256,27 +258,58 @@ def registerFunction(self, name, f, returnType=StringType()):
             >>> spark.sql("SELECT stringLengthInt('test')").collect()
             [Row(stringLengthInt(test)=4)]
     
    +        >>> from pyspark.sql.types import IntegerType
    +        >>> from pyspark.sql.functions import udf
    +        >>> slen = udf(lambda s: len(s), IntegerType())
    +        >>> _ = spark.udf.register("slen", slen)
    +        >>> spark.sql("SELECT slen('test')").collect()
    +        [Row(slen(test)=4)]
    +
             >>> import random
             >>> from pyspark.sql.functions import udf
    -        >>> from pyspark.sql.types import IntegerType, StringType
    +        >>> from pyspark.sql.types import IntegerType
             >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
    -        >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf,
StringType())
    +        >>> newRandom_udf = spark.udf.register("random_udf", random_udf)
             >>> spark.sql("SELECT random_udf()").collect()  # doctest: +SKIP
    -        [Row(random_udf()=u'82')]
    +        [Row(random_udf()=82)]
             >>> spark.range(1).select(newRandom_udf()).collect()  # doctest: +SKIP
    -        [Row(random_udf()=u'62')]
    +        [Row(<lambda>()=26)]
    +
    +        >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
    +        >>> @pandas_udf("integer", PandasUDFType.SCALAR)  # doctest: +SKIP
    +        ... def add_one(x):
    +        ...     return x + 1
    +        ...
    +        >>> _ = spark.udf.register("add_one", add_one)  # doctest: +SKIP
    +        >>> spark.sql("SELECT add_one(id) FROM range(3)").collect()  # doctest:
+SKIP
    +        [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
             """
     
             # This is to check whether the input function is a wrapped/native UserDefinedFunction
             if hasattr(f, 'asNondeterministic'):
    -            udf = UserDefinedFunction(f.func, returnType=returnType, name=name,
    -                                      evalType=PythonEvalType.SQL_BATCHED_UDF,
    -                                      deterministic=f.deterministic)
    +            if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
    +                                  PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
    +                raise ValueError(
    +                    "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF")
    +            if returnType is not None and not isinstance(returnType, DataType):
    +                returnType = _parse_datatype_string(returnType)
    +            if returnType is not None and returnType != f.returnType:
    --- End diff --
    
    Since the API already has this parameter `returnType`, we should support it if possible.
We need to do our best to avoid issuing the unnecessary exception. 


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Mime
View raw message