flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From fhue...@apache.org
Subject [2/5] flink git commit: [FLINK-1512] [scala api] Add CsvReader for reading into POJOs
Date Wed, 25 Mar 2015 19:45:15 GMT
[FLINK-1512] [scala api] Add CsvReader for reading into POJOs


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/7a6f2960
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/7a6f2960
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/7a6f2960

Branch: refs/heads/master
Commit: 7a6f296094b26b940f9f9f66f64e5e2a0f700cb1
Parents: 7b1c19c
Author: Chiwan Park <chiwanpark@icloud.com>
Authored: Fri Feb 20 02:23:56 2015 +0900
Committer: Fabian Hueske <fhueske@apache.org>
Committed: Wed Mar 25 20:38:59 2015 +0100

----------------------------------------------------------------------
 .../scala/operators/ScalaCsvInputFormat.java    | 270 ++++++++-----------
 .../flink/api/scala/ExecutionEnvironment.scala  |  47 +++-
 .../flink/api/scala/io/CsvInputFormatTest.scala | 125 ++++++++-
 .../scala/io/ScalaCsvReaderWithPOJOITCase.scala | 124 +++++++++
 4 files changed, 378 insertions(+), 188 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/7a6f2960/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java
b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java
index 79c6659..9adbed8 100644
--- a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java
+++ b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java
@@ -19,66 +19,91 @@
 package org.apache.flink.api.scala.operators;
 
 
-import com.google.common.base.Charsets;
 import com.google.common.base.Preconditions;
-
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.io.GenericCsvInputFormat;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.PojoTypeInfo;
 import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
 import org.apache.flink.api.java.typeutils.runtime.TupleSerializerBase;
 import org.apache.flink.core.fs.FileInputSplit;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.types.parser.FieldParser;
-import org.apache.flink.util.StringUtils;
 
+import org.apache.flink.types.parser.FieldParser;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
-import java.nio.charset.Charset;
-import java.nio.charset.IllegalCharsetNameException;
-import java.nio.charset.UnsupportedCharsetException;
-import java.util.Map;
-import java.util.TreeMap;
+import java.lang.reflect.Field;
+import java.util.Arrays;
 
-import scala.Product;
-
-public class ScalaCsvInputFormat<OUT extends Product> extends GenericCsvInputFormat<OUT>
{
+public class ScalaCsvInputFormat<OUT> extends GenericCsvInputFormat<OUT> {
 
 	private static final long serialVersionUID = 1L;
 
 	private static final Logger LOG = LoggerFactory.getLogger(ScalaCsvInputFormat.class);
-	
-	private transient Object[] parsedValues;
-	
-	// To speed up readRecord processing. Used to find windows line endings.
-	// It is set when open so that readRecord does not have to evaluate it
-	private boolean lineDelimiterIsLinebreak = false;
 
-	private final TupleSerializerBase<OUT> serializer;
+	private transient Object[] parsedValues;
 
-	private byte[] commentPrefix = null;
+	private final TupleSerializerBase<OUT> tupleSerializer;
 
-	private transient int commentCount;
-	private transient int invalidLineCount;
+	private Class<OUT> pojoTypeClass = null;
+	private String[] pojoFieldsName = null;
+	private transient Field[] pojoFields = null;
+	private transient PojoTypeInfo<OUT> pojoTypeInfo = null;
 
 	public ScalaCsvInputFormat(Path filePath, TypeInformation<OUT> typeInfo) {
 		super(filePath);
 
-		if (!(typeInfo.isTupleType())) {
-			throw new UnsupportedOperationException("This only works on tuple types.");
+		Class<?>[] classes = new Class[typeInfo.getArity()];
+
+		if (typeInfo instanceof TupleTypeInfoBase) {
+			TupleTypeInfoBase<OUT> tupleType = (TupleTypeInfoBase<OUT>) typeInfo;
+			// We can use an empty config here, since we only use the serializer to create
+			// the top-level case class
+			tupleSerializer = (TupleSerializerBase<OUT>) tupleType.createSerializer(new ExecutionConfig());
+
+			for (int i = 0; i < tupleType.getArity(); i++) {
+				classes[i] = tupleType.getTypeAt(i).getTypeClass();
+			}
+
+			setFieldTypes(classes);
+		} else {
+			tupleSerializer = null;
+			pojoTypeInfo = (PojoTypeInfo<OUT>) typeInfo;
+			pojoTypeClass = typeInfo.getTypeClass();
+			pojoFieldsName = pojoTypeInfo.getFieldNames();
+
+			for (int i = 0, arity = pojoTypeInfo.getArity(); i < arity; i++) {
+				classes[i] = pojoTypeInfo.getTypeAt(i).getTypeClass();
+			}
+
+			setFieldTypes(classes);
+			setOrderOfPOJOFields(pojoFieldsName);
+		}
+	}
+
+	public void setOrderOfPOJOFields(String[] fieldsOrder) {
+		Preconditions.checkNotNull(pojoTypeClass, "Field order can only be specified if output
type is a POJO.");
+		Preconditions.checkNotNull(fieldsOrder);
+
+		int includedCount = 0;
+		for (boolean isIncluded : fieldIncluded) {
+			if (isIncluded) {
+				includedCount++;
+			}
 		}
-		TupleTypeInfoBase<OUT> tupleType = (TupleTypeInfoBase<OUT>) typeInfo;
-		// We can use an empty config here, since we only use the serializer to create
-		// the top-level case class
-		serializer = (TupleSerializerBase<OUT>) tupleType.createSerializer(new ExecutionConfig());
-
-		Class<?>[] classes = new Class[tupleType.getArity()];
-		for (int i = 0; i < tupleType.getArity(); i++) {
-			classes[i] = tupleType.getTypeAt(i).getTypeClass();
+
+		Preconditions.checkArgument(includedCount == fieldsOrder.length,
+			"The number of selected POJO fields should be the same as that of CSV fields.");
+
+		for (String field : fieldsOrder) {
+			Preconditions.checkNotNull(field, "The field name cannot be null.");
+			Preconditions.checkArgument(pojoTypeInfo.getFieldIndex(field) != -1,
+				"The given field name isn't matched to POJO fields.");
 		}
-		setFieldTypes(classes);
+
+		pojoFieldsName = Arrays.copyOfRange(fieldsOrder, 0, fieldsOrder.length);
 	}
 
 	public void setFieldTypes(Class<?>[] fieldTypes) {
@@ -98,98 +123,66 @@ public class ScalaCsvInputFormat<OUT extends Product> extends GenericCsvInputFor
 		setFieldsGeneric(sourceFieldIndices, fieldTypes);
 	}
 
-	public byte[] getCommentPrefix() {
-		return commentPrefix;
-	}
-
-	public void setCommentPrefix(byte[] commentPrefix) {
-		this.commentPrefix = commentPrefix;
-	}
-
-	public void setCommentPrefix(char commentPrefix) {
-		setCommentPrefix(String.valueOf(commentPrefix));
-	}
+	public void setFields(boolean[] sourceFieldMask, Class<?>[] fieldTypes) {
+		Preconditions.checkNotNull(sourceFieldMask);
+		Preconditions.checkNotNull(fieldTypes);
 
-	public void setCommentPrefix(String commentPrefix) {
-		setCommentPrefix(commentPrefix, Charsets.UTF_8);
+		setFieldsGeneric(sourceFieldMask, fieldTypes);
 	}
 
-	public void setCommentPrefix(String commentPrefix, String charsetName) throws IllegalCharsetNameException,
UnsupportedCharsetException {
-		if (charsetName == null) {
-			throw new IllegalArgumentException("Charset name must not be null");
-		}
-
-		if (commentPrefix != null) {
-			Charset charset = Charset.forName(charsetName);
-			setCommentPrefix(commentPrefix, charset);
-		} else {
-			this.commentPrefix = null;
-		}
+	public Class<?>[] getFieldTypes() {
+		return super.getGenericFieldTypes();
 	}
 
-	public void setCommentPrefix(String commentPrefix, Charset charset) {
-		if (charset == null) {
-			throw new IllegalArgumentException("Charset must not be null");
-		}
-		if (commentPrefix != null) {
-			this.commentPrefix = commentPrefix.getBytes(charset);
-		} else {
-			this.commentPrefix = null;
-		}
-	}
-
-	@Override
-	public void close() throws IOException {
-		if (this.invalidLineCount > 0) {
-			if (LOG.isWarnEnabled()) {
-				LOG.warn("In file \""+ this.filePath + "\" (split start: " + this.splitStart + ") " +
this.invalidLineCount +" invalid line(s) were skipped.");
-			}
-		}
-
-		if (this.commentCount > 0) {
-			if (LOG.isInfoEnabled()) {
-				LOG.info("In file \""+ this.filePath + "\" (split start: " + this.splitStart + ") " +
this.commentCount +" comment line(s) were skipped.");
-			}
-		}
-		super.close();
-	}
-
-	@Override
-	public OUT nextRecord(OUT record) throws IOException {
-		OUT returnRecord = null;
-		do {
-			returnRecord = super.nextRecord(record);
-		} while (returnRecord == null && !reachedEnd());
-
-		return returnRecord;
-	}
-	
 	@Override
 	public void open(FileInputSplit split) throws IOException {
 		super.open(split);
-		
+
 		@SuppressWarnings("unchecked")
 		FieldParser<Object>[] fieldParsers = (FieldParser<Object>[]) getFieldParsers();
-		
+
 		//throw exception if no field parsers are available
 		if (fieldParsers.length == 0) {
 			throw new IOException("CsvInputFormat.open(FileInputSplit split) - no field parsers to
parse input");
 		}
-		
+
 		// create the value holders
 		this.parsedValues = new Object[fieldParsers.length];
 		for (int i = 0; i < fieldParsers.length; i++) {
 			this.parsedValues[i] = fieldParsers[i].createValue();
 		}
 
-		this.commentCount = 0;
-		this.invalidLineCount = 0;
-		
 		// left to right evaluation makes access [0] okay
 		// this marker is used to fasten up readRecord, so that it doesn't have to check each call
if the line ending is set to default
 		if (this.getDelimiter().length == 1 && this.getDelimiter()[0] == '\n' ) {
 			this.lineDelimiterIsLinebreak = true;
 		}
+
+		// for POJO type
+		if (pojoTypeClass != null) {
+			pojoFields = new Field[pojoFieldsName.length];
+			for (int i = 0; i < pojoFieldsName.length; i++) {
+				try {
+					pojoFields[i] = pojoTypeClass.getDeclaredField(pojoFieldsName[i]);
+					pojoFields[i].setAccessible(true);
+				} catch (NoSuchFieldException e) {
+					throw new RuntimeException("There is no field called \"" + pojoFieldsName[i] + "\" in
" + pojoTypeClass.getName(), e);
+				}
+			}
+		}
+
+		this.commentCount = 0;
+		this.invalidLineCount = 0;
+	}
+
+	@Override
+	public OUT nextRecord(OUT record) throws IOException {
+		OUT returnRecord = null;
+		do {
+			returnRecord = super.nextRecord(record);
+		} while (returnRecord == null && !reachedEnd());
+
+		return returnRecord;
 	}
 
 	@Override
@@ -219,73 +212,22 @@ public class ScalaCsvInputFormat<OUT extends Product> extends
GenericCsvInputFor
 		}
 
 		if (parseRecord(parsedValues, bytes, offset, numBytes)) {
-			OUT result = serializer.createInstance(parsedValues);
-			return result;
+			if (tupleSerializer != null) {
+				return tupleSerializer.createInstance(parsedValues);
+			} else {
+				for (int i = 0; i < pojoFields.length; i++) {
+					try {
+						pojoFields[i].set(reuse, parsedValues[i]);
+					} catch (IllegalAccessException e) {
+						throw new RuntimeException("Parsed value could not be set in POJO field \"" + pojoFieldsName[i]
+ "\"", e);
+					}
+				}
+
+				return reuse;
+			}
 		} else {
 			this.invalidLineCount++;
 			return null;
 		}
 	}
-	
-	
-	@Override
-	public String toString() {
-		return "CSV Input (" + StringUtils.showControlCharacters(String.valueOf(getFieldDelimiter()))
+ ") " + getFilePath();
-	}
-	
-	// --------------------------------------------------------------------------------------------
-	
-	@SuppressWarnings("unused")
-	private static void checkAndCoSort(int[] positions, Class<?>[] types) {
-		if (positions.length != types.length) {
-			throw new IllegalArgumentException("The positions and types must be of the same length");
-		}
-		
-		TreeMap<Integer, Class<?>> map = new TreeMap<Integer, Class<?>>();
-		
-		for (int i = 0; i < positions.length; i++) {
-			if (positions[i] < 0) {
-				throw new IllegalArgumentException("The field " + " (" + positions[i] + ") is invalid.");
-			}
-			if (types[i] == null) {
-				throw new IllegalArgumentException("The type " + i + " is invalid (null)");
-			}
-			
-			if (map.containsKey(positions[i])) {
-				throw new IllegalArgumentException("The position " + positions[i] + " occurs multiple
times.");
-			}
-			
-			map.put(positions[i], types[i]);
-		}
-		
-		int i = 0;
-		for (Map.Entry<Integer, Class<?>> entry : map.entrySet()) {
-			positions[i] = entry.getKey();
-			types[i] = entry.getValue();
-			i++;
-		}
-	}
-	
-	private static void checkForMonotonousOrder(int[] positions, Class<?>[] types) {
-		if (positions.length != types.length) {
-			throw new IllegalArgumentException("The positions and types must be of the same length");
-		}
-		
-		int lastPos = -1;
-		
-		for (int i = 0; i < positions.length; i++) {
-			if (positions[i] < 0) {
-				throw new IllegalArgumentException("The field " + " (" + positions[i] + ") is invalid.");
-			}
-			if (types[i] == null) {
-				throw new IllegalArgumentException("The type " + i + " is invalid (null)");
-			}
-			
-			if (positions[i] <= lastPos) {
-				throw new IllegalArgumentException("The positions must be strictly increasing (no permutations
are supported).");
-			}
-			
-			lastPos = positions[i];
-		}
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/7a6f2960/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala
b/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala
index 4c1e627..7073f07 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala
@@ -26,7 +26,7 @@ import org.apache.flink.api.java.io._
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo
 import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer
-import org.apache.flink.api.java.typeutils.{ValueTypeInfo, TupleTypeInfoBase}
+import org.apache.flink.api.java.typeutils.{PojoTypeInfo, ValueTypeInfo, TupleTypeInfoBase}
 import org.apache.flink.api.scala.hadoop.mapred
 import org.apache.flink.api.scala.hadoop.mapreduce
 import org.apache.flink.api.scala.operators.ScalaCsvInputFormat
@@ -46,6 +46,7 @@ import org.apache.hadoop.fs.{Path => HadoopPath}
 
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
 
 import scala.reflect.ClassTag
 
@@ -243,8 +244,9 @@ class ExecutionEnvironment(javaEnv: JavaEnv) {
    * @param lenient Whether the parser should silently ignore malformed lines.
    * @param includedFields The fields in the file that should be read. Per default all fields
    *                       are read.
+   * @param pojoFields The fields of the POJO which are mapped to CSV fields.
    */
-  def readCsvFile[T <: Product : ClassTag : TypeInformation](
+  def readCsvFile[T : ClassTag : TypeInformation](
       filePath: String,
       lineDelimiter: String = "\n",
       fieldDelimiter: String = ",",
@@ -252,9 +254,10 @@ class ExecutionEnvironment(javaEnv: JavaEnv) {
       ignoreFirstLine: Boolean = false,
       ignoreComments: String = null,
       lenient: Boolean = false,
-      includedFields: Array[Int] = null): DataSet[T] = {
+      includedFields: Array[Int] = null,
+      pojoFields: Array[String] = null): DataSet[T] = {
 
-    val typeInfo = implicitly[TypeInformation[T]].asInstanceOf[TupleTypeInfoBase[T]]
+    val typeInfo = implicitly[TypeInformation[T]]
 
     val inputFormat = new ScalaCsvInputFormat[T](new Path(filePath), typeInfo)
     inputFormat.setDelimiter(lineDelimiter)
@@ -267,16 +270,40 @@ class ExecutionEnvironment(javaEnv: JavaEnv) {
       inputFormat.enableQuotedStringParsing(quoteCharacter);
     }
 
-    val classes: Array[Class[_]] = new Array[Class[_]](typeInfo.getArity)
-    for (i <- 0 until typeInfo.getArity) {
-      classes(i) = typeInfo.getTypeAt(i).getTypeClass
+    val classesBuf: ArrayBuffer[Class[_]] = new ArrayBuffer[Class[_]]
+    typeInfo match {
+      case info: TupleTypeInfoBase[T] =>
+        for (i <- 0 until info.getArity) {
+          classesBuf += info.getTypeAt(i).getTypeClass()
+        }
+      case info: PojoTypeInfo[T] =>
+        if (pojoFields == null) {
+          throw new IllegalArgumentException(
+            "POJO fields must be specified (not null) if output type is a POJO.")
+        } else {
+          for (i <- 0 until pojoFields.length) {
+            val pos = info.getFieldIndex(pojoFields(i))
+            if (pos < 0) {
+              throw new IllegalArgumentException(
+                "Field \"" + pojoFields(i) + "\" not part of POJO type " +
+                  info.getTypeClass.getCanonicalName);
+            }
+            classesBuf += info.getPojoFieldAt(pos).`type`.getTypeClass
+          }
+        }
+      case _ => throw new IllegalArgumentException("Type information is not valid.")
     }
+
     if (includedFields != null) {
-      Validate.isTrue(typeInfo.getArity == includedFields.length, "Number of tuple fields
and" +
+      Validate.isTrue(classesBuf.size == includedFields.length, "Number of tuple fields and"
+
         " included fields must match.")
-      inputFormat.setFields(includedFields, classes)
+      inputFormat.setFields(includedFields, classesBuf.toArray)
     } else {
-      inputFormat.setFieldTypes(classes)
+      inputFormat.setFieldTypes(classesBuf.toArray)
+    }
+
+    if (pojoFields != null) {
+      inputFormat.setOrderOfPOJOFields(pojoFields)
     }
 
     wrap(new DataSource[T](javaEnv, inputFormat, typeInfo, getCallLocationName()))

http://git-wip-us.apache.org/repos/asf/flink/blob/7a6f2960/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala
b/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala
index 9964a9d..4bcd35a 100644
--- a/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala
@@ -17,21 +17,15 @@
  */
 package org.apache.flink.api.scala.io
 
+import java.io.{File, FileOutputStream, FileWriter, OutputStreamWriter}
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.scala._
 import org.apache.flink.api.scala.operators.ScalaCsvInputFormat
-import org.junit.Assert._
-import org.junit.Assert.assertEquals
-import org.junit.Assert.assertNotNull
-import org.junit.Assert.assertNull
-import org.junit.Assert.assertTrue
-import org.junit.Assert.fail
-import java.io.File
-import java.io.FileOutputStream
-import java.io.FileWriter
-import java.io.OutputStreamWriter
 import org.apache.flink.configuration.Configuration
 import org.apache.flink.core.fs.{FileInputSplit, Path}
+import org.junit.Assert.{assertEquals, assertNotNull, assertNull, assertTrue, fail}
 import org.junit.Test
-import org.apache.flink.api.scala._
 
 class CsvInputFormatTest {
 
@@ -315,7 +309,8 @@ class CsvInputFormatTest {
         PATH,
         createTypeInformation[(Int, Int, Int)])
       format.setFieldDelimiter("|")
-      format.setFields(Array(0, 3, 7), Array(classOf[Integer], classOf[Integer], classOf[Integer]))
+      format.setFields(Array(0, 3, 7),
+        Array(classOf[Integer], classOf[Integer], classOf[Integer]): Array[Class[_]])
       format.configure(new Configuration)
       format.open(split)
       var result: (Int, Int, Int) = null
@@ -347,7 +342,8 @@ class CsvInputFormatTest {
         createTypeInformation[(Int, Int, Int)])
       format.setFieldDelimiter("|")
       try {
-        format.setFields(Array(8, 1, 3), Array(classOf[Integer],classOf[Integer],classOf[Integer]))
+        format.setFields(Array(8, 1, 3),
+          Array(classOf[Integer], classOf[Integer], classOf[Integer]): Array[Class[_]])
         fail("Input sequence should have been rejected.")
       }
       catch {
@@ -408,5 +404,106 @@ class CsvInputFormatTest {
         fail("Test erroneous")
     }
   }
-}
 
+  class POJOItem(var field1: Int, var field2: String, var field3: Double) {
+    def this() {
+      this(-1, "", -1)
+    }
+  }
+
+  case class CaseClassItem(field1: Int, field2: String, field3: Double)
+
+  @Test
+  def testPOJOType(): Unit = {
+    val fileContent = "123,HELLO,3.123\n" + "456,ABC,1.234"
+    val tempFile = createTempFile(fileContent)
+    val typeInfo: TypeInformation[POJOItem] = createTypeInformation[POJOItem]
+    val format = new ScalaCsvInputFormat[POJOItem](PATH, typeInfo)
+
+    format.setDelimiter('\n')
+    format.setFieldDelimiter(',')
+    format.configure(new Configuration)
+    format.open(tempFile)
+
+    var result = new POJOItem()
+    result = format.nextRecord(result)
+    assertEquals(123, result.field1)
+    assertEquals("HELLO", result.field2)
+    assertEquals(3.123, result.field3, 0.001)
+
+    result = format.nextRecord(result)
+    assertEquals(456, result.field1)
+    assertEquals("ABC", result.field2)
+    assertEquals(1.234, result.field3, 0.001)
+  }
+
+  @Test
+  def testCaseClass(): Unit = {
+    val fileContent = "123,HELLO,3.123\n" + "456,ABC,1.234"
+    val tempFile = createTempFile(fileContent)
+    val typeInfo: TypeInformation[CaseClassItem] = createTypeInformation[CaseClassItem]
+    val format = new ScalaCsvInputFormat[CaseClassItem](PATH, typeInfo)
+
+    format.setDelimiter('\n')
+    format.setFieldDelimiter(',')
+    format.configure(new Configuration)
+    format.open(tempFile)
+
+    var result = format.nextRecord(null)
+    assertEquals(123, result.field1)
+    assertEquals("HELLO", result.field2)
+    assertEquals(3.123, result.field3, 0.001)
+
+    result = format.nextRecord(null)
+    assertEquals(456, result.field1)
+    assertEquals("ABC", result.field2)
+    assertEquals(1.234, result.field3, 0.001)
+  }
+
+  @Test
+  def testPOJOTypeWithFieldMapping(): Unit = {
+    val fileContent = "HELLO,123,3.123\n" + "ABC,456,1.234"
+    val tempFile = createTempFile(fileContent)
+    val typeInfo: TypeInformation[POJOItem] = createTypeInformation[POJOItem]
+    val format = new ScalaCsvInputFormat[POJOItem](PATH, typeInfo)
+
+    format.setDelimiter('\n')
+    format.setFieldDelimiter(',')
+    format.setFieldTypes(Array(classOf[String], classOf[Integer], classOf[java.lang.Double]))
+    format.setOrderOfPOJOFields(Array("field2", "field1", "field3"))
+    format.configure(new Configuration)
+    format.open(tempFile)
+
+    var result = new POJOItem()
+    result = format.nextRecord(result)
+    assertEquals(123, result.field1)
+    assertEquals("HELLO", result.field2)
+    assertEquals(3.123, result.field3, 0.001)
+
+    result = format.nextRecord(result)
+    assertEquals(456, result.field1)
+    assertEquals("ABC", result.field2)
+    assertEquals(1.234, result.field3, 0.001)
+  }
+  
+  @Test
+  def testPOJOTypeWithFieldSubsetAndDataSubset(): Unit = {
+    val fileContent = "123,HELLO,3.123\n" + "456,ABC,1.234"
+    val tempFile = createTempFile(fileContent)
+    val typeInfo: TypeInformation[POJOItem] = createTypeInformation[POJOItem]
+    val format = new ScalaCsvInputFormat[POJOItem](PATH, typeInfo)
+
+    format.setDelimiter('\n')
+    format.setFieldDelimiter(',')
+    format.setFields(Array(false, true), Array(classOf[String]): Array[Class[_]])
+    format.setOrderOfPOJOFields(Array("field2", "field1", "field3"))
+    format.configure(new Configuration)
+    format.open(tempFile)
+
+    var result = format.nextRecord(new POJOItem())
+    assertEquals("HELLO", result.field2)
+
+    result = format.nextRecord(result)
+    assertEquals("ABC", result.field2)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/7a6f2960/flink-tests/src/test/scala/org/apache/flink/api/scala/io/ScalaCsvReaderWithPOJOITCase.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/io/ScalaCsvReaderWithPOJOITCase.scala
b/flink-tests/src/test/scala/org/apache/flink/api/scala/io/ScalaCsvReaderWithPOJOITCase.scala
new file mode 100644
index 0000000..21aa93d
--- /dev/null
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/io/ScalaCsvReaderWithPOJOITCase.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.io
+
+import com.google.common.base.Charsets
+import com.google.common.io.Files
+import org.apache.flink.api.scala._
+import org.apache.flink.core.fs.FileSystem.WriteMode
+import org.apache.flink.test.util.MultipleProgramsTestBase
+import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
+import org.junit.Assert._
+import org.junit.rules.TemporaryFolder
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+import org.junit.{After, Before, Rule, Test}
+
+@RunWith(classOf[Parameterized])
+class ScalaCsvReaderWithPOJOITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode)
{
+  private val _tempFolder = new TemporaryFolder()
+  private var resultPath: String = null
+  private var expected: String = null
+
+  @Rule
+  def tempFolder = _tempFolder
+
+  @Before
+  def before(): Unit = {
+    resultPath = tempFolder.newFile("result").toURI.toString
+  }
+
+  @After
+  def after(): Unit = {
+    compareResultsByLinesInMemory(expected, resultPath)
+  }
+
+  def createInputData(data: String): String = {
+    val dataFile = tempFolder.newFile("data")
+    Files.write(data, dataFile, Charsets.UTF_8)
+    dataFile.toURI.toString
+  }
+
+  @Test
+  def testPOJOType(): Unit = {
+    val dataPath = createInputData("ABC,2.20,3\nDEF,5.1,5\nDEF,3.30,1\nGHI,3.30,10")
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val data = env.readCsvFile[POJOItem](dataPath, pojoFields = Array("f1", "f2", "f3"))
+
+    implicit val typeInfo = createTypeInformation[(String, Int)]
+    data.writeAsText(resultPath, WriteMode.OVERWRITE)
+
+    env.execute()
+
+    expected = "ABC,2.20,3\nDEF,5.10,5\nDEF,3.30,1\nGHI,3.30,10"
+  }
+
+  @Test
+  def testPOJOTypeWithFieldsOrder(): Unit = {
+    val dataPath = createInputData("2.20,ABC,3\n5.1,DEF,5\n3.30,DEF,1\n3.30,GHI,10")
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val data = env.readCsvFile[POJOItem](dataPath, pojoFields = Array("f2", "f1", "f3"))
+
+    implicit val typeInfo = createTypeInformation[(String, Int)]
+    data.writeAsText(resultPath, WriteMode.OVERWRITE)
+
+    env.execute()
+
+    expected = "ABC,2.20,3\nDEF,5.10,5\nDEF,3.30,1\nGHI,3.30,10"
+  }
+
+  @Test
+  def testPOJOTypeWithoutFieldsOrder(): Unit = {
+    val dataPath = createInputData("")
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    try {
+      val data = env.readCsvFile[POJOItem](dataPath)
+      fail("POJO type without fields order must raise IllegalArgumentException!")
+    } catch {
+      case _: IllegalArgumentException => // success
+    }
+
+    expected = ""
+    resultPath = dataPath
+  }
+
+  @Test
+  def testPOJOTypeWithFieldsOrderAndFieldsSelection(): Unit = {
+    val dataPath = createInputData("2.20,3,ABC\n5.1,5,DEF\n3.30,1,DEF\n3.30,10,GHI")
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val data = env.readCsvFile[POJOItem](dataPath, includedFields = Array(1, 2),
+      pojoFields = Array("f3", "f1"))
+
+    implicit val typeInfo = createTypeInformation[(String, Int)]
+    data.writeAsText(resultPath, WriteMode.OVERWRITE)
+
+    env.execute()
+
+    expected = "ABC,0.00,3\nDEF,0.00,5\nDEF,0.00,1\nGHI,0.00,10"
+  }
+}
+
+class POJOItem(var f1: String, var f2: Double, var f3: Int) {
+  def this() {
+    this("", 0.0, 0)
+  }
+
+  override def toString: String = "%s,%.02f,%d".format(f1, f2, f3)
+}


Mime
View raw message