Return-Path: X-Original-To: apmail-flink-commits-archive@minotaur.apache.org Delivered-To: apmail-flink-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 391CB183CA for ; Fri, 18 Mar 2016 13:48:34 +0000 (UTC) Received: (qmail 45102 invoked by uid 500); 18 Mar 2016 13:48:05 -0000 Delivered-To: apmail-flink-commits-archive@flink.apache.org Received: (qmail 45035 invoked by uid 500); 18 Mar 2016 13:48:04 -0000 Mailing-List: contact commits-help@flink.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@flink.apache.org Delivered-To: mailing list commits@flink.apache.org Received: (qmail 43357 invoked by uid 99); 18 Mar 2016 13:47:56 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 18 Mar 2016 13:47:56 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 2DA35E93E1; Fri, 18 Mar 2016 13:47:56 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: vasia@apache.org To: commits@flink.apache.org Date: Fri, 18 Mar 2016 13:48:33 -0000 Message-Id: <27bf082ae31f42f3a19ce8cb7c2afc1c@git.apache.org> In-Reply-To: <23694d950bbb4ace842cfa98be8f6e58@git.apache.org> References: <23694d950bbb4ace842cfa98be8f6e58@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: [39/50] [abbrv] flink git commit: [FLINK-3489] TableAPI refactoring and cleanup http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowComparator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowComparator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowComparator.scala deleted file mode 100644 index 17c6d56..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowComparator.scala +++ /dev/null @@ -1,417 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.api.table.typeinfo - -import java.util - -import org.apache.flink.api.common.typeutils.{CompositeTypeComparator, TypeComparator, TypeSerializer} -import org.apache.flink.api.java.typeutils.runtime.TupleComparatorBase -import org.apache.flink.api.table.Row -import org.apache.flink.api.table.typeinfo.NullMaskUtils.readIntoNullMask -import org.apache.flink.api.table.typeinfo.RowComparator.{createAuxiliaryFields, makeNullAware} -import org.apache.flink.core.memory.{DataInputView, DataOutputView, MemorySegment} -import org.apache.flink.types.KeyFieldOutOfBoundsException - -/** - * Comparator for [[Row]]. - */ -class RowComparator private ( - /** key positions describe which fields are keys in what order */ - val keyPositions: Array[Int], - /** null-aware comparators for the key fields, in the same order as the key fields */ - val comparators: Array[NullAwareComparator[Any]], - /** serializers to deserialize the first n fields for comparison */ - val serializers: Array[TypeSerializer[Any]], - /** auxiliary fields for normalized key support */ - private val auxiliaryFields: (Array[Int], Int, Int, Boolean)) - extends CompositeTypeComparator[Row] with Serializable { - - // null masks for serialized comparison - private val nullMask1 = new Array[Boolean](serializers.length) - private val nullMask2 = new Array[Boolean](serializers.length) - - // cache for the deserialized key field objects - @transient - private lazy val deserializedKeyFields1: Array[Any] = instantiateDeserializationFields() - - @transient - private lazy val deserializedKeyFields2: Array[Any] = instantiateDeserializationFields() - - // create auxiliary fields - private val normalizedKeyLengths: Array[Int] = auxiliaryFields._1 - private val numLeadingNormalizableKeys: Int = auxiliaryFields._2 - private val normalizableKeyPrefixLen: Int = auxiliaryFields._3 - private val invertNormKey: Boolean = auxiliaryFields._4 - - /** - * Intermediate constructor for creating auxiliary fields. - */ - def this( - keyPositions: Array[Int], - comparators: Array[NullAwareComparator[Any]], - serializers: Array[TypeSerializer[Any]]) = { - this( - keyPositions, - comparators, - serializers, - createAuxiliaryFields(keyPositions, comparators)) - } - - /** - * General constructor for RowComparator. - * - * @param keyPositions key positions describe which fields are keys in what order - * @param comparators non-null-aware comparators for the key fields, in the same order as - * the key fields - * @param serializers serializers to deserialize the first n fields for comparison - * @param orders sorting orders for the fields - */ - def this( - keyPositions: Array[Int], - comparators: Array[TypeComparator[Any]], - serializers: Array[TypeSerializer[Any]], - orders: Array[Boolean]) = { - this( - keyPositions, - makeNullAware(comparators, orders), - serializers) - } - - private def instantiateDeserializationFields(): Array[Any] = { - val newFields = new Array[Any](serializers.length) - var i = 0 - while (i < serializers.length) { - newFields(i) = serializers(i).createInstance() - i += 1 - } - newFields - } - - // -------------------------------------------------------------------------------------------- - // Comparator Methods - // -------------------------------------------------------------------------------------------- - - override def compareToReference(referencedComparator: TypeComparator[Row]): Int = { - val other: RowComparator = referencedComparator.asInstanceOf[RowComparator] - var i = 0 - try { - while (i < keyPositions.length) { - val comparator = comparators(i) - val otherComparator = other.comparators(i) - - val cmp = comparator.compareToReference(otherComparator) - if (cmp != 0) { - return cmp - } - i = i + 1 - } - 0 - } - catch { - case iobex: IndexOutOfBoundsException => - throw new KeyFieldOutOfBoundsException(keyPositions(i)) - } - } - - override def compareSerialized(firstSource: DataInputView, secondSource: DataInputView): Int = { - val len = serializers.length - val keyLen = keyPositions.length - - readIntoNullMask(len, firstSource, nullMask1) - readIntoNullMask(len, secondSource, nullMask2) - - // deserialize - var i = 0 - while (i < len) { - val serializer = serializers(i) - - // deserialize field 1 - if (!nullMask1(i)) { - deserializedKeyFields1(i) = serializer.deserialize(deserializedKeyFields1(i), firstSource) - } - - // deserialize field 2 - if (!nullMask2(i)) { - deserializedKeyFields2(i) = serializer.deserialize(deserializedKeyFields2(i), secondSource) - } - - i += 1 - } - - // compare - i = 0 - while (i < keyLen) { - val keyPos = keyPositions(i) - val comparator = comparators(i) - - val isNull1 = nullMask1(keyPos) - val isNull2 = nullMask2(keyPos) - - var cmp = 0 - // both values are null -> equality - if (isNull1 && isNull2) { - cmp = 0 - } - // first value is null -> inequality - else if (isNull1) { - cmp = comparator.compare(null, deserializedKeyFields2(keyPos)) - } - // second value is null -> inequality - else if (isNull2) { - cmp = comparator.compare(deserializedKeyFields1(keyPos), null) - } - // no null values - else { - cmp = comparator.compare(deserializedKeyFields1(keyPos), deserializedKeyFields2(keyPos)) - } - - if (cmp != 0) { - return cmp - } - - i += 1 - } - 0 - } - - override def supportsNormalizedKey(): Boolean = numLeadingNormalizableKeys > 0 - - override def getNormalizeKeyLen: Int = normalizableKeyPrefixLen - - override def isNormalizedKeyPrefixOnly(keyBytes: Int): Boolean = - numLeadingNormalizableKeys < keyPositions.length || - normalizableKeyPrefixLen == Integer.MAX_VALUE || - normalizableKeyPrefixLen > keyBytes - - override def invertNormalizedKey(): Boolean = invertNormKey - - override def supportsSerializationWithKeyNormalization(): Boolean = false - - override def writeWithKeyNormalization(record: Row, target: DataOutputView): Unit = - throw new UnsupportedOperationException("Record serialization with leading normalized keys " + - "not supported.") - - override def readWithKeyDenormalization(reuse: Row, source: DataInputView): Row = - throw new UnsupportedOperationException("Record deserialization with leading normalized keys " + - "not supported.") - - override def duplicate(): TypeComparator[Row] = { - // copy comparator and serializer factories - val comparatorsCopy = comparators.map(_.duplicate().asInstanceOf[NullAwareComparator[Any]]) - val serializersCopy = serializers.map(_.duplicate()) - - new RowComparator( - keyPositions, - comparatorsCopy, - serializersCopy, - auxiliaryFields) - } - - override def hash(value: Row): Int = { - var code: Int = 0 - var i = 0 - try { - while(i < keyPositions.length) { - code *= TupleComparatorBase.HASH_SALT(i & 0x1F) - val element = value.productElement(keyPositions(i)) // element can be null - code += comparators(i).hash(element) - i += 1 - } - } catch { - case iobex: IndexOutOfBoundsException => - throw new KeyFieldOutOfBoundsException(keyPositions(i)) - } - code - } - - override def setReference(toCompare: Row) { - var i = 0 - try { - while(i < keyPositions.length) { - val comparator = comparators(i) - val element = toCompare.productElement(keyPositions(i)) - comparator.setReference(element) // element can be null - i += 1 - } - } catch { - case iobex: IndexOutOfBoundsException => - throw new KeyFieldOutOfBoundsException(keyPositions(i)) - } - } - - override def equalToReference(candidate: Row): Boolean = { - var i = 0 - try { - while(i < keyPositions.length) { - val comparator = comparators(i) - val element = candidate.productElement(keyPositions(i)) // element can be null - // check if reference is not equal - if (!comparator.equalToReference(element)) { - return false - } - i += 1 - } - } catch { - case iobex: IndexOutOfBoundsException => - throw new KeyFieldOutOfBoundsException(keyPositions(i)) - } - true - } - - override def compare(first: Row, second: Row): Int = { - var i = 0 - try { - while(i < keyPositions.length) { - val keyPos: Int = keyPositions(i) - val comparator = comparators(i) - val firstElement = first.productElement(keyPos) // element can be null - val secondElement = second.productElement(keyPos) // element can be null - - val cmp = comparator.compare(firstElement, secondElement) - if (cmp != 0) { - return cmp - } - i += 1 - } - } catch { - case iobex: IndexOutOfBoundsException => - throw new KeyFieldOutOfBoundsException(keyPositions(i)) - } - 0 - } - - override def putNormalizedKey( - record: Row, - target: MemorySegment, - offset: Int, - numBytes: Int) - : Unit = { - var bytesLeft = numBytes - var currentOffset = offset - - var i = 0 - while (i < numLeadingNormalizableKeys && bytesLeft > 0) { - var len = normalizedKeyLengths(i) - len = if (bytesLeft >= len) len else bytesLeft - - val comparator = comparators(i) - val element = record.productElement(keyPositions(i)) // element can be null - // write key - comparator.putNormalizedKey(element, target, currentOffset, len) - - bytesLeft -= len - currentOffset += len - i += 1 - } - } - - override def getFlatComparator(flatComparators: util.List[TypeComparator[_]]): Unit = - comparators.foreach { c => - c.getFlatComparators.foreach { fc => - flatComparators.add(fc) - } - } - - override def extractKeys(record: Any, target: Array[AnyRef], index: Int): Int = { - val len = comparators.length - var localIndex = index - var i = 0 - while (i < len) { - val element = record.asInstanceOf[Row].productElement(keyPositions(i)) // element can be null - localIndex += comparators(i).extractKeys(element, target, localIndex) - i += 1 - } - localIndex - index - } -} - -object RowComparator { - private def makeNullAware( - comparators: Array[TypeComparator[Any]], - orders: Array[Boolean]) - : Array[NullAwareComparator[Any]] = - comparators - .zip(orders) - .map { case (comp, order) => - new NullAwareComparator[Any]( - comp, - order) - } - - /** - * @return creates auxiliary fields for normalized key support - */ - private def createAuxiliaryFields( - keyPositions: Array[Int], - comparators: Array[NullAwareComparator[Any]]) - : (Array[Int], Int, Int, Boolean) = { - - val normalizedKeyLengths = new Array[Int](keyPositions.length) - var numLeadingNormalizableKeys = 0 - var normalizableKeyPrefixLen = 0 - var inverted = false - - var i = 0 - while (i < keyPositions.length) { - val k = comparators(i) - // as long as the leading keys support normalized keys, we can build up the composite key - if (k.supportsNormalizedKey()) { - if (i == 0) { - // the first comparator decides whether we need to invert the key direction - inverted = k.invertNormalizedKey() - } - else if (k.invertNormalizedKey() != inverted) { - // if a successor does not agree on the inversion direction, it cannot be part of the - // normalized key - return (normalizedKeyLengths, - numLeadingNormalizableKeys, - normalizableKeyPrefixLen, - inverted) - } - numLeadingNormalizableKeys += 1 - val len = k.getNormalizeKeyLen - if (len < 0) { - throw new RuntimeException("Comparator " + k.getClass.getName + - " specifies an invalid length for the normalized key: " + len) - } - normalizedKeyLengths(i) = len - normalizableKeyPrefixLen += len - if (normalizableKeyPrefixLen < 0) { - // overflow, which means we are out of budget for normalized key space anyways - return (normalizedKeyLengths, - numLeadingNormalizableKeys, - Integer.MAX_VALUE, - inverted) - } - } - else { - return (normalizedKeyLengths, - numLeadingNormalizableKeys, - normalizableKeyPrefixLen, - inverted) - } - i += 1 - } - (normalizedKeyLengths, - numLeadingNormalizableKeys, - normalizableKeyPrefixLen, - inverted) - } -} - http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowSerializer.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowSerializer.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowSerializer.scala deleted file mode 100644 index 4f2d535..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowSerializer.scala +++ /dev/null @@ -1,209 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.api.table.typeinfo - -import org.apache.flink.api.common.typeutils.TypeSerializer -import org.apache.flink.api.table.Row -import org.apache.flink.api.table.typeinfo.NullMaskUtils.{writeNullMask, readIntoNullMask, readIntoAndCopyNullMask} -import org.apache.flink.core.memory.{DataInputView, DataOutputView} - -/** - * Serializer for [[Row]]. - */ -class RowSerializer(val fieldSerializers: Array[TypeSerializer[Any]]) - extends TypeSerializer[Row] { - - private val nullMask = new Array[Boolean](fieldSerializers.length) - - override def isImmutableType: Boolean = false - - override def getLength: Int = -1 - - override def duplicate: RowSerializer = { - val duplicateFieldSerializers = fieldSerializers.map(_.duplicate()) - new RowSerializer(duplicateFieldSerializers) - } - - override def createInstance: Row = { - new Row(fieldSerializers.length) - } - - override def copy(from: Row, reuse: Row): Row = { - val len = fieldSerializers.length - - // cannot reuse, do a non-reuse copy - if (reuse == null) { - return copy(from) - } - - if (from.productArity != len || reuse.productArity != len) { - throw new RuntimeException("Row arity of reuse or from is incompatible with this " + - "RowSerializer.") - } - - var i = 0 - while (i < len) { - val fromField = from.productElement(i) - if (fromField != null) { - val reuseField = reuse.productElement(i) - if (reuseField != null) { - val copy = fieldSerializers(i).copy(fromField, reuseField) - reuse.setField(i, copy) - } - else { - val copy = fieldSerializers(i).copy(fromField) - reuse.setField(i, copy) - } - } - else { - reuse.setField(i, null) - } - i += 1 - } - reuse - } - - override def copy(from: Row): Row = { - val len = fieldSerializers.length - - if (from.productArity != len) { - throw new RuntimeException("Row arity of from does not match serializers.") - } - val result = new Row(len) - var i = 0 - while (i < len) { - val fromField = from.productElement(i).asInstanceOf[AnyRef] - if (fromField != null) { - val copy = fieldSerializers(i).copy(fromField) - result.setField(i, copy) - } - else { - result.setField(i, null) - } - i += 1 - } - result - } - - override def serialize(value: Row, target: DataOutputView) { - val len = fieldSerializers.length - - if (value.productArity != len) { - throw new RuntimeException("Row arity of value does not match serializers.") - } - - // write a null mask - writeNullMask(len, value, target) - - // serialize non-null fields - var i = 0 - while (i < len) { - val o = value.productElement(i).asInstanceOf[AnyRef] - if (o != null) { - val serializer = fieldSerializers(i) - serializer.serialize(value.productElement(i), target) - } - i += 1 - } - } - - override def deserialize(reuse: Row, source: DataInputView): Row = { - val len = fieldSerializers.length - - if (reuse.productArity != len) { - throw new RuntimeException("Row arity of reuse does not match serializers.") - } - - // read null mask - readIntoNullMask(len, source, nullMask) - - // read non-null fields - var i = 0 - while (i < len) { - if (nullMask(i)) { - reuse.setField(i, null) - } - else { - val reuseField = reuse.productElement(i).asInstanceOf[AnyRef] - if (reuseField != null) { - reuse.setField(i, fieldSerializers(i).deserialize(reuseField, source)) - } - else { - reuse.setField(i, fieldSerializers(i).deserialize(source)) - } - } - i += 1 - } - reuse - } - - override def deserialize(source: DataInputView): Row = { - val len = fieldSerializers.length - - val result = new Row(len) - - // read null mask - readIntoNullMask(len, source, nullMask) - - // read non-null fields - var i = 0 - while (i < len) { - if (nullMask(i)) { - result.setField(i, null) - } - else { - result.setField(i, fieldSerializers(i).deserialize(source)) - } - i += 1 - } - result - } - - override def copy(source: DataInputView, target: DataOutputView): Unit = { - val len = fieldSerializers.length - - // copy null mask - readIntoAndCopyNullMask(len, source, target, nullMask) - - // read non-null fields - var i = 0 - while (i < len) { - if (!nullMask(i)) { - fieldSerializers(i).copy(source, target) - } - i += 1 - } - } - - override def equals(any: Any): Boolean = { - any match { - case otherRS: RowSerializer => - otherRS.canEqual(this) && - fieldSerializers.sameElements(otherRS.fieldSerializers) - case _ => false - } - } - - override def canEqual(obj: AnyRef): Boolean = { - obj.isInstanceOf[RowSerializer] - } - - override def hashCode(): Int = { - java.util.Arrays.hashCode(fieldSerializers.asInstanceOf[Array[AnyRef]]) - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowTypeInfo.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowTypeInfo.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowTypeInfo.scala deleted file mode 100644 index 81c3836..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowTypeInfo.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.api.table.typeinfo - -import org.apache.flink.api.common.ExecutionConfig -import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.common.typeutils.CompositeType.TypeComparatorBuilder -import org.apache.flink.api.common.typeutils.TypeComparator -import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo - -import scala.collection.mutable.ArrayBuffer -import org.apache.flink.api.common.typeutils.TypeSerializer -import org.apache.flink.api.table.Row - -/** - * TypeInformation for [[Row]]. - */ -class RowTypeInfo(fieldTypes: Seq[TypeInformation[_]], fieldNames: Seq[String]) - extends CaseClassTypeInfo[Row]( - classOf[Row], - Array(), - fieldTypes, - fieldNames) -{ - - if (fieldTypes.length != fieldNames.length) { - throw new IllegalArgumentException("Number of field types and names is different.") - } - if (fieldNames.length != fieldNames.toSet.size) { - throw new IllegalArgumentException("Field names are not unique.") - } - - def this(fieldTypes: Seq[TypeInformation[_]]) = { - this(fieldTypes, for (i <- fieldTypes.indices) yield "f" + i) - } - - /** - * Temporary variable for directly passing orders to comparators. - */ - var comparatorOrders: Option[Array[Boolean]] = None - - override def createSerializer(executionConfig: ExecutionConfig): TypeSerializer[Row] = { - val fieldSerializers: Array[TypeSerializer[Any]] = new Array[TypeSerializer[Any]](getArity) - for (i <- 0 until getArity) { - fieldSerializers(i) = this.types(i).createSerializer(executionConfig) - .asInstanceOf[TypeSerializer[Any]] - } - - new RowSerializer(fieldSerializers) - } - - override def createComparator( - logicalKeyFields: Array[Int], - orders: Array[Boolean], - logicalFieldOffset: Int, - config: ExecutionConfig) - : TypeComparator[Row] = { - // store the order information for the builder - comparatorOrders = Some(orders) - val comparator = super.createComparator(logicalKeyFields, orders, logicalFieldOffset, config) - comparatorOrders = None - comparator - } - - override def createTypeComparatorBuilder(): TypeComparatorBuilder[Row] = { - new RowTypeComparatorBuilder(comparatorOrders.getOrElse( - throw new IllegalStateException("Cannot create comparator builder without orders."))) - } - - private class RowTypeComparatorBuilder( - comparatorOrders: Array[Boolean]) - extends TypeComparatorBuilder[Row] { - - val fieldComparators: ArrayBuffer[TypeComparator[_]] = new ArrayBuffer[TypeComparator[_]]() - val logicalKeyFields: ArrayBuffer[Int] = new ArrayBuffer[Int]() - - override def initializeTypeComparatorBuilder(size: Int): Unit = { - fieldComparators.sizeHint(size) - logicalKeyFields.sizeHint(size) - } - - override def addComparatorField(fieldId: Int, comparator: TypeComparator[_]): Unit = { - fieldComparators += comparator - logicalKeyFields += fieldId - } - - override def createTypeComparator(config: ExecutionConfig): TypeComparator[Row] = { - val maxIndex = logicalKeyFields.max - - new RowComparator( - logicalKeyFields.toArray, - fieldComparators.toArray.asInstanceOf[Array[TypeComparator[Any]]], - types.take(maxIndex + 1).map(_.createSerializer(config).asInstanceOf[TypeSerializer[Any]]), - comparatorOrders - ) - } - } -} - http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/NullAwareComparator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/NullAwareComparator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/NullAwareComparator.scala new file mode 100644 index 0000000..86a768d --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/NullAwareComparator.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.typeutils + +import org.apache.flink.api.common.typeutils.{CompositeTypeComparator, TypeComparator} +import org.apache.flink.core.memory.{DataInputView, DataOutputView, MemorySegment} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer + +/** + * Null-aware comparator that wraps a comparator which does not support null references. + * + * NOTE: This class assumes to be used within a composite type comparator (such + * as [[RowComparator]]) that handles serialized comparison. + */ +class NullAwareComparator[T]( + val wrappedComparator: TypeComparator[T], + val order: Boolean) + extends TypeComparator[T] { + + // number of flat fields + private val flatFields = wrappedComparator.getFlatComparators.length + + // stores the null for reference comparison + private var nullReference = false + + override def hash(record: T): Int = { + if (record != null) { + wrappedComparator.hash(record) + } + else { + 0 + } + } + + override def getNormalizeKeyLen: Int = { + val len = wrappedComparator.getNormalizeKeyLen + if (len == Integer.MAX_VALUE) { + Integer.MAX_VALUE + } + else { + len + 1 // add one for a null byte + } + } + + override def putNormalizedKey( + record: T, + target: MemorySegment, + offset: Int, + numBytes: Int) + : Unit = { + if (numBytes > 0) { + // write a null byte with padding + if (record == null) { + target.putBoolean(offset, false) + // write padding + var j = 0 + while (j < numBytes - 1) { + target.put(offset + 1 + j, 0.toByte) + j += 1 + } + } + // write a non-null byte with key + else { + target.putBoolean(offset, true) + // write key + wrappedComparator.putNormalizedKey(record, target, offset + 1, numBytes - 1) + } + } + } + + override def invertNormalizedKey(): Boolean = wrappedComparator.invertNormalizedKey() + + override def supportsSerializationWithKeyNormalization(): Boolean = false + + override def writeWithKeyNormalization(record: T, target: DataOutputView): Unit = + throw new UnsupportedOperationException("Record serialization with leading normalized keys" + + " not supported.") + + override def readWithKeyDenormalization(reuse: T, source: DataInputView): T = + throw new UnsupportedOperationException("Record deserialization with leading normalized keys" + + " not supported.") + + override def isNormalizedKeyPrefixOnly(keyBytes: Int): Boolean = + wrappedComparator.isNormalizedKeyPrefixOnly(keyBytes - 1) + + override def setReference(toCompare: T): Unit = { + if (toCompare == null) { + nullReference = true + } + else { + nullReference = false + wrappedComparator.setReference(toCompare) + } + } + + override def compare(first: T, second: T): Int = { + // both values are null -> equality + if (first == null && second == null) { + 0 + } + // first value is null -> inequality + // but order is considered + else if (first == null) { + if (order) -1 else 1 + } + // second value is null -> inequality + // but order is considered + else if (second == null) { + if (order) 1 else -1 + } + // no null values + else { + wrappedComparator.compare(first, second) + } + } + + override def compareToReference(referencedComparator: TypeComparator[T]): Int = { + val otherComparator = referencedComparator.asInstanceOf[NullAwareComparator[T]] + val otherNullReference = otherComparator.nullReference + // both values are null -> equality + if (nullReference && otherNullReference) { + 0 + } + // first value is null -> inequality + // but order is considered + else if (nullReference) { + if (order) 1 else -1 + } + // second value is null -> inequality + // but order is considered + else if (otherNullReference) { + if (order) -1 else 1 + } + // no null values + else { + wrappedComparator.compareToReference(otherComparator.wrappedComparator) + } + } + + override def supportsNormalizedKey(): Boolean = wrappedComparator.supportsNormalizedKey() + + override def equalToReference(candidate: T): Boolean = { + // both values are null + if (candidate == null && nullReference) { + true + } + // one value is null + else if (candidate == null || nullReference) { + false + } + // no null value + else { + wrappedComparator.equalToReference(candidate) + } + } + + override def duplicate(): TypeComparator[T] = { + new NullAwareComparator[T](wrappedComparator.duplicate(), order) + } + + override def extractKeys(record: Any, target: Array[AnyRef], index: Int): Int = { + if (record == null) { + var i = 0 + while (i < flatFields) { + target(index + i) = null + i += 1 + } + flatFields + } + else { + wrappedComparator.extractKeys(record, target, index) + } + } + + + override def getFlatComparators: Array[TypeComparator[_]] = { + // determine the flat comparators and wrap them again in null-aware comparators + val flatComparators = new ArrayBuffer[TypeComparator[_]]() + wrappedComparator match { + case ctc: CompositeTypeComparator[_] => ctc.getFlatComparator(flatComparators) + case c: TypeComparator[_] => flatComparators += c + } + val wrappedComparators = flatComparators.map { c => + new NullAwareComparator[Any](c.asInstanceOf[TypeComparator[Any]], order) + } + wrappedComparators.toArray[TypeComparator[_]] + } + + /** + * This method is not implemented here. It must be implemented by the comparator this class + * is contained in (e.g. RowComparator). + * + * @param firstSource The input view containing the first record. + * @param secondSource The input view containing the second record. + * @return An integer defining the oder among the objects in the same way as + * { @link java.util.Comparator#compare(Object, Object)}. + */ + override def compareSerialized(firstSource: DataInputView, secondSource: DataInputView): Int = + throw new UnsupportedOperationException("Comparator does not support null-aware serialized " + + "comparision.") +} http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/NullMaskUtils.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/NullMaskUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/NullMaskUtils.scala new file mode 100644 index 0000000..dcdc775 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/NullMaskUtils.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.typeutils + +import org.apache.flink.api.table.Row +import org.apache.flink.core.memory.{DataInputView, DataOutputView} + +object NullMaskUtils { + + def writeNullMask(len: Int, value: Row, target: DataOutputView): Unit = { + var b = 0x00 + var bytePos = 0 + + var fieldPos = 0 + var numPos = 0 + while (fieldPos < len) { + b = 0x00 + // set bits in byte + bytePos = 0 + numPos = Math.min(8, len - fieldPos) + while (bytePos < numPos) { + b = b << 1 + // set bit if field is null + if(value.productElement(fieldPos + bytePos) == null) { + b |= 0x01 + } + bytePos += 1 + } + fieldPos += numPos + // shift bits if last byte is not completely filled + b <<= (8 - bytePos) + // write byte + target.writeByte(b) + } + } + + def readIntoNullMask(len: Int, source: DataInputView, nullMask: Array[Boolean]): Unit = { + var b = 0x00 + var bytePos = 0 + + var fieldPos = 0 + var numPos = 0 + while (fieldPos < len) { + // read byte + b = source.readUnsignedByte() + bytePos = 0 + numPos = Math.min(8, len - fieldPos) + while (bytePos < numPos) { + nullMask(fieldPos + bytePos) = (b & 0x80) > 0 + b = b << 1 + bytePos += 1 + } + fieldPos += numPos + } + } + + def readIntoAndCopyNullMask( + len: Int, + source: DataInputView, + target: DataOutputView, + nullMask: Array[Boolean]): Unit = { + var b = 0x00 + var bytePos = 0 + + var fieldPos = 0 + var numPos = 0 + while (fieldPos < len) { + // read byte + b = source.readUnsignedByte() + // copy byte + target.writeByte(b) + bytePos = 0 + numPos = Math.min(8, len - fieldPos) + while (bytePos < numPos) { + nullMask(fieldPos + bytePos) = (b & 0x80) > 0 + b = b << 1 + bytePos += 1 + } + fieldPos += numPos + } + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowComparator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowComparator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowComparator.scala new file mode 100644 index 0000000..cc97656 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowComparator.scala @@ -0,0 +1,417 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.typeutils + +import java.util + +import org.apache.flink.api.common.typeutils.{CompositeTypeComparator, TypeComparator, TypeSerializer} +import org.apache.flink.api.java.typeutils.runtime.TupleComparatorBase +import org.apache.flink.api.table.Row +import org.apache.flink.api.table.typeutils.NullMaskUtils.readIntoNullMask +import org.apache.flink.api.table.typeutils.RowComparator.{createAuxiliaryFields, makeNullAware} +import org.apache.flink.core.memory.{DataInputView, DataOutputView, MemorySegment} +import org.apache.flink.types.KeyFieldOutOfBoundsException + +/** + * Comparator for [[Row]]. + */ +class RowComparator private ( + /** key positions describe which fields are keys in what order */ + val keyPositions: Array[Int], + /** null-aware comparators for the key fields, in the same order as the key fields */ + val comparators: Array[NullAwareComparator[Any]], + /** serializers to deserialize the first n fields for comparison */ + val serializers: Array[TypeSerializer[Any]], + /** auxiliary fields for normalized key support */ + private val auxiliaryFields: (Array[Int], Int, Int, Boolean)) + extends CompositeTypeComparator[Row] with Serializable { + + // null masks for serialized comparison + private val nullMask1 = new Array[Boolean](serializers.length) + private val nullMask2 = new Array[Boolean](serializers.length) + + // cache for the deserialized key field objects + @transient + private lazy val deserializedKeyFields1: Array[Any] = instantiateDeserializationFields() + + @transient + private lazy val deserializedKeyFields2: Array[Any] = instantiateDeserializationFields() + + // create auxiliary fields + private val normalizedKeyLengths: Array[Int] = auxiliaryFields._1 + private val numLeadingNormalizableKeys: Int = auxiliaryFields._2 + private val normalizableKeyPrefixLen: Int = auxiliaryFields._3 + private val invertNormKey: Boolean = auxiliaryFields._4 + + /** + * Intermediate constructor for creating auxiliary fields. + */ + def this( + keyPositions: Array[Int], + comparators: Array[NullAwareComparator[Any]], + serializers: Array[TypeSerializer[Any]]) = { + this( + keyPositions, + comparators, + serializers, + createAuxiliaryFields(keyPositions, comparators)) + } + + /** + * General constructor for RowComparator. + * + * @param keyPositions key positions describe which fields are keys in what order + * @param comparators non-null-aware comparators for the key fields, in the same order as + * the key fields + * @param serializers serializers to deserialize the first n fields for comparison + * @param orders sorting orders for the fields + */ + def this( + keyPositions: Array[Int], + comparators: Array[TypeComparator[Any]], + serializers: Array[TypeSerializer[Any]], + orders: Array[Boolean]) = { + this( + keyPositions, + makeNullAware(comparators, orders), + serializers) + } + + private def instantiateDeserializationFields(): Array[Any] = { + val newFields = new Array[Any](serializers.length) + var i = 0 + while (i < serializers.length) { + newFields(i) = serializers(i).createInstance() + i += 1 + } + newFields + } + + // -------------------------------------------------------------------------------------------- + // Comparator Methods + // -------------------------------------------------------------------------------------------- + + override def compareToReference(referencedComparator: TypeComparator[Row]): Int = { + val other: RowComparator = referencedComparator.asInstanceOf[RowComparator] + var i = 0 + try { + while (i < keyPositions.length) { + val comparator = comparators(i) + val otherComparator = other.comparators(i) + + val cmp = comparator.compareToReference(otherComparator) + if (cmp != 0) { + return cmp + } + i = i + 1 + } + 0 + } + catch { + case iobex: IndexOutOfBoundsException => + throw new KeyFieldOutOfBoundsException(keyPositions(i)) + } + } + + override def compareSerialized(firstSource: DataInputView, secondSource: DataInputView): Int = { + val len = serializers.length + val keyLen = keyPositions.length + + readIntoNullMask(len, firstSource, nullMask1) + readIntoNullMask(len, secondSource, nullMask2) + + // deserialize + var i = 0 + while (i < len) { + val serializer = serializers(i) + + // deserialize field 1 + if (!nullMask1(i)) { + deserializedKeyFields1(i) = serializer.deserialize(deserializedKeyFields1(i), firstSource) + } + + // deserialize field 2 + if (!nullMask2(i)) { + deserializedKeyFields2(i) = serializer.deserialize(deserializedKeyFields2(i), secondSource) + } + + i += 1 + } + + // compare + i = 0 + while (i < keyLen) { + val keyPos = keyPositions(i) + val comparator = comparators(i) + + val isNull1 = nullMask1(keyPos) + val isNull2 = nullMask2(keyPos) + + var cmp = 0 + // both values are null -> equality + if (isNull1 && isNull2) { + cmp = 0 + } + // first value is null -> inequality + else if (isNull1) { + cmp = comparator.compare(null, deserializedKeyFields2(keyPos)) + } + // second value is null -> inequality + else if (isNull2) { + cmp = comparator.compare(deserializedKeyFields1(keyPos), null) + } + // no null values + else { + cmp = comparator.compare(deserializedKeyFields1(keyPos), deserializedKeyFields2(keyPos)) + } + + if (cmp != 0) { + return cmp + } + + i += 1 + } + 0 + } + + override def supportsNormalizedKey(): Boolean = numLeadingNormalizableKeys > 0 + + override def getNormalizeKeyLen: Int = normalizableKeyPrefixLen + + override def isNormalizedKeyPrefixOnly(keyBytes: Int): Boolean = + numLeadingNormalizableKeys < keyPositions.length || + normalizableKeyPrefixLen == Integer.MAX_VALUE || + normalizableKeyPrefixLen > keyBytes + + override def invertNormalizedKey(): Boolean = invertNormKey + + override def supportsSerializationWithKeyNormalization(): Boolean = false + + override def writeWithKeyNormalization(record: Row, target: DataOutputView): Unit = + throw new UnsupportedOperationException("Record serialization with leading normalized keys " + + "not supported.") + + override def readWithKeyDenormalization(reuse: Row, source: DataInputView): Row = + throw new UnsupportedOperationException("Record deserialization with leading normalized keys " + + "not supported.") + + override def duplicate(): TypeComparator[Row] = { + // copy comparator and serializer factories + val comparatorsCopy = comparators.map(_.duplicate().asInstanceOf[NullAwareComparator[Any]]) + val serializersCopy = serializers.map(_.duplicate()) + + new RowComparator( + keyPositions, + comparatorsCopy, + serializersCopy, + auxiliaryFields) + } + + override def hash(value: Row): Int = { + var code: Int = 0 + var i = 0 + try { + while(i < keyPositions.length) { + code *= TupleComparatorBase.HASH_SALT(i & 0x1F) + val element = value.productElement(keyPositions(i)) // element can be null + code += comparators(i).hash(element) + i += 1 + } + } catch { + case iobex: IndexOutOfBoundsException => + throw new KeyFieldOutOfBoundsException(keyPositions(i)) + } + code + } + + override def setReference(toCompare: Row) { + var i = 0 + try { + while(i < keyPositions.length) { + val comparator = comparators(i) + val element = toCompare.productElement(keyPositions(i)) + comparator.setReference(element) // element can be null + i += 1 + } + } catch { + case iobex: IndexOutOfBoundsException => + throw new KeyFieldOutOfBoundsException(keyPositions(i)) + } + } + + override def equalToReference(candidate: Row): Boolean = { + var i = 0 + try { + while(i < keyPositions.length) { + val comparator = comparators(i) + val element = candidate.productElement(keyPositions(i)) // element can be null + // check if reference is not equal + if (!comparator.equalToReference(element)) { + return false + } + i += 1 + } + } catch { + case iobex: IndexOutOfBoundsException => + throw new KeyFieldOutOfBoundsException(keyPositions(i)) + } + true + } + + override def compare(first: Row, second: Row): Int = { + var i = 0 + try { + while(i < keyPositions.length) { + val keyPos: Int = keyPositions(i) + val comparator = comparators(i) + val firstElement = first.productElement(keyPos) // element can be null + val secondElement = second.productElement(keyPos) // element can be null + + val cmp = comparator.compare(firstElement, secondElement) + if (cmp != 0) { + return cmp + } + i += 1 + } + } catch { + case iobex: IndexOutOfBoundsException => + throw new KeyFieldOutOfBoundsException(keyPositions(i)) + } + 0 + } + + override def putNormalizedKey( + record: Row, + target: MemorySegment, + offset: Int, + numBytes: Int) + : Unit = { + var bytesLeft = numBytes + var currentOffset = offset + + var i = 0 + while (i < numLeadingNormalizableKeys && bytesLeft > 0) { + var len = normalizedKeyLengths(i) + len = if (bytesLeft >= len) len else bytesLeft + + val comparator = comparators(i) + val element = record.productElement(keyPositions(i)) // element can be null + // write key + comparator.putNormalizedKey(element, target, currentOffset, len) + + bytesLeft -= len + currentOffset += len + i += 1 + } + } + + override def getFlatComparator(flatComparators: util.List[TypeComparator[_]]): Unit = + comparators.foreach { c => + c.getFlatComparators.foreach { fc => + flatComparators.add(fc) + } + } + + override def extractKeys(record: Any, target: Array[AnyRef], index: Int): Int = { + val len = comparators.length + var localIndex = index + var i = 0 + while (i < len) { + val element = record.asInstanceOf[Row].productElement(keyPositions(i)) // element can be null + localIndex += comparators(i).extractKeys(element, target, localIndex) + i += 1 + } + localIndex - index + } +} + +object RowComparator { + private def makeNullAware( + comparators: Array[TypeComparator[Any]], + orders: Array[Boolean]) + : Array[NullAwareComparator[Any]] = + comparators + .zip(orders) + .map { case (comp, order) => + new NullAwareComparator[Any]( + comp, + order) + } + + /** + * @return creates auxiliary fields for normalized key support + */ + private def createAuxiliaryFields( + keyPositions: Array[Int], + comparators: Array[NullAwareComparator[Any]]) + : (Array[Int], Int, Int, Boolean) = { + + val normalizedKeyLengths = new Array[Int](keyPositions.length) + var numLeadingNormalizableKeys = 0 + var normalizableKeyPrefixLen = 0 + var inverted = false + + var i = 0 + while (i < keyPositions.length) { + val k = comparators(i) + // as long as the leading keys support normalized keys, we can build up the composite key + if (k.supportsNormalizedKey()) { + if (i == 0) { + // the first comparator decides whether we need to invert the key direction + inverted = k.invertNormalizedKey() + } + else if (k.invertNormalizedKey() != inverted) { + // if a successor does not agree on the inversion direction, it cannot be part of the + // normalized key + return (normalizedKeyLengths, + numLeadingNormalizableKeys, + normalizableKeyPrefixLen, + inverted) + } + numLeadingNormalizableKeys += 1 + val len = k.getNormalizeKeyLen + if (len < 0) { + throw new RuntimeException("Comparator " + k.getClass.getName + + " specifies an invalid length for the normalized key: " + len) + } + normalizedKeyLengths(i) = len + normalizableKeyPrefixLen += len + if (normalizableKeyPrefixLen < 0) { + // overflow, which means we are out of budget for normalized key space anyways + return (normalizedKeyLengths, + numLeadingNormalizableKeys, + Integer.MAX_VALUE, + inverted) + } + } + else { + return (normalizedKeyLengths, + numLeadingNormalizableKeys, + normalizableKeyPrefixLen, + inverted) + } + i += 1 + } + (normalizedKeyLengths, + numLeadingNormalizableKeys, + normalizableKeyPrefixLen, + inverted) + } +} + http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowSerializer.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowSerializer.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowSerializer.scala new file mode 100644 index 0000000..825a99c --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowSerializer.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.typeutils + +import org.apache.flink.api.common.typeutils.TypeSerializer +import org.apache.flink.api.table.Row +import org.apache.flink.api.table.typeutils.NullMaskUtils.{writeNullMask, readIntoNullMask, readIntoAndCopyNullMask} +import org.apache.flink.core.memory.{DataInputView, DataOutputView} + +/** + * Serializer for [[Row]]. + */ +class RowSerializer(val fieldSerializers: Array[TypeSerializer[Any]]) + extends TypeSerializer[Row] { + + private val nullMask = new Array[Boolean](fieldSerializers.length) + + override def isImmutableType: Boolean = false + + override def getLength: Int = -1 + + override def duplicate: RowSerializer = { + val duplicateFieldSerializers = fieldSerializers.map(_.duplicate()) + new RowSerializer(duplicateFieldSerializers) + } + + override def createInstance: Row = { + new Row(fieldSerializers.length) + } + + override def copy(from: Row, reuse: Row): Row = { + val len = fieldSerializers.length + + // cannot reuse, do a non-reuse copy + if (reuse == null) { + return copy(from) + } + + if (from.productArity != len || reuse.productArity != len) { + throw new RuntimeException("Row arity of reuse or from is incompatible with this " + + "RowSerializer.") + } + + var i = 0 + while (i < len) { + val fromField = from.productElement(i) + if (fromField != null) { + val reuseField = reuse.productElement(i) + if (reuseField != null) { + val copy = fieldSerializers(i).copy(fromField, reuseField) + reuse.setField(i, copy) + } + else { + val copy = fieldSerializers(i).copy(fromField) + reuse.setField(i, copy) + } + } + else { + reuse.setField(i, null) + } + i += 1 + } + reuse + } + + override def copy(from: Row): Row = { + val len = fieldSerializers.length + + if (from.productArity != len) { + throw new RuntimeException("Row arity of from does not match serializers.") + } + val result = new Row(len) + var i = 0 + while (i < len) { + val fromField = from.productElement(i).asInstanceOf[AnyRef] + if (fromField != null) { + val copy = fieldSerializers(i).copy(fromField) + result.setField(i, copy) + } + else { + result.setField(i, null) + } + i += 1 + } + result + } + + override def serialize(value: Row, target: DataOutputView) { + val len = fieldSerializers.length + + if (value.productArity != len) { + throw new RuntimeException("Row arity of value does not match serializers.") + } + + // write a null mask + writeNullMask(len, value, target) + + // serialize non-null fields + var i = 0 + while (i < len) { + val o = value.productElement(i).asInstanceOf[AnyRef] + if (o != null) { + val serializer = fieldSerializers(i) + serializer.serialize(value.productElement(i), target) + } + i += 1 + } + } + + override def deserialize(reuse: Row, source: DataInputView): Row = { + val len = fieldSerializers.length + + if (reuse.productArity != len) { + throw new RuntimeException("Row arity of reuse does not match serializers.") + } + + // read null mask + readIntoNullMask(len, source, nullMask) + + // read non-null fields + var i = 0 + while (i < len) { + if (nullMask(i)) { + reuse.setField(i, null) + } + else { + val reuseField = reuse.productElement(i).asInstanceOf[AnyRef] + if (reuseField != null) { + reuse.setField(i, fieldSerializers(i).deserialize(reuseField, source)) + } + else { + reuse.setField(i, fieldSerializers(i).deserialize(source)) + } + } + i += 1 + } + reuse + } + + override def deserialize(source: DataInputView): Row = { + val len = fieldSerializers.length + + val result = new Row(len) + + // read null mask + readIntoNullMask(len, source, nullMask) + + // read non-null fields + var i = 0 + while (i < len) { + if (nullMask(i)) { + result.setField(i, null) + } + else { + result.setField(i, fieldSerializers(i).deserialize(source)) + } + i += 1 + } + result + } + + override def copy(source: DataInputView, target: DataOutputView): Unit = { + val len = fieldSerializers.length + + // copy null mask + readIntoAndCopyNullMask(len, source, target, nullMask) + + // read non-null fields + var i = 0 + while (i < len) { + if (!nullMask(i)) { + fieldSerializers(i).copy(source, target) + } + i += 1 + } + } + + override def equals(any: Any): Boolean = { + any match { + case otherRS: RowSerializer => + otherRS.canEqual(this) && + fieldSerializers.sameElements(otherRS.fieldSerializers) + case _ => false + } + } + + override def canEqual(obj: AnyRef): Boolean = { + obj.isInstanceOf[RowSerializer] + } + + override def hashCode(): Int = { + java.util.Arrays.hashCode(fieldSerializers.asInstanceOf[Array[AnyRef]]) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowTypeInfo.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowTypeInfo.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowTypeInfo.scala new file mode 100644 index 0000000..3babb77 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowTypeInfo.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.typeutils + +import org.apache.flink.api.common.ExecutionConfig +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeutils.CompositeType.TypeComparatorBuilder +import org.apache.flink.api.common.typeutils.TypeComparator +import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo + +import scala.collection.mutable.ArrayBuffer +import org.apache.flink.api.common.typeutils.TypeSerializer +import org.apache.flink.api.table.Row + +/** + * TypeInformation for [[Row]]. + */ +class RowTypeInfo(fieldTypes: Seq[TypeInformation[_]], fieldNames: Seq[String]) + extends CaseClassTypeInfo[Row]( + classOf[Row], + Array(), + fieldTypes, + fieldNames) +{ + + if (fieldTypes.length != fieldNames.length) { + throw new IllegalArgumentException("Number of field types and names is different.") + } + if (fieldNames.length != fieldNames.toSet.size) { + throw new IllegalArgumentException("Field names are not unique.") + } + + def this(fieldTypes: Seq[TypeInformation[_]]) = { + this(fieldTypes, for (i <- fieldTypes.indices) yield "f" + i) + } + + /** + * Temporary variable for directly passing orders to comparators. + */ + var comparatorOrders: Option[Array[Boolean]] = None + + override def createSerializer(executionConfig: ExecutionConfig): TypeSerializer[Row] = { + val fieldSerializers: Array[TypeSerializer[Any]] = new Array[TypeSerializer[Any]](getArity) + for (i <- 0 until getArity) { + fieldSerializers(i) = this.types(i).createSerializer(executionConfig) + .asInstanceOf[TypeSerializer[Any]] + } + + new RowSerializer(fieldSerializers) + } + + override def createComparator( + logicalKeyFields: Array[Int], + orders: Array[Boolean], + logicalFieldOffset: Int, + config: ExecutionConfig) + : TypeComparator[Row] = { + // store the order information for the builder + comparatorOrders = Some(orders) + val comparator = super.createComparator(logicalKeyFields, orders, logicalFieldOffset, config) + comparatorOrders = None + comparator + } + + override def createTypeComparatorBuilder(): TypeComparatorBuilder[Row] = { + new RowTypeComparatorBuilder(comparatorOrders.getOrElse( + throw new IllegalStateException("Cannot create comparator builder without orders."))) + } + + private class RowTypeComparatorBuilder( + comparatorOrders: Array[Boolean]) + extends TypeComparatorBuilder[Row] { + + val fieldComparators: ArrayBuffer[TypeComparator[_]] = new ArrayBuffer[TypeComparator[_]]() + val logicalKeyFields: ArrayBuffer[Int] = new ArrayBuffer[Int]() + + override def initializeTypeComparatorBuilder(size: Int): Unit = { + fieldComparators.sizeHint(size) + logicalKeyFields.sizeHint(size) + } + + override def addComparatorField(fieldId: Int, comparator: TypeComparator[_]): Unit = { + fieldComparators += comparator + logicalKeyFields += fieldId + } + + override def createTypeComparator(config: ExecutionConfig): TypeComparator[Row] = { + val maxIndex = logicalKeyFields.max + + new RowComparator( + logicalKeyFields.toArray, + fieldComparators.toArray.asInstanceOf[Array[TypeComparator[Any]]], + types.take(maxIndex + 1).map(_.createSerializer(config).asInstanceOf[TypeSerializer[Any]]), + comparatorOrders + ) + } + } +} + http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeConverter.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeConverter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeConverter.scala new file mode 100644 index 0000000..dc3abb7 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeConverter.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.typeutils + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.JoinRelType +import org.apache.calcite.rel.core.JoinRelType._ +import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.calcite.sql.`type`.SqlTypeName._ +import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ +import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation} +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.api.java.operators.join.JoinType +import org.apache.flink.api.java.tuple.Tuple +import org.apache.flink.api.java.typeutils.ValueTypeInfo._ +import org.apache.flink.api.java.typeutils.{PojoTypeInfo, TupleTypeInfo} +import org.apache.flink.api.table.{Row, TableException} + +import scala.collection.JavaConversions._ + +object TypeConverter { + + val DEFAULT_ROW_TYPE = new RowTypeInfo(Seq(), Seq()).asInstanceOf[TypeInformation[Any]] + + def typeInfoToSqlType(typeInfo: TypeInformation[_]): SqlTypeName = typeInfo match { + case BOOLEAN_TYPE_INFO => BOOLEAN + case BOOLEAN_VALUE_TYPE_INFO => BOOLEAN + case BYTE_TYPE_INFO => TINYINT + case BYTE_VALUE_TYPE_INFO => TINYINT + case SHORT_TYPE_INFO => SMALLINT + case SHORT_VALUE_TYPE_INFO => SMALLINT + case INT_TYPE_INFO => INTEGER + case INT_VALUE_TYPE_INFO => INTEGER + case LONG_TYPE_INFO => BIGINT + case LONG_VALUE_TYPE_INFO => BIGINT + case FLOAT_TYPE_INFO => FLOAT + case FLOAT_VALUE_TYPE_INFO => FLOAT + case DOUBLE_TYPE_INFO => DOUBLE + case DOUBLE_VALUE_TYPE_INFO => DOUBLE + case STRING_TYPE_INFO => VARCHAR + case STRING_VALUE_TYPE_INFO => VARCHAR + case DATE_TYPE_INFO => DATE + + case CHAR_TYPE_INFO | CHAR_VALUE_TYPE_INFO => + throw new TableException("Character type is not supported.") + + case t@_ => + throw new TableException(s"Type is not supported: $t") + } + + def sqlTypeToTypeInfo(sqlType: SqlTypeName): TypeInformation[_] = sqlType match { + case BOOLEAN => BOOLEAN_TYPE_INFO + case TINYINT => BYTE_TYPE_INFO + case SMALLINT => SHORT_TYPE_INFO + case INTEGER => INT_TYPE_INFO + case BIGINT => LONG_TYPE_INFO + case FLOAT => FLOAT_TYPE_INFO + case DOUBLE => DOUBLE_TYPE_INFO + case VARCHAR | CHAR => STRING_TYPE_INFO + case DATE => DATE_TYPE_INFO + + // symbol for special flags e.g. TRIM's BOTH, LEADING, TRAILING + // are represented as integer + case SYMBOL => INT_TYPE_INFO + + case _ => + throw new TableException("Type " + sqlType.toString + "is not supported") + } + + /** + * Determines the return type of Flink operators based on the logical fields, the expected + * physical type and configuration parameters. + * + * For example: + * - No physical type expected, only 3 non-null fields and efficient type usage enabled + * -> return Tuple3 + * - No physical type expected, efficient type usage enabled, but 3 nullable fields + * -> return Row because Tuple does not support null values + * - Physical type expected + * -> check if physical type is compatible and return it + * + * @param logicalRowType logical row information + * @param expectedPhysicalType expected physical type + * @param nullable fields can be nullable + * @param useEfficientTypes use the most efficient types (e.g. Tuples and value types) + * @return suitable return type + */ + def determineReturnType( + logicalRowType: RelDataType, + expectedPhysicalType: Option[TypeInformation[Any]], + nullable: Boolean, + useEfficientTypes: Boolean) + : TypeInformation[Any] = { + // convert to type information + val logicalFieldTypes = logicalRowType.getFieldList map { relDataType => + TypeConverter.sqlTypeToTypeInfo(relDataType.getType.getSqlTypeName) + } + // field names + val logicalFieldNames = logicalRowType.getFieldNames + + val returnType = expectedPhysicalType match { + // a certain physical type is expected (but not Row) + // check if expected physical type is compatible with logical field type + case Some(typeInfo) if typeInfo.getTypeClass != classOf[Row] => + if (typeInfo.getArity != logicalFieldTypes.length) { + throw new TableException("Arity of result does not match expected type.") + } + typeInfo match { + + // POJO type expected + case pt: PojoTypeInfo[_] => + logicalFieldNames.zip(logicalFieldTypes) foreach { + case (fName, fType) => + val pojoIdx = pt.getFieldIndex(fName) + if (pojoIdx < 0) { + throw new TableException(s"POJO does not define field name: $fName") + } + val expectedTypeInfo = pt.getTypeAt(pojoIdx) + if (fType != expectedTypeInfo) { + throw new TableException(s"Result field does not match expected type. " + + s"Expected: $expectedTypeInfo; Actual: $fType") + } + } + + // Tuple/Case class type expected + case ct: CompositeType[_] => + logicalFieldTypes.zipWithIndex foreach { + case (fieldTypeInfo, i) => + val expectedTypeInfo = ct.getTypeAt(i) + if (fieldTypeInfo != expectedTypeInfo) { + throw new TableException(s"Result field does not match expected type. " + + s"Expected: $expectedTypeInfo; Actual: $fieldTypeInfo") + } + } + + // Atomic type expected + case at: AtomicType[_] => + val fieldTypeInfo = logicalFieldTypes.head + if (fieldTypeInfo != at) { + throw new TableException(s"Result field does not match expected type. " + + s"Expected: $at; Actual: $fieldTypeInfo") + } + + case _ => + throw new TableException("Unsupported result type.") + } + typeInfo + + // Row is expected, create the arity for it + case Some(typeInfo) if typeInfo.getTypeClass == classOf[Row] => + new RowTypeInfo(logicalFieldTypes, logicalFieldNames) + + // no physical type + // determine type based on logical fields and configuration parameters + case None => + // no need for efficient types -> use Row + // we cannot use efficient types if row arity > tuple arity or nullable + if (!useEfficientTypes || logicalFieldTypes.length > Tuple.MAX_ARITY || nullable) { + new RowTypeInfo(logicalFieldTypes, logicalFieldNames) + } + // use efficient type tuple or atomic type + else { + if (logicalFieldTypes.length == 1) { + logicalFieldTypes.head + } + else { + new TupleTypeInfo[Tuple](logicalFieldTypes.toArray:_*) + } + } + } + returnType.asInstanceOf[TypeInformation[Any]] + } + + def sqlJoinTypeToFlinkJoinType(sqlJoinType: JoinRelType): JoinType = sqlJoinType match { + case INNER => JoinType.INNER + case LEFT => JoinType.LEFT_OUTER + case RIGHT => JoinType.RIGHT_OUTER + case FULL => JoinType.FULL_OUTER + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java index e26bc32..2ab38e5 100644 --- a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java +++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java @@ -36,7 +36,7 @@ package org.apache.flink.api.java.table.test; */ import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.table.ExpressionException; +import org.apache.flink.api.table.ExpressionParserException; import org.apache.flink.api.table.Row; import org.apache.flink.api.table.Table; import org.apache.flink.api.java.ExecutionEnvironment; @@ -177,7 +177,7 @@ public class AggregationsITCase extends MultipleProgramsTestBase { compareResultAsText(results, expected); } - @Test(expected = ExpressionException.class) + @Test(expected = ExpressionParserException.class) public void testNoNestedAggregation() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); TableEnvironment tableEnv = new TableEnvironment(); http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala index ad9a66d..0ac662a 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala @@ -19,7 +19,7 @@ package org.apache.flink.api.scala.table.test import org.apache.flink.api.table.plan.PlanGenException -import org.apache.flink.api.table.{ExpressionException, Row} +import org.apache.flink.api.table.Row import org.apache.flink.api.scala._ import org.apache.flink.api.scala.table._ import org.apache.flink.api.scala.util.CollectionDataSets http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/UnionITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/UnionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/UnionITCase.scala index 3708107..e3d1b18 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/UnionITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/UnionITCase.scala @@ -21,7 +21,7 @@ package org.apache.flink.api.scala.table.test import org.apache.flink.api.scala._ import org.apache.flink.api.scala.table._ import org.apache.flink.api.scala.util.CollectionDataSets -import org.apache.flink.api.table.{ExpressionException, Row} +import org.apache.flink.api.table.Row import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode import org.apache.flink.test.util.TestBaseUtils import org.junit._ http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/test/ScalarFunctionsTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/test/ScalarFunctionsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/test/ScalarFunctionsTest.scala index 6cbae1e..8f242e9 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/test/ScalarFunctionsTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/test/ScalarFunctionsTest.scala @@ -22,10 +22,9 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.scala.table._ import org.apache.flink.api.table.Row -import org.apache.flink.api.table.expressions.Expression -import org.apache.flink.api.table.parser.ExpressionParser +import org.apache.flink.api.table.expressions.{ExpressionParser, Expression} import org.apache.flink.api.table.test.utils.ExpressionEvaluator -import org.apache.flink.api.table.typeinfo.RowTypeInfo +import org.apache.flink.api.table.typeutils.RowTypeInfo import org.junit.Assert.assertEquals import org.junit.Test http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowComparatorTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowComparatorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowComparatorTest.scala deleted file mode 100644 index 11d1a8a..0000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowComparatorTest.scala +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.api.table.typeinfo - -import org.apache.flink.api.common.ExecutionConfig -import org.apache.flink.api.common.typeinfo.BasicTypeInfo -import org.apache.flink.api.common.typeutils.{ComparatorTestBase, TypeComparator, TypeSerializer} -import org.apache.flink.api.java.tuple -import org.apache.flink.api.java.typeutils.{TupleTypeInfo, TypeExtractor} -import org.apache.flink.api.table.Row -import org.apache.flink.api.table.typeinfo.RowComparatorTest.MyPojo -import org.junit.Assert._ - -class RowComparatorTest extends ComparatorTestBase[Row] { - - val typeInfo = new RowTypeInfo( - Array( - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.DOUBLE_TYPE_INFO, - BasicTypeInfo.STRING_TYPE_INFO, - new TupleTypeInfo[tuple.Tuple2[Int, Boolean]]( - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.BOOLEAN_TYPE_INFO, - BasicTypeInfo.SHORT_TYPE_INFO), - TypeExtractor.createTypeInfo(classOf[MyPojo]))) - - val testPojo1 = new MyPojo() - // TODO we cannot test null here as PojoComparator has no support for null keys - testPojo1.name = "" - val testPojo2 = new MyPojo() - testPojo2.name = "Test1" - val testPojo3 = new MyPojo() - testPojo3.name = "Test2" - - val data: Array[Row] = Array( - createRow(null, null, null, null, null), - createRow(0, null, null, null, null), - createRow(0, 0.0, null, null, null), - createRow(0, 0.0, "a", null, null), - createRow(1, 0.0, "a", null, null), - createRow(1, 1.0, "a", null, null), - createRow(1, 1.0, "b", null, null), - createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](1, false, 2), null), - createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, false, 2), null), - createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 2), null), - createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), null), - createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), testPojo1), - createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), testPojo2), - createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), testPojo3) - ) - - override protected def deepEquals(message: String, should: Row, is: Row): Unit = { - val arity = should.productArity - assertEquals(message, arity, is.productArity) - var index = 0 - while (index < arity) { - val copiedValue: Any = should.productElement(index) - val element: Any = is.productElement(index) - assertEquals(message, element, copiedValue) - index += 1 - } - } - - override protected def createComparator(ascending: Boolean): TypeComparator[Row] = { - typeInfo.createComparator( - Array(0, 1, 2, 3, 4, 5, 6), - Array(ascending, ascending, ascending, ascending, ascending, ascending, ascending), - 0, - new ExecutionConfig()) - } - - override protected def createSerializer(): TypeSerializer[Row] = { - typeInfo.createSerializer(new ExecutionConfig()) - } - - override protected def getSortedTestData: Array[Row] = { - data - } - - override protected def supportsNullKeys: Boolean = true - - def createRow(f0: Any, f1: Any, f2: Any, f3: Any, f4: Any): Row = { - val r: Row = new Row(5) - r.setField(0, f0) - r.setField(1, f1) - r.setField(2, f2) - r.setField(3, f3) - r.setField(4, f4) - r - } -} - -object RowComparatorTest { - class MyPojo() extends Serializable with Comparable[MyPojo] { - // we cannot use null because the PojoComparator does not support null properly - var name: String = "" - - override def compareTo(o: MyPojo): Int = { - if (name == null && o.name == null) { - 0 - } - else if (name == null) { - -1 - } - else if (o.name == null) { - 1 - } - else { - name.compareTo(o.name) - } - } - - override def equals(other: Any): Boolean = other match { - case that: MyPojo => compareTo(that) == 0 - case _ => false - } - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/63c6dad4/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowSerializerTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowSerializerTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowSerializerTest.scala deleted file mode 100644 index fc000fd..0000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowSerializerTest.scala +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.api.table.typeinfo - -import org.apache.flink.api.common.ExecutionConfig -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} -import org.apache.flink.api.common.typeutils.{SerializerTestInstance, TypeSerializer} -import org.apache.flink.api.java.tuple -import org.apache.flink.api.java.typeutils.{TypeExtractor, TupleTypeInfo} -import org.apache.flink.api.table.Row -import org.apache.flink.api.table.typeinfo.RowSerializerTest.MyPojo -import org.junit.Assert._ -import org.junit.Test - -class RowSerializerTest { - - class RowSerializerTestInstance( - serializer: TypeSerializer[Row], - testData: Array[Row]) - extends SerializerTestInstance[Row](serializer, classOf[Row], -1, testData: _*) { - - override protected def deepEquals(message: String, should: Row, is: Row): Unit = { - val arity = should.productArity - assertEquals(message, arity, is.productArity) - var index = 0 - while (index < arity) { - val copiedValue: Any = should.productElement(index) - val element: Any = is.productElement(index) - assertEquals(message, element, copiedValue) - index += 1 - } - } - } - - @Test - def testRowSerializer(): Unit = { - val rowInfo: TypeInformation[Row] = new RowTypeInfo( - Seq(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO)) - - val row1 = new Row(2) - row1.setField(0, 1) - row1.setField(1, "a") - - val row2 = new Row(2) - row2.setField(0, 2) - row2.setField(1, null) - - val testData: Array[Row] = Array(row1, row2) - - val rowSerializer: TypeSerializer[Row] = rowInfo.createSerializer(new ExecutionConfig) - - val testInstance = new RowSerializerTestInstance(rowSerializer, testData) - - testInstance.testAll() - } - - @Test - def testLargeRowSerializer(): Unit = { - val rowInfo: TypeInformation[Row] = new RowTypeInfo(Seq( - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.STRING_TYPE_INFO)) - - val row = new Row(13) - row.setField(0, 2) - row.setField(1, null) - row.setField(3, null) - row.setField(4, null) - row.setField(5, null) - row.setField(6, null) - row.setField(7, null) - row.setField(8, null) - row.setField(9, null) - row.setField(10, null) - row.setField(11, null) - row.setField(12, "Test") - - val testData: Array[Row] = Array(row) - - val rowSerializer: TypeSerializer[Row] = rowInfo.createSerializer(new ExecutionConfig) - - val testInstance = new RowSerializerTestInstance(rowSerializer, testData) - - testInstance.testAll() - } - - @Test - def testRowSerializerWithComplexTypes(): Unit = { - val rowInfo = new RowTypeInfo( - Array( - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.DOUBLE_TYPE_INFO, - BasicTypeInfo.STRING_TYPE_INFO, - new TupleTypeInfo[tuple.Tuple2[Int, Boolean]]( - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.BOOLEAN_TYPE_INFO, - BasicTypeInfo.SHORT_TYPE_INFO), - TypeExtractor.createTypeInfo(classOf[MyPojo]))) - - val testPojo1 = new MyPojo() - testPojo1.name = null - val testPojo2 = new MyPojo() - testPojo2.name = "Test1" - val testPojo3 = new MyPojo() - testPojo3.name = "Test2" - - val testData: Array[Row] = Array( - createRow(null, null, null, null, null), - createRow(0, null, null, null, null), - createRow(0, 0.0, null, null, null), - createRow(0, 0.0, "a", null, null), - createRow(1, 0.0, "a", null, null), - createRow(1, 1.0, "a", null, null), - createRow(1, 1.0, "b", null, null), - createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](1, false, 2), null), - createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, false, 2), null), - createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 2), null), - createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), null), - createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), testPojo1), - createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), testPojo2), - createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), testPojo3) - ) - - val rowSerializer: TypeSerializer[Row] = rowInfo.createSerializer(new ExecutionConfig) - - val testInstance = new RowSerializerTestInstance(rowSerializer, testData) - - testInstance.testAll() - } - - // ---------------------------------------------------------------------------------------------- - - def createRow(f0: Any, f1: Any, f2: Any, f3: Any, f4: Any): Row = { - val r: Row = new Row(5) - r.setField(0, f0) - r.setField(1, f1) - r.setField(2, f2) - r.setField(3, f3) - r.setField(4, f4) - r - } -} - -object RowSerializerTest { - class MyPojo() extends Serializable with Comparable[MyPojo] { - var name: String = null - - override def compareTo(o: MyPojo): Int = { - if (name == null && o.name == null) { - 0 - } - else if (name == null) { - -1 - } - else if (o.name == null) { - 1 - } - else { - name.compareTo(o.name) - } - } - - override def equals(other: Any): Boolean = other match { - case that: MyPojo => compareTo(that) == 0 - case _ => false - } - } -}