Return-Path: X-Original-To: apmail-flink-commits-archive@minotaur.apache.org Delivered-To: apmail-flink-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 6735517400 for ; Sat, 18 Oct 2014 15:17:44 +0000 (UTC) Received: (qmail 40838 invoked by uid 500); 18 Oct 2014 15:17:44 -0000 Delivered-To: apmail-flink-commits-archive@flink.apache.org Received: (qmail 40812 invoked by uid 500); 18 Oct 2014 15:17:44 -0000 Mailing-List: contact commits-help@flink.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@flink.incubator.apache.org Delivered-To: mailing list commits@flink.incubator.apache.org Received: (qmail 40803 invoked by uid 99); 18 Oct 2014 15:17:44 -0000 Received: from athena.apache.org (HELO athena.apache.org) (140.211.11.136) by apache.org (qpsmtpd/0.29) with ESMTP; Sat, 18 Oct 2014 15:17:44 +0000 X-ASF-Spam-Status: No, hits=-2000.0 required=5.0 tests=ALL_TRUSTED,T_RP_MATCHES_RCVD X-Spam-Check-By: apache.org Received: from [140.211.11.3] (HELO mail.apache.org) (140.211.11.3) by apache.org (qpsmtpd/0.29) with SMTP; Sat, 18 Oct 2014 15:17:37 +0000 Received: (qmail 40096 invoked by uid 99); 18 Oct 2014 15:17:17 -0000 Received: from tyr.zones.apache.org (HELO tyr.zones.apache.org) (140.211.11.114) by apache.org (qpsmtpd/0.29) with ESMTP; Sat, 18 Oct 2014 15:17:17 +0000 Received: by tyr.zones.apache.org (Postfix, from userid 65534) id F2FD29CD1C6; Sat, 18 Oct 2014 15:17:16 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: rmetzger@apache.org To: commits@flink.incubator.apache.org Date: Sat, 18 Oct 2014 15:17:19 -0000 Message-Id: In-Reply-To: References: X-Mailer: ASF-Git Admin Mailer Subject: [4/9] [FLINK-1171] Move Scala API tests to flink-tests project X-Virus-Checked: Checked by ClamAV on apache.org http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/a0ad9031/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/ExamplesITCase.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/ExamplesITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/ExamplesITCase.scala new file mode 100644 index 0000000..f43052a --- /dev/null +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/ExamplesITCase.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.operators + +import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.configuration.Configuration +import org.apache.flink.test.util.JavaProgramTestBase +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.apache.flink.api.scala._ +import org.junit.runners.Parameterized.Parameters +import scala.collection.JavaConverters._ + +import scala.collection.mutable + +// TODO case class Tuple2[T1, T2](_1: T1, _2: T2) +// TODO case class Foo(a: Int, b: String + +case class Nested(myLong: Long) + +class Pojo(var myString: String, var myInt: Int, var nested: Nested) { + def this() = { + this("", 0, new Nested(1)) + } + + def this(myString: String, myInt: Int, myLong: Long) { this(myString, myInt, new Nested(myLong)) } + + override def toString = s"myString=$myString myInt=$myInt nested.myLong=${nested.myLong}" +} + +class NestedPojo(var myLong: Long) { + def this() { this(0) } +} + +class PojoWithPojo(var myString: String, var myInt: Int, var nested: Nested) { + def this() = { + this("", 0, new Nested(1)) + } + + def this(myString: String, myInt: Int, myLong: Long) { this(myString, myInt, new Nested(myLong)) } + + override def toString = s"myString=$myString myInt=$myInt nested.myLong=${nested.myLong}" +} + +object ExampleProgs { + var NUM_PROGRAMS: Int = 4 + + def runProgram(progId: Int, resultPath: String, onCollection: Boolean): String = { + progId match { + case 1 => + /* + Test nested tuples with int offset + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) ) + + val grouped = ds.groupBy(0).reduce( { (e1, e2) => ((e1._1._1, e1._1._2), e1._2 + e2._2)}) + grouped.writeAsText(resultPath) + env.execute() + "((this,hello),3)\n((this,is),3)\n" + + case 2 => + /* + Test nested tuples with int offset + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) ) + + val grouped = ds.groupBy("_1._1").reduce{ + (e1, e2) => ((e1._1._1, e1._1._2), e1._2 + e2._2) + } + grouped.writeAsText(resultPath) + env.execute() + "((this,is),6)\n" + + case 3 => + /* + Test nested pojos + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements( + new PojoWithPojo("one", 1, 1L), + new PojoWithPojo("one", 1, 1L), + new PojoWithPojo("two", 666, 2L) ) + + val grouped = ds.groupBy("nested.myLong").reduce { + (p1, p2) => + p1.myInt += p2.myInt + p1 + } + grouped.writeAsText(resultPath) + env.execute() + "myString=two myInt=666 nested.myLong=2\nmyString=one myInt=2 nested.myLong=1\n" + + case 4 => + /* + Test pojo with nested case class + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements( + new Pojo("one", 1, 1L), + new Pojo("one", 1, 1L), + new Pojo("two", 666, 2L) ) + + val grouped = ds.groupBy("nested.myLong").reduce { + (p1, p2) => + p1.myInt += p2.myInt + p1 + } + grouped.writeAsText(resultPath) + env.execute() + "myString=two myInt=666 nested.myLong=2\nmyString=one myInt=2 nested.myLong=1\n" + } + } +} + +@RunWith(classOf[Parameterized]) +class ExamplesITCase(config: Configuration) extends JavaProgramTestBase(config) { + + private var curProgId: Int = config.getInteger("ProgramId", -1) + private var resultPath: String = null + private var expectedResult: String = null + + protected override def preSubmit(): Unit = { + resultPath = getTempDirPath("result") + } + + protected def testProgram(): Unit = { + expectedResult = ExampleProgs.runProgram(curProgId, resultPath, isCollectionExecution) + } + + protected override def postSubmit(): Unit = { + compareResultsByLinesInMemory(expectedResult, resultPath) + } +} + +object ExamplesITCase { + @Parameters + def getConfigurations: java.util.Collection[Array[AnyRef]] = { + val configs = mutable.MutableList[Array[AnyRef]]() + for (i <- 1 to ExampleProgs.NUM_PROGRAMS) { + val config = new Configuration() + config.setInteger("ProgramId", i) + configs += Array(config) + } + + configs.asJavaCollection + } +} http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/a0ad9031/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FilterITCase.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FilterITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FilterITCase.scala new file mode 100644 index 0000000..973028b --- /dev/null +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FilterITCase.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.operators + +import org.apache.flink.api.common.functions.RichFilterFunction +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.configuration.Configuration +import org.apache.flink.test.util.JavaProgramTestBase +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.flink.api.scala._ + + +object FilterProgs { + var NUM_PROGRAMS: Int = 7 + + def runProgram(progId: Int, resultPath: String): String = { + progId match { + case 1 => + /* + * Test all-rejecting filter. + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val filterDs = ds.filter( t => false ) + filterDs.writeAsCsv(resultPath) + env.execute() + "\n" + + case 2 => + /* + * Test all-passing filter. + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val filterDs = ds.filter( t => true ) + filterDs.writeAsCsv(resultPath) + env.execute() + "1,1,Hi\n" + "2,2,Hello\n" + "3,2,Hello world\n" + "4,3,Hello world, " + + "how are you?\n" + "5,3,I am fine.\n" + "6,3,Luke Skywalker\n" + "7,4," + + "Comment#1\n" + "8,4,Comment#2\n" + "9,4,Comment#3\n" + "10,4,Comment#4\n" + "11,5," + + "Comment#5\n" + "12,5,Comment#6\n" + "13,5,Comment#7\n" + "14,5,Comment#8\n" + "15,5," + + "Comment#9\n" + "16,6,Comment#10\n" + "17,6,Comment#11\n" + "18,6,Comment#12\n" + "19," + + "6,Comment#13\n" + "20,6,Comment#14\n" + "21,6,Comment#15\n" + + case 3 => + /* + * Test filter on String tuple field. + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val filterDs = ds.filter( _._3.contains("world") ) + filterDs.writeAsCsv(resultPath) + env.execute() + "3,2,Hello world\n" + "4,3,Hello world, how are you?\n" + + case 4 => + /* + * Test filter on Integer tuple field. + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val filterDs = ds.filter( _._1 % 2 == 0 ) + filterDs.writeAsCsv(resultPath) + env.execute() + "2,2,Hello\n" + "4,3,Hello world, how are you?\n" + "6,3,Luke Skywalker\n" + "8,4," + + "Comment#2\n" + "10,4,Comment#4\n" + "12,5,Comment#6\n" + "14,5,Comment#8\n" + "16,6," + + "Comment#10\n" + "18,6,Comment#12\n" + "20,6,Comment#14\n" + + case 5 => + /* + * Test filter on basic type + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getStringDataSet(env) + val filterDs = ds.filter( _.startsWith("H") ) + filterDs.writeAsText(resultPath) + env.execute() + "Hi\n" + "Hello\n" + "Hello world\n" + "Hello world, how are you?\n" + + case 6 => + /* + * Test filter on custom type + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getCustomTypeDataSet(env) + val filterDs = ds.filter( _.myString.contains("a") ) + filterDs.writeAsText(resultPath) + env.execute() + "3,3,Hello world, how are you?\n" + "3,4,I am fine.\n" + "3,5,Luke Skywalker\n" + + case 7 => + /* + * Test filter on String tuple field. + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ints = CollectionDataSets.getIntDataSet(env) + val ds = CollectionDataSets.get3TupleDataSet(env) + val filterDs = ds.filter( new RichFilterFunction[(Int, Long, String)] { + var literal = -1 + override def open(config: Configuration): Unit = { + val ints = getRuntimeContext.getBroadcastVariable[Int]("ints") + for (i <- ints.asScala) { + literal = if (literal < i) i else literal + } + } + override def filter(value: (Int, Long, String)): Boolean = value._1 < literal + }).withBroadcastSet(ints, "ints") + filterDs.writeAsCsv(resultPath) + env.execute() + "1,1,Hi\n" + "2,2,Hello\n" + "3,2,Hello world\n" + "4,3,Hello world, how are you?\n" + + case _ => + throw new IllegalArgumentException("Invalid program id") + } + } +} + + +@RunWith(classOf[Parameterized]) +class FilterITCase(config: Configuration) extends JavaProgramTestBase(config) { + + private var curProgId: Int = config.getInteger("ProgramId", -1) + private var resultPath: String = null + private var expectedResult: String = null + + protected override def preSubmit(): Unit = { + resultPath = getTempDirPath("result") + } + + protected def testProgram(): Unit = { + expectedResult = FilterProgs.runProgram(curProgId, resultPath) + } + + protected override def postSubmit(): Unit = { + compareResultsByLinesInMemory(expectedResult, resultPath) + } +} + +object FilterITCase { + @Parameters + def getConfigurations: java.util.Collection[Array[AnyRef]] = { + val configs = mutable.MutableList[Array[AnyRef]]() + for (i <- 1 to FilterProgs.NUM_PROGRAMS) { + val config = new Configuration() + config.setInteger("ProgramId", i) + configs += Array(config) + } + + configs.asJavaCollection + } +} + http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/a0ad9031/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FirstNITCase.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FirstNITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FirstNITCase.scala new file mode 100644 index 0000000..6882885 --- /dev/null +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FirstNITCase.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.operators + +import org.apache.flink.api.common.operators.Order +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.configuration.Configuration +import org.apache.flink.test.util.JavaProgramTestBase +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.flink.api.scala._ + + +object FirstNProgs { + var NUM_PROGRAMS: Int = 3 + + def runProgram(progId: Int, resultPath: String): String = { + progId match { + case 1 => + /* + * First-n on ungrouped data set + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val seven = ds.first(7).map( t => new Tuple1(1) ).sum(0) + seven.writeAsText(resultPath) + env.execute() + "(7)\n" + + case 2 => + /* + * First-n on grouped data set + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val first = ds.groupBy(1).first(4).map( t => (t._2, 1)).groupBy(0).sum(1) + first.writeAsText(resultPath) + env.execute() + "(1,1)\n(2,2)\n(3,3)\n(4,4)\n(5,4)\n(6,4)\n" + + case 3 => + /* + * First-n on grouped and sorted data set + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val first = ds.groupBy(1) + .sortGroup(0, Order.DESCENDING) + .first(3) + .map ( t => (t._2, t._1)) + first.writeAsText(resultPath) + env.execute() + "(1,1)\n" + "(2,3)\n(2,2)\n" + "(3,6)\n(3,5)\n(3,4)\n" + "(4,10)\n(4,9)\n(4," + + "8)\n" + "(5,15)\n(5,14)\n(5,13)\n" + "(6,21)\n(6,20)\n(6,19)\n" + + case _ => + throw new IllegalArgumentException("Invalid program id") + } + } +} + + +@RunWith(classOf[Parameterized]) +class FirstNITCase(config: Configuration) extends JavaProgramTestBase(config) { + + private var curProgId: Int = config.getInteger("ProgramId", -1) + private var resultPath: String = null + private var expectedResult: String = null + + protected override def preSubmit(): Unit = { + resultPath = getTempDirPath("result") + } + + protected def testProgram(): Unit = { + expectedResult = FirstNProgs.runProgram(curProgId, resultPath) + } + + protected override def postSubmit(): Unit = { + compareResultsByLinesInMemory(expectedResult, resultPath) + } +} + +object FirstNITCase { + @Parameters + def getConfigurations: java.util.Collection[Array[AnyRef]] = { + val configs = mutable.MutableList[Array[AnyRef]]() + for (i <- 1 to FirstNProgs.NUM_PROGRAMS) { + val config = new Configuration() + config.setInteger("ProgramId", i) + configs += Array(config) + } + + configs.asJavaCollection + } +} + http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/a0ad9031/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FirstNOperatorTest.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FirstNOperatorTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FirstNOperatorTest.scala new file mode 100644 index 0000000..7c259b8 --- /dev/null +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FirstNOperatorTest.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.operators + +import org.apache.flink.api.common.InvalidProgramException +import org.apache.flink.api.common.operators.Order +import org.apache.flink.api.scala.ExecutionEnvironment +import org.junit.{Assert, Test} + +import org.apache.flink.api.scala._ + +class FirstNOperatorTest { + + private val emptyTupleData = Array[(Int, Long, String, Long, Int)]() + + @Test + def testUngroupedFirstN(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + + val tupleDs = env.fromCollection(emptyTupleData) + + try { + tupleDs.first(1) + } + catch { + case e: Exception => { + Assert.fail() + } + } + try { + tupleDs.first(10) + } + catch { + case e: Exception => { + Assert.fail() + } + } + try { + tupleDs.first(0) + Assert.fail() + } + catch { + case ipe: InvalidProgramException => { + } + case e: Exception => { + Assert.fail() + } + } + try { + tupleDs.first(-1) + Assert.fail() + } + catch { + case ipe: InvalidProgramException => { + } + case e: Exception => { + Assert.fail() + } + } + } + + @Test + def testGroupedFirstN(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tupleDs = env.fromCollection(emptyTupleData) + + try { + tupleDs.groupBy(2).first(1) + } + catch { + case e: Exception => { + Assert.fail() + } + } + try { + tupleDs.groupBy(1, 3).first(10) + } + catch { + case e: Exception => { + Assert.fail() + } + } + try { + tupleDs.groupBy(0).first(0) + Assert.fail() + } + catch { + case ipe: InvalidProgramException => { + } + case e: Exception => { + Assert.fail() + } + } + try { + tupleDs.groupBy(2).first(-1) + Assert.fail() + } + catch { + case ipe: InvalidProgramException => { + } + case e: Exception => { + Assert.fail() + } + } + } + + @Test + def testGroupedSortedFirstN(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tupleDs = env.fromCollection(emptyTupleData) + + try { + tupleDs.groupBy(2).sortGroup(4, Order.ASCENDING).first(1) + } + catch { + case e: Exception => { + Assert.fail() + } + } + try { + tupleDs.groupBy(1, 3).sortGroup(4, Order.ASCENDING).first(10) + } + catch { + case e: Exception => { + Assert.fail() + } + } + try { + tupleDs.groupBy(0).sortGroup(4, Order.ASCENDING).first(0) + Assert.fail() + } + catch { + case ipe: InvalidProgramException => { + } + case e: Exception => { + Assert.fail() + } + } + try { + tupleDs.groupBy(2).sortGroup(4, Order.ASCENDING).first(-1) + Assert.fail() + } + catch { + case ipe: InvalidProgramException => { + } + case e: Exception => { + Assert.fail() + } + } + } + +} + http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/a0ad9031/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FlatMapITCase.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FlatMapITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FlatMapITCase.scala new file mode 100644 index 0000000..0d80d22 --- /dev/null +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/FlatMapITCase.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.operators + +import org.apache.flink.api.common.functions.RichFlatMapFunction +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.api.scala.util.CollectionDataSets.MutableTuple3 +import org.apache.flink.configuration.Configuration +import org.apache.flink.test.util.JavaProgramTestBase +import org.apache.flink.util.Collector +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.flink.api.scala._ + + +object FlatMapProgs { + var NUM_PROGRAMS: Int = 7 + + def runProgram(progId: Int, resultPath: String): String = { + progId match { + case 1 => + /* + * Test non-passing flatmap + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getStringDataSet(env) + val nonPassingFlatMapDs = ds.flatMap( in => if (in.contains("banana")) Some(in) else None ) + nonPassingFlatMapDs.writeAsText(resultPath) + env.execute() + "\n" + + case 2 => + /* + * Test data duplicating flatmap + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getStringDataSet(env) + val duplicatingFlatMapDs = ds.flatMap( in => Seq(in, in.toUpperCase) ) + duplicatingFlatMapDs.writeAsText(resultPath) + env.execute() + "Hi\n" + "HI\n" + "Hello\n" + "HELLO\n" + "Hello world\n" + "HELLO WORLD\n" + + "Hello world, how are you?\n" + "HELLO WORLD, HOW ARE YOU?\n" + "I am fine.\n" + "I AM " + + "FINE.\n" + "Luke Skywalker\n" + "LUKE SKYWALKER\n" + "Random comment\n" + "RANDOM " + + "COMMENT\n" + "LOL\n" + "LOL\n" + + case 3 => + /* + * Test flatmap with varying number of emitted tuples + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val varyingTuplesMapDs = ds.flatMap { + in => + val numTuples = in._1 % 3 + (0 until numTuples) map { i => in } + } + varyingTuplesMapDs.writeAsCsv(resultPath) + env.execute() + "1,1,Hi\n" + "2,2,Hello\n" + "2,2,Hello\n" + "4,3,Hello world, " + + "how are you?\n" + "5,3,I am fine.\n" + "5,3,I am fine.\n" + "7,4,Comment#1\n" + "8,4," + + "Comment#2\n" + "8,4,Comment#2\n" + "10,4,Comment#4\n" + "11,5,Comment#5\n" + "11,5," + + "Comment#5\n" + "13,5,Comment#7\n" + "14,5,Comment#8\n" + "14,5,Comment#8\n" + "16,6," + + "Comment#10\n" + "17,6,Comment#11\n" + "17,6,Comment#11\n" + "19,6,Comment#13\n" + "20," + + "6,Comment#14\n" + "20,6,Comment#14\n" + + case 4 => + /* + * Test type conversion flatmapper (Custom -> Tuple) + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getCustomTypeDataSet(env) + val typeConversionFlatMapDs = ds.flatMap { in => Some((in.myInt, in.myLong, in.myString)) } + typeConversionFlatMapDs.writeAsCsv(resultPath) + env.execute() + "1,0,Hi\n" + "2,1,Hello\n" + "2,2,Hello world\n" + "3,3,Hello world, " + + "how are you?\n" + "3,4,I am fine.\n" + "3,5,Luke Skywalker\n" + "4,6," + + "Comment#1\n" + "4,7,Comment#2\n" + "4,8,Comment#3\n" + "4,9,Comment#4\n" + "5,10," + + "Comment#5\n" + "5,11,Comment#6\n" + "5,12,Comment#7\n" + "5,13,Comment#8\n" + "5,14," + + "Comment#9\n" + "6,15,Comment#10\n" + "6,16,Comment#11\n" + "6,17,Comment#12\n" + "6," + + "18,Comment#13\n" + "6,19,Comment#14\n" + "6,20,Comment#15\n" + + case 5 => + /* + * Test type conversion flatmapper (Tuple -> Basic) + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val typeConversionFlatMapDs = ds.flatMap ( in => Some(in._3) ) + typeConversionFlatMapDs.writeAsText(resultPath) + env.execute() + "Hi\n" + "Hello\n" + "Hello world\n" + "Hello world, how are you?\n" + "I am fine" + + ".\n" + "Luke Skywalker\n" + "Comment#1\n" + "Comment#2\n" + "Comment#3\n" + + "Comment#4\n" + "Comment#5\n" + "Comment#6\n" + "Comment#7\n" + "Comment#8\n" + + "Comment#9\n" + "Comment#10\n" + "Comment#11\n" + "Comment#12\n" + "Comment#13\n" + + "Comment#14\n" + "Comment#15\n" + + case 6 => + /* + * Test flatmapper if UDF returns input object + * multiple times and changes it in between + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env).map { + t => MutableTuple3(t._1, t._2, t._3) + } + val inputObjFlatMapDs = ds.flatMap { + (in, out: Collector[MutableTuple3[Int, Long, String]]) => + val numTuples = in._1 % 4 + (0 until numTuples) foreach { i => in._1 = i; out.collect(in) } + } + inputObjFlatMapDs.writeAsCsv(resultPath) + env.execute() + "0,1,Hi\n" + "0,2,Hello\n" + "1,2,Hello\n" + "0,2,Hello world\n" + "1,2," + + "Hello world\n" + "2,2,Hello world\n" + "0,3,I am fine.\n" + "0,3," + + "Luke Skywalker\n" + "1,3,Luke Skywalker\n" + "0,4,Comment#1\n" + "1,4," + + "Comment#1\n" + "2,4,Comment#1\n" + "0,4,Comment#3\n" + "0,4,Comment#4\n" + "1,4," + + "Comment#4\n" + "0,5,Comment#5\n" + "1,5,Comment#5\n" + "2,5,Comment#5\n" + "0,5," + + "Comment#7\n" + "0,5,Comment#8\n" + "1,5,Comment#8\n" + "0,5,Comment#9\n" + "1,5," + + "Comment#9\n" + "2,5,Comment#9\n" + "0,6,Comment#11\n" + "0,6,Comment#12\n" + "1,6," + + "Comment#12\n" + "0,6,Comment#13\n" + "1,6,Comment#13\n" + "2,6,Comment#13\n" + "0,6," + + "Comment#15\n" + + case 7 => + /* + * Test flatmap with broadcast set + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ints = CollectionDataSets.getIntDataSet(env) + val ds = CollectionDataSets.get3TupleDataSet(env).map { + t => MutableTuple3(t._1, t._2, t._3) + } + val bcFlatMapDs = ds.flatMap( + new RichFlatMapFunction[MutableTuple3[Int, Long, String], + MutableTuple3[Int, Long, String]] { + private var f2Replace = 0 + private val outTuple = MutableTuple3(0, 0L, "") + override def open(config: Configuration): Unit = { + val ints = getRuntimeContext.getBroadcastVariable[Int]("ints").asScala + f2Replace = ints.sum + } + override def flatMap( + value: MutableTuple3[Int, Long, String], + out: Collector[MutableTuple3[Int, Long, String]]): Unit = { + outTuple._1 = f2Replace + outTuple._2 = value._2 + outTuple._3 = value._3 + out.collect(outTuple) + } + }).withBroadcastSet(ints, "ints") + bcFlatMapDs.writeAsCsv(resultPath) + env.execute() + "55,1,Hi\n" + "55,2,Hello\n" + "55,2,Hello world\n" + "55,3,Hello world, " + + "how are you?\n" + "55,3,I am fine.\n" + "55,3,Luke Skywalker\n" + "55,4," + + "Comment#1\n" + "55,4,Comment#2\n" + "55,4,Comment#3\n" + "55,4,Comment#4\n" + "55,5," + + "Comment#5\n" + "55,5,Comment#6\n" + "55,5,Comment#7\n" + "55,5,Comment#8\n" + "55,5," + + "Comment#9\n" + "55,6,Comment#10\n" + "55,6,Comment#11\n" + "55,6,Comment#12\n" + "55," + + "6,Comment#13\n" + "55,6,Comment#14\n" + "55,6,Comment#15\n" + + case _ => + throw new IllegalArgumentException("Invalid program id") + } + } +} + + +@RunWith(classOf[Parameterized]) +class FlatMapITCase(config: Configuration) extends JavaProgramTestBase(config) { + + private var curProgId: Int = config.getInteger("ProgramId", -1) + private var resultPath: String = null + private var expectedResult: String = null + + protected override def preSubmit(): Unit = { + resultPath = getTempDirPath("result") + } + + protected def testProgram(): Unit = { + expectedResult = FlatMapProgs.runProgram(curProgId, resultPath) + } + + protected override def postSubmit(): Unit = { + compareResultsByLinesInMemory(expectedResult, resultPath) + } +} + +object FlatMapITCase { + @Parameters + def getConfigurations: java.util.Collection[Array[AnyRef]] = { + val configs = mutable.MutableList[Array[AnyRef]]() + for (i <- 1 to FlatMapProgs.NUM_PROGRAMS) { + val config = new Configuration() + config.setInteger("ProgramId", i) + configs += Array(config) + } + + configs.asJavaCollection + } +} + http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/a0ad9031/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupReduceITCase.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupReduceITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupReduceITCase.scala new file mode 100644 index 0000000..7b81933 --- /dev/null +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupReduceITCase.scala @@ -0,0 +1,748 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.operators + +import java.lang.Iterable + +import org.apache.flink.api.common.functions._ +import org.apache.flink.api.common.operators.Order +import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.api.scala.util.CollectionDataSets.{CrazyNested, POJO, MutableTuple3, +CustomType} +import org.apache.flink.compiler.PactCompiler +import org.apache.flink.configuration.Configuration +import org.apache.flink.test.util.JavaProgramTestBase +import org.apache.flink.util.Collector +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.flink.api.scala._ + + +object GroupReduceProgs { + var NUM_PROGRAMS: Int = 26 + + def runProgram(progId: Int, resultPath: String, onCollection: Boolean): String = { + progId match { + case 1 => + /* + * check correctness of groupReduce on tuples with key field selector + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val reduceDs = ds.groupBy(1).reduceGroup { + in => + in.map(t => (t._1, t._2)).reduce((l, r) => (l._1 + r._1, l._2)) + } + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1\n" + "5,2\n" + "15,3\n" + "34,4\n" + "65,5\n" + "111,6\n" + + case 2 => + /* + * check correctness of groupReduce on tuples with multiple key field selector + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets + .get5TupleDataSet(env) + val reduceDs = ds.groupBy(4, 0).reduceGroup { + in => + val (i, l, l2) = in + .map( t => (t._1, t._2, t._5)) + .reduce((l, r) => (l._1, l._2 + r._2, l._3)) + (i, l, 0, "P-)", l2) + } + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1,0,P-),1\n" + "2,3,0,P-),1\n" + "2,2,0,P-),2\n" + "3,9,0,P-),2\n" + "3,6,0," + + "P-),3\n" + "4,17,0,P-),1\n" + "4,17,0,P-),2\n" + "5,11,0,P-),1\n" + "5,29,0,P-)," + + "2\n" + "5,25,0,P-),3\n" + + case 3 => + /* + * check correctness of groupReduce on tuples with key field selector and group sorting + */ + val env = ExecutionEnvironment.getExecutionEnvironment + env.setDegreeOfParallelism(1) + val ds = CollectionDataSets.get3TupleDataSet(env) + val reduceDs = ds.groupBy(1).sortGroup(2, Order.ASCENDING).reduceGroup { + in => + in.reduce((l, r) => (l._1 + r._1, l._2, l._3 + "-" + r._3)) + } + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1,Hi\n" + + "5,2,Hello-Hello world\n" + + "15,3,Hello world, how are you?-I am fine.-Luke Skywalker\n" + + "34,4,Comment#1-Comment#2-Comment#3-Comment#4\n" + + "65,5,Comment#5-Comment#6-Comment#7-Comment#8-Comment#9\n" + + "111,6,Comment#10-Comment#11-Comment#12-Comment#13-Comment#14-Comment#15\n" + + case 4 => + /* + * check correctness of groupReduce on tuples with key extractor + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val reduceDs = ds.groupBy(_._2).reduceGroup { + in => + in.map(t => (t._1, t._2)).reduce((l, r) => (l._1 + r._1, l._2)) + } + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1\n" + "5,2\n" + "15,3\n" + "34,4\n" + "65,5\n" + "111,6\n" + + case 5 => + /* + * check correctness of groupReduce on custom type with type extractor + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getCustomTypeDataSet(env) + val reduceDs = ds.groupBy(_.myInt).reduceGroup { + in => + val iter = in.toIterator + val o = new CustomType + val c = iter.next() + + o.myString = "Hello!" + o.myInt = c.myInt + o.myLong = c.myLong + + while (iter.hasNext) { + val next = iter.next() + o.myLong += next.myLong + } + o + } + reduceDs.writeAsText(resultPath) + env.execute() + "1,0,Hello!\n" + "2,3,Hello!\n" + "3,12,Hello!\n" + "4,30,Hello!\n" + "5,60," + + "Hello!\n" + "6,105,Hello!\n" + + case 6 => + /* + * check correctness of all-groupreduce for tuples + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val reduceDs = ds.reduceGroup { + in => + var i = 0 + var l = 0L + for (t <- in) { + i += t._1 + l += t._2 + } + (i, l, "Hello World") + } + reduceDs.writeAsCsv(resultPath) + env.execute() + "231,91,Hello World\n" + + case 7 => + /* + * check correctness of all-groupreduce for custom types + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getCustomTypeDataSet(env) + val reduceDs = ds.reduceGroup { + in => + val o = new CustomType(0, 0, "Hello!") + for (t <- in) { + o.myInt += t.myInt + o.myLong += t.myLong + } + o + } + reduceDs.writeAsText(resultPath) + env.execute() + "91,210,Hello!" + + case 8 => + /* + * check correctness of groupReduce with broadcast set + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val intDs = CollectionDataSets.getIntDataSet(env) + val ds = CollectionDataSets.get3TupleDataSet(env) + val reduceDs = ds.groupBy(1).reduceGroup( + new RichGroupReduceFunction[(Int, Long, String), (Int, Long, String)] { + private var f2Replace = "" + + override def open(config: Configuration) { + val ints = this.getRuntimeContext.getBroadcastVariable[Int]("ints").asScala + f2Replace = ints.sum + "" + } + + override def reduce( + values: Iterable[(Int, Long, String)], + out: Collector[(Int, Long, String)]): Unit = { + var i: Int = 0 + var l: Long = 0L + for (t <- values.asScala) { + i += t._1 + l = t._2 + } + out.collect((i, l, f2Replace)) + } + }).withBroadcastSet(intDs, "ints") + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1,55\n" + "5,2,55\n" + "15,3,55\n" + "34,4,55\n" + "65,5,55\n" + "111,6,55\n" + + case 9 => + /* + * check correctness of groupReduce if UDF returns input objects multiple times and + * changes it in between + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + .map( t => MutableTuple3(t._1, t._2, t._3) ) + val reduceDs = ds.groupBy(1).reduceGroup { + (in, out: Collector[MutableTuple3[Int, Long, String]]) => + for (t <- in) { + if (t._1 < 4) { + t._3 = "Hi!" + t._1 += 10 + out.collect(t) + t._1 += 10 + t._3 = "Hi again!" + out.collect(t) + } + } + } + reduceDs.writeAsCsv(resultPath) + env.execute() + "11,1,Hi!\n" + "21,1,Hi again!\n" + "12,2,Hi!\n" + "22,2,Hi again!\n" + "13,2," + + "Hi!\n" + "23,2,Hi again!\n" + + case 10 => + /* + * check correctness of groupReduce on custom type with key extractor and combine + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getCustomTypeDataSet(env) + + @RichGroupReduceFunction.Combinable + class CustomTypeGroupReduceWithCombine + extends RichGroupReduceFunction[CustomType, CustomType] { + override def combine(values: Iterable[CustomType], out: Collector[CustomType]): Unit = { + val o = new CustomType() + for (c <- values.asScala) { + o.myInt = c.myInt + o.myLong += c.myLong + o.myString = "test" + c.myInt + } + out.collect(o) + } + + override def reduce(values: Iterable[CustomType], out: Collector[CustomType]): Unit = { + val o = new CustomType(0, 0, "") + for (c <- values.asScala) { + o.myInt = c.myInt + o.myLong += c.myLong + o.myString = c.myString + } + out.collect(o) + } + } + val reduceDs = ds.groupBy(_.myInt).reduceGroup(new CustomTypeGroupReduceWithCombine) + + reduceDs.writeAsText(resultPath) + env.execute() + if (onCollection) { + null + } + else { + "1,0,test1\n" + "2,3,test2\n" + "3,12,test3\n" + "4,30,test4\n" + "5,60," + + "test5\n" + "6,105,test6\n" + } + + case 11 => + /* + * check correctness of groupReduce on tuples with combine + */ + val env = ExecutionEnvironment.getExecutionEnvironment + // important because it determines how often the combiner is called + env.setDegreeOfParallelism(2) + val ds = CollectionDataSets.get3TupleDataSet(env) + @RichGroupReduceFunction.Combinable + class Tuple3GroupReduceWithCombine + extends RichGroupReduceFunction[(Int, Long, String), (Int, String)] { + override def combine( + values: Iterable[(Int, Long, String)], + out: Collector[(Int, Long, String)]): Unit = { + var i = 0 + var l = 0L + var s = "" + for (t <- values.asScala) { + i += t._1 + l = t._2 + s = "test" + t._2 + } + out.collect((i, l, s)) + } + + override def reduce( + values: Iterable[(Int, Long, String)], + out: Collector[(Int, String)]): Unit = { + var i = 0 + var s = "" + for (t <- values.asScala) { + i += t._1 + s = t._3 + } + out.collect((i, s)) + } + } + val reduceDs = ds.groupBy(1).reduceGroup(new Tuple3GroupReduceWithCombine) + reduceDs.writeAsCsv(resultPath) + env.execute() + if (onCollection) { + null + } + else { + "1,test1\n" + "5,test2\n" + "15,test3\n" + "34,test4\n" + "65,test5\n" + "111," + + "test6\n" + } + + + // all-groupreduce with combine + + + case 12 => + /* + * check correctness of all-groupreduce for tuples with combine + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env).map(t => t).setParallelism(4) + + val cfg: Configuration = new Configuration + cfg.setString(PactCompiler.HINT_SHIP_STRATEGY, PactCompiler.HINT_SHIP_STRATEGY_REPARTITION) + + @RichGroupReduceFunction.Combinable + class Tuple3AllGroupReduceWithCombine + extends RichGroupReduceFunction[(Int, Long, String), (Int, String)] { + override def combine( + values: Iterable[(Int, Long, String)], + out: Collector[(Int, Long, String)]): Unit = { + var i = 0 + var l = 0L + var s = "" + for (t <- values.asScala) { + i += t._1 + l += t._2 + s += "test" + } + out.collect((i, l, s)) + } + + override def reduce( + values: Iterable[(Int, Long, String)], + out: Collector[(Int, String)]): Unit = { + var i = 0 + var s = "" + for (t <- values.asScala) { + i += t._1 + t._2.toInt + s += t._3 + } + out.collect((i, s)) + } + } + val reduceDs = ds.reduceGroup(new Tuple3AllGroupReduceWithCombine).withParameters(cfg) + + reduceDs.writeAsCsv(resultPath) + env.execute() + if (onCollection) { + null + } + else { + "322," + + "testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttest\n" + } + + case 13 => + /* + * check correctness of groupReduce with descending group sort + */ + val env = ExecutionEnvironment.getExecutionEnvironment + env.setDegreeOfParallelism(1) + val ds = CollectionDataSets.get3TupleDataSet(env) + val reduceDs = ds.groupBy(1).sortGroup(2, Order.DESCENDING).reduceGroup { + in => + in.reduce((l, r) => (l._1 + r._1, l._2, l._3 + "-" + r._3)) + } + + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1,Hi\n" + "5,2,Hello world-Hello\n" + "15,3,Luke Skywalker-I am fine.-Hello " + + "world, how are you?\n" + "34,4,Comment#4-Comment#3-Comment#2-Comment#1\n" + "65,5," + + "Comment#9-Comment#8-Comment#7-Comment#6-Comment#5\n" + "111,6," + + "Comment#15-Comment#14-Comment#13-Comment#12-Comment#11-Comment#10\n" + + case 14 => + /* + * check correctness of groupReduce on tuples with tuple-returning key selector + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets + .get5TupleDataSet(env) + val reduceDs = ds.groupBy( t => (t._1, t._5)).reduceGroup { + in => + val (i, l, l2) = in + .map( t => (t._1, t._2, t._5)) + .reduce((l, r) => (l._1, l._2 + r._2, l._3)) + (i, l, 0, "P-)", l2) + } + + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1,0,P-),1\n" + "2,3,0,P-),1\n" + "2,2,0,P-),2\n" + "3,9,0,P-),2\n" + "3,6,0," + + "P-),3\n" + "4,17,0,P-),1\n" + "4,17,0,P-),2\n" + "5,11,0,P-),1\n" + "5,29,0,P-)," + + "2\n" + "5,25,0,P-),3\n" + + case 15 => + /* + * check that input of combiner is also sorted for combinable groupReduce with group + * sorting + */ + val env = ExecutionEnvironment.getExecutionEnvironment + env.setDegreeOfParallelism(1) + val ds = CollectionDataSets.get3TupleDataSet(env).map { t => + MutableTuple3(t._1, t._2, t._3) + } + + @RichGroupReduceFunction.Combinable + class OrderCheckingCombinableReduce + extends RichGroupReduceFunction[MutableTuple3[Int, Long, String], + MutableTuple3[Int, Long, String]] { + def reduce( + values: Iterable[MutableTuple3[Int, Long, String]], + out: Collector[MutableTuple3[Int, Long, String]]) { + val it = values.iterator() + var t = it.next() + val i = t._1 + out.collect(t) + + while (it.hasNext) { + t = it.next() + if (i > t._1 || (t._3 == "INVALID-ORDER!")) { + t._3 = "INVALID-ORDER!" + out.collect(t) + } + } + } + + override def combine( + values: Iterable[MutableTuple3[Int, Long, String]], + out: Collector[MutableTuple3[Int, Long, String]]) { + val it = values.iterator() + var t = it.next + val i: Int = t._1 + out.collect(t) + while (it.hasNext) { + t = it.next + if (i > t._1) { + t._3 = "INVALID-ORDER!" + out.collect(t) + } + } + } + } + + val reduceDs = ds.groupBy(1) + .sortGroup(0, Order.ASCENDING).reduceGroup(new OrderCheckingCombinableReduce) + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1,Hi\n" + "2,2,Hello\n" + "4,3,Hello world, how are you?\n" + "7,4," + + "Comment#1\n" + "11,5,Comment#5\n" + "16,6,Comment#10\n" + + case 16 => + /* + * Deep nesting test + * + null value in pojo + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets + .getCrazyNestedDataSet(env) + val reduceDs = ds.groupBy("nest_Lvl1.nest_Lvl2.nest_Lvl3.nest_Lvl4.f1nal") + .reduceGroup { + in => + var c = 0 + var n: String = null + for (v <- in) { + c += 1 + n = v.nest_Lvl1.nest_Lvl2.nest_Lvl3.nest_Lvl4.f1nal + } + (n, c) + } + reduceDs.writeAsCsv(resultPath) + env.execute() + "aa,1\nbb,2\ncc,3\n" + + + case 17 => + // We don't have that test but keep numbering compatible to Java GroupReduceITCase + val env = ExecutionEnvironment.getExecutionEnvironment + env.fromElements("Hello world").writeAsText(resultPath) + env.execute() + "Hello world" + + case 18 => + /* + * Test Pojo containing a Writable and Tuples + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets + .getPojoContainingTupleAndWritable(env) + val reduceDs = ds.groupBy("hadoopFan", "theTuple.*").reduceGroup(new + GroupReduceFunction[CollectionDataSets.PojoContainingTupleAndWritable, Integer] { + def reduce( + values: Iterable[CollectionDataSets.PojoContainingTupleAndWritable], + out: Collector[Integer]) { + var c: Int = 0 + for (v <- values.asScala) { + c += 1 + } + out.collect(c) + } + }) + reduceDs.writeAsText(resultPath) + env.execute() + "1\n5\n" + + case 19 => + /* + * Test Tuple containing pojos and regular fields + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.getTupleContainingPojos(env) + val reduceDs = ds.groupBy("_1", "_2.*").reduceGroup( + new GroupReduceFunction[(Int, CrazyNested, POJO), Int] { + def reduce(values: Iterable[(Int, CrazyNested, POJO)], out: Collector[Int]) { + var c: Int = 0 + for (v <- values.asScala) { + c += 1 + } + out.collect(c) + } + }) + reduceDs.writeAsText(resultPath) + env.execute() + "3\n1\n" + + case 20 => + /* + * Test string-based definition on group sort, based on test: + * check correctness of groupReduce with descending group sort + */ + val env = ExecutionEnvironment.getExecutionEnvironment + env.setDegreeOfParallelism(1) + val ds = CollectionDataSets.get3TupleDataSet(env) + val reduceDs = ds.groupBy(1) + .sortGroup("_3", Order.DESCENDING) + .reduceGroup { + in => + in.reduce((l, r) => (l._1 + r._1, l._2, l._3 + "-" + r._3)) + } + reduceDs.writeAsCsv(resultPath) + env.execute() + "1,1,Hi\n" + "5,2,Hello world-Hello\n" + "15,3,Luke Skywalker-I am fine.-Hello " + + "world, how are you?\n" + "34,4,Comment#4-Comment#3-Comment#2-Comment#1\n" + "65,5," + + "Comment#9-Comment#8-Comment#7-Comment#6-Comment#5\n" + "111,6," + + "Comment#15-Comment#14-Comment#13-Comment#12-Comment#11-Comment#10\n" + + case 21 => + /* + * Test int-based definition on group sort, for (full) nested Tuple + */ + val env = ExecutionEnvironment.getExecutionEnvironment + env.setDegreeOfParallelism(1) + val ds = CollectionDataSets.getGroupSortedNestedTupleDataSet(env) + val reduceDs = ds.groupBy("_2").sortGroup(0, Order.DESCENDING) + .reduceGroup(new NestedTupleReducer) + reduceDs.writeAsText(resultPath) + env.execute() + "a--(2,1)-(1,3)-(1,2)-\n" + "b--(2,2)-\n" + "c--(4,9)-(3,6)-(3,3)-\n" + + + case 22 => + /* + * Test int-based definition on group sort, for (partial) nested Tuple ASC + */ + val env = ExecutionEnvironment.getExecutionEnvironment + env.setDegreeOfParallelism(1) + val ds = CollectionDataSets.getGroupSortedNestedTupleDataSet(env) + val reduceDs = ds.groupBy("_2") + .sortGroup("_1._1", Order.ASCENDING) + .sortGroup("_1._2", Order.ASCENDING) + .reduceGroup(new NestedTupleReducer) + reduceDs.writeAsText(resultPath) + env.execute() + "a--(1,2)-(1,3)-(2,1)-\n" + "b--(2,2)-\n" + "c--(3,3)-(3,6)-(4,9)-\n" + + case 23 => + /* + * Test string-based definition on group sort, for (partial) nested Tuple DESC + */ + val env = ExecutionEnvironment.getExecutionEnvironment + env.setDegreeOfParallelism(1) + val ds = CollectionDataSets.getGroupSortedNestedTupleDataSet(env) + val reduceDs = ds.groupBy("_2") + .sortGroup("_1._1", Order.DESCENDING) + .sortGroup("_1._2", Order.ASCENDING) + .reduceGroup(new NestedTupleReducer) + reduceDs.writeAsText(resultPath) + env.execute() + "a--(2,1)-(1,2)-(1,3)-\n" + "b--(2,2)-\n" + "c--(4,9)-(3,3)-(3,6)-\n" + + case 24 => + /* + * Test string-based definition on group sort, for two grouping keys + */ + val env = ExecutionEnvironment.getExecutionEnvironment + env.setDegreeOfParallelism(1) + val ds = CollectionDataSets.getGroupSortedNestedTupleDataSet(env) + val reduceDs = ds.groupBy("_2") + .sortGroup("_1._1", Order.DESCENDING) + .sortGroup("_1._2", Order.DESCENDING) + .reduceGroup(new NestedTupleReducer) + reduceDs.writeAsText(resultPath) + env.execute() + "a--(2,1)-(1,3)-(1,2)-\n" + "b--(2,2)-\n" + "c--(4,9)-(3,6)-(3,3)-\n" + + case 25 => + /* + * Test string-based definition on group sort, for two grouping keys with Pojos + */ + val env = ExecutionEnvironment.getExecutionEnvironment + env.setDegreeOfParallelism(1) + val ds = CollectionDataSets.getGroupSortedPojoContainingTupleAndWritable(env) + val reduceDs = ds.groupBy("hadoopFan") + .sortGroup("theTuple._1", Order.DESCENDING) + .sortGroup("theTuple._2", Order.DESCENDING) + .reduceGroup( + new GroupReduceFunction[CollectionDataSets.PojoContainingTupleAndWritable, String] { + def reduce( + values: Iterable[CollectionDataSets.PojoContainingTupleAndWritable], + out: Collector[String]) { + var once: Boolean = false + val concat: StringBuilder = new StringBuilder + for (value <- values.asScala) { + if (!once) { + concat.append(value.hadoopFan.get) + concat.append("---") + once = true + } + concat.append(value.theTuple) + concat.append("-") + } + out.collect(concat.toString()) + } + }) + reduceDs.writeAsText(resultPath) + env.execute() + "1---(10,100)-\n" + "2---(30,600)-(30,400)-(30,200)-(20,201)-(20,200)-\n" + + case 26 => + /* + * Test grouping with pojo containing multiple pojos (was a bug) + */ + val env = ExecutionEnvironment.getExecutionEnvironment + env.setDegreeOfParallelism(1) + val ds = CollectionDataSets.getPojoWithMultiplePojos(env) + val reduceDs = ds.groupBy("p2.a2") + .reduceGroup( + new GroupReduceFunction[CollectionDataSets.PojoWithMultiplePojos, String] { + def reduce( + values: Iterable[CollectionDataSets.PojoWithMultiplePojos], + out: Collector[String]) { + val concat: StringBuilder = new StringBuilder + for (value <- values.asScala) { + concat.append(value.p2.a2) + } + out.collect(concat.toString()) + } + }) + reduceDs.writeAsText(resultPath) + env.execute() + "b\nccc\nee\n" + + case _ => + throw new IllegalArgumentException("Invalid program id") + } + } +} + + +@RunWith(classOf[Parameterized]) +class GroupReduceITCase(config: Configuration) extends JavaProgramTestBase(config) { + + private val curProgId: Int = config.getInteger("ProgramId", -1) + private var resultPath: String = null + private var expectedResult: String = null + + protected override def preSubmit(): Unit = { + resultPath = getTempDirPath("result") + } + + protected def testProgram(): Unit = { + expectedResult = GroupReduceProgs.runProgram(curProgId, resultPath, isCollectionExecution) + } + + protected override def postSubmit(): Unit = { + if (expectedResult != null) compareResultsByLinesInMemory(expectedResult, resultPath) + } +} + +object GroupReduceITCase { + @Parameters + def getConfigurations: java.util.Collection[Array[AnyRef]] = { + val configs = mutable.MutableList[Array[AnyRef]]() + for (i <- 1 to GroupReduceProgs.NUM_PROGRAMS) { + val config = new Configuration() + config.setInteger("ProgramId", i) + configs += Array(config) + } + + configs.asJavaCollection + } +} + +class NestedTupleReducer extends GroupReduceFunction[((Int, Int), String), String] { + def reduce(values: Iterable[((Int, Int), String)], out: Collector[String]) { + var once: Boolean = false + val concat: StringBuilder = new StringBuilder + for (value <- values.asScala) { + if (!once) { + concat.append(value._2).append("--") + once = true + } + concat.append(value._1) + concat.append("-") + } + out.collect(concat.toString()) + } +} + + http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/a0ad9031/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupingTest.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupingTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupingTest.scala new file mode 100644 index 0000000..fe1dd43 --- /dev/null +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupingTest.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.operators + +import org.apache.flink.api.scala.util.CollectionDataSets.CustomType +import org.junit.Assert +import org.apache.flink.api.common.InvalidProgramException +import org.apache.flink.api.common.operators.Order +import org.junit.Test + +import org.apache.flink.api.scala._ + + +class GroupingTest { + + private val emptyTupleData = Array[(Int, Long, String, Long, Int)]() + private val customTypeData = Array[CustomType](new CustomType()) + private val emptyLongData = Array[Long]() + + @Test + def testGroupByKeyIndices1(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tupleDs = env.fromCollection(emptyTupleData) + + // should work + try { + tupleDs.groupBy(0) + } + catch { + case e: Exception => Assert.fail() + } + } + + @Test(expected = classOf[InvalidProgramException]) + def testGroupByKeyIndices2(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val longDs = env.fromCollection(emptyLongData) + + // should not work, grouping on basic type + longDs.groupBy(0) + } + + @Test(expected = classOf[InvalidProgramException]) + def testGroupByKeyIndices3(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val customDs = env.fromCollection(customTypeData) + + // should not work, field position key on custom type + customDs.groupBy(0) + } + + @Test(expected = classOf[IllegalArgumentException]) + def testGroupByKeyIndices4(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tupleDs = env.fromCollection(emptyTupleData) + + // should not work, fiels position out of range + tupleDs.groupBy(5) + } + + @Test(expected = classOf[IllegalArgumentException]) + def testGroupByKeyIndices5(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tupleDs = env.fromCollection(emptyTupleData) + + // should not work, negative field position + tupleDs.groupBy(-1) + } + + @Test + def testGroupByKeyFields1(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tupleDs = env.fromCollection(emptyTupleData) + + // should work + try { + tupleDs.groupBy("_1") + } + catch { + case e: Exception => Assert.fail() + } + } + + @Test(expected = classOf[IllegalArgumentException]) + def testGroupByKeyFields2(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val longDs = env.fromCollection(emptyLongData) + + // should not work, grouping on basic type + longDs.groupBy("_1") + } + + @Test(expected = classOf[IllegalArgumentException]) + def testGroupByKeyFields3(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val customDs = env.fromCollection(customTypeData) + + // should not work, field key on custom type + customDs.groupBy("_1") + } + + @Test(expected = classOf[RuntimeException]) + def testGroupByKeyFields4(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tupleDs = env.fromCollection(emptyTupleData) + + // should not work, invalid field + tupleDs.groupBy("foo") + } + + @Test + def testGroupByKeyFields5(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val customDs = env.fromCollection(customTypeData) + + // should not work + customDs.groupBy("myInt") + } + + @Test + def testGroupByKeyExpressions1(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromCollection(customTypeData) + + // should work + try { + ds.groupBy("myInt") + } + catch { + case e: Exception => Assert.fail() + } + } + + @Test(expected = classOf[IllegalArgumentException]) + def testGroupByKeyExpressions2(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + + // should not work: groups on basic type + val longDs = env.fromCollection(emptyLongData) + longDs.groupBy("l") + } + + @Test(expected = classOf[InvalidProgramException]) + def testGroupByKeyExpressions3(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val customDs = env.fromCollection(customTypeData) + + // should not work: groups on custom type + customDs.groupBy(0) + } + + @Test(expected = classOf[IllegalArgumentException]) + def testGroupByKeyExpressions4(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromCollection(customTypeData) + + // should not work, non-existent field + ds.groupBy("myNonExistent") + } + + @Test + def testGroupByKeySelector1(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + try { + val customDs = env.fromCollection(customTypeData) + customDs.groupBy { _.myLong } + } + catch { + case e: Exception => Assert.fail() + } + } + + @Test + def testGroupSortKeyFields1(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tupleDs = env.fromCollection(emptyTupleData) + try { + tupleDs.groupBy(0).sortGroup(0, Order.ASCENDING) + } + catch { + case e: Exception => Assert.fail() + } + } + + @Test(expected = classOf[IllegalArgumentException]) + def testGroupSortKeyFields2(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tupleDs = env.fromCollection(emptyTupleData) + + // should not work, field position out of range + tupleDs.groupBy(0).sortGroup(5, Order.ASCENDING) + } + + @Test(expected = classOf[InvalidProgramException]) + def testGroupSortKeyFields3(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val longDs = env.fromCollection(emptyLongData) + longDs.groupBy { x: Long => x } .sortGroup(0, Order.ASCENDING) + } + + @Test + def testChainedGroupSortKeyFields(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tupleDs = env.fromCollection(emptyTupleData) + try { + tupleDs.groupBy(0).sortGroup(0, Order.ASCENDING).sortGroup(2, Order.DESCENDING) + } + catch { + case e: Exception => Assert.fail() + } + } +} + http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/a0ad9031/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinITCase.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinITCase.scala new file mode 100644 index 0000000..2605830 --- /dev/null +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/JoinITCase.scala @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.operators + +import org.apache.flink.api.common.functions.RichJoinFunction +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.api.scala.util.CollectionDataSets.MutableTuple3 +import org.apache.flink.configuration.Configuration +import org.apache.flink.test.util.JavaProgramTestBase +import org.apache.flink.util.Collector +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.flink.api.scala._ + + +object JoinProgs { + var NUM_PROGRAMS: Int = 20 + + def runProgram(progId: Int, resultPath: String): String = { + progId match { + case 1 => + /* + * UDF Join on tuples with key field positions + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env) + val ds2 = CollectionDataSets.get5TupleDataSet(env) + val joinDs = ds1.join(ds2).where(1).equalTo(1) { (l, r) => (l._3, r._4) } + joinDs.writeAsCsv(resultPath) + env.execute() + "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt\n" + + case 2 => + /* + * UDF Join on tuples with multiple key field positions + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.get3TupleDataSet(env) + val ds2 = CollectionDataSets.get5TupleDataSet(env) + val joinDs = ds1.join(ds2).where(0, 1).equalTo(0, 4) { (l, r) => (l._3, r._4) } + joinDs.writeAsCsv(resultPath) + env.execute() + "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt wie gehts?\n" + + "Hello world,ABC\n" + "I am fine.,HIJ\n" + "I am fine.,IJK\n" + + case 3 => + /* + * Default Join on tuples + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env) + val ds2 = CollectionDataSets.get5TupleDataSet(env) + val joinDs = ds1.join(ds2).where(0).equalTo(2) + joinDs.writeAsCsv(resultPath) + env.execute() + "(1,1,Hi),(2,2,1,Hallo Welt,2)\n" + "(2,2,Hello),(2,3,2,Hallo Welt wie," + + "1)\n" + "(3,2,Hello world),(3,4,3,Hallo Welt wie gehts?,2)\n" + + case 4 => + /* + * Join with Huge + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env) + val ds2 = CollectionDataSets.get5TupleDataSet(env) + val joinDs = ds1.joinWithHuge(ds2).where(1).equalTo(1) { (l, r) => (l._3, r._4) } + joinDs.writeAsCsv(resultPath) + env.execute() + "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt\n" + + case 5 => + /* + * Join with Tiny + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env) + val ds2 = CollectionDataSets.get5TupleDataSet(env) + val joinDs = ds1.joinWithTiny(ds2).where(1).equalTo(1) { (l, r) => (l._3, r._4) } + joinDs.writeAsCsv(resultPath) + env.execute() + "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt\n" + + case 6 => + /* + * Join that returns the left input object + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env) + val ds2 = CollectionDataSets.get5TupleDataSet(env) + val joinDs = ds1.join(ds2).where(1).equalTo(1) { (l, r) => l } + joinDs.writeAsCsv(resultPath) + env.execute() + "1,1,Hi\n" + "2,2,Hello\n" + "3,2,Hello world\n" + + case 7 => + /* + * Join that returns the right input object + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env) + val ds2 = CollectionDataSets.get5TupleDataSet(env) + val joinDs = ds1.join(ds2).where(1).equalTo(1) { (l, r) => r } + joinDs.writeAsCsv(resultPath) + env.execute() + "1,1,0,Hallo,1\n" + "2,2,1,Hallo Welt,2\n" + "2,2,1,Hallo Welt,2\n" + + case 8 => + /* + * Join with broadcast set + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val intDs = CollectionDataSets.getIntDataSet(env) + val ds1 = CollectionDataSets.get3TupleDataSet(env) + val ds2 = CollectionDataSets.getSmall5TupleDataSet(env) + val joinDs = ds1.join(ds2).where(1).equalTo(4).apply( + new RichJoinFunction[ + (Int, Long, String), + (Int, Long, Int, String, Long), + (String, String, Int)] { + private var broadcast = 41 + + override def open(config: Configuration) { + val ints = this.getRuntimeContext.getBroadcastVariable[Int]("ints").asScala + broadcast = ints.sum + } + + override def join( + first: (Int, Long, String), + second: (Int, Long, Int, String, Long)): (String, String, Int) = { + (first._3, second. _4, broadcast) + } + } + ).withBroadcastSet(intDs, "ints") + joinDs.writeAsCsv(resultPath) + env.execute() + "Hi,Hallo,55\n" + "Hi,Hallo Welt wie,55\n" + "Hello,Hallo Welt," + + "55\n" + "Hello world,Hallo Welt,55\n" + + case 9 => + /* + * Join on a tuple input with key field selector and a custom type input with key extractor + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmallCustomTypeDataSet(env) + val ds2 = CollectionDataSets.get3TupleDataSet(env) + val joinDs = ds1.join(ds2).where( _.myInt ).equalTo(0) { (l, r) => (l.myString, r._3) } + joinDs.writeAsCsv(resultPath) + env.execute() + "Hi,Hi\n" + "Hello,Hello\n" + "Hello world,Hello\n" + + case 10 => // 12 in Java ITCase + /* + * Join on a tuple input with key field selector and a custom type input with key extractor + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env) + val ds2 = CollectionDataSets.getCustomTypeDataSet(env) + val joinDs = ds1.join(ds2).where(1).equalTo(_.myLong) apply { (l, r) => (l._3, r.myString) } + joinDs.writeAsCsv(resultPath) + env.execute() + "Hi,Hello\n" + "Hello,Hello world\n" + "Hello world,Hello world\n" + + case 11 => // 13 in Java ITCase + /* + * (Default) Join on two custom type inputs with key extractors + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getCustomTypeDataSet(env) + val ds2 = CollectionDataSets.getSmallCustomTypeDataSet(env) + val joinDs = ds1.join(ds2).where(_.myInt).equalTo(_.myInt) + joinDs.writeAsCsv(resultPath) + env.execute() + "1,0,Hi,1,0,Hi\n" + "2,1,Hello,2,1,Hello\n" + "2,1,Hello,2,2,Hello world\n" + "2," + + "2,Hello world,2,1,Hello\n" + "2,2,Hello world,2,2,Hello world\n" + + case 12 => // 14 in Java ITCase + /* + * UDF Join on tuples with tuple-returning key selectors + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.get3TupleDataSet(env) + val ds2 = CollectionDataSets.get5TupleDataSet(env) + val joinDs = ds1.join(ds2).where( t => (t._1, t._2)).equalTo( t => (t._1, t._5)) apply { + (l, r) => (l._3, r._4) + } + joinDs.writeAsCsv(resultPath) + env.execute() + "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt wie gehts?\n" + + "Hello world,ABC\n" + "I am fine.,HIJ\n" + "I am fine.,IJK\n" + + /** + * Joins with POJOs + */ + case 13 => // 15 in Java ITCase + /* + * Join nested pojo against tuple (selected using a string) + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmallPojoDataSet(env) + val ds2 = CollectionDataSets.getSmallTuplebasedPojoMatchingDataSet(env) + val joinDs = ds1.join(ds2).where("nestedPojo.longNumber").equalTo("_7") + joinDs.writeAsCsv(resultPath) + env.execute() + "1 First (10,100,1000,One) 10000,(1,First,10,100,1000,One," + + "10000)\n" + "2 Second (20,200,2000,Two) 20000,(2,Second,20,200,2000,Two," + + "20000)\n" + "3 Third (30,300,3000,Three) 30000,(3,Third,30,300,3000,Three,30000)\n" + + case 14 => // 16 in Java ITCase + /* + * Join nested pojo against tuple (selected as an integer) + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmallPojoDataSet(env) + val ds2 = CollectionDataSets.getSmallTuplebasedPojoMatchingDataSet(env) + val joinDs = ds1.join(ds2).where("nestedPojo.longNumber").equalTo(6) // <-- difference + joinDs.writeAsCsv(resultPath) + env.execute() + "1 First (10,100,1000,One) 10000,(1,First,10,100,1000,One," + + "10000)\n" + "2 Second (20,200,2000,Two) 20000,(2,Second,20,200,2000,Two," + + "20000)\n" + "3 Third (30,300,3000,Three) 30000,(3,Third,30,300,3000,Three,30000)\n" + + case 15 => // 17 in Java ITCase + /* + * selecting multiple fields using expression language + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmallPojoDataSet(env) + val ds2 = CollectionDataSets.getSmallTuplebasedPojoMatchingDataSet(env) + val joinDs = ds1.join(ds2) + .where("nestedPojo.longNumber", "number", "str") + .equalTo("_7", "_1", "_2") + joinDs.writeAsCsv(resultPath) + env.setDegreeOfParallelism(1) + env.execute() + "1 First (10,100,1000,One) 10000,(1,First,10,100,1000,One," + + "10000)\n" + "2 Second (20,200,2000,Two) 20000,(2,Second,20,200,2000,Two," + + "20000)\n" + "3 Third (30,300,3000,Three) 30000,(3,Third,30,300,3000,Three,30000)\n" + + case 16 => // 18 in Java ITCase + /* + * nested into tuple + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmallPojoDataSet(env) + val ds2 = CollectionDataSets.getSmallTuplebasedPojoMatchingDataSet(env) + val joinDs = ds1.join(ds2).where("nestedPojo.longNumber", "number", + "nestedTupleWithCustom._1").equalTo("_7", "_1", "_3") + joinDs.writeAsCsv(resultPath) + env.setDegreeOfParallelism(1) + env.execute() + "1 First (10,100,1000,One) 10000,(1,First,10,100,1000,One," + + "10000)\n" + "2 Second (20,200,2000,Two) 20000,(2,Second,20,200,2000,Two," + + "20000)\n" + "3 Third (30,300,3000,Three) 30000,(3,Third,30,300,3000,Three,30000)\n" + + case 17 => // 19 in Java ITCase + /* + * nested into tuple into pojo + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmallPojoDataSet(env) + val ds2 = CollectionDataSets.getSmallTuplebasedPojoMatchingDataSet(env) + val joinDs = ds1.join(ds2) + .where("nestedTupleWithCustom._1", + "nestedTupleWithCustom._2.myInt", + "nestedTupleWithCustom._2.myLong") + .equalTo("_3", "_4", "_5") + joinDs.writeAsCsv(resultPath) + env.setDegreeOfParallelism(1) + env.execute() + "1 First (10,100,1000,One) 10000,(1,First,10,100,1000,One," + + "10000)\n" + "2 Second (20,200,2000,Two) 20000,(2,Second,20,200,2000,Two," + + "20000)\n" + "3 Third (30,300,3000,Three) 30000,(3,Third,30,300,3000,Three,30000)\n" + + case 18 => // 20 in Java ITCase + /* + * Non-POJO test to verify that full-tuple keys are working. + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmallNestedTupleDataSet(env) + val ds2 = CollectionDataSets.getSmallNestedTupleDataSet(env) + val joinDs = ds1.join(ds2).where(0).equalTo("_1._1", "_1._2") + joinDs.writeAsCsv(resultPath) + env.setDegreeOfParallelism(1) + env.execute() + "((1,1),one),((1,1),one)\n" + "((2,2),two),((2,2),two)\n" + "((3,3),three),((3,3)," + + "three)\n" + + case 19 => // 21 in Java ITCase + /* + * Non-POJO test to verify "nested" tuple-element selection. + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmallNestedTupleDataSet(env) + val ds2 = CollectionDataSets.getSmallNestedTupleDataSet(env) + val joinDs = ds1.join(ds2).where("_1._1").equalTo("_1._1") + joinDs.writeAsCsv(resultPath) + env.setDegreeOfParallelism(1) + env.execute() + "((1,1),one),((1,1),one)\n" + "((2,2),two),((2,2),two)\n" + "((3,3),three),((3,3),three)\n" + + case 20 => // 22 in Java ITCase + /* + * full pojo with full tuple + */ + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmallPojoDataSet(env) + val ds2 = CollectionDataSets.getSmallTuplebasedDataSetMatchingPojo(env) + val joinDs = ds1.join(ds2).where("*").equalTo("*") + joinDs.writeAsCsv(resultPath) + env.setDegreeOfParallelism(1) + env.execute() + "1 First (10,100,1000,One) 10000,(10000,10,100,1000,One,1,First)\n" + + "2 Second (20,200,2000,Two) 20000,(20000,20,200,2000,Two,2,Second)\n" + + "3 Third (30,300,3000,Three) 30000,(30000,30,300,3000,Three,3,Third)\n" + + case _ => + throw new IllegalArgumentException("Invalid program id: " + progId) + } + } +} + + +@RunWith(classOf[Parameterized]) +class JoinITCase(config: Configuration) extends JavaProgramTestBase(config) { + + private var curProgId: Int = config.getInteger("ProgramId", -1) + private var resultPath: String = null + private var expectedResult: String = null + + protected override def preSubmit(): Unit = { + resultPath = getTempDirPath("result") + } + + protected def testProgram(): Unit = { + expectedResult = JoinProgs.runProgram(curProgId, resultPath) + } + + protected override def postSubmit(): Unit = { + compareResultsByLinesInMemory(expectedResult, resultPath) + } +} + +object JoinITCase { + @Parameters + def getConfigurations: java.util.Collection[Array[AnyRef]] = { + val configs = mutable.MutableList[Array[AnyRef]]() + for (i <- 1 to JoinProgs.NUM_PROGRAMS) { + val config = new Configuration() + config.setInteger("ProgramId", i) + configs += Array(config) + } + + configs.asJavaCollection + } +} +