spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From l...@apache.org
Subject spark git commit: [SPARK-14843][ML] Fix encoding error in LibSVMRelation
Date Fri, 22 Apr 2016 17:11:47 GMT
Repository: spark
Updated Branches:
  refs/heads/master c089c6f4e -> 8098f1585


[SPARK-14843][ML] Fix encoding error in LibSVMRelation

## What changes were proposed in this pull request?

We use `RowEncoder` in libsvm data source to serialize the label and features read from libsvm
files. However, the schema passed in this encoder is not correct. As the result, we can't
correctly select `features` column from the DataFrame. We should use full data schema instead
of `requiredSchema` to serialize the data read in. Then do projection to select required columns
later.

## How was this patch tested?
`LibSVMRelationSuite`.

Author: Liang-Chi Hsieh <simonh@tw.ibm.com>

Closes #12611 from viirya/fix-libsvm.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8098f158
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8098f158
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8098f158

Branch: refs/heads/master
Commit: 8098f158576b07343f74e2061d217b106c71b62d
Parents: c089c6f
Author: Liang-Chi Hsieh <simonh@tw.ibm.com>
Authored: Sat Apr 23 01:11:36 2016 +0800
Committer: Cheng Lian <lian@databricks.com>
Committed: Sat Apr 23 01:11:36 2016 +0800

----------------------------------------------------------------------
 .../org/apache/spark/ml/source/libsvm/LibSVMRelation.scala  | 9 ++++++---
 .../apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala | 9 +++++++--
 2 files changed, 13 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8098f158/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index e8b0dd6..dc2a6f5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -202,7 +202,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
             LabeledPoint(label, Vectors.sparse(numFeatures, indices, values))
           }
 
-      val converter = RowEncoder(requiredSchema)
+      val converter = RowEncoder(dataSchema)
 
       val unsafeRowIterator = points.map { pt =>
         val features = if (sparse) pt.features.toSparse else pt.features.toDense
@@ -213,9 +213,12 @@ class DefaultSource extends FileFormat with DataSourceRegister {
         AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
 
       // Appends partition values
-      val fullOutput = (requiredSchema ++ partitionSchema).map(toAttribute)
+      val fullOutput = (dataSchema ++ partitionSchema).map(toAttribute)
+      val requiredOutput = fullOutput.filter { a =>
+        requiredSchema.fieldNames.contains(a.name) || partitionSchema.fieldNames.contains(a.name)
+      }
       val joinedRow = new JoinedRow()
-      val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput)
+      val appendPartitionColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
 
       unsafeRowIterator.map { dataRow =>
         appendPartitionColumns(joinedRow(dataRow, file.partitionValues))

http://git-wip-us.apache.org/repos/asf/spark/blob/8098f158/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 0bd1497..e52fbd7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -23,9 +23,9 @@ import java.nio.charset.StandardCharsets
 import com.google.common.io.Files
 
 import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.{Row, SaveMode}
 import org.apache.spark.util.Utils
 
 
@@ -104,4 +104,9 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext
{
       df.write.format("libsvm").save(path + "_2")
     }
   }
+
+  test("select features from libsvm relation") {
+    val df = sqlContext.read.format("libsvm").load(path)
+    df.select("features").rdd.map { case Row(d: Vector) => d }.first
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message