flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From va...@apache.org
Subject flink git commit: [FLINK-3612] remove PageRank Table example
Date Wed, 16 Mar 2016 10:59:47 GMT
Repository: flink
Updated Branches:
  refs/heads/tableOnCalcite 3c85e2c5d -> 6e4018b0b


[FLINK-3612] remove PageRank Table example

add a type conversion mapper after aggregations if the expected type is not a Row

This closes #1793


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/6e4018b0
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/6e4018b0
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/6e4018b0

Branch: refs/heads/tableOnCalcite
Commit: 6e4018b0bb5d5c7fe08687c4809628eb47a96e0f
Parents: 3c85e2c
Author: vasia <vasia@apache.org>
Authored: Tue Mar 15 11:35:25 2016 +0100
Committer: vasia <vasia@apache.org>
Committed: Wed Mar 16 11:17:29 2016 +0100

----------------------------------------------------------------------
 .../flink/examples/java/JavaTableExample.java   |   5 +-
 .../plan/nodes/dataset/DataSetAggregate.scala   |  78 +++++--
 .../flink/examples/scala/PageRankTable.scala    | 210 -------------------
 .../api/java/table/test/AggregationsITCase.java |  31 +++
 .../scala/table/test/PageRankTableITCase.java   | 103 ---------
 .../scala/table/test/AggregationsITCase.scala   |  25 +++
 6 files changed, 114 insertions(+), 338 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/6e4018b0/flink-libraries/flink-table/src/main/java/org/apache/flink/examples/java/JavaTableExample.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/java/org/apache/flink/examples/java/JavaTableExample.java
b/flink-libraries/flink-table/src/main/java/org/apache/flink/examples/java/JavaTableExample.java
index c043508..c7e69c9 100644
--- a/flink-libraries/flink-table/src/main/java/org/apache/flink/examples/java/JavaTableExample.java
+++ b/flink-libraries/flink-table/src/main/java/org/apache/flink/examples/java/JavaTableExample.java
@@ -17,7 +17,6 @@
  */
 package org.apache.flink.examples.java;
 
-
 import org.apache.flink.api.table.Table;
 import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.ExecutionEnvironment;
@@ -30,14 +29,14 @@ public class JavaTableExample {
 
 	public static class WC {
 		public String word;
-		public int count;
+		public long count;
 
 		// Public constructor to make it a Flink POJO
 		public WC() {
 
 		}
 
-		public WC(String word, int count) {
+		public WC(String word, long count) {
 			this.word = word;
 			this.count = count;
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/6e4018b0/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala
index ce60621..01710fb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala
@@ -22,9 +22,11 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
 import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.rel.core.AggregateCall
 import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
+import org.apache.flink.api.common.functions.MapFunction
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.DataSet
-import org.apache.flink.api.table.plan.PlanGenException
+import org.apache.flink.api.table.codegen.CodeGenerator
+import org.apache.flink.api.table.runtime.MapRunner
 import org.apache.flink.api.table.runtime.aggregate.AggregateUtil
 import org.apache.flink.api.table.runtime.aggregate.AggregateUtil.CalcitePair
 import org.apache.flink.api.table.typeutils.{TypeConverter, RowTypeInfo}
@@ -69,12 +71,6 @@ class DataSetAggregate(
       config: TableConfig,
       expectedType: Option[TypeInformation[Any]]): DataSet[Any] = {
 
-    expectedType match {
-      case Some(typeInfo) if typeInfo.getTypeClass != classOf[Row] =>
-        throw new PlanGenException("Aggregate operations currently only support returning
Rows.")
-      case _ => // ok
-    }
-
     val groupingKeys = grouping.indices.toArray
     // add grouping fields, position keys in the input, and input type
     val aggregateResult = AggregateUtil.createOperatorFunctionsForAggregates(namedAggregates,
@@ -96,24 +92,34 @@ class DataSetAggregate(
     val mappedInput = inputDS.map(aggregateResult._1).name(s"prepare $aggString")
     val groupReduceFunction = aggregateResult._2
 
-    if (groupingKeys.length > 0) {
+    val result = {
+      if (groupingKeys.length > 0) {
+        val inFields = inputType.getFieldNames.asScala.toList
+        val groupByString = s"groupBy: (${grouping.map(inFields(_)).mkString(", ")})"
 
-      val inFields = inputType.getFieldNames.asScala.toList
-      val groupByString = s"groupBy: (${grouping.map( inFields(_) ).mkString(", ")})"
-
-      mappedInput.asInstanceOf[DataSet[Row]]
-        .groupBy(groupingKeys: _*)
-        .reduceGroup(groupReduceFunction)
-        .returns(rowTypeInfo)
+        mappedInput.asInstanceOf[DataSet[Row]]
+          .groupBy(groupingKeys: _*)
+          .reduceGroup(groupReduceFunction)
+          .returns(rowTypeInfo)
           .name(groupByString + ", " + aggString)
-        .asInstanceOf[DataSet[Any]]
+          .asInstanceOf[DataSet[Any]]
+      }
+      else {
+        // global aggregation
+        mappedInput.asInstanceOf[DataSet[Row]]
+          .reduceGroup(groupReduceFunction)
+          .returns(rowTypeInfo)
+          .asInstanceOf[DataSet[Any]]
+      }
     }
-    else {
-      // global aggregation
-      mappedInput.asInstanceOf[DataSet[Row]]
-        .reduceGroup(groupReduceFunction)
-        .returns(rowTypeInfo)
-        .asInstanceOf[DataSet[Any]]
+
+    // if the expected type is not a Row, inject a mapper to convert to the expected type
+    expectedType match {
+      case Some(typeInfo) if typeInfo.getTypeClass != classOf[Row] =>
+        val mapName = s"convert: (${rowType.getFieldNames.asScala.toList.mkString(", ")})"
+        result.map(typeConversion(config, rowTypeInfo, expectedType.get))
+        .name(mapName)
+      case _ => result
     }
   }
 
@@ -137,4 +143,32 @@ class DataSetAggregate(
     s"select: (${outFieldsString.mkString(", ")})"
   }
 
+  private def typeConversion(
+      config: TableConfig,
+      rowTypeInfo: RowTypeInfo,
+      expectedType: TypeInformation[Any]): MapFunction[Any, Any] = {
+
+    val generator = new CodeGenerator(config, rowTypeInfo.asInstanceOf[TypeInformation[Any]])
+    val conversion = generator.generateConverterResultExpression(
+      expectedType, rowType.getFieldNames.asScala)
+
+    val body =
+      s"""
+          |${conversion.code}
+          |return ${conversion.resultTerm};
+          |""".stripMargin
+
+    val genFunction = generator.generateFunction(
+      "AggregateOutputConversion",
+      classOf[MapFunction[Any, Any]],
+      body,
+      expectedType)
+
+    new MapRunner[Any, Any](
+      genFunction.name,
+      genFunction.code,
+      genFunction.returnType)
+
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/6e4018b0/flink-libraries/flink-table/src/main/scala/org/apache/flink/examples/scala/PageRankTable.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/examples/scala/PageRankTable.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/examples/scala/PageRankTable.scala
deleted file mode 100644
index dda6265..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/examples/scala/PageRankTable.scala
+++ /dev/null
@@ -1,210 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.examples.scala
-
-import org.apache.flink.api.table.expressions.Literal
-import org.apache.flink.api.common.functions.GroupReduceFunction
-import org.apache.flink.api.scala._
-import org.apache.flink.api.scala.table._
-import org.apache.flink.examples.java.graph.util.PageRankData
-import org.apache.flink.util.Collector
-
-import _root_.scala.collection.JavaConverters._
-
-/**
-* A basic implementation of the Page Rank algorithm using a bulk iteration.
-*
-* This implementation requires a set of pages and a set of directed links as input and works
as
-* follows.
-*
-* In each iteration, the rank of every page is evenly distributed to all pages it points
to. Each
-* page collects the partial ranks of all pages that point to it, sums them up, and applies
a
-* dampening factor to the sum. The result is the new rank of the page. A new iteration is
started
-* with the new ranks of all pages. This implementation terminates after a fixed number of
-* iterations. This is the Wikipedia entry for the
-* [[http://en.wikipedia.org/wiki/Page_rank Page Rank algorithm]]
-*
-* Input files are plain text files and must be formatted as follows:
-*
-*  - Pages represented as an (long) ID separated by new-line characters.
-*    For example `"1\n2\n12\n42\n63"` gives five pages with IDs 1, 2, 12, 42, and 63.
-*  - Links are represented as pairs of page IDs which are separated by space  characters.
Links
-*    are separated by new-line characters.
-*    For example `"1 2\n2 12\n1 12\n42 63"` gives four (directed) links (1)->(2), (2)->(12),
-*    (1)->(12), and (42)->(63). For this simple implementation it is required that
each page has
-*    at least one incoming and one outgoing link (a page can point to itself).
-*
-* Usage:
-* {{{
-*   PageRankBasic <pages path> <links path> <output path> <num pages>
<num iterations>
-* }}}
-*
-* If no parameters are provided, the program is run with default data from
-* [[org.apache.flink.examples.java.graph.util.PageRankData]] and 10 iterations.
-*
-* This example shows how to use:
-*
-*  - Bulk Iterations
-*  - Table API expressions
-*/
-object PageRankTable {
-
-  private final val DAMPENING_FACTOR: Double = 0.85
-  private final val EPSILON: Double = 0.0001
-
-  def main(args: Array[String]) {
-    if (!parseParameters(args)) {
-      return
-    }
-
-    // set up execution environment
-    val env = ExecutionEnvironment.getExecutionEnvironment
-
-    // read input data
-    val pagesWithRanks = getPagesDataSet(env).map { p => (p, 1.0 / numPages) }
-      .as('pageId, 'rank)
-
-    val links = getLinksDataSet(env)
-
-    // build adjacency list from link input
-    val adjacencyLists = links
-      .groupBy("sourceId").reduceGroup( new GroupReduceFunction[Link, AdjacencyList] {
-
-        override def reduce(
-            values: _root_.java.lang.Iterable[Link],
-            out: Collector[AdjacencyList]): Unit = {
-          var outputId = -1L
-          val outputList = values.asScala map { t => outputId = t.sourceId; t.targetId
}
-          out.collect(new AdjacencyList(outputId, outputList.toArray))
-        }
-
-      }).as('sourceId, 'targetIds)
-
-    // start iteration
-    val finalRanks = pagesWithRanks.iterateWithTermination(maxIterations) {
-      currentRanks =>
-        val newRanks = currentRanks.toTable
-          // distribute ranks to target pages
-          .join(adjacencyLists).where('pageId === 'sourceId)
-          .select('rank, 'targetIds).toDataSet[RankOutput]
-          .flatMap {
-            (in, out: Collector[(Long, Double)]) =>
-              val targets = in.targetIds
-              val len = targets.length
-              targets foreach { t => out.collect((t, in.rank / len )) }
-          }
-          .as('pageId, 'rank)
-          // collect ranks and sum them up
-          .groupBy('pageId).select('pageId, 'rank.sum as 'rank)
-          // apply dampening factor
-          .select(
-            'pageId,
-            ('rank * DAMPENING_FACTOR) + (Literal(1) - DAMPENING_FACTOR) / numPages as 'rank)
-
-
-        val termination = currentRanks.toTable
-          .as('curId, 'curRank).join(newRanks.as('newId, 'newRank))
-          .where('curId === 'newId && ('curRank - 'newRank).abs > EPSILON)
-
-        (newRanks, termination)
-    }
-
-    val result = finalRanks
-
-    // emit result
-    if (fileOutput) {
-      result.writeAsCsv(outputPath, "\n", " ")
-      // execute program
-      env.execute("Expression PageRank Example")
-    } else {
-      // execute program and print result
-      result.print()
-    }
-  }
-
-  // *************************************************************************
-  //     USER TYPES
-  // *************************************************************************
-
-  case class Link(sourceId: Long, targetId: Long)
-
-  case class Page(pageId: Long, rank: Double)
-
-  case class AdjacencyList(sourceId: Long, targetIds: Array[Long])
-
-  case class RankOutput(rank: Double, targetIds: Array[Long])
-
-  // *************************************************************************
-  //     UTIL METHODS
-  // *************************************************************************
-
-  private def parseParameters(args: Array[String]): Boolean = {
-    if (args.length > 0) {
-      fileOutput = true
-      if (args.length == 5) {
-        pagesInputPath = args(0)
-        linksInputPath = args(1)
-        outputPath = args(2)
-        numPages = args(3).toLong
-        maxIterations = args(4).toInt
-      } else {
-        System.err.println("Usage: PageRankBasic <pages path> <links path> <output
path> <num " +
-          "pages> <num iterations>")
-        false
-      }
-    } else {
-      System.out.println("Executing PageRank Basic example with default parameters and built-in
" +
-        "default data.")
-      System.out.println("  Provide parameters to read input data from files.")
-      System.out.println("  See the documentation for the correct format of input files.")
-      System.out.println("  Usage: PageRankBasic <pages path> <links path> <output
path> <num " +
-        "pages> <num iterations>")
-
-      numPages = PageRankData.getNumberOfPages
-    }
-    true
-  }
-
-  private def getPagesDataSet(env: ExecutionEnvironment): DataSet[Long] = {
-    if (fileOutput) {
-      env.readCsvFile[Tuple1[Long]](pagesInputPath, fieldDelimiter = " ", lineDelimiter =
"\n")
-        .map(x => x._1)
-    } else {
-      env.generateSequence(1, 15)
-    }
-  }
-
-  private def getLinksDataSet(env: ExecutionEnvironment): DataSet[Link] = {
-    if (fileOutput) {
-      env.readCsvFile[Link](linksInputPath, fieldDelimiter = " ",
-        includedFields = Array(0, 1))
-    } else {
-      val edges = PageRankData.EDGES.map { case Array(v1, v2) => Link(v1.asInstanceOf[Long],
-        v2.asInstanceOf[Long])}
-      env.fromCollection(edges)
-    }
-  }
-
-  private var fileOutput: Boolean = false
-  private var pagesInputPath: String = null
-  private var linksInputPath: String = null
-  private var outputPath: String = null
-  private var numPages: Double = 0
-  private var maxIterations: Int = 10
-
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/6e4018b0/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java
b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java
index 2ab38e5..9797950 100644
--- a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java
+++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/AggregationsITCase.java
@@ -35,6 +35,7 @@ package org.apache.flink.api.java.table.test;
  * limitations under the License.
  */
 
+import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.table.ExpressionParserException;
 import org.apache.flink.api.table.Row;
@@ -45,11 +46,13 @@ import org.apache.flink.api.java.operators.DataSource;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.tuple.Tuple7;
 import org.apache.flink.api.table.plan.PlanGenException;
+import org.apache.flink.examples.java.JavaTableExample;
 import org.apache.flink.test.javaApiOperators.util.CollectionDataSets;
 import org.apache.flink.test.util.MultipleProgramsTestBase;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
+import org.apache.flink.examples.java.JavaTableExample.WC;
 
 import java.util.List;
 
@@ -195,5 +198,33 @@ public class AggregationsITCase extends MultipleProgramsTestBase {
 		String expected = "";
 		compareResultAsText(results, expected);
 	}
+
+	@Test
+	public void testPojoAggregation() throws Exception {
+		ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+		TableEnvironment tableEnv = new TableEnvironment();
+		DataSet<WC> input = env.fromElements(
+				new WC("Hello", 1),
+				new WC("Ciao", 1),
+				new WC("Hello", 1),
+				new WC("Hola", 1),
+				new WC("Hola", 1));
+
+		Table table = tableEnv.fromDataSet(input);
+
+		Table filtered = table
+				.groupBy("word")
+				.select("word.count as count, word")
+				.filter("count = 2");
+
+		List<String> result = tableEnv.toDataSet(filtered, WC.class)
+				.map(new MapFunction<WC, String>() {
+					public String map(WC value) throws Exception {
+						return value.word;
+					}
+				}).collect();
+		String expected = "Hello\n" + "Hola";
+		compareResultAsText(result, expected);
+	}
 }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/6e4018b0/flink-libraries/flink-table/src/test/java/org/apache/flink/api/scala/table/test/PageRankTableITCase.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/scala/table/test/PageRankTableITCase.java
b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/scala/table/test/PageRankTableITCase.java
deleted file mode 100644
index a893d4d..0000000
--- a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/scala/table/test/PageRankTableITCase.java
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.flink.api.scala.table.test;
-
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.examples.scala.PageRankTable;
-import org.apache.flink.test.testdata.PageRankData;
-import org.apache.flink.test.util.JavaProgramTestBase;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
-import org.junit.runners.Parameterized.Parameters;
-
-import java.io.IOException;
-import java.util.Collection;
-import java.util.LinkedList;
-
-@RunWith(Parameterized.class)
-public class PageRankTableITCase extends JavaProgramTestBase {
-
-	private static int NUM_PROGRAMS = 2;
-
-	private int curProgId = config.getInteger("ProgramId", -1);
-
-	private String verticesPath;
-	private String edgesPath;
-	private String resultPath;
-	private String expectedResult;
-
-	public PageRankTableITCase(Configuration config) {
-		super(config);
-	}
-
-	@Override
-	protected void preSubmit() throws Exception {
-		resultPath = getTempDirPath("result");
-		verticesPath = createTempFile("vertices.txt", PageRankData.VERTICES);
-		edgesPath = createTempFile("edges.txt", PageRankData.EDGES);
-	}
-
-	@Override
-	protected void testProgram() throws Exception {
-		expectedResult = runProgram(curProgId);
-	}
-
-	@Override
-	protected void postSubmit() throws Exception {
-		compareKeyValuePairsWithDelta(expectedResult, resultPath, " ", 0.01);
-	}
-
-	@Parameters
-	public static Collection<Object[]> getConfigurations() throws IOException {
-
-		LinkedList<Configuration> tConfigs = new LinkedList<Configuration>();
-
-		for(int i=1; i <= NUM_PROGRAMS; i++) {
-			Configuration config = new Configuration();
-			config.setInteger("ProgramId", i);
-			tConfigs.add(config);
-		}
-
-		// TODO: Enable test again once:
-		//   1) complex types (long[]) can be shipped through Table API
-		//   2) abs function is available
-//		return toParameterList(tConfigs);
-		return new LinkedList<>();
-	}
-
-
-	public String runProgram(int progId) throws Exception {
-
-		switch(progId) {
-		case 1: {
-			PageRankTable.main(new String[]{verticesPath, edgesPath, resultPath, PageRankData
-					.NUM_VERTICES + "", "3"});
-			return PageRankData.RANKS_AFTER_3_ITERATIONS;
-		}
-		case 2: {
-			// start with a very high number of iteration such that the dynamic convergence criterion
must handle termination
-			PageRankTable.main(new String[] {verticesPath, edgesPath, resultPath, PageRankData.NUM_VERTICES+"",
"1000"});
-			return PageRankData.RANKS_AFTER_EPSILON_0_0001_CONVERGENCE;
-		}
-
-		default:
-			throw new IllegalArgumentException("Invalid program id");
-		}
-	}
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/6e4018b0/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
index 0ac662a..81b22ba 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala
@@ -29,6 +29,7 @@ import org.junit._
 import org.junit.runner.RunWith
 import org.junit.runners.Parameterized
 import scala.collection.JavaConverters._
+import org.apache.flink.examples.scala.WordCountTable.{WC => MyWC}
 
 @RunWith(classOf[Parameterized])
 class AggregationsITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode)
{
@@ -160,4 +161,28 @@ class AggregationsITCase(mode: TestExecutionMode) extends MultipleProgramsTestBa
     TestBaseUtils.compareResultAsText(results.asJava, expected)
   }
 
+  @Test
+  def testPojoAggregation(): Unit = {
+
+    // test aggregations with a custom WordCount class
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val input = env.fromElements(
+      MyWC("hello", 1),
+      MyWC("hello", 1),
+      MyWC("ciao", 1),
+      MyWC("hola", 1),
+      MyWC("hola", 1))
+    val expr = input.toTable
+    val result = expr
+      .groupBy('word)
+      .select('word, 'count.sum as 'count)
+      .filter('count === 2)
+      .toDataSet[MyWC]
+
+    val mappedResult = result.map(w => (w.word, w.count * 10)).collect()
+    val expected = "(hello,20)\n" + "(hola,20)"
+    TestBaseUtils.compareResultAsText(mappedResult.asJava, expected)
+  }
+
 }
+


Mime
View raw message