flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From chiwanp...@apache.org
Subject flink git commit: [FLINK-2673] [core] Add a comparator for Scala Option type
Date Tue, 31 May 2016 02:05:15 GMT
Repository: flink
Updated Branches:
  refs/heads/master da23ee38e -> c60326f85


[FLINK-2673] [core] Add a comparator for Scala Option type

This closes #2017.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/c60326f8
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/c60326f8
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/c60326f8

Branch: refs/heads/master
Commit: c60326f85faaa38bcc359d555cd2d2818ef2e4e7
Parents: da23ee3
Author: Chiwan Park <chiwanpark@apache.org>
Authored: Sun May 22 20:39:10 2016 +0900
Committer: Chiwan Park <chiwanpark@apache.org>
Committed: Tue May 31 11:04:37 2016 +0900

----------------------------------------------------------------------
 .../scala/typeutils/OptionTypeComparator.scala  | 157 +++++++++++++++++++
 .../api/scala/typeutils/OptionTypeInfo.scala    |  20 ++-
 .../typeutils/OptionTypeComparatorTest.scala    |  31 ++++
 .../flink/api/scala/operators/JoinITCase.scala  |  11 ++
 4 files changed, 214 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/c60326f8/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/OptionTypeComparator.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/OptionTypeComparator.scala
b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/OptionTypeComparator.scala
new file mode 100644
index 0000000..e20ec16
--- /dev/null
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/OptionTypeComparator.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.typeutils
+
+import org.apache.flink.annotation.Internal
+import org.apache.flink.api.common.typeutils.TypeComparator
+import org.apache.flink.core.memory.{DataInputView, DataOutputView, MemorySegment}
+
+/**
+  * Comparator for [[Option]] values. Note that [[None]] is lesser than any [[Some]] values.
+  */
+@Internal
+class OptionTypeComparator[A](
+  private val ascending: Boolean,
+  private val typeComparator: TypeComparator[A]
+) extends TypeComparator[Option[A]] {
+  private var reference: Option[A] = _
+
+  override def hash(record: Option[A]) = record.hashCode()
+
+  override def compare(first: Option[A], second: Option[A]) = {
+    first match {
+      case Some(firstValue: A) =>
+        second match {
+          case Some(secondValue: A) => typeComparator.compare(firstValue, secondValue)
+          case None =>
+            if (ascending) {
+              1
+            } else {
+              -1
+            }
+        }
+      case None =>
+        second match {
+          case Some(secondValue) =>
+            if (ascending) {
+              -1
+            } else {
+              1
+            }
+          case None => 0
+        }
+    }
+  }
+
+  override def compareSerialized(firstSource: DataInputView, secondSource: DataInputView)
= {
+    val firstSome = firstSource.readBoolean()
+    val secondSome = secondSource.readBoolean()
+
+    if (firstSome) {
+      if (secondSome) {
+        typeComparator.compareSerialized(firstSource, secondSource)
+      } else {
+        if (ascending) {
+          1
+        } else {
+          -1
+        }
+      }
+    } else {
+      if (secondSome) {
+        if (ascending) {
+          -1
+        } else {
+          1
+        }
+      } else {
+        0
+      }
+    }
+  }
+
+  override def extractKeys(record: AnyRef, target: Array[AnyRef], index: Int) = {
+    target(index) = record
+    1
+  }
+
+  override def setReference(toCompare: Option[A]) = {
+    reference = toCompare
+  }
+
+  override def equalToReference(candidate: Option[A]) = {
+    compare(reference, candidate) == 0
+  }
+
+  override def compareToReference(referencedComparator: TypeComparator[Option[A]]) = {
+    compare(referencedComparator.asInstanceOf[this.type].reference, reference)
+  }
+
+  override lazy val getFlatComparators = {
+    Array(this).asInstanceOf[Array[TypeComparator[_]]]
+  }
+
+  override def getNormalizeKeyLen = 1 + typeComparator.getNormalizeKeyLen
+
+  override def putNormalizedKey(
+    record: Option[A],
+    target: MemorySegment,
+    offset: Int,
+    numBytes: Int
+  ) = {
+    if (numBytes >= 1) {
+      record match {
+        case Some(v: A) =>
+          target.put(offset, OptionTypeComparator.OneInByte)
+          typeComparator.putNormalizedKey(v, target, offset + 1, numBytes - 1)
+        case None =>
+          target.put(offset, OptionTypeComparator.ZeroInByte)
+          var i = 1
+          while (i < numBytes) {
+            target.put(offset + i, OptionTypeComparator.ZeroInByte)
+            i += 1
+          }
+      }
+    }
+  }
+
+  override def invertNormalizedKey() = !ascending
+
+  override def readWithKeyDenormalization(reuse: Option[A], source: DataInputView) = {
+    throw new UnsupportedOperationException
+  }
+
+  override def writeWithKeyNormalization(record: Option[A], target: DataOutputView) = {
+    throw new UnsupportedOperationException
+  }
+
+  override def isNormalizedKeyPrefixOnly(keyBytes: Int) = {
+    typeComparator.isNormalizedKeyPrefixOnly(keyBytes - 1)
+  }
+
+  override def supportsSerializationWithKeyNormalization() = false
+
+  override def supportsNormalizedKey() = typeComparator.supportsNormalizedKey()
+
+  override def duplicate() = new OptionTypeComparator[A](ascending, typeComparator)
+}
+
+object OptionTypeComparator {
+  val ZeroInByte = 0.asInstanceOf[Byte]
+  val OneInByte = 1.asInstanceOf[Byte]
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/c60326f8/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/OptionTypeInfo.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/OptionTypeInfo.scala
b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/OptionTypeInfo.scala
index 70db4fa..58ae77c 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/OptionTypeInfo.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/OptionTypeInfo.scala
@@ -17,10 +17,10 @@
  */
 package org.apache.flink.api.scala.typeutils
 
-import org.apache.flink.annotation.{PublicEvolving, Public}
+import org.apache.flink.annotation.{Public, PublicEvolving}
 import org.apache.flink.api.common.ExecutionConfig
-import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.common.typeutils.TypeSerializer
+import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
+import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer}
 
 import scala.collection.JavaConverters._
 
@@ -29,14 +29,14 @@ import scala.collection.JavaConverters._
  */
 @Public
 class OptionTypeInfo[A, T <: Option[A]](private val elemTypeInfo: TypeInformation[A])
-  extends TypeInformation[T] {
+  extends TypeInformation[T] with AtomicType[T] {
 
   @PublicEvolving
   override def isBasicType: Boolean = false
   @PublicEvolving
   override def isTupleType: Boolean = false
   @PublicEvolving
-  override def isKeyType: Boolean = false
+  override def isKeyType: Boolean = elemTypeInfo.isKeyType
   @PublicEvolving
   override def getTotalFields: Int = 1
   @PublicEvolving
@@ -46,6 +46,16 @@ class OptionTypeInfo[A, T <: Option[A]](private val elemTypeInfo: TypeInformatio
   @PublicEvolving
   override def getGenericParameters = List[TypeInformation[_]](elemTypeInfo).asJava
 
+  @PublicEvolving
+  override def createComparator(ascending: Boolean, executionConfig: ExecutionConfig) = {
+    if (isKeyType) {
+      val elemCompartor = elemTypeInfo.asInstanceOf[AtomicType[A]]
+        .createComparator(ascending, executionConfig)
+      new OptionTypeComparator[A](ascending, elemCompartor).asInstanceOf[TypeComparator[T]]
+    } else {
+      throw new UnsupportedOperationException("Element type that doesn't support ")
+    }
+  }
 
   @PublicEvolving
   def createSerializer(executionConfig: ExecutionConfig): TypeSerializer[T] = {

http://git-wip-us.apache.org/repos/asf/flink/blob/c60326f8/flink-scala/src/test/scala/org/apache/flink/api/scala/typeutils/OptionTypeComparatorTest.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/typeutils/OptionTypeComparatorTest.scala
b/flink-scala/src/test/scala/org/apache/flink/api/scala/typeutils/OptionTypeComparatorTest.scala
new file mode 100644
index 0000000..5b171c0
--- /dev/null
+++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/typeutils/OptionTypeComparatorTest.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.typeutils
+
+import org.apache.flink.api.common.typeutils.ComparatorTestBase
+import org.apache.flink.api.common.typeutils.base.{StringComparator, StringSerializer}
+
+class OptionTypeComparatorTest extends ComparatorTestBase[Option[String]] {
+  override protected def createComparator(ascending: Boolean) = {
+    new OptionTypeComparator[String](ascending, new StringComparator(ascending))
+  }
+
+  override protected def createSerializer() = new OptionSerializer[String](new StringSerializer)
+
+  override protected def getSortedTestData = Array(None, Some("a"), Some("b"), Some("c"))
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/c60326f8/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinITCase.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinITCase.scala
b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinITCase.scala
index a958250..85b6ea4 100644
--- a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinITCase.scala
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinITCase.scala
@@ -406,4 +406,15 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode)
     env.execute()
     expected = "1,(1,1,Hi)\n2,(2,2,Hello)"
   }
+
+  @Test
+  def testWithScalaOptionValues(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val ds1 = env.fromElements(None, Some("a"), Some("b"))
+    val ds2 = env.fromElements(None, Some("a"))
+    val joinDs = ds1.join(ds2).where("_").equalTo("_")
+    joinDs.writeAsCsv(resultPath, writeMode = WriteMode.OVERWRITE)
+    env.execute()
+    expected = "None,None\nSome(a),Some(a)"
+  }
 }


Mime
View raw message