spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sro...@apache.org
Subject spark git commit: [SPARK-21723][ML] Fix writing LibSVM (key not found: numFeatures)
Date Wed, 16 Aug 2017 07:21:45 GMT
Repository: spark
Updated Branches:
  refs/heads/master 8c54f1eb7 -> 8321c141f


[SPARK-21723][ML] Fix writing LibSVM (key not found: numFeatures)

## What changes were proposed in this pull request?

Check the option "numFeatures" only when reading LibSVM, not when writing. When writing, Spark
was raising an exception. After the change it will ignore the option completely. liancheng
HyukjinKwon

(Maybe the usage should be forbidden when writing, in a major version change?).

## How was this patch tested?

Manual test, that loading and writing LibSVM files work fine, both with and without the numFeatures
option.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Jan Vrsovsky <jan.vrsovsky@firma.seznam.cz>

Closes #18872 from ProtD/master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8321c141
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8321c141
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8321c141

Branch: refs/heads/master
Commit: 8321c141f63a911a97ec183aefa5ff75a338c051
Parents: 8c54f1e
Author: Jan Vrsovsky <jan.vrsovsky@firma.seznam.cz>
Authored: Wed Aug 16 08:21:42 2017 +0100
Committer: Sean Owen <sowen@cloudera.com>
Committed: Wed Aug 16 08:21:42 2017 +0100

----------------------------------------------------------------------
 .../spark/ml/source/libsvm/LibSVMRelation.scala |  8 ++---
 .../ml/source/libsvm/LibSVMRelationSuite.scala  | 36 ++++++++++++++++----
 2 files changed, 34 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8321c141/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 74aaed9..4e84ff0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -76,12 +76,12 @@ private[libsvm] class LibSVMFileFormat
 
   override def toString: String = "LibSVM"
 
-  private def verifySchema(dataSchema: StructType): Unit = {
+  private def verifySchema(dataSchema: StructType, forWriting: Boolean): Unit = {
     if (
       dataSchema.size != 2 ||
         !dataSchema(0).dataType.sameType(DataTypes.DoubleType) ||
         !dataSchema(1).dataType.sameType(new VectorUDT()) ||
-        !(dataSchema(1).metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt > 0)
+        !(forWriting || dataSchema(1).metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt
> 0)
     ) {
       throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema")
     }
@@ -119,7 +119,7 @@ private[libsvm] class LibSVMFileFormat
       job: Job,
       options: Map[String, String],
       dataSchema: StructType): OutputWriterFactory = {
-    verifySchema(dataSchema)
+    verifySchema(dataSchema, true)
     new OutputWriterFactory {
       override def newInstance(
           path: String,
@@ -142,7 +142,7 @@ private[libsvm] class LibSVMFileFormat
       filters: Seq[Filter],
       options: Map[String, String],
       hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
-    verifySchema(dataSchema)
+    verifySchema(dataSchema, false)
     val numFeatures = dataSchema("features").metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt
     assert(numFeatures > 0)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8321c141/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index a67e49d..3eabff4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -19,13 +19,16 @@ package org.apache.spark.ml.source.libsvm
 
 import java.io.{File, IOException}
 import java.nio.charset.StandardCharsets
+import java.util.List
 
 import com.google.common.io.Files
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
+import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
 import org.apache.spark.util.Utils
 
 
@@ -44,14 +47,14 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext
{
       """
         |0 2:4.0 4:5.0 6:6.0
       """.stripMargin
-    val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data")
+    val dir = Utils.createTempDir()
     val succ = new File(dir, "_SUCCESS")
     val file0 = new File(dir, "part-00000")
     val file1 = new File(dir, "part-00001")
     Files.write("", succ, StandardCharsets.UTF_8)
     Files.write(lines0, file0, StandardCharsets.UTF_8)
     Files.write(lines1, file1, StandardCharsets.UTF_8)
-    path = dir.toURI.toString
+    path = dir.getPath
   }
 
   override def afterAll(): Unit = {
@@ -108,12 +111,12 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext
{
 
   test("write libsvm data and read it again") {
     val df = spark.read.format("libsvm").load(path)
-    val tempDir2 = new File(tempDir, "read_write_test")
-    val writepath = tempDir2.toURI.toString
+    val writePath = Utils.createTempDir().getPath
+
     // TODO: Remove requirement to coalesce by supporting multiple reads.
-    df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)
+    df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writePath)
 
-    val df2 = spark.read.format("libsvm").load(writepath)
+    val df2 = spark.read.format("libsvm").load(writePath)
     val row1 = df2.first()
     val v = row1.getAs[SparseVector](1)
     assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
@@ -126,6 +129,27 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext
{
     }
   }
 
+  test("write libsvm data from scratch and read it again") {
+    val rawData = new java.util.ArrayList[Row]()
+    rawData.add(Row(1.0, Vectors.sparse(3, Seq((0, 2.0), (1, 3.0)))))
+    rawData.add(Row(4.0, Vectors.sparse(3, Seq((0, 5.0), (2, 6.0)))))
+
+    val struct = StructType(
+      StructField("labelFoo", DoubleType, false) ::
+      StructField("featuresBar", VectorType, false) :: Nil
+    )
+    val df = spark.sqlContext.createDataFrame(rawData, struct)
+
+    val writePath = Utils.createTempDir().getPath
+
+    df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writePath)
+
+    val df2 = spark.read.format("libsvm").load(writePath)
+    val row1 = df2.first()
+    val v = row1.getAs[SparseVector](1)
+    assert(v == Vectors.sparse(3, Seq((0, 2.0), (1, 3.0))))
+  }
+
   test("select features from libsvm relation") {
     val df = spark.read.format("libsvm").load(path)
     df.select("features").rdd.map { case Row(d: Vector) => d }.first


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message