Return-Path: X-Original-To: apmail-flink-commits-archive@minotaur.apache.org Delivered-To: apmail-flink-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id B237917B10 for ; Wed, 8 Oct 2014 09:41:22 +0000 (UTC) Received: (qmail 24670 invoked by uid 500); 8 Oct 2014 09:41:22 -0000 Delivered-To: apmail-flink-commits-archive@flink.apache.org Received: (qmail 24646 invoked by uid 500); 8 Oct 2014 09:41:22 -0000 Mailing-List: contact commits-help@flink.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@flink.incubator.apache.org Delivered-To: mailing list commits@flink.incubator.apache.org Received: (qmail 24637 invoked by uid 99); 8 Oct 2014 09:41:22 -0000 Received: from nike.apache.org (HELO nike.apache.org) (192.87.106.230) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 08 Oct 2014 09:41:22 +0000 X-ASF-Spam-Status: No, hits=-2000.0 required=5.0 tests=ALL_TRUSTED,T_RP_MATCHES_RCVD X-Spam-Check-By: apache.org Received: from [140.211.11.3] (HELO mail.apache.org) (140.211.11.3) by apache.org (qpsmtpd/0.29) with SMTP; Wed, 08 Oct 2014 09:40:26 +0000 Received: (qmail 18144 invoked by uid 99); 8 Oct 2014 09:40:22 -0000 Received: from tyr.zones.apache.org (HELO tyr.zones.apache.org) (140.211.11.114) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 08 Oct 2014 09:40:22 +0000 Received: by tyr.zones.apache.org (Postfix, from userid 65534) id 61817906FEC; Wed, 8 Oct 2014 09:40:22 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: rmetzger@apache.org To: commits@flink.incubator.apache.org Date: Wed, 08 Oct 2014 09:40:22 -0000 Message-Id: <614124ae64c14a8a844db0eff5d5f30f@git.apache.org> In-Reply-To: <42e3404d984c4f18b03bbfad4302d33b@git.apache.org> References: <42e3404d984c4f18b03bbfad4302d33b@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: [02/11] git commit: Really add POJO support and nested keys for Scala API X-Virus-Checked: Checked by ClamAV on apache.org Really add POJO support and nested keys for Scala API This also adds more integration tests, but not all tests of the Java API have been ported to Scala yet. Project: http://git-wip-us.apache.org/repos/asf/incubator-flink/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-flink/commit/6be85554 Tree: http://git-wip-us.apache.org/repos/asf/incubator-flink/tree/6be85554 Diff: http://git-wip-us.apache.org/repos/asf/incubator-flink/diff/6be85554 Branch: refs/heads/master Commit: 6be85554332fb576abc16e39d866ed05a7c38b72 Parents: 598ae37 Author: Aljoscha Krettek Authored: Tue Oct 7 12:00:15 2014 +0200 Committer: Robert Metzger Committed: Wed Oct 8 11:39:01 2014 +0200 ---------------------------------------------------------------------- .../scala/graph/ConnectedComponents.scala | 2 +- .../apache/flink/api/java/operators/Keys.java | 8 +- .../type/extractor/PojoTypeExtractionTest.java | 4 +- .../org/apache/flink/api/scala/DataSet.scala | 15 +- .../apache/flink/api/scala/GroupedDataSet.scala | 42 +- .../flink/api/scala/codegen/TypeAnalyzer.scala | 167 ++---- .../api/scala/codegen/TypeDescriptors.scala | 45 +- .../api/scala/codegen/TypeInformationGen.scala | 75 ++- .../api/scala/typeutils/CaseClassTypeInfo.scala | 84 ++- .../api/scala/unfinishedKeyPairOperation.scala | 16 +- .../api/scala/operators/AggregateITCase.scala | 39 +- .../scala/operators/CoGroupOperatorTest.scala | 34 +- .../flink/api/scala/operators/CustomType.scala | 33 -- .../api/scala/operators/DistinctITCase.scala | 191 ++++++ .../scala/operators/DistinctOperatorTest.scala | 20 +- .../api/scala/operators/ExamplesITCase.scala | 79 ++- .../api/scala/operators/GroupReduceITCase.scala | 574 +++++++++++++++++++ .../api/scala/operators/GroupingTest.scala | 31 +- .../api/scala/operators/JoinOperatorTest.scala | 33 +- .../api/scala/operators/PartitionITCase.scala | 84 +-- .../scala/types/TypeInformationGenTest.scala | 10 +- .../api/scala/util/CollectionDataSets.scala | 394 +++++++++++++ 22 files changed, 1587 insertions(+), 393 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/ConnectedComponents.scala ---------------------------------------------------------------------- diff --git a/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/ConnectedComponents.scala b/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/ConnectedComponents.scala index 4462e45..d261173 100644 --- a/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/ConnectedComponents.scala +++ b/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/ConnectedComponents.scala @@ -76,7 +76,7 @@ object ConnectedComponents { val edges = getEdgesDataSet(env).flatMap { edge => Seq(edge, (edge._2, edge._1)) } // open a delta iteration - val verticesWithComponents = vertices.iterateDelta(vertices, maxIterations, Array(0)) { + val verticesWithComponents = vertices.iterateDelta(vertices, maxIterations, Array("_1")) { (s, ws) => // apply the step logic: join with the edges http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java index 482370e..40ce238 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.LinkedList; import java.util.List; +import com.google.common.base.Joiner; import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -306,7 +307,12 @@ public abstract class Keys { } return Ints.toArray(logicalKeys); } - + + @Override + public String toString() { + Joiner join = Joiner.on('.'); + return "ExpressionKeys: " + join.join(keyFields); + } } private static String[] removeDuplicates(String[] in) { http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeExtractionTest.java ---------------------------------------------------------------------- diff --git a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeExtractionTest.java b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeExtractionTest.java index 01d32c1..83c81f7 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeExtractionTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeExtractionTest.java @@ -345,7 +345,9 @@ public class PojoTypeExtractionTest { Assert.assertEquals(typeInfo.getTypeClass(), WC.class); Assert.assertEquals(typeInfo.getArity(), 2); } - + + // Kryo is required for this, so disable for now. + @Ignore @Test public void testPojoAllPublic() { TypeInformation typeForClass = TypeExtractor.createTypeInfo(AllPublic.class); http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala index 895b964..7a2c699 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala @@ -550,14 +550,11 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { /** * Creates a new DataSet containing the distinct elements of this DataSet. The decision whether * two elements are distinct or not is made based on only the specified fields. - * - * This only works on CaseClass DataSets */ def distinct(firstField: String, otherFields: String*): DataSet[T] = { - val fieldIndices = fieldNames2Indices(javaSet.getType, firstField +: otherFields.toArray) wrap(new DistinctOperator[T]( javaSet, - new Keys.ExpressionKeys[T](fieldIndices, javaSet.getType, true))) + new Keys.ExpressionKeys[T](firstField +: otherFields.toArray, javaSet.getType))) } /** @@ -615,8 +612,6 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { * This only works on CaseClass DataSets. */ def groupBy(firstField: String, otherFields: String*): GroupedDataSet[T] = { - // val fieldIndices = fieldNames2Indices(javaSet.getType, firstField +: otherFields.toArray) - new GroupedDataSet[T]( this, new Keys.ExpressionKeys[T](firstField +: otherFields.toArray, javaSet.getType)) @@ -862,10 +857,8 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { */ def iterateDelta[R: ClassTag](workset: DataSet[R], maxIterations: Int, keyFields: Array[String])( stepFunction: (DataSet[T], DataSet[R]) => (DataSet[T], DataSet[R])) = { - val fieldIndices = fieldNames2Indices(javaSet.getType, keyFields) - - val key = new ExpressionKeys[T](fieldIndices, javaSet.getType, false) + val key = new ExpressionKeys[T](keyFields, javaSet.getType) val iterativeSet = new DeltaIteration[T, R]( javaSet.getExecutionEnvironment, javaSet.getType, @@ -931,12 +924,10 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { * significant amount of time. */ def partitionByHash(firstField: String, otherFields: String*): DataSet[T] = { - val fieldIndices = fieldNames2Indices(javaSet.getType, firstField +: otherFields.toArray) - val op = new PartitionOperator[T]( javaSet, PartitionMethod.HASH, - new Keys.ExpressionKeys[T](fieldIndices, javaSet.getType, false)) + new Keys.ExpressionKeys[T](firstField +: otherFields.toArray, javaSet.getType)) wrap(op) } http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala index d715939..e7d8978 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala @@ -47,7 +47,7 @@ class GroupedDataSet[T: ClassTag]( // These are for optional secondary sort. They are only used // when using a group-at-a-time reduce function. - private val groupSortKeyPositions = mutable.MutableList[Int]() + private val groupSortKeyPositions = mutable.MutableList[Either[Int, String]]() private val groupSortOrders = mutable.MutableList[Order]() /** @@ -64,7 +64,7 @@ class GroupedDataSet[T: ClassTag]( if (field >= set.getType.getArity) { throw new IllegalArgumentException("Order key out of tuple bounds.") } - groupSortKeyPositions += field + groupSortKeyPositions += Left(field) groupSortOrders += order this } @@ -76,9 +76,7 @@ class GroupedDataSet[T: ClassTag]( * This only works on CaseClass DataSets. */ def sortGroup(field: String, order: Order): GroupedDataSet[T] = { - val fieldIndex = fieldNames2Indices(set.getType, Array(field))(0) - - groupSortKeyPositions += fieldIndex + groupSortKeyPositions += Right(field) groupSortOrders += order this } @@ -88,14 +86,32 @@ class GroupedDataSet[T: ClassTag]( */ private def maybeCreateSortedGrouping(): Grouping[T] = { if (groupSortKeyPositions.length > 0) { - val grouping = new SortedGrouping[T]( - set.javaSet, - keys, - groupSortKeyPositions(0), - groupSortOrders(0)) + val grouping = groupSortKeyPositions(0) match { + case Left(pos) => + new SortedGrouping[T]( + set.javaSet, + keys, + pos, + groupSortOrders(0)) + + case Right(field) => + new SortedGrouping[T]( + set.javaSet, + keys, + field, + groupSortOrders(0)) + + } // now manually add the rest of the keys for (i <- 1 until groupSortKeyPositions.length) { - grouping.sortGroup(groupSortKeyPositions(i), groupSortOrders(i)) + groupSortKeyPositions(i) match { + case Left(pos) => + grouping.sortGroup(pos, groupSortOrders(i)) + + case Right(field) => + grouping.sortGroup(field, groupSortOrders(i)) + + } } grouping } else { @@ -209,7 +225,7 @@ class GroupedDataSet[T: ClassTag]( } } wrap( - new GroupReduceOperator[T, R](createUnsortedGrouping(), + new GroupReduceOperator[T, R](maybeCreateSortedGrouping(), implicitly[TypeInformation[R]], reducer)) } @@ -227,7 +243,7 @@ class GroupedDataSet[T: ClassTag]( } } wrap( - new GroupReduceOperator[T, R](createUnsortedGrouping(), + new GroupReduceOperator[T, R](maybeCreateSortedGrouping(), implicitly[TypeInformation[R]], reducer)) } http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeAnalyzer.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeAnalyzer.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeAnalyzer.scala index f0ba195..3a4deca 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeAnalyzer.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeAnalyzer.scala @@ -17,7 +17,6 @@ */ package org.apache.flink.api.scala.codegen -import scala.Option.option2Iterable import scala.collection.GenTraversableOnce import scala.collection.mutable import scala.reflect.macros.Context @@ -59,12 +58,17 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] case PrimitiveType(default, wrapper) => PrimitiveDescriptor(id, tpe, default, wrapper) case BoxedPrimitiveType(default, wrapper, box, unbox) => BoxedPrimitiveDescriptor(id, tpe, default, wrapper, box, unbox) - case ListType(elemTpe, iter) => analyzeList(id, tpe, elemTpe, iter) + case ListType(elemTpe, iter) => + analyzeList(id, tpe, elemTpe, iter) case CaseClassType() => analyzeCaseClass(id, tpe) - case BaseClassType() => analyzeClassHierarchy(id, tpe) case ValueType() => ValueDescriptor(id, tpe) case WritableType() => WritableDescriptor(id, tpe) - case _ => GenericClassDescriptor(id, tpe) + case JavaType() => + // It's a Java Class, let the TypeExtractor deal with it... + c.warning(c.enclosingPosition, s"Type $tpe is a java class. Will be analyzed by " + + s"TypeExtractor at runtime.") + GenericClassDescriptor(id, tpe) + case _ => analyzePojo(id, tpe) } } } @@ -78,110 +82,63 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] case desc => ListDescriptor(id, tpe, iter, desc) } - private def analyzeClassHierarchy(id: Int, tpe: Type): UDTDescriptor = { - - val tagField = { - val (intTpe, intDefault, intWrapper) = PrimitiveType.intPrimitive - FieldAccessor( - NoSymbol, - NoSymbol, - NullaryMethodType(intTpe), - isBaseField = true, - PrimitiveDescriptor(cache.newId, intTpe, intDefault, intWrapper)) + private def analyzePojo(id: Int, tpe: Type): UDTDescriptor = { + val immutableFields = tpe.members filter { _.isTerm } map { _.asTerm } filter { _.isVal } + if (immutableFields.nonEmpty) { + // We don't support POJOs with immutable fields + c.warning( + c.enclosingPosition, + s"Type $tpe is no POJO, has immutable fields: ${immutableFields.mkString(", ")}.") + return GenericClassDescriptor(id, tpe) } - - val subTypes = tpe.typeSymbol.asClass.knownDirectSubclasses.toList flatMap { d => - - val dTpe = - { - val tArgs = (tpe.typeSymbol.asClass.typeParams, typeArgs(tpe)).zipped.toMap - val dArgs = d.asClass.typeParams map { dp => - val tArg = tArgs.keySet.find { tp => - dp == tp.typeSignature.asSeenFrom(d.typeSignature, tpe.typeSymbol).typeSymbol - } - tArg map { tArgs(_) } getOrElse dp.typeSignature - } - appliedType(d.asType.toType, dArgs) - } + val fields = tpe.members + .filter { _.isTerm } + .map { _.asTerm } + .filter { _.isVar } + .filterNot { _.annotations.exists( _.tpe <:< typeOf[scala.transient]) } - if (dTpe <:< tpe) { - Some(analyze(dTpe)) - } else { - None - } + if (fields.isEmpty) { + c.warning(c.enclosingPosition, "Type $tpe has no fields that are visible from Scala Type" + + " analysis. Falling back to Java Type Analysis (TypeExtractor).") + return GenericClassDescriptor(id, tpe) } - val errors = subTypes flatMap { _.findByType[UnsupportedDescriptor] } - - errors match { - case _ :: _ => - val errorMessage = errors flatMap { - case UnsupportedDescriptor(_, subType, errs) => - errs map { err => "Subtype " + subType + " - " + err } - } - UnsupportedDescriptor(id, tpe, errorMessage) - - case Nil if subTypes.isEmpty => - UnsupportedDescriptor(id, tpe, Seq("No instantiable subtypes found for base class")) - case Nil => - val (tParams, _) = tpe.typeSymbol.asClass.typeParams.zip(typeArgs(tpe)).unzip - val baseMembers = - tpe.members filter { f => f.isMethod } filter { f => f.asMethod.isSetter } map { - f => (f, f.asMethod.setter, f.asMethod.returnType) - } - - val subMembers = subTypes map { - case BaseClassDescriptor(_, _, getters, _) => getters - case CaseClassDescriptor(_, _, _, _, getters) => getters - case _ => Seq() - } - - val baseFields = baseMembers flatMap { - case (bGetter, bSetter, bTpe) => - val accessors = subMembers map { - _ find { sf => - sf.getter.name == bGetter.name && - sf.tpe.termSymbol.asMethod.returnType <:< bTpe.termSymbol.asMethod.returnType - } - } - accessors.forall { _.isDefined } match { - case true => - Some( - FieldAccessor( - bGetter, - bSetter, - bTpe, - isBaseField = true, - analyze(bTpe.termSymbol.asMethod.returnType))) - case false => None - } - } + // check whether all fields are either: 1. public, 2. have getter/setter + val invalidFields = fields filterNot { + f => + f.isPublic || + (f.getter != NoSymbol && f.getter.isPublic && f.setter != NoSymbol && f.setter.isPublic) + } - def wireBaseFields(desc: UDTDescriptor): UDTDescriptor = { + if (invalidFields.nonEmpty) { + c.warning(c.enclosingPosition, s"Type $tpe is no POJO because it has non-public fields '" + + s"${invalidFields.mkString(", ")}' that don't have public getters/setters.") + return GenericClassDescriptor(id, tpe) + } - def updateField(field: FieldAccessor) = { - baseFields find { bf => bf.getter.name == field.getter.name } match { - case Some(FieldAccessor(_, _, _, _, fieldDesc)) => - field.copy(isBaseField = true, desc = fieldDesc) - case None => field - } - } + // check whether we have a zero-parameter ctor + val hasZeroCtor = tpe.declarations exists { + case m: MethodSymbol + if m.isConstructor && m.paramss.length == 1 && m.paramss(0).length == 0 => true + case _ => false + } - desc match { - case desc @ BaseClassDescriptor(_, _, getters, baseSubTypes) => - desc.copy( - getters = getters map updateField, - subTypes = baseSubTypes map wireBaseFields) - case desc @ CaseClassDescriptor(_, _, _, _, getters) => - desc.copy(getters = getters map updateField) - case _ => desc - } - } + if (!hasZeroCtor) { + // We don't support POJOs without zero-paramter ctor + c.warning( + c.enclosingPosition, + s"Class $tpe is no POJO, has no zero-parameters constructor.") + return GenericClassDescriptor(id, tpe) + } - BaseClassDescriptor(id, tpe, tagField +: baseFields.toSeq, subTypes map wireBaseFields) + val fieldDescriptors = fields map { + f => + val fieldTpe = f.getter.asMethod.returnType.asSeenFrom(tpe, tpe.typeSymbol) + FieldDescriptor(f.name.toString.trim, f.getter, f.setter, fieldTpe, analyze(fieldTpe)) } + PojoDescriptor(id, tpe, fieldDescriptors.toSeq) } private def analyzeCaseClass(id: Int, tpe: Type): UDTDescriptor = { @@ -216,7 +173,7 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] } val fields = caseFields map { case (fgetter, fsetter, fTpe) => - FieldAccessor(fgetter, fsetter, fTpe, isBaseField = false, analyze(fTpe)) + FieldDescriptor(fgetter.name.toString.trim, fgetter, fsetter, fTpe, analyze(fTpe)) } val mutable = enableMutableUDTs && (fields forall { f => f.setter != NoSymbol }) if (mutable) { @@ -226,8 +183,9 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] case errs @ _ :: _ => val msgs = errs flatMap { f => (f: @unchecked) match { - case FieldAccessor(fgetter, _,_,_, UnsupportedDescriptor(_, fTpe, errors)) => - errors map { err => "Field " + fgetter.name + ": " + fTpe + " - " + err } + case FieldDescriptor( + fName, _, _, _, UnsupportedDescriptor(_, fTpe, errors)) => + errors map { err => "Field " + fName + ": " + fTpe + " - " + err } } } UnsupportedDescriptor(id, tpe, msgs) @@ -296,11 +254,6 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isCaseClass } - private object BaseClassType { - def unapply(tpe: Type): Boolean = - tpe.typeSymbol.asClass.isAbstractClass && tpe.typeSymbol.asClass.isSealed - } - private object ValueType { def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.baseClasses exists { @@ -315,6 +268,10 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C] } } + private object JavaType { + def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isJava + } + private class UDTAnalyzerCache { private val caches = new DynamicVariable[Map[Type, RecursiveDescriptor]](Map()) http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeDescriptors.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeDescriptors.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeDescriptors.scala index 8201a68..66299c7 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeDescriptors.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeDescriptors.scala @@ -36,10 +36,8 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] def canBeKey: Boolean - def mkRoot: UDTDescriptor = this - def flatten: Seq[UDTDescriptor] - def getters: Seq[FieldAccessor] = Seq() + def getters: Seq[FieldDescriptor] = Seq() def select(member: String): Option[UDTDescriptor] = getters find { _.getter.name.toString == member } map { _.desc } @@ -48,7 +46,7 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] case Nil => Seq(Some(this)) case head :: tail => getters find { _.getter.name.toString == head } match { case None => Seq(None) - case Some(d : FieldAccessor) => d.desc.select(tail) + case Some(d : FieldDescriptor) => d.desc.select(tail) } } @@ -60,7 +58,7 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] } def getRecursiveRefs: Seq[UDTDescriptor] = - findByType[RecursiveDescriptor].flatMap { rd => findById(rd.refId) }.map { _.mkRoot }.distinct + findByType[RecursiveDescriptor].flatMap { rd => findById(rd.refId) }.distinct } case class GenericClassDescriptor(id: Int, tpe: Type) extends UDTDescriptor { @@ -116,30 +114,45 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] } } - case class BaseClassDescriptor( - id: Int, tpe: Type, override val getters: Seq[FieldAccessor], subTypes: Seq[UDTDescriptor]) + case class PojoDescriptor(id: Int, tpe: Type, override val getters: Seq[FieldDescriptor]) extends UDTDescriptor { - override def flatten = - this +: ((getters flatMap { _.desc.flatten }) ++ (subTypes flatMap { _.flatten })) + override val isPrimitiveProduct = getters.nonEmpty && getters.forall(_.desc.isPrimitiveProduct) + + override def flatten = this +: (getters flatMap { _.desc.flatten }) + override def canBeKey = flatten forall { f => f.canBeKey } - + + // Hack: ignore the ctorTpe, since two Type instances representing + // the same ctor function type don't appear to be considered equal. + // Equality of the tpe and ctor fields implies equality of ctorTpe anyway. + override def hashCode = (id, tpe, getters).hashCode + override def equals(that: Any) = that match { + case PojoDescriptor(thatId, thatTpe, thatGetters) => + (id, tpe, getters).equals( + thatId, thatTpe, thatGetters) + case _ => false + } + override def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match { case Nil => getters flatMap { g => g.desc.select(Nil) } case head :: tail => getters find { _.getter.name.toString == head } match { case None => Seq(None) - case Some(d : FieldAccessor) => d.desc.select(tail) + case Some(d : FieldDescriptor) => d.desc.select(tail) } } } case class CaseClassDescriptor( - id: Int, tpe: Type, mutable: Boolean, ctor: Symbol, override val getters: Seq[FieldAccessor]) + id: Int, + tpe: Type, + mutable: Boolean, + ctor: Symbol, + override val getters: Seq[FieldDescriptor]) extends UDTDescriptor { override val isPrimitiveProduct = getters.nonEmpty && getters.forall(_.desc.isPrimitiveProduct) - override def mkRoot = this.copy(getters = getters map { _.copy(isBaseField = false) }) override def flatten = this +: (getters flatMap { _.desc.flatten }) override def canBeKey = flatten forall { f => f.canBeKey } @@ -159,16 +172,16 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] case Nil => getters flatMap { g => g.desc.select(Nil) } case head :: tail => getters find { _.getter.name.toString == head } match { case None => Seq(None) - case Some(d : FieldAccessor) => d.desc.select(tail) + case Some(d : FieldDescriptor) => d.desc.select(tail) } } } - case class FieldAccessor( + case class FieldDescriptor( + name: String, getter: Symbol, setter: Symbol, tpe: Type, - isBaseField: Boolean, desc: UDTDescriptor) case class RecursiveDescriptor(id: Int, tpe: Type, refId: Int) extends UDTDescriptor { http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala index f6a89d3..0686668 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala @@ -20,7 +20,6 @@ package org.apache.flink.api.scala.codegen import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo import org.apache.flink.api.common.typeinfo.BasicTypeInfo -import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.common.typeutils.TypeSerializer import org.apache.flink.api.java.typeutils._ @@ -29,6 +28,8 @@ import org.apache.flink.types.Value import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.hadoop.io.Writable +import scala.collection.JavaConverters._ + import scala.reflect.macros.Context private[flink] trait TypeInformationGen[C <: Context] { @@ -41,7 +42,7 @@ private[flink] trait TypeInformationGen[C <: Context] { // This is for external calling by TypeUtils.createTypeInfo def mkTypeInfo[T: c.WeakTypeTag]: c.Expr[TypeInformation[T]] = { - val desc = getUDTDescriptor(weakTypeOf[T]) + val desc = getUDTDescriptor(weakTypeTag[T].tpe) val result: c.Expr[TypeInformation[T]] = mkTypeInfo(desc)(c.WeakTypeTag(desc.tpe)) result } @@ -61,6 +62,7 @@ private[flink] trait TypeInformationGen[C <: Context] { case d : WritableDescriptor => mkWritableTypeInfo(d)(c.WeakTypeTag(d.tpe).asInstanceOf[c.WeakTypeTag[Writable]]) .asInstanceOf[c.Expr[TypeInformation[T]]] + case pojo: PojoDescriptor => mkPojo(pojo) case d => mkGenericTypeInfo(d) } @@ -96,7 +98,7 @@ private[flink] trait TypeInformationGen[C <: Context] { def mkListTypeInfo[T: c.WeakTypeTag](desc: ListDescriptor): c.Expr[TypeInformation[T]] = { val arrayClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe))) val elementClazz = c.Expr[Class[T]](Literal(Constant(desc.elem.tpe))) - val elementTypeInfo = mkTypeInfo(desc.elem) + val elementTypeInfo = mkTypeInfo(desc.elem)(c.WeakTypeTag(desc.elem.tpe)) desc.elem match { // special case for string, which in scala is a primitive, but not in java case p: PrimitiveDescriptor if p.tpe <:< typeOf[String] => @@ -115,7 +117,8 @@ private[flink] trait TypeInformationGen[C <: Context] { reify { ObjectArrayTypeInfo.getInfoFor( arrayClazz.splice, - elementTypeInfo.splice).asInstanceOf[TypeInformation[T]] + elementTypeInfo.splice.asInstanceOf[TypeInformation[_]]) + .asInstanceOf[TypeInformation[T]] } } } @@ -136,6 +139,35 @@ private[flink] trait TypeInformationGen[C <: Context] { } } + def mkPojo[T: c.WeakTypeTag](desc: PojoDescriptor): c.Expr[TypeInformation[T]] = { + val tpeClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe))) + val fieldsTrees = desc.getters map { + f => + val name = c.Expr(Literal(Constant(f.name))) + val fieldType = mkTypeInfo(f.desc)(c.WeakTypeTag(f.tpe)) + reify { (name.splice, fieldType.splice) }.tree + } + + val fieldsList = c.Expr[List[(String, TypeInformation[_])]](mkList(fieldsTrees.toList)) + + reify { + val fields = fieldsList.splice + val clazz: Class[T] = tpeClazz.splice + + val fieldMap = TypeExtractor.getAllDeclaredFields(clazz).asScala map { + f => (f.getName, f) + } toMap + + val pojoFields = fields map { + case (fName, fTpe) => + new PojoField(fieldMap(fName), fTpe) + } + + new PojoTypeInfo(clazz, pojoFields.asJava) + + } + } + def mkGenericTypeInfo[T: c.WeakTypeTag](desc: UDTDescriptor): c.Expr[TypeInformation[T]] = { val tpeClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe))) reify { @@ -158,39 +190,4 @@ private[flink] trait TypeInformationGen[C <: Context] { val result = Apply(Select(New(TypeTree(desc.tpe)), nme.CONSTRUCTOR), fields.toList) c.Expr[T](result) } - -// def mkCaseClassTypeInfo[T: c.WeakTypeTag]( -// desc: CaseClassDescriptor): c.Expr[TypeInformation[T]] = { -// val tpeClazz = c.Expr[Class[_]](Literal(Constant(desc.tpe))) -// val caseFields = mkCaseFields(desc) -// reify { -// new ScalaTupleTypeInfo[T] { -// def createSerializer: TypeSerializer[T] = { -// null -// } -// -// val fields: Map[String, TypeInformation[_]] = caseFields.splice -// val clazz = tpeClazz.splice -// } -// } -// } -// -// private def mkCaseFields(desc: UDTDescriptor): c.Expr[Map[String, TypeInformation[_]]] = { -// val fields = getFields("_root_", desc).toList map { case (fieldName, fieldDesc) => -// val nameTree = c.Expr(Literal(Constant(fieldName))) -// val fieldTypeInfo = mkTypeInfo(fieldDesc)(c.WeakTypeTag(fieldDesc.tpe)) -// reify { (nameTree.splice, fieldTypeInfo.splice) }.tree -// } -// -// c.Expr(mkMap(fields)) -// } -// -// protected def getFields(name: String, desc: UDTDescriptor): Seq[(String, UDTDescriptor)] = -// desc match { -// // Flatten product types -// case CaseClassDescriptor(_, _, _, _, getters) => -// getters filterNot { _.isBaseField } flatMap { -// f => getFields(name + "." + f.getter.name, f.desc) } -// case _ => Seq((name, desc)) -// } } http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala index 3e9d4c6..53d1dea 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala @@ -19,8 +19,10 @@ package org.apache.flink.api.scala.typeutils import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.common.typeinfo.AtomicType +import org.apache.flink.api.common.typeutils.CompositeType.FlatFieldDescriptor +import org.apache.flink.api.java.operators.Keys.ExpressionKeys import org.apache.flink.api.java.typeutils.TupleTypeInfoBase -import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer} +import org.apache.flink.api.common.typeutils.{CompositeType, TypeComparator} /** * TypeInformation for Case Classes. Creation and access is different from @@ -58,16 +60,82 @@ abstract class CaseClassTypeInfo[T <: Product]( override protected def getNewComparator: TypeComparator[T] = { val finalLogicalKeyFields = logicalKeyFields.take(comparatorHelperIndex) val finalComparators = fieldComparators.take(comparatorHelperIndex) - var maxKey: Int = 0 - for (key <- finalLogicalKeyFields) { - maxKey = Math.max(maxKey, key) + val maxKey = finalLogicalKeyFields.max + + // create serializers only up to the last key, fields after that are not needed + val fieldSerializers = types.take(maxKey + 1).map(_.createSerializer) + new CaseClassComparator[T](finalLogicalKeyFields, finalComparators, fieldSerializers.toArray) + } + + override def getKey( + fieldExpression: String, + offset: Int, + result: java.util.List[FlatFieldDescriptor]): Unit = { + + if (fieldExpression == ExpressionKeys.SELECT_ALL_CHAR) { + var keyPosition = 0 + for (tpe <- types) { + tpe match { + case a: AtomicType[_] => + result.add(new CompositeType.FlatFieldDescriptor(offset + keyPosition, tpe)) + + case co: CompositeType[_] => + co.getKey(ExpressionKeys.SELECT_ALL_CHAR, offset + keyPosition, result) + keyPosition += co.getTotalFields - 1 + + case _ => throw new RuntimeException(s"Unexpected key type: $tpe") + + } + keyPosition += 1 + } + return + } + + if (fieldExpression == null || fieldExpression.length <= 0) { + throw new IllegalArgumentException("Field expression must not be empty.") } - val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](maxKey + 1) - for (i <- 0 to maxKey) { - fieldSerializers(i) = types(i).createSerializer + fieldExpression.split('.').toList match { + case headField :: Nil => + var fieldId = 0 + for (i <- 0 until fieldNames.length) { + fieldId += types(i).getTotalFields - 1 + + if (fieldNames(i) == headField) { + if (fieldTypes(i).isInstanceOf[CompositeType[_]]) { + throw new IllegalArgumentException( + s"The specified field '$fieldExpression' is refering to a composite type.\n" + + s"Either select all elements in this type with the " + + s"'${ExpressionKeys.SELECT_ALL_CHAR}' operator or specify a field in" + + s" the sub-type") + } + result.add(new CompositeType.FlatFieldDescriptor(offset + fieldId, fieldTypes(i))) + return + } + + fieldId += 1 + } + case firstField :: rest => + var fieldId = 0 + for (i <- 0 until fieldNames.length) { + + if (fieldNames(i) == firstField) { + fieldTypes(i) match { + case co: CompositeType[_] => + co.getKey(rest.mkString("."), offset + fieldId, result) + return + + case _ => + throw new RuntimeException(s"Field ${fieldTypes(i)} is not a composite type.") + + } + } + + fieldId += types(i).getTotalFields + } } - new CaseClassComparator[T](finalLogicalKeyFields, finalComparators, fieldSerializers) + + throw new RuntimeException(s"Unable to find field $fieldExpression in type $this.") } override def toString = clazz.getSimpleName + "(" + fieldNames.zip(types).map { http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala index 9d9a19f..b2929b9 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala @@ -69,11 +69,9 @@ private[flink] abstract class UnfinishedKeyPairOperation[L, R, O]( * This only works on a CaseClass [[DataSet]]. */ def where(firstLeftField: String, otherLeftFields: String*) = { - val fieldIndices = fieldNames2Indices( - leftInput.getType, - firstLeftField +: otherLeftFields.toArray) - - val leftKey = new ExpressionKeys[L](fieldIndices, leftInput.getType) + val leftKey = new ExpressionKeys[L]( + firstLeftField +: otherLeftFields.toArray, + leftInput.getType) new HalfUnfinishedKeyPairOperation[L, R, O](this, leftKey) } @@ -118,11 +116,9 @@ private[flink] class HalfUnfinishedKeyPairOperation[L, R, O]( * This only works on a CaseClass [[DataSet]]. */ def equalTo(firstRightField: String, otherRightFields: String*): O = { - val fieldIndices = fieldNames2Indices( - unfinished.rightInput.getType, - firstRightField +: otherRightFields.toArray) - - val rightKey = new ExpressionKeys[R](fieldIndices, unfinished.rightInput.getType) + val rightKey = new ExpressionKeys[R]( + firstRightField +: otherRightFields.toArray, + unfinished.rightInput.getType) if (!leftKey.areCompatible(rightKey)) { throw new InvalidProgramException("The types of the key fields do not match. Left: " + leftKey + " Right: " + rightKey) http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/AggregateITCase.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/AggregateITCase.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/AggregateITCase.scala index 631e68a..0e3f2ed 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/AggregateITCase.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/AggregateITCase.scala @@ -19,6 +19,7 @@ package org.apache.flink.api.scala.operators import org.apache.flink.api.java.aggregation.Aggregations import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.api.scala.util.CollectionDataSets import org.apache.flink.configuration.Configuration import org.apache.flink.test.util.JavaProgramTestBase import org.junit.runner.RunWith @@ -34,38 +35,14 @@ import org.apache.flink.api.scala._ object AggregateProgs { var NUM_PROGRAMS: Int = 3 - val tupleInput = Array( - (1,1L,"Hi"), - (2,2L,"Hello"), - (3,2L,"Hello world"), - (4,3L,"Hello world, how are you?"), - (5,3L,"I am fine."), - (6,3L,"Luke Skywalker"), - (7,4L,"Comment#1"), - (8,4L,"Comment#2"), - (9,4L,"Comment#3"), - (10,4L,"Comment#4"), - (11,5L,"Comment#5"), - (12,5L,"Comment#6"), - (13,5L,"Comment#7"), - (14,5L,"Comment#8"), - (15,5L,"Comment#9"), - (16,6L,"Comment#10"), - (17,6L,"Comment#11"), - (18,6L,"Comment#12"), - (19,6L,"Comment#13"), - (20,6L,"Comment#14"), - (21,6L,"Comment#15") - ) - - def runProgram(progId: Int, resultPath: String): String = { progId match { case 1 => // Full aggregate val env = ExecutionEnvironment.getExecutionEnvironment env.setDegreeOfParallelism(10) - val ds = env.fromCollection(tupleInput) +// val ds = CollectionDataSets.get3TupleDataSet(env) + val ds = CollectionDataSets.get3TupleDataSet(env) val aggregateDs = ds .aggregate(Aggregations.SUM,0) @@ -84,7 +61,7 @@ object AggregateProgs { case 2 => // Grouped aggregate val env = ExecutionEnvironment.getExecutionEnvironment - val ds = env.fromCollection(tupleInput) + val ds = CollectionDataSets.get3TupleDataSet(env) val aggregateDs = ds .groupBy(1) @@ -103,7 +80,7 @@ object AggregateProgs { case 3 => // Nested aggregate val env = ExecutionEnvironment.getExecutionEnvironment - val ds = env.fromCollection(tupleInput) + val ds = CollectionDataSets.get3TupleDataSet(env) val aggregateDs = ds .groupBy(1) @@ -111,7 +88,7 @@ object AggregateProgs { .aggregate(Aggregations.MIN, 0) // Ensure aggregate operator correctly copies other fields .filter(_._3 != null) - .map { t => Tuple1(t._1) } + .map { t => new Tuple1(t._1) } aggregateDs.writeAsCsv(resultPath) @@ -140,7 +117,7 @@ class AggregateITCase(config: Configuration) extends JavaProgramTestBase(config) } protected def testProgram(): Unit = { - expectedResult = AggregateProgs.runProgram(curProgId, resultPath) + expectedResult = DistinctProgs.runProgram(curProgId, resultPath) } protected override def postSubmit(): Unit = { @@ -152,7 +129,7 @@ object AggregateITCase { @Parameters def getConfigurations: java.util.Collection[Array[AnyRef]] = { val configs = mutable.MutableList[Array[AnyRef]]() - for (i <- 1 to AggregateProgs.NUM_PROGRAMS) { + for (i <- 1 to DistinctProgs.NUM_PROGRAMS) { val config = new Configuration() config.setInteger("ProgramId", i) configs += Array(config) http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala index d962b76..3f0ca5f 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala @@ -17,13 +17,11 @@ */ package org.apache.flink.api.scala.operators -import java.io.Serializable -import org.apache.flink.api.common.InvalidProgramException import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException import org.junit.Assert -import org.junit.Ignore import org.junit.Test import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.util.CollectionDataSets.CustomType class CoGroupOperatorTest { @@ -130,7 +128,7 @@ class CoGroupOperatorTest { ds1.coGroup(ds2).where("_1", "_2").equalTo("_3") } - @Test(expected = classOf[IllegalArgumentException]) + @Test(expected = classOf[RuntimeException]) def testCoGroupKeyFieldNames4(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) @@ -140,7 +138,7 @@ class CoGroupOperatorTest { ds1.coGroup(ds2).where("_6").equalTo("_1") } - @Test(expected = classOf[IllegalArgumentException]) + @Test(expected = classOf[RuntimeException]) def testCoGroupKeyFieldNames5(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) @@ -150,7 +148,7 @@ class CoGroupOperatorTest { ds1.coGroup(ds2).where("_1").equalTo("bar") } - @Test(expected = classOf[UnsupportedOperationException]) + @Test(expected = classOf[RuntimeException]) def testCoGroupKeyFieldNames6(): Unit = { val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) @@ -160,7 +158,6 @@ class CoGroupOperatorTest { ds1.coGroup(ds2).where("_3").equalTo("_1") } - @Ignore @Test def testCoGroupKeyExpressions1(): Unit = { val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment @@ -176,29 +173,26 @@ class CoGroupOperatorTest { } } - @Ignore - @Test(expected = classOf[InvalidProgramException]) + @Test(expected = classOf[IncompatibleKeysException]) def testCoGroupKeyExpressions2(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(customTypeData) val ds2 = env.fromCollection(customTypeData) // should not work, incompatible key types -// ds1.coGroup(ds2).where("i").equalTo("s") + ds1.coGroup(ds2).where("myInt").equalTo("myString") } - @Ignore - @Test(expected = classOf[InvalidProgramException]) + @Test(expected = classOf[IncompatibleKeysException]) def testCoGroupKeyExpressions3(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(customTypeData) val ds2 = env.fromCollection(customTypeData) // should not work, incompatible number of keys -// ds1.coGroup(ds2).where("i", "s").equalTo("s") + ds1.coGroup(ds2).where("myInt", "myString").equalTo("myString") } - @Ignore @Test(expected = classOf[IllegalArgumentException]) def testCoGroupKeyExpressions4(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment @@ -207,7 +201,7 @@ class CoGroupOperatorTest { // should not work, key non-existent -// ds1.coGroup(ds2).where("myNonExistent").equalTo("i") + ds1.coGroup(ds2).where("myNonExistent").equalTo("i") } @Test @@ -218,7 +212,7 @@ class CoGroupOperatorTest { // Should work try { - ds1.coGroup(ds2).where { _.l } equalTo { _.l } + ds1.coGroup(ds2).where { _.myLong } equalTo { _.myLong } } catch { case e: Exception => Assert.fail() @@ -233,7 +227,7 @@ class CoGroupOperatorTest { // Should work try { - ds1.coGroup(ds2).where { _.l}.equalTo(3) + ds1.coGroup(ds2).where { _.myLong }.equalTo(3) } catch { case e: Exception => Assert.fail() @@ -248,7 +242,7 @@ class CoGroupOperatorTest { // Should work try { - ds1.coGroup(ds2).where(3).equalTo { _.l } + ds1.coGroup(ds2).where(3).equalTo { _.myLong } } catch { case e: Exception => Assert.fail() @@ -262,7 +256,7 @@ class CoGroupOperatorTest { val ds2 = env.fromCollection(customTypeData) // Should not work, incompatible types - ds1.coGroup(ds2).where(2).equalTo { _.l } + ds1.coGroup(ds2).where(2).equalTo { _.myLong } } @Test(expected = classOf[IncompatibleKeysException]) @@ -272,7 +266,7 @@ class CoGroupOperatorTest { val ds2 = env.fromCollection(customTypeData) // Should not work, more than one field position key - ds1.coGroup(ds2).where(1, 3).equalTo { _.l } + ds1.coGroup(ds2).where(1, 3).equalTo { _.myLong } } } http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CustomType.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CustomType.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CustomType.scala deleted file mode 100644 index 94627b9..0000000 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CustomType.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.api.scala.operators - -import java.io.Serializable - -/** - * A custom data type that is used by the operator Tests. - */ -class CustomType(var i:Int, var l: Long, var s: String) extends Serializable { - def this() { - this(0, 0, null) - } - - override def toString: String = { - i + "," + l + "," + s - } -} http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/DistinctITCase.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/DistinctITCase.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/DistinctITCase.scala new file mode 100644 index 0000000..855335d --- /dev/null +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/DistinctITCase.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.operators + +import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.configuration.Configuration +import org.apache.flink.test.util.JavaProgramTestBase +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.flink.api.scala._ + + +object DistinctProgs { + var NUM_PROGRAMS: Int = 8 + + + def runProgram(progId: Int, resultPath: String): String = { + progId match { + case 1 => + /* + * Check correctness of distinct on tuples with key field selector + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getSmall3TupleDataSet(env) + + val distinctDs = ds.union(ds).distinct(0, 1, 2) + distinctDs.writeAsCsv(resultPath) + + env.execute() + + // return expected result + "1,1,Hi\n" + + "2,2,Hello\n" + + "3,2,Hello world\n" + + case 2 => + /* + * check correctness of distinct on tuples with key field selector with not all fields + * selected + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getSmall5TupleDataSet(env) + + val distinctDs = ds.union(ds).distinct(0).map(_._1) + + distinctDs.writeAsText(resultPath) + env.execute() + "1\n" + "2\n" + + case 3 => + /* + * check correctness of distinct on tuples with key extractor + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getSmall5TupleDataSet(env) + + val reduceDs = ds.union(ds).distinct(_._1).map(_._1) + + reduceDs.writeAsText(resultPath) + env.execute() + "1\n" + "2\n" + + case 4 => + /* + * check correctness of distinct on custom type with type extractor + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getCustomTypeDataSet(env) + + val reduceDs = ds.distinct(_.myInt).map( t => new Tuple1(t.myInt)) + + reduceDs.writeAsCsv(resultPath) + env.execute() + "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n" + + case 5 => + /* + * check correctness of distinct on tuples + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getSmall3TupleDataSet(env) + + val distinctDs = ds.union(ds).distinct() + + distinctDs.writeAsCsv(resultPath) + env.execute() + "1,1,Hi\n" + "2,2,Hello\n" + "3,2,Hello world\n" + + case 6 => + /* + * check correctness of distinct on custom type with tuple-returning type extractor + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get5TupleDataSet(env) + + val reduceDs = ds.distinct( t => (t._1, t._5)).map( t => (t._1, t._5) ) + + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1\n" + "2,1\n" + "2,2\n" + "3,2\n" + "3,3\n" + "4,1\n" + "4,2\n" + "5," + + "1\n" + "5,2\n" + "5,3\n" + + case 7 => + /* + * check correctness of distinct on tuples with field expressions + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getSmall5TupleDataSet(env) + + val reduceDs = ds.union(ds).distinct("_1").map(t => new Tuple1(t._1)) + + reduceDs.writeAsCsv(resultPath) + env.execute() + "1\n" + "2\n" + + case 8 => + /* + * check correctness of distinct on Pojos + */ + + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getDuplicatePojoDataSet(env) + + val reduceDs = ds.distinct("nestedPojo.longNumber").map(_.nestedPojo.longNumber.toInt) + + reduceDs.writeAsText(resultPath) + env.execute() + "10000\n20000\n30000\n" + + case _ => + throw new IllegalArgumentException("Invalid program id") + } + } +} + + +@RunWith(classOf[Parameterized]) +class DistinctITCase(config: Configuration) extends JavaProgramTestBase(config) { + + private var curProgId: Int = config.getInteger("ProgramId", -1) + private var resultPath: String = null + private var expectedResult: String = null + + protected override def preSubmit(): Unit = { + resultPath = getTempDirPath("result") + } + + protected def testProgram(): Unit = { + expectedResult = DistinctProgs.runProgram(curProgId, resultPath) + } + + protected override def postSubmit(): Unit = { + compareResultsByLinesInMemory(expectedResult, resultPath) + } +} + +object DistinctITCase { + @Parameters + def getConfigurations: java.util.Collection[Array[AnyRef]] = { + val configs = mutable.MutableList[Array[AnyRef]]() + for (i <- 1 to DistinctProgs.NUM_PROGRAMS) { + val config = new Configuration() + config.setInteger("ProgramId", i) + configs += Array(config) + } + + configs.asJavaCollection + } +} + http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/DistinctOperatorTest.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/DistinctOperatorTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/DistinctOperatorTest.scala index b146e1c..e9d214b 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/DistinctOperatorTest.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/DistinctOperatorTest.scala @@ -17,6 +17,7 @@ */ package org.apache.flink.api.scala.operators +import org.apache.flink.api.scala.util.CollectionDataSets.CustomType import org.junit.Assert import org.apache.flink.api.common.InvalidProgramException import org.junit.Test @@ -102,7 +103,7 @@ class DistinctOperatorTest { } } - @Test(expected = classOf[UnsupportedOperationException]) + @Test(expected = classOf[RuntimeException]) def testDistinctByKeyFields2(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val longDs = env.fromCollection(emptyLongData) @@ -111,16 +112,16 @@ class DistinctOperatorTest { longDs.distinct("_1") } - @Test(expected = classOf[UnsupportedOperationException]) + @Test(expected = classOf[RuntimeException]) def testDistinctByKeyFields3(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val customDs = env.fromCollection(customTypeData) - // should not work: field key on custom type + // should not work: invalid fields customDs.distinct("_1") } - @Test(expected = classOf[IllegalArgumentException]) + @Test(expected = classOf[RuntimeException]) def testDistinctByKeyFields4(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val tupleDs = env.fromCollection(emptyTupleData) @@ -130,11 +131,20 @@ class DistinctOperatorTest { } @Test + def testDistinctByKeyFields5(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val customDs = env.fromCollection(customTypeData) + + // should work + customDs.distinct("myInt") + } + + @Test def testDistinctByKeySelector1(): Unit = { val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment try { val customDs = env.fromCollection(customTypeData) - customDs.distinct {_.l} + customDs.distinct {_.myLong} } catch { case e: Exception => Assert.fail() http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/ExamplesITCase.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/ExamplesITCase.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/ExamplesITCase.scala index d5ae6b6..f43052a 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/ExamplesITCase.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/ExamplesITCase.scala @@ -29,25 +29,36 @@ import scala.collection.JavaConverters._ import scala.collection.mutable // TODO case class Tuple2[T1, T2](_1: T1, _2: T2) -// TODO case class Foo(a: Int, b: String) +// TODO case class Foo(a: Int, b: String -class Nested(var myLong: Long) { +case class Nested(myLong: Long) + +class Pojo(var myString: String, var myInt: Int, var nested: Nested) { def this() = { - this(0); + this("", 0, new Nested(1)) } + + def this(myString: String, myInt: Int, myLong: Long) { this(myString, myInt, new Nested(myLong)) } + + override def toString = s"myString=$myString myInt=$myInt nested.myLong=${nested.myLong}" +} + +class NestedPojo(var myLong: Long) { + def this() { this(0) } } -class Pojo(var myString: String, var myInt: Int, myLong: Long) { - var nested = new Nested(myLong) +class PojoWithPojo(var myString: String, var myInt: Int, var nested: Nested) { def this() = { - this("", 0, 0) + this("", 0, new Nested(1)) } - override def toString() = "myString="+myString+" myInt="+myInt+" nested.myLong="+nested.myLong + def this(myString: String, myInt: Int, myLong: Long) { this(myString, myInt, new Nested(myLong)) } + + override def toString = s"myString=$myString myInt=$myInt nested.myLong=${nested.myLong}" } object ExampleProgs { - var NUM_PROGRAMS: Int = 3 + var NUM_PROGRAMS: Int = 4 def runProgram(progId: Int, resultPath: String, onCollection: Boolean): String = { progId match { @@ -58,27 +69,53 @@ object ExampleProgs { val env = ExecutionEnvironment.getExecutionEnvironment val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) ) - val grouped = ds.groupBy(0).reduce( { (e1, e2) => ((e1._1._1,e1._1._2), e1._2+e2._2)}) + val grouped = ds.groupBy(0).reduce( { (e1, e2) => ((e1._1._1, e1._1._2), e1._2 + e2._2)}) grouped.writeAsText(resultPath) env.execute() "((this,hello),3)\n((this,is),3)\n" + case 2 => - /* - Test nested tuples with int offset - */ - val env = ExecutionEnvironment.getExecutionEnvironment - val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) ) - - val grouped = ds.groupBy("f0.f0").reduce( { (e1, e2) => ((e1._1._1,e1._1._2), e1._2+e2._2)}) - grouped.writeAsText(resultPath) - env.execute() - "((this,is),6)\n" + /* + Test nested tuples with int offset + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) ) + + val grouped = ds.groupBy("_1._1").reduce{ + (e1, e2) => ((e1._1._1, e1._1._2), e1._2 + e2._2) + } + grouped.writeAsText(resultPath) + env.execute() + "((this,is),6)\n" + case 3 => /* Test nested pojos */ val env = ExecutionEnvironment.getExecutionEnvironment - val ds = env.fromElements( new Pojo("one", 1, 1L),new Pojo("one", 1, 1L),new Pojo("two", 666, 2L) ) + val ds = env.fromElements( + new PojoWithPojo("one", 1, 1L), + new PojoWithPojo("one", 1, 1L), + new PojoWithPojo("two", 666, 2L) ) + + val grouped = ds.groupBy("nested.myLong").reduce { + (p1, p2) => + p1.myInt += p2.myInt + p1 + } + grouped.writeAsText(resultPath) + env.execute() + "myString=two myInt=666 nested.myLong=2\nmyString=one myInt=2 nested.myLong=1\n" + + case 4 => + /* + Test pojo with nested case class + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements( + new Pojo("one", 1, 1L), + new Pojo("one", 1, 1L), + new Pojo("two", 666, 2L) ) val grouped = ds.groupBy("nested.myLong").reduce { (p1, p2) => @@ -124,4 +161,4 @@ object ExamplesITCase { configs.asJavaCollection } -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/6be85554/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/GroupReduceITCase.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/GroupReduceITCase.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/GroupReduceITCase.scala new file mode 100644 index 0000000..b796a81 --- /dev/null +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/GroupReduceITCase.scala @@ -0,0 +1,574 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.operators + +import java.lang.Iterable + +import org.apache.flink.api.common.functions._ +import org.apache.flink.api.common.operators.Order +import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.api.scala.util.CollectionDataSets.CustomType +import org.apache.flink.compiler.PactCompiler +import org.apache.flink.configuration.Configuration +import org.apache.flink.configuration.Configuration +import org.apache.flink.test.util.JavaProgramTestBase +import org.apache.flink.test.util.JavaProgramTestBase +import org.apache.flink.util.Collector +import org.junit.runner.RunWith +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters +import org.junit.runners.Parameterized.Parameters + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.flink.api.scala._ + + +object GroupReduceProgs { + var NUM_PROGRAMS: Int = 8 + + def runProgram(progId: Int, resultPath: String, onCollection: Boolean): String = { + progId match { + case 1 => + /* + * check correctness of groupReduce on tuples with key field selector + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val reduceDs = ds.groupBy(1).reduceGroup { + in => + in.map(t => (t._1, t._2)).reduce((l, r) => (l._1 + r._1, l._2)) + } + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1\n" + "5,2\n" + "15,3\n" + "34,4\n" + "65,5\n" + "111,6\n" + + case 2 => + /* + * check correctness of groupReduce on tuples with multiple key field selector + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets + .get5TupleDataSet(env) + val reduceDs = ds.groupBy(4, 0).reduceGroup { + in => + val (i, l, l2) = in + .map( t => (t._1, t._2, t._5)) + .reduce((l, r) => (l._1, l._2 + r._2, l._3)) + (i, l, 0, "P-)", l2) + } + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1,0,P-),1\n" + "2,3,0,P-),1\n" + "2,2,0,P-),2\n" + "3,9,0,P-),2\n" + "3,6,0," + + "P-),3\n" + "4,17,0,P-),1\n" + "4,17,0,P-),2\n" + "5,11,0,P-),1\n" + "5,29,0,P-)," + + "2\n" + "5,25,0,P-),3\n" + + case 3 => + /* + * check correctness of groupReduce on tuples with key field selector and group sorting + */ + val env = ExecutionEnvironment.getExecutionEnvironment + env.setDegreeOfParallelism(1) + val ds = CollectionDataSets.get3TupleDataSet(env) + val reduceDs = ds.groupBy(1).sortGroup(2, Order.ASCENDING).reduceGroup { + in => + in.reduce((l, r) => (l._1 + r._1, l._2, l._3 + "-" + r._3)) + } + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1,Hi\n" + + "5,2,Hello-Hello world\n" + + "15,3,Hello world, how are you?-I am fine.-Luke Skywalker\n" + + "34,4,Comment#1-Comment#2-Comment#3-Comment#4\n" + + "65,5,Comment#5-Comment#6-Comment#7-Comment#8-Comment#9\n" + + "111,6,Comment#10-Comment#11-Comment#12-Comment#13-Comment#14-Comment#15\n" + + case 4 => + /* + * check correctness of groupReduce on tuples with key extractor + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val reduceDs = ds.groupBy(_._2).reduceGroup { + in => + in.map(t => (t._1, t._2)).reduce((l, r) => (l._1 + r._1, l._2)) + } + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1\n" + "5,2\n" + "15,3\n" + "34,4\n" + "65,5\n" + "111,6\n" + + case 5 => + /* + * check correctness of groupReduce on custom type with type extractor + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getCustomTypeDataSet(env) + val reduceDs = ds.groupBy(_.myInt).reduceGroup { + in => + val iter = in.toIterator + val o = new CustomType + var c = iter.next() + + o.myString = "Hello!" + o.myInt = c.myInt + o.myLong = c.myLong + + while (iter.hasNext) { + val next = iter.next() + o.myLong += next.myLong + } + o + } + reduceDs.writeAsText(resultPath) + env.execute() + "1,0,Hello!\n" + "2,3,Hello!\n" + "3,12,Hello!\n" + "4,30,Hello!\n" + "5,60," + + "Hello!\n" + "6,105,Hello!\n" + + case 6 => + /* + * check correctness of all-groupreduce for tuples + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val reduceDs = ds.reduceGroup { + in => + var i = 0 + var l = 0L + for (t <- in) { + i += t._1 + l += t._2 + } + (i, l, "Hello World") + } + reduceDs.writeAsCsv(resultPath) + env.execute() + "231,91,Hello World\n" + + case 7 => + /* + * check correctness of all-groupreduce for custom types + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getCustomTypeDataSet(env) + val reduceDs = ds.reduceGroup { + in => + val o = new CustomType(0, 0, "Hello!") + for (t <- in) { + o.myInt += t.myInt + o.myLong += t.myLong + } + o + } + reduceDs.writeAsText(resultPath) + env.execute() + "91,210,Hello!" + + case 8 => + /* + * check correctness of groupReduce with broadcast set + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val intDs = CollectionDataSets.getIntDataSet(env) + val ds = CollectionDataSets.get3TupleDataSet(env) + val reduceDs = ds.groupBy(1).reduceGroup( + new RichGroupReduceFunction[(Int, Long, String), (Int, Long, String)] { + private var f2Replace = "" + + override def open(config: Configuration) { + val ints = this.getRuntimeContext.getBroadcastVariable[Int]("ints").asScala + f2Replace = ints.sum + "" + } + + override def reduce( + values: Iterable[(Int, Long, String)], + out: Collector[(Int, Long, String)]): Unit = { + var i: Int = 0 + var l: Long = 0L + for (t <- values.asScala) { + i += t._1 + l = t._2 + } + out.collect((i, l, f2Replace)) + } + }).withBroadcastSet(intDs, "ints") + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1,55\n" + "5,2,55\n" + "15,3,55\n" + "34,4,55\n" + "65,5,55\n" + "111,6,55\n" + +// case 9 => +// val env = ExecutionEnvironment.getExecutionEnvironment +// val ds = CollectionDataSets.get3TupleDataSet(env) +// val reduceDs = ds.groupBy(1).reduceGroup(new +// GroupReduceITCase.InputReturningTuple3GroupReduce) +// reduceDs.writeAsCsv(resultPath) +// env.execute() +// "11,1,Hi!\n" + "21,1,Hi again!\n" + "12,2,Hi!\n" + "22,2,Hi again!\n" + "13,2," + +// "Hi!\n" + "23,2,Hi again!\n" + +// case 10 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// val ds = CollectionDataSets.getCustomTypeDataSet +// (env) +// val reduceDs = ds.groupBy(new +// KeySelector[CollectionDataSets.CustomType, Integer] { +// def getKey(in: CollectionDataSets.CustomType): Integer = { +// return in.myInt +// } +// }).reduceGroup(new GroupReduceITCase.CustomTypeGroupReduceWithCombine) +// reduceDs.writeAsText(resultPath) +// env.execute() +// if (collectionExecution) { +// return null +// } +// else { +// "1,0,test1\n" + "2,3,test2\n" + "3,12,test3\n" + "4,30,test4\n" + "5,60," + +// "test5\n" + "6,105,test6\n" +// } +// } +// case 11 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// env.setDegreeOfParallelism(2) +// val ds = CollectionDataSets.get3TupleDataSet(env) +// val reduceDs = ds.groupBy(1).reduceGroup(new +// GroupReduceITCase.Tuple3GroupReduceWithCombine) +// reduceDs.writeAsCsv(resultPath) +// env.execute() +// if (collectionExecution) { +// return null +// } +// else { +// "1,test1\n" + "5,test2\n" + "15,test3\n" + "34,test4\n" + "65,test5\n" + "111," + +// "test6\n" +// } +// } +// +// +// // all-groupreduce with combine +// +// +// case 12 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// val ds = CollectionDataSets.get3TupleDataSet(env) +// .map(new GroupReduceITCase.IdentityMapper[Tuple3[Integer, Long, +// String]]).setParallelism(4) +// val cfg: Configuration = new Configuration +// cfg.setString(PactCompiler.HINT_SHIP_STRATEGY, +// PactCompiler.HINT_SHIP_STRATEGY_REPARTITION) +// val reduceDs = ds.reduceGroup(new GroupReduceITCase +// .Tuple3AllGroupReduceWithCombine).withParameters(cfg) +// reduceDs.writeAsCsv(resultPath) +// env.execute() +// if (collectionExecution) { +// return null +// } +// else { +// "322," + +// "testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttest\n" +// } +// } +// case 13 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// env.setDegreeOfParallelism(1) +// val ds = CollectionDataSets.get3TupleDataSet(env) +// val reduceDs = ds.groupBy(1).sortGroup(2, +// Order.DESCENDING).reduceGroup(new GroupReduceITCase.Tuple3SortedGroupReduce) +// reduceDs.writeAsCsv(resultPath) +// env.execute() +// "1,1,Hi\n" + "5,2,Hello world-Hello\n" + "15,3,Luke Skywalker-I am fine.-Hello " + +// "world, how are you?\n" + "34,4,Comment#4-Comment#3-Comment#2-Comment#1\n" + "65,5," + +// "Comment#9-Comment#8-Comment#7-Comment#6-Comment#5\n" + "111,6," + +// "Comment#15-Comment#14-Comment#13-Comment#12-Comment#11-Comment#10\n" +// } +// case 14 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// val ds = CollectionDataSets +// .get5TupleDataSet(env) +// val reduceDs: DataSet[Tuple5[Integer, Long, Integer, String, +// Long]] = ds.groupBy(new KeySelector[Tuple5[Integer, Long, Integer, String, Long], +// Tuple2[Integer, Long]] { +// def getKey(t: Tuple5[Integer, Long, Integer, String, Long]): Tuple2[Integer, Long] = { +// return new Tuple2[Integer, Long](t.f0, t.f4) +// } +// }).reduceGroup(new GroupReduceITCase.Tuple5GroupReduce) +// reduceDs.writeAsCsv(resultPath) +// env.execute() +// "1,1,0,P-),1\n" + "2,3,0,P-),1\n" + "2,2,0,P-),2\n" + "3,9,0,P-),2\n" + "3,6,0," + +// "P-),3\n" + "4,17,0,P-),1\n" + "4,17,0,P-),2\n" + "5,11,0,P-),1\n" + "5,29,0,P-)," + +// "2\n" + "5,25,0,P-),3\n" +// } +// case 15 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// env.setDegreeOfParallelism(1) +// val ds = CollectionDataSets.get3TupleDataSet(env) +// val reduceDs = ds.groupBy(1).sortGroup(0, +// Order.ASCENDING).reduceGroup(new GroupReduceITCase.OrderCheckingCombinableReduce) +// reduceDs.writeAsCsv(resultPath) +// env.execute() +// "1,1,Hi\n" + "2,2,Hello\n" + "4,3,Hello world, how are you?\n" + "7,4," + +// "Comment#1\n" + "11,5,Comment#5\n" + "16,6,Comment#10\n" +// } +// case 16 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// val ds = CollectionDataSets +// .getCrazyNestedDataSet(env) +// val reduceDs = ds.groupBy("nest_Lvl1.nest_Lvl2" + +// ".nest_Lvl3.nest_Lvl4.f1nal").reduceGroup(new GroupReduceFunction[CollectionDataSets +// .CrazyNested, Tuple2[String, Integer]] { +// def reduce(values: Iterable[CollectionDataSets.CrazyNested], +// out: Collector[Tuple2[String, Integer]]) { +// var c: Int = 0 +// var n: String = null +// import scala.collection.JavaConversions._ +// for (v <- values) { +// c += 1 +// n = v.nest_Lvl1.nest_Lvl2.nest_Lvl3.nest_Lvl4.f1nal +// } +// out.collect(new Tuple2[String, Integer](n, c)) +// } +// }) +// reduceDs.writeAsCsv(resultPath) +// env.execute() +// "aa,1\nbb,2\ncc,3\n" +// } +// case 17 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// val ds = CollectionDataSets +// .getPojoExtendingFromTuple(env) +// val reduceDs = ds.groupBy("special", +// "f2") +// .reduceGroup(new GroupReduceFunction[CollectionDataSets.FromTupleWithCTor, Integer] { +// def reduce(values: Iterable[CollectionDataSets.FromTupleWithCTor], +// out: Collector[Integer]) { +// var c: Int = 0 +// import scala.collection.JavaConversions._ +// for (v <- values) { +// c += 1 +// } +// out.collect(c) +// } +// }) +// reduceDs.writeAsText(resultPath) +// env.execute() +// "3\n2\n" +// } +// case 18 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// val ds = CollectionDataSets +// .getPojoContainingTupleAndWritable(env) +// val reduceDs = ds.groupBy("hadoopFan", "theTuple.*").reduceGroup(new +// GroupReduceFunction[CollectionDataSets.PojoContainingTupleAndWritable, Integer] { +// def reduce(values: Iterable[CollectionDataSets.PojoContainingTupleAndWritable], +// out: Collector[Integer]) { +// var c: Int = 0 +// import scala.collection.JavaConversions._ +// for (v <- values) { +// c += 1 +// } +// out.collect(c) +// } +// }) +// reduceDs.writeAsText(resultPath) +// env.execute() +// "1\n5\n" +// } +// case 19 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// val ds: DataSet[Tuple3[Integer, CollectionDataSets.CrazyNested, +// CollectionDataSets.POJO]] = CollectionDataSets.getTupleContainingPojos(env) +// val reduceDs = ds.groupBy("f0", "f1.*").reduceGroup(new +// GroupReduceFunction[Tuple3[Integer, CollectionDataSets.CrazyNested, +// CollectionDataSets.POJO], Integer] { +// def reduce(values: Iterable[Tuple3[Integer, CollectionDataSets.CrazyNested, +// CollectionDataSets.POJO]], out: Collector[Integer]) { +// var c: Int = 0 +// import scala.collection.JavaConversions._ +// for (v <- values) { +// c += 1 +// } +// out.collect(c) +// } +// }) +// reduceDs.writeAsText(resultPath) +// env.execute() +// "3\n1\n" +// } +// case 20 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// env.setDegreeOfParallelism(1) +// val ds = CollectionDataSets.get3TupleDataSet(env) +// val reduceDs = ds.groupBy(1).sortGroup("f2", +// Order.DESCENDING).reduceGroup(new GroupReduceITCase.Tuple3SortedGroupReduce) +// reduceDs.writeAsCsv(resultPath) +// env.execute() +// "1,1,Hi\n" + "5,2,Hello world-Hello\n" + "15,3,Luke Skywalker-I am fine.-Hello " + +// "world, how are you?\n" + "34,4,Comment#4-Comment#3-Comment#2-Comment#1\n" + "65,5," + +// "Comment#9-Comment#8-Comment#7-Comment#6-Comment#5\n" + "111,6," + +// "Comment#15-Comment#14-Comment#13-Comment#12-Comment#11-Comment#10\n" +// } +// case 21 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// env.setDegreeOfParallelism(1) +// val ds = CollectionDataSets +// .getGroupSortedNestedTupleDataSet(env) +// val reduceDs = ds.groupBy("f1").sortGroup(0, +// Order.DESCENDING).reduceGroup(new GroupReduceITCase.NestedTupleReducer) +// reduceDs.writeAsText(resultPath) +// env.execute() +// "a--(1,1)-(1,2)-(1,3)-\n" + "b--(2,2)-\n" + "c--(3,3)-(3,6)-(3,9)-\n" +// } +// case 22 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// env.setDegreeOfParallelism(1) +// val ds = CollectionDataSets +// .getGroupSortedNestedTupleDataSet(env) +// val reduceDs = ds.groupBy("f1").sortGroup("f0.f0", +// Order.ASCENDING).reduceGroup(new GroupReduceITCase.NestedTupleReducer) +// reduceDs.writeAsText(resultPath) +// env.execute() +// "a--(1,3)-(1,2)-(2,1)-\n" + "b--(2,2)-\n" + "c--(3,3)-(3,6)-(4,9)-\n" +// } +// case 23 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// env.setDegreeOfParallelism(1) +// val ds = CollectionDataSets +// .getGroupSortedNestedTupleDataSet(env) +// val reduceDs = ds.groupBy("f1").sortGroup("f0.f0", +// Order.DESCENDING).reduceGroup(new GroupReduceITCase.NestedTupleReducer) +// reduceDs.writeAsText(resultPath) +// env.execute() +// "a--(2,1)-(1,3)-(1,2)-\n" + "b--(2,2)-\n" + "c--(4,9)-(3,3)-(3,6)-\n" +// } +// case 24 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// env.setDegreeOfParallelism(1) +// val ds = CollectionDataSets +// .getGroupSortedNestedTupleDataSet(env) +// val reduceDs = ds.groupBy("f1").sortGroup("f0.f0", +// Order.DESCENDING).sortGroup("f0.f1", Order.DESCENDING).reduceGroup(new +// GroupReduceITCase.NestedTupleReducer) +// reduceDs.writeAsText(resultPath) +// env.execute() +// "a--(2,1)-(1,3)-(1,2)-\n" + "b--(2,2)-\n" + "c--(4,9)-(3,6)-(3,3)-\n" +// } +// case 25 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// env.setDegreeOfParallelism(1) +// val ds = CollectionDataSets +// .getGroupSortedPojoContainingTupleAndWritable(env) +// val reduceDs = ds.groupBy("hadoopFan").sortGroup("theTuple.f0", +// Order.DESCENDING) +// .sortGroup("theTuple.f1", Order.DESCENDING) +// .reduceGroup(new GroupReduceFunction[CollectionDataSets.PojoContainingTupleAndWritable, String] { +// def reduce(values: Iterable[CollectionDataSets.PojoContainingTupleAndWritable], +// out: Collector[String]) { +// var once: Boolean = false +// val concat: StringBuilder = new StringBuilder +// import scala.collection.JavaConversions._ +// for (value <- values) { +// if (!once) { +// concat.append(value.hadoopFan.get) +// concat.append("---") +// once = true +// } +// concat.append(value.theTuple) +// concat.append("-") +// } +// out.collect(concat.toString) +// } +// }) +// reduceDs.writeAsText(resultPath) +// env.execute() +// "1---(10,100)-\n" + "2---(30,600)-(30,400)-(30,200)-(20,201)-(20,200)-\n" +// } +// case 26 => { +// val env = ExecutionEnvironment.getExecutionEnvironment +// env.setDegreeOfParallelism(1) +// val ds = CollectionDataSets.getPojoWithMultiplePojos(env) +// val reduceDs = ds.groupBy("hadoopFan") +// .sortGroup("theTuple.f0", Order.DESCENDING).sortGroup("theTuple.f1", Order.DESCENDING) +// .reduceGroup(new GroupReduceFunction[CollectionDataSets.PojoContainingTupleAndWritable, String] { +// def reduce(values: Iterable[CollectionDataSets.PojoContainingTupleAndWritable], +// out: Collector[String]) { +// var once: Boolean = false +// val concat: StringBuilder = new StringBuilder +// import scala.collection.JavaConversions._ +// for (value <- values) { +// if (!once) { +// concat.append(value.hadoopFan.get) +// concat.append("---") +// once = true +// } +// concat.append(value.theTuple) +// concat.append("-") +// } +// out.collect(concat.toString) +// } +// }) +// reduceDs.writeAsText(resultPath) +// env.execute() +// "1---(10,100)-\n" + "2---(30,600)-(30,400)-(30,200)-(20,201)-(20,200)-\n" +// } +// case _ => { +// throw new IllegalArgumentException("Invalid program id") +// } + } + } +} + + +@RunWith(classOf[Parameterized]) +class GroupReduceITCase(config: Configuration) extends JavaProgramTestBase(config) { + + private var curProgId: Int = config.getInteger("ProgramId", -1) + private var resultPath: String = null + private var expectedResult: String = null + + protected override def preSubmit(): Unit = { + resultPath = getTempDirPath("result") + } + + protected def testProgram(): Unit = { + expectedResult = GroupReduceProgs.runProgram(curProgId, resultPath, isCollectionExecution) + } + + protected override def postSubmit(): Unit = { + compareResultsByLinesInMemory(expectedResult, resultPath) + } +} + +object GroupReduceITCase { + @Parameters + def getConfigurations: java.util.Collection[Array[AnyRef]] = { + val configs = mutable.MutableList[Array[AnyRef]]() + for (i <- 1 to GroupReduceProgs.NUM_PROGRAMS) { + val config = new Configuration() + config.setInteger("ProgramId", i) + configs += Array(config) + } + + configs.asJavaCollection + } +} + +