spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dav...@apache.org
Subject spark git commit: [SPARK-9157] [SQL] codegen substring
Date Tue, 21 Jul 2015 05:44:29 GMT
Repository: spark
Updated Branches:
  refs/heads/master c032b0bf9 -> 560b355cc


[SPARK-9157] [SQL] codegen substring

https://issues.apache.org/jira/browse/SPARK-9157

Author: Tarek Auel <tarek.auel@googlemail.com>

Closes #7534 from tarekauel/SPARK-9157 and squashes the following commits:

e65e3e9 [Tarek Auel] [SPARK-9157] indent fix
44e89f8 [Tarek Auel] [SPARK-9157] use EMPTY_UTF8
37d54c4 [Tarek Auel] Merge branch 'master' into SPARK-9157
60732ea [Tarek Auel] [SPARK-9157] created substringSQL in UTF8String
18c3576 [Tarek Auel] [SPARK-9157][SQL] remove slice pos
1a2e611 [Tarek Auel] [SPARK-9157][SQL] codegen substring


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/560b355c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/560b355c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/560b355c

Branch: refs/heads/master
Commit: 560b355ccd038ca044726c9c9fcffd14d02e6696
Parents: c032b0b
Author: Tarek Auel <tarek.auel@googlemail.com>
Authored: Mon Jul 20 22:43:30 2015 -0700
Committer: Davies Liu <davies.liu@gmail.com>
Committed: Mon Jul 20 22:44:21 2015 -0700

----------------------------------------------------------------------
 .../catalyst/expressions/stringOperations.scala | 87 ++++++++++----------
 .../apache/spark/unsafe/types/UTF8String.java   | 12 +++
 .../spark/unsafe/types/UTF8StringSuite.java     | 19 +++++
 3 files changed, 75 insertions(+), 43 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/560b355c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 5c1908d..438215e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -640,7 +640,7 @@ case class StringSplit(str: Expression, pattern: Expression)
  * Defined for String and Binary types.
  */
 case class Substring(str: Expression, pos: Expression, len: Expression)
-  extends Expression with ImplicitCastInputTypes with CodegenFallback {
+  extends Expression with ImplicitCastInputTypes {
 
   def this(str: Expression, pos: Expression) = {
     this(str, pos, Literal(Integer.MAX_VALUE))
@@ -649,58 +649,59 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
   override def foldable: Boolean = str.foldable && pos.foldable && len.foldable
   override def nullable: Boolean = str.nullable || pos.nullable || len.nullable
 
-  override def dataType: DataType = {
-    if (!resolved) {
-      throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved")
-    }
-    if (str.dataType == BinaryType) str.dataType else StringType
-  }
+  override def dataType: DataType = StringType
 
   override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
 
   override def children: Seq[Expression] = str :: pos :: len :: Nil
 
-  @inline
-  def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = {
-    // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
-    // negative indices for start positions. If a start index i is greater than 0, it
-    // refers to element i-1 in the sequence. If a start index i is less than 0, it refers
-    // to the -ith element before the end of the sequence. If a start index i is 0, it
-    // refers to the first element.
-
-    val start = startPos match {
-      case pos if pos > 0 => pos - 1
-      case neg if neg < 0 => length() + neg
-      case _ => 0
-    }
-
-    val end = sliceLen match {
-      case max if max == Integer.MAX_VALUE => max
-      case x => start + x
+  override def eval(input: InternalRow): Any = {
+    val stringEval = str.eval(input)
+    if (stringEval != null) {
+      val posEval = pos.eval(input)
+      if (posEval != null) {
+        val lenEval = len.eval(input)
+        if (lenEval != null) {
+          stringEval.asInstanceOf[UTF8String]
+            .substringSQL(posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int])
+        } else {
+          null
+        }
+      } else {
+        null
+      }
+    } else {
+      null
     }
-
-    (start, end)
   }
 
-  override def eval(input: InternalRow): Any = {
-    val string = str.eval(input)
-    val po = pos.eval(input)
-    val ln = len.eval(input)
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val strGen = str.gen(ctx)
+    val posGen = pos.gen(ctx)
+    val lenGen = len.gen(ctx)
 
-    if ((string == null) || (po == null) || (ln == null)) {
-      null
-    } else {
-      val start = po.asInstanceOf[Int]
-      val length = ln.asInstanceOf[Int]
-      string match {
-        case ba: Array[Byte] =>
-          val (st, end) = slicePos(start, length, () => ba.length)
-          ba.slice(st, end)
-        case s: UTF8String =>
-          val (st, end) = slicePos(start, length, () => s.numChars())
-          s.substring(st, end)
+    val start = ctx.freshName("start")
+    val end = ctx.freshName("end")
+
+    s"""
+      ${strGen.code}
+      boolean ${ev.isNull} = ${strGen.isNull};
+      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+      if (!${ev.isNull}) {
+        ${posGen.code}
+        if (!${posGen.isNull}) {
+          ${lenGen.code}
+          if (!${lenGen.isNull}) {
+            ${ev.primitive} = ${strGen.primitive}
+              .substringSQL(${posGen.primitive}, ${lenGen.primitive});
+          } else {
+            ${ev.isNull} = true;
+          }
+        } else {
+          ${ev.isNull} = true;
+        }
       }
-    }
+     """
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/560b355c/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index ed354f7..946d355 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -165,6 +165,18 @@ public final class UTF8String implements Comparable<UTF8String>,
Serializable {
     return fromBytes(bytes);
   }
 
+  public UTF8String substringSQL(int pos, int length) {
+    // Information regarding the pos calculation:
+    // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
+    // negative indices for start positions. If a start index i is greater than 0, it
+    // refers to element i-1 in the sequence. If a start index i is less than 0, it refers
+    // to the -ith element before the end of the sequence. If a start index i is 0, it
+    // refers to the first element.
+    int start = (pos > 0) ? pos -1 : ((pos < 0) ? numChars() + pos : 0);
+    int end = (length == Integer.MAX_VALUE) ? Integer.MAX_VALUE : start + length;
+    return substring(start, end);
+  }
+
   /**
    * Returns whether this contains `substring` or not.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/560b355c/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
----------------------------------------------------------------------
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 1f5572c..e2a5628 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -273,6 +273,25 @@ public class UTF8StringSuite {
   }
 
   @Test
+  public void substringSQL() {
+    UTF8String e = fromString("example");
+    assertEquals(e.substringSQL(0, 2), fromString("ex"));
+    assertEquals(e.substringSQL(1, 2), fromString("ex"));
+    assertEquals(e.substringSQL(0, 7), fromString("example"));
+    assertEquals(e.substringSQL(1, 2), fromString("ex"));
+    assertEquals(e.substringSQL(0, 100), fromString("example"));
+    assertEquals(e.substringSQL(1, 100), fromString("example"));
+    assertEquals(e.substringSQL(2, 2), fromString("xa"));
+    assertEquals(e.substringSQL(1, 6), fromString("exampl"));
+    assertEquals(e.substringSQL(2, 100), fromString("xample"));
+    assertEquals(e.substringSQL(0, 0), fromString(""));
+    assertEquals(e.substringSQL(100, 4), EMPTY_UTF8);
+    assertEquals(e.substringSQL(0, Integer.MAX_VALUE), fromString("example"));
+    assertEquals(e.substringSQL(1, Integer.MAX_VALUE), fromString("example"));
+    assertEquals(e.substringSQL(2, Integer.MAX_VALUE), fromString("xample"));
+  }
+
+  @Test
   public void split() {
     assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1),
       new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi")}));


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message