spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-1406] Mllib pmml model export
Date Thu, 30 Apr 2015 06:21:39 GMT
Repository: spark
Updated Branches:
  refs/heads/master 445951449 -> 254e05097


[SPARK-1406] Mllib pmml model export

See PDF attached to the JIRA issue 1406.

The contribution is my original work and I license the work to the project under the project's open source license.

Author: Vincenzo Selvaggio <vselvaggio@hotmail.it>
Author: Xiangrui Meng <meng@databricks.com>
Author: selvinsource <vselvaggio@hotmail.it>

Closes #3062 from selvinsource/mllib_pmml_model_export_SPARK-1406 and squashes the following commits:

852aac6 [Vincenzo Selvaggio] [SPARK-1406] Update JPMML version to 1.1.15 in LICENSE file
085cf42 [Vincenzo Selvaggio] [SPARK-1406] Added Double Min and Max Fixed scala style
30165c4 [Vincenzo Selvaggio] [SPARK-1406] Fixed extreme cases for logit
7a5e0ec [Vincenzo Selvaggio] [SPARK-1406] Binary classification for SVM and Logistic Regression
cfcb596 [Vincenzo Selvaggio] [SPARK-1406] Throw IllegalArgumentException when exporting a multinomial logistic regression
25dce33 [Vincenzo Selvaggio] [SPARK-1406] Update code to latest pmml model
dea98ca [Vincenzo Selvaggio] [SPARK-1406] Exclude transitive dependency for pmml model
66b7c12 [Vincenzo Selvaggio] [SPARK-1406] Updated pmml model lib to 1.1.15, latest Java 6 compatible
a0a55f7 [Vincenzo Selvaggio] Merge pull request #2 from mengxr/SPARK-1406
3c22f79 [Xiangrui Meng] more code style
e2313df [Vincenzo Selvaggio] Merge pull request #1 from mengxr/SPARK-1406
472d757 [Xiangrui Meng] fix code style
1676e15 [Vincenzo Selvaggio] fixed scala issue
e2ffae8 [Vincenzo Selvaggio] fixed scala style
b8823b0 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406
b25bbf7 [Vincenzo Selvaggio] [SPARK-1406] Added export of pmml to distributed file system using the spark context
7a949d0 [Vincenzo Selvaggio] [SPARK-1406] Fixed scala style
f46c75c [Vincenzo Selvaggio] [SPARK-1406] Added PMMLExportable to supported models
7b33b4e [Vincenzo Selvaggio] [SPARK-1406] Added a PMMLExportable interface Restructured code in a new package mllib.pmml Supported models implements the new PMMLExportable interface: LogisticRegression, SVM, KMeansModel, LinearRegression, RidgeRegression, Lasso
d559ec5 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406
8fe12bb [Vincenzo Selvaggio] [SPARK-1406] Adjusted logistic regression export description and target categories
03bc3a5 [Vincenzo Selvaggio] added logistic regression
da2ec11 [Vincenzo Selvaggio] [SPARK-1406] added linear SVM PMML export
82f2131 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406
19adf29 [Vincenzo Selvaggio] [SPARK-1406] Fixed scala style
1faf985 [Vincenzo Selvaggio] [SPARK-1406] Added target field to the regression model for completeness Adjusted unit test to deal with this change
3ae8ae5 [Vincenzo Selvaggio] [SPARK-1406] Adjusted imported order according to the guidelines
c67ce81 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406
78515ec [Vincenzo Selvaggio] [SPARK-1406] added pmml export for LinearRegressionModel, RidgeRegressionModel and LassoModel
e29dfb9 [Vincenzo Selvaggio] removed version, by default is set to 4.2 (latest from jpmml) removed copyright
ae8b993 [Vincenzo Selvaggio] updated some commented tests to use the new ModelExporter object reordered the imports
df8a89e [Vincenzo Selvaggio] added pmml version to pmml model changed the copyright to spark
a1b4dc3 [Vincenzo Selvaggio] updated imports
834ca44 [Vincenzo Selvaggio] reordered the import accordingly to the guidelines
349a76b [Vincenzo Selvaggio] new helper object to serialize the models to pmml format
c3ef9b8 [Vincenzo Selvaggio] set it to private
6357b98 [Vincenzo Selvaggio] set it to private
e1eb251 [Vincenzo Selvaggio] removed serialization part, this will be part of the ModelExporter helper object
aba5ee1 [Vincenzo Selvaggio] fixed cluster export
cd6c07c [Vincenzo Selvaggio] fixed scala style to run tests
f75b988 [Vincenzo Selvaggio] Merge remote-tracking branch 'origin/master' into mllib_pmml_model_export_SPARK-1406
07a29bf [selvinsource] Update LICENSE
8841439 [Vincenzo Selvaggio] adjust scala style in order to compile
1433b11 [Vincenzo Selvaggio] complete suite tests
8e71b8d [Vincenzo Selvaggio] kmeans pmml export implementation
9bc494f [Vincenzo Selvaggio] added scala suite tests added saveLocalFile to ModelExport trait
226e184 [Vincenzo Selvaggio] added javadoc and export model type in case there is a need to support other types of export (not just PMML)
a0e3679 [Vincenzo Selvaggio] export and pmml export traits kmeans test implementation


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/254e0509
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/254e0509
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/254e0509

Branch: refs/heads/master
Commit: 254e0509762937acc9c72b432d5d953bf72c3c52
Parents: 4459514
Author: Vincenzo Selvaggio <vselvaggio@hotmail.it>
Authored: Wed Apr 29 23:21:21 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Wed Apr 29 23:21:21 2015 -0700

----------------------------------------------------------------------
 LICENSE                                         |  1 +
 mllib/pom.xml                                   | 15 ++++
 .../classification/LogisticRegression.scala     |  3 +-
 .../apache/spark/mllib/classification/SVM.scala |  3 +-
 .../spark/mllib/clustering/KMeansModel.scala    |  4 +-
 .../spark/mllib/pmml/PMMLExportable.scala       | 74 +++++++++++++++
 .../BinaryClassificationPMMLModelExport.scala   | 90 +++++++++++++++++++
 .../GeneralizedLinearPMMLModelExport.scala      | 75 ++++++++++++++++
 .../pmml/export/KMeansPMMLModelExport.scala     | 83 +++++++++++++++++
 .../mllib/pmml/export/PMMLModelExport.scala     | 47 ++++++++++
 .../pmml/export/PMMLModelExportFactory.scala    | 64 +++++++++++++
 .../apache/spark/mllib/regression/Lasso.scala   |  3 +-
 .../mllib/regression/LinearRegression.scala     |  3 +-
 .../mllib/regression/RidgeRegression.scala      |  3 +-
 ...naryClassificationPMMLModelExportSuite.scala | 84 +++++++++++++++++
 .../GeneralizedLinearPMMLModelExportSuite.scala | 84 +++++++++++++++++
 .../export/KMeansPMMLModelExportSuite.scala     | 49 ++++++++++
 .../export/PMMLModelExportFactorySuite.scala    | 95 ++++++++++++++++++++
 18 files changed, 774 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/LICENSE
----------------------------------------------------------------------
diff --git a/LICENSE b/LICENSE
index 9b364a4..21c42e9 100644
--- a/LICENSE
+++ b/LICENSE
@@ -814,6 +814,7 @@ BSD-style licenses
 The following components are provided under a BSD-style license. See project link for details.
 
      (BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core)
+     (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model)
      (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.3 - http://jblas.org/)
      (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/)
      (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org)

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/pom.xml
----------------------------------------------------------------------
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 5dfab36..a3c57ae 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -109,6 +109,21 @@
       <type>test-jar</type>
       <scope>test</scope>
     </dependency>
+    <dependency>
+      <groupId>org.jpmml</groupId>
+      <artifactId>pmml-model</artifactId>
+      <version>1.1.15</version>
+      <exclusions>
+        <exclusion>
+          <groupId>com.sun.xml.fastinfoset</groupId>
+          <artifactId>FastInfoset</artifactId>
+        </exclusion>
+        <exclusion>
+          <groupId>com.sun.istack</groupId>
+          <artifactId>istack-commons-runtime</artifactId>
+        </exclusion>
+      </exclusions>
+    </dependency>
   </dependencies>
   <profiles>
     <profile>

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 057b628..bd2e907 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -23,6 +23,7 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel
 import org.apache.spark.mllib.linalg.BLAS.dot
 import org.apache.spark.mllib.linalg.{DenseVector, Vector}
 import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.pmml.PMMLExportable
 import org.apache.spark.mllib.regression._
 import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
 import org.apache.spark.rdd.RDD
@@ -46,7 +47,7 @@ class LogisticRegressionModel (
     val numFeatures: Int,
     val numClasses: Int)
   extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
-  with Saveable {
+  with Saveable with PMMLExportable {
 
   if (numClasses == 2) {
     require(weights.size == numFeatures,

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index 52fb62d..33104cf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -22,6 +22,7 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.classification.impl.GLMClassificationModel
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.pmml.PMMLExportable
 import org.apache.spark.mllib.regression._
 import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable}
 import org.apache.spark.rdd.RDD
@@ -36,7 +37,7 @@ class SVMModel (
     override val weights: Vector,
     override val intercept: Double)
   extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
-  with Saveable {
+  with Saveable with PMMLExportable {
 
   private var threshold: Option[Double] = Some(0.0)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index e4e411a..ba228b1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.pmml.PMMLExportable
 import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.SparkContext
@@ -34,7 +35,8 @@ import org.apache.spark.sql.Row
 /**
  * A clustering model for K-means. Each point belongs to the cluster with the closest center.
  */
-class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable {
+class KMeansModel (
+    val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable {
 
   /** A Java-friendly constructor that takes an Iterable of Vectors. */
   def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray)

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala
new file mode 100644
index 0000000..354e90f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml
+
+import java.io.{File, OutputStream, StringWriter}
+import javax.xml.transform.stream.StreamResult
+
+import org.jpmml.model.JAXBUtil
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory
+
+/**
+ * Export model to the PMML format
+ * Predictive Model Markup Language (PMML) is an XML-based file format
+ * developed by the Data Mining Group (www.dmg.org).
+ */
+trait PMMLExportable {
+
+  /**
+   * Export the model to the stream result in PMML format
+   */
+  private def toPMML(streamResult: StreamResult): Unit = {
+    val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this)
+    JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult)
+  }
+
+  /**
+   * Export the model to a local file in PMML format
+   */
+  def toPMML(localPath: String): Unit = {
+    toPMML(new StreamResult(new File(localPath)))
+  }
+
+  /**
+   * Export the model to a directory on a distributed file system in PMML format
+   */
+  def toPMML(sc: SparkContext, path: String): Unit = {
+    val pmml = toPMML()
+    sc.parallelize(Array(pmml), 1).saveAsTextFile(path)
+  }
+
+  /**
+   * Export the model to the OutputStream in PMML format
+   */
+  def toPMML(outputStream: OutputStream): Unit = {
+    toPMML(new StreamResult(outputStream))
+  }
+
+  /**
+   * Export the model to a String in PMML format
+   */
+  def toPMML(): String = {
+    val writer = new StringWriter
+    toPMML(new StreamResult(writer))
+    writer.toString
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
new file mode 100644
index 0000000..34b4475
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import scala.{Array => SArray}
+
+import org.dmg.pmml._
+
+import org.apache.spark.mllib.regression.GeneralizedLinearModel
+
+/**
+ * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel
+ */
+private[mllib] class BinaryClassificationPMMLModelExport(
+    model : GeneralizedLinearModel, 
+    description : String,
+    normalizationMethod : RegressionNormalizationMethodType,
+    threshold: Double) 
+  extends PMMLModelExport {
+
+  populateBinaryClassificationPMML()
+
+  /**
+   * Export the input LogisticRegressionModel or SVMModel to PMML format.
+   */
+  private def populateBinaryClassificationPMML(): Unit = {
+     pmml.getHeader.setDescription(description)
+
+     if (model.weights.size > 0) {
+       val fields = new SArray[FieldName](model.weights.size)
+       val dataDictionary = new DataDictionary
+       val miningSchema = new MiningSchema
+       val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1")
+       var interceptNO = threshold
+       if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) {
+         if (threshold <= 0) {
+           interceptNO = Double.MinValue
+         } else if (threshold >= 1) {
+           interceptNO = Double.MaxValue
+         } else {
+           interceptNO = -math.log(1 / threshold - 1)
+         }
+       }
+       val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0")
+       val regressionModel = new RegressionModel()
+         .withFunctionName(MiningFunctionType.CLASSIFICATION)
+         .withMiningSchema(miningSchema)
+         .withModelName(description)
+         .withNormalizationMethod(normalizationMethod)
+         .withRegressionTables(regressionTableYES, regressionTableNO)
+
+       for (i <- 0 until model.weights.size) {
+         fields(i) = FieldName.create("field_" + i)
+         dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
+         miningSchema
+           .withMiningFields(new MiningField(fields(i))
+           .withUsageType(FieldUsageType.ACTIVE))
+         regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
+       }
+       
+       // add target field
+       val targetField = FieldName.create("target")
+       dataDictionary
+         .withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING))
+       miningSchema
+         .withMiningFields(new MiningField(targetField)
+         .withUsageType(FieldUsageType.TARGET))
+       
+       dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
+       
+       pmml.setDataDictionary(dataDictionary)
+       pmml.withModels(regressionModel)
+     }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala
new file mode 100644
index 0000000..1874786
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import scala.{Array => SArray}
+
+import org.dmg.pmml._
+
+import org.apache.spark.mllib.regression.GeneralizedLinearModel
+
+/**
+ * PMML Model Export for GeneralizedLinearModel abstract class
+ */
+private[mllib] class GeneralizedLinearPMMLModelExport(
+    model: GeneralizedLinearModel,
+    description: String)
+  extends PMMLModelExport {
+
+  populateGeneralizedLinearPMML(model)
+
+  /**
+   * Export the input GeneralizedLinearModel model to PMML format.
+   */
+  private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = {
+    pmml.getHeader.setDescription(description)
+
+    if (model.weights.size > 0) {
+      val fields = new SArray[FieldName](model.weights.size)
+      val dataDictionary = new DataDictionary
+      val miningSchema = new MiningSchema
+      val regressionTable = new RegressionTable(model.intercept)
+      val regressionModel = new RegressionModel()
+        .withFunctionName(MiningFunctionType.REGRESSION)
+        .withMiningSchema(miningSchema)
+        .withModelName(description)
+        .withRegressionTables(regressionTable)
+
+      for (i <- 0 until model.weights.size) {
+        fields(i) = FieldName.create("field_" + i)
+        dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
+        miningSchema
+          .withMiningFields(new MiningField(fields(i))
+          .withUsageType(FieldUsageType.ACTIVE))
+        regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
+      }
+
+      // for completeness add target field
+      val targetField = FieldName.create("target")
+      dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE))
+      miningSchema
+        .withMiningFields(new MiningField(targetField)
+        .withUsageType(FieldUsageType.TARGET))
+
+      dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
+
+      pmml.setDataDictionary(dataDictionary)
+      pmml.withModels(regressionModel)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala
new file mode 100644
index 0000000..069e7af
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import scala.{Array => SArray}
+
+import org.dmg.pmml._
+
+import org.apache.spark.mllib.clustering.KMeansModel
+
+/**
+ * PMML Model Export for KMeansModel class
+ */
+private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{
+
+  populateKMeansPMML(model)
+
+  /**
+   * Export the input KMeansModel model to PMML format.
+   */
+  private def populateKMeansPMML(model : KMeansModel): Unit = {
+    pmml.getHeader.setDescription("k-means clustering")
+
+    if (model.clusterCenters.length > 0) {
+      val clusterCenter = model.clusterCenters(0)
+      val fields = new SArray[FieldName](clusterCenter.size)
+      val dataDictionary = new DataDictionary
+      val miningSchema = new MiningSchema
+      val comparisonMeasure = new ComparisonMeasure()
+        .withKind(ComparisonMeasure.Kind.DISTANCE)
+        .withMeasure(new SquaredEuclidean())
+      val clusteringModel = new ClusteringModel()
+        .withModelName("k-means")
+        .withMiningSchema(miningSchema)
+        .withComparisonMeasure(comparisonMeasure)
+        .withFunctionName(MiningFunctionType.CLUSTERING)
+        .withModelClass(ClusteringModel.ModelClass.CENTER_BASED)
+        .withNumberOfClusters(model.clusterCenters.length)
+
+      for (i <- 0 until clusterCenter.size) {
+        fields(i) = FieldName.create("field_" + i)
+        dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
+        miningSchema
+          .withMiningFields(new MiningField(fields(i))
+          .withUsageType(FieldUsageType.ACTIVE))
+        clusteringModel.withClusteringFields(
+          new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF))
+      }
+
+      dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
+
+      for (i <- 0 until model.clusterCenters.length) {
+        val cluster = new Cluster()
+          .withName("cluster_" + i)
+          .withArray(new org.dmg.pmml.Array()
+          .withType(Array.Type.REAL)
+          .withN(clusterCenter.size)
+          .withValue(model.clusterCenters(i).toArray.mkString(" ")))
+        // we don't have the size of the single cluster but only the centroids (withValue)
+        // .withSize(value)
+        clusteringModel.withClusters(cluster)
+      }
+
+      pmml.setDataDictionary(dataDictionary)
+      pmml.withModels(clusteringModel)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
new file mode 100644
index 0000000..ebdeae5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import java.text.SimpleDateFormat
+import java.util.Date
+
+import scala.beans.BeanProperty
+
+import org.dmg.pmml.{Application, Header, PMML, Timestamp}
+
+private[mllib] trait PMMLModelExport {
+  
+  /**
+   * Holder of the exported model in PMML format
+   */
+  @BeanProperty
+  val pmml: PMML = new PMML
+
+  setHeader(pmml)
+  
+  private def setHeader(pmml: PMML): Unit = {
+    val version = getClass.getPackage.getImplementationVersion
+    val app = new Application().withName("Apache Spark MLlib").withVersion(version)
+    val timestamp = new Timestamp()
+      .withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date()))
+    val header = new Header()
+      .withApplication(app)
+      .withTimestamp(timestamp)
+    pmml.setHeader(header)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
new file mode 100644
index 0000000..c16e83d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import org.dmg.pmml.RegressionNormalizationMethodType
+
+import org.apache.spark.mllib.classification.LogisticRegressionModel
+import org.apache.spark.mllib.classification.SVMModel
+import org.apache.spark.mllib.clustering.KMeansModel
+import org.apache.spark.mllib.regression.LassoModel
+import org.apache.spark.mllib.regression.LinearRegressionModel
+import org.apache.spark.mllib.regression.RidgeRegressionModel
+
+private[mllib] object PMMLModelExportFactory {
+  
+  /**
+   * Factory object to help creating the necessary PMMLModelExport implementation 
+   * taking as input the machine learning model (for example KMeansModel).
+   */
+  def createPMMLModelExport(model: Any): PMMLModelExport = {
+    model match {
+      case kmeans: KMeansModel =>
+        new KMeansPMMLModelExport(kmeans)
+      case linear: LinearRegressionModel =>
+        new GeneralizedLinearPMMLModelExport(linear, "linear regression")
+      case ridge: RidgeRegressionModel =>
+        new GeneralizedLinearPMMLModelExport(ridge, "ridge regression")
+      case lasso: LassoModel =>
+        new GeneralizedLinearPMMLModelExport(lasso, "lasso regression")
+      case svm: SVMModel =>
+        new BinaryClassificationPMMLModelExport(
+          svm, "linear SVM", RegressionNormalizationMethodType.NONE, 
+          svm.getThreshold.getOrElse(0.0))
+      case logistic: LogisticRegressionModel =>
+        if (logistic.numClasses == 2) {
+          new BinaryClassificationPMMLModelExport(
+            logistic, "logistic regression", RegressionNormalizationMethodType.LOGIT,
+            logistic.getThreshold.getOrElse(0.5))
+        } else {
+          throw new IllegalArgumentException(
+            "PMML Export not supported for Multinomial Logistic Regression")
+        }
+      case _ =>
+        throw new IllegalArgumentException(
+          "PMML Export not supported for model: " + model.getClass.getName)
+    }
+  }
+  
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index e8b0381..4f48238 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression
 import org.apache.spark.SparkContext
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.pmml.PMMLExportable
 import org.apache.spark.mllib.regression.impl.GLMRegressionModel
 import org.apache.spark.mllib.util.{Saveable, Loader}
 import org.apache.spark.rdd.RDD
@@ -34,7 +35,7 @@ class LassoModel (
     override val weights: Vector,
     override val intercept: Double)
   extends GeneralizedLinearModel(weights, intercept)
-  with RegressionModel with Serializable with Saveable {
+  with RegressionModel with Serializable with Saveable with PMMLExportable {
 
   override protected def predictPoint(
       dataMatrix: Vector,

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 6fa7ad5..9453c4f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression
 import org.apache.spark.SparkContext
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.pmml.PMMLExportable
 import org.apache.spark.mllib.regression.impl.GLMRegressionModel
 import org.apache.spark.mllib.util.{Saveable, Loader}
 import org.apache.spark.rdd.RDD
@@ -34,7 +35,7 @@ class LinearRegressionModel (
     override val weights: Vector,
     override val intercept: Double)
   extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable
-  with Saveable {
+  with Saveable with PMMLExportable {
 
   override protected def predictPoint(
       dataMatrix: Vector,

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index 309f9af..e0c03d8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression
 import org.apache.spark.SparkContext
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.pmml.PMMLExportable
 import org.apache.spark.mllib.regression.impl.GLMRegressionModel
 import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
@@ -35,7 +36,7 @@ class RidgeRegressionModel (
     override val weights: Vector,
     override val intercept: Double)
   extends GeneralizedLinearModel(weights, intercept)
-  with RegressionModel with Serializable with Saveable {
+  with RegressionModel with Serializable with Saveable with PMMLExportable {
 
   override protected def predictPoint(
       dataMatrix: Vector,

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
new file mode 100644
index 0000000..0b646cf
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import org.dmg.pmml.RegressionModel
+import org.dmg.pmml.RegressionNormalizationMethodType
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.classification.LogisticRegressionModel
+import org.apache.spark.mllib.classification.SVMModel
+import org.apache.spark.mllib.util.LinearDataGenerator
+
+class BinaryClassificationPMMLModelExportSuite extends FunSuite {
+
+  test("logistic regression PMML export") {
+    val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
+    val logisticRegressionModel =
+      new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)
+
+    val logisticModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)
+
+    // assert that the PMML format is as expected
+    assert(logisticModelExport.isInstanceOf[PMMLModelExport])
+    val pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml
+    assert(pmml.getHeader.getDescription === "logistic regression")
+    // check that the number of fields match the weights size
+    assert(pmml.getDataDictionary.getNumberOfFields === logisticRegressionModel.weights.size + 1)
+    // This verify that there is a model attached to the pmml object and the model is a regression
+    // one.  It also verifies that the pmml model has a regression table (for target category 1)
+    // with the same number of predictors of the model weights.
+    val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel]
+    assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1")
+    assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size
+      === logisticRegressionModel.weights.size)
+    // verify if there is a second table with target category 0 and no predictors
+    assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0")
+    assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0)
+    // ensure logistic regression has normalization method set to LOGIT
+    assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT)
+  }
+  
+  test("linear SVM PMML export") {
+    val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
+    val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
+    
+    val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
+    
+    // assert that the PMML format is as expected
+    assert(svmModelExport.isInstanceOf[PMMLModelExport])
+    val pmml = svmModelExport.getPmml
+    assert(pmml.getHeader.getDescription
+      === "linear SVM")
+    // check that the number of fields match the weights size
+    assert(pmml.getDataDictionary.getNumberOfFields === svmModel.weights.size + 1)
+    // This verify that there is a model attached to the pmml object and the model is a regression
+    // one.  It also verifies that the pmml model has a regression table (for target category 1)
+    // with the same number of predictors of the model weights.
+    val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel]
+    assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1")
+    assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size
+      === svmModel.weights.size)
+    // verify if there is a second table with target category 0 and no predictors
+    assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0")
+    assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0)
+    // ensure linear SVM has normalization method set to NONE
+    assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE)
+  }
+  
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
new file mode 100644
index 0000000..f9afbd8
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import org.dmg.pmml.RegressionModel
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
+import org.apache.spark.mllib.util.LinearDataGenerator
+
+class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
+
+  test("linear regression PMML export") {
+    val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
+    val linearRegressionModel =
+      new LinearRegressionModel(linearInput(0).features, linearInput(0).label)
+    val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel)
+    // assert that the PMML format is as expected
+    assert(linearModelExport.isInstanceOf[PMMLModelExport])
+    val pmml = linearModelExport.getPmml
+    assert(pmml.getHeader.getDescription === "linear regression")
+    // check that the number of fields match the weights size
+    assert(pmml.getDataDictionary.getNumberOfFields === linearRegressionModel.weights.size + 1)
+    // This verifies that there is a model attached to the pmml object and the model is a regression
+    // one.  It also verifies that the pmml model has a regression table with the same number of
+    // predictors of the model weights.
+    val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel]
+    assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size
+      === linearRegressionModel.weights.size)
+  }
+
+  test("ridge regression PMML export") {
+    val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
+    val ridgeRegressionModel =
+      new RidgeRegressionModel(linearInput(0).features, linearInput(0).label)
+    val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel)
+    // assert that the PMML format is as expected
+    assert(ridgeModelExport.isInstanceOf[PMMLModelExport])
+    val pmml = ridgeModelExport.getPmml
+    assert(pmml.getHeader.getDescription === "ridge regression")
+    // check that the number of fields match the weights size
+    assert(pmml.getDataDictionary.getNumberOfFields === ridgeRegressionModel.weights.size + 1)
+    // This verify that there is a model attached to the pmml object and the model is a regression
+    // one.  It also verifies that the pmml model has a regression table with the same number of
+    // predictors of the model weights.
+    val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel]
+    assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size
+      === ridgeRegressionModel.weights.size)
+  }
+
+  test("lasso PMML export") {
+    val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
+    val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
+    val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel)
+    // assert that the PMML format is as expected
+    assert(lassoModelExport.isInstanceOf[PMMLModelExport])
+    val pmml = lassoModelExport.getPmml
+    assert(pmml.getHeader.getDescription === "lasso regression")
+    // check that the number of fields match the weights size
+    assert(pmml.getDataDictionary.getNumberOfFields === lassoModel.weights.size + 1)
+    // This verify that there is a model attached to the pmml object and the model is a regression
+    // one. It also verifies that the pmml model has a regression table with the same number of
+    // predictors of the model weights.
+    val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel]
+    assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size
+      === lassoModel.weights.size)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
new file mode 100644
index 0000000..b985d04
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import org.dmg.pmml.ClusteringModel
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.clustering.KMeansModel
+import org.apache.spark.mllib.linalg.Vectors
+
+class KMeansPMMLModelExportSuite extends FunSuite {
+
+  test("KMeansPMMLModelExport generate PMML format") {
+    val clusterCenters = Array(
+      Vectors.dense(1.0, 2.0, 6.0),
+      Vectors.dense(1.0, 3.0, 0.0),
+      Vectors.dense(1.0, 4.0, 6.0))
+    val kmeansModel = new KMeansModel(clusterCenters)
+
+    val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel)
+
+    // assert that the PMML format is as expected
+    assert(modelExport.isInstanceOf[PMMLModelExport])
+    val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml
+    assert(pmml.getHeader.getDescription === "k-means clustering")
+    // check that the number of fields match the single vector size
+    assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size)
+    // This verify that there is a model attached to the pmml object and the model is a clustering
+    // one. It also verifies that the pmml model has the same number of clusters of the spark model.
+    val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
+    assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
+  }
+  
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/254e0509/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
new file mode 100644
index 0000000..f28a4ac
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel}
+import org.apache.spark.mllib.clustering.KMeansModel
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
+import org.apache.spark.mllib.util.LinearDataGenerator
+
+class PMMLModelExportFactorySuite extends FunSuite {
+
+  test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") {
+    val clusterCenters = Array(
+      Vectors.dense(1.0, 2.0, 6.0),
+      Vectors.dense(1.0, 3.0, 0.0),
+      Vectors.dense(1.0, 4.0, 6.0))
+    val kmeansModel = new KMeansModel(clusterCenters)
+
+    val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel)
+
+    assert(modelExport.isInstanceOf[KMeansPMMLModelExport])
+  }
+
+  test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a "
+    + "LinearRegressionModel, RidgeRegressionModel or LassoModel") {
+    val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
+
+    val linearRegressionModel =
+      new LinearRegressionModel(linearInput(0).features, linearInput(0).label)
+    val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel)
+    assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
+
+    val ridgeRegressionModel =
+      new RidgeRegressionModel(linearInput(0).features, linearInput(0).label)
+    val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel)
+    assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
+
+    val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
+    val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel)
+    assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
+  }
+
+  test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport "
+    + "when passing a LogisticRegressionModel or SVMModel") {
+    val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
+    
+    val logisticRegressionModel =
+      new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)
+    val logisticRegressionModelExport =
+      PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)
+    assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
+    
+    val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
+    val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
+    assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
+  }
+  
+  test("PMMLModelExportFactory throw IllegalArgumentException "
+    + "when passing a Multinomial Logistic Regression") {
+    /** 3 classes, 2 features */
+    val multiclassLogisticRegressionModel = new LogisticRegressionModel(
+      weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, 
+      numFeatures = 2, numClasses = 3)
+    
+    intercept[IllegalArgumentException] {
+      PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel)
+    }
+  }
+
+  test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") {
+    val invalidModel = new Object
+
+    intercept[IllegalArgumentException] {
+      PMMLModelExportFactory.createPMMLModelExport(invalidModel)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message