spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject [1/2] spark git commit: [SPARK-3530][MLLIB] pipeline and parameters with examples
Date Wed, 12 Nov 2014 18:39:13 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.2 14b933fb6 -> 38f9f2e1c


http://git-wip-us.apache.org/repos/asf/spark/blob/38f9f2e1/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
index 17c753c..2067b36 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.mllib.regression
 
+import scala.beans.BeanInfo
+
 import org.apache.spark.mllib.linalg.{Vectors, Vector}
 import org.apache.spark.mllib.util.NumericParser
 import org.apache.spark.SparkException
@@ -27,6 +29,7 @@ import org.apache.spark.SparkException
  * @param label Label for this data point.
  * @param features List of features for this data point.
  */
+@BeanInfo
 case class LabeledPoint(label: Double, features: Vector) {
   override def toString: String = {
     "(%s,%s)".format(label, features)

http://git-wip-us.apache.org/repos/asf/spark/blob/38f9f2e1/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
new file mode 100644
index 0000000..4284667
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.feature.StandardScaler;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+  .generateLogisticInputAsList;
+
+/**
+ * Test Pipeline construction and fitting in Java.
+ */
+public class JavaPipelineSuite {
+
+  private transient JavaSparkContext jsc;
+  private transient JavaSQLContext jsql;
+  private transient JavaSchemaRDD dataset;
+
+  @Before
+  public void setUp() {
+    jsc = new JavaSparkContext("local", "JavaPipelineSuite");
+    jsql = new JavaSQLContext(jsc);
+    JavaRDD<LabeledPoint> points =
+      jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
+    dataset = jsql.applySchema(points, LabeledPoint.class);
+  }
+
+  @After
+  public void tearDown() {
+    jsc.stop();
+    jsc = null;
+  }
+
+  @Test
+  public void pipeline() {
+    StandardScaler scaler = new StandardScaler()
+      .setInputCol("features")
+      .setOutputCol("scaledFeatures");
+    LogisticRegression lr = new LogisticRegression()
+      .setFeaturesCol("scaledFeatures");
+    Pipeline pipeline = new Pipeline()
+      .setStages(new PipelineStage[] {scaler, lr});
+    PipelineModel model = pipeline.fit(dataset);
+    model.transform(dataset).registerTempTable("prediction");
+    JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+    predictions.collect();
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/38f9f2e1/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
new file mode 100644
index 0000000..76eb7f0
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+  .generateLogisticInputAsList;
+
+public class JavaLogisticRegressionSuite implements Serializable {
+
+  private transient JavaSparkContext jsc;
+  private transient JavaSQLContext jsql;
+  private transient JavaSchemaRDD dataset;
+
+  @Before
+  public void setUp() {
+    jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
+    jsql = new JavaSQLContext(jsc);
+    List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
+    dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+  }
+
+  @After
+  public void tearDown() {
+    jsc.stop();
+    jsc = null;
+  }
+
+  @Test
+  public void logisticRegression() {
+    LogisticRegression lr = new LogisticRegression();
+    LogisticRegressionModel model = lr.fit(dataset);
+    model.transform(dataset).registerTempTable("prediction");
+    JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+    predictions.collect();
+  }
+
+  @Test
+  public void logisticRegressionWithSetters() {
+    LogisticRegression lr = new LogisticRegression()
+      .setMaxIter(10)
+      .setRegParam(1.0);
+    LogisticRegressionModel model = lr.fit(dataset);
+    model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
+      .registerTempTable("prediction");
+    JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+    predictions.collect();
+  }
+
+  @Test
+  public void logisticRegressionFitWithVarargs() {
+    LogisticRegression lr = new LogisticRegression();
+    lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0));
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/38f9f2e1/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
new file mode 100644
index 0000000..a266ebd
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+  .generateLogisticInputAsList;
+
+public class JavaCrossValidatorSuite implements Serializable {
+
+  private transient JavaSparkContext jsc;
+  private transient JavaSQLContext jsql;
+  private transient JavaSchemaRDD dataset;
+
+  @Before
+  public void setUp() {
+    jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
+    jsql = new JavaSQLContext(jsc);
+    List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
+    dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+  }
+
+  @After
+  public void tearDown() {
+    jsc.stop();
+    jsc = null;
+  }
+
+  @Test
+  public void crossValidationWithLogisticRegression() {
+    LogisticRegression lr = new LogisticRegression();
+    ParamMap[] lrParamMaps = new ParamGridBuilder()
+      .addGrid(lr.regParam(), new double[] {0.001, 1000.0})
+      .addGrid(lr.maxIter(), new int[] {0, 10})
+      .build();
+    BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
+    CrossValidator cv = new CrossValidator()
+      .setEstimator(lr)
+      .setEstimatorParamMaps(lrParamMaps)
+      .setEvaluator(eval)
+      .setNumFolds(3);
+    CrossValidatorModel cvModel = cv.fit(dataset);
+    ParamMap bestParamMap = cvModel.bestModel().fittingParamMap();
+    Assert.assertEquals(0.001, bestParamMap.apply(lr.regParam()));
+    Assert.assertEquals(10, bestParamMap.apply(lr.maxIter()));
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/38f9f2e1/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
new file mode 100644
index 0000000..4515084
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.mockito.Matchers.{any, eq => meq}
+import org.mockito.Mockito.when
+import org.scalatest.FunSuite
+import org.scalatest.mock.MockitoSugar.mock
+
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.SchemaRDD
+
+class PipelineSuite extends FunSuite {
+
+  abstract class MyModel extends Model[MyModel]
+
+  test("pipeline") {
+    val estimator0 = mock[Estimator[MyModel]]
+    val model0 = mock[MyModel]
+    val transformer1 = mock[Transformer]
+    val estimator2 = mock[Estimator[MyModel]]
+    val model2 = mock[MyModel]
+    val transformer3 = mock[Transformer]
+    val dataset0 = mock[SchemaRDD]
+    val dataset1 = mock[SchemaRDD]
+    val dataset2 = mock[SchemaRDD]
+    val dataset3 = mock[SchemaRDD]
+    val dataset4 = mock[SchemaRDD]
+
+    when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0)
+    when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1)
+    when(model0.parent).thenReturn(estimator0)
+    when(transformer1.transform(meq(dataset1), any[ParamMap])).thenReturn(dataset2)
+    when(estimator2.fit(meq(dataset2), any[ParamMap]())).thenReturn(model2)
+    when(model2.transform(meq(dataset2), any[ParamMap]())).thenReturn(dataset3)
+    when(model2.parent).thenReturn(estimator2)
+    when(transformer3.transform(meq(dataset3), any[ParamMap]())).thenReturn(dataset4)
+
+    val pipeline = new Pipeline()
+      .setStages(Array(estimator0, transformer1, estimator2, transformer3))
+    val pipelineModel = pipeline.fit(dataset0)
+
+    assert(pipelineModel.stages.size === 4)
+    assert(pipelineModel.stages(0).eq(model0))
+    assert(pipelineModel.stages(1).eq(transformer1))
+    assert(pipelineModel.stages(2).eq(model2))
+    assert(pipelineModel.stages(3).eq(transformer3))
+
+    assert(pipelineModel.getModel(estimator0).eq(model0))
+    assert(pipelineModel.getModel(estimator2).eq(model2))
+    intercept[NoSuchElementException] {
+      pipelineModel.getModel(mock[Estimator[MyModel]])
+    }
+    val output = pipelineModel.transform(dataset0)
+    assert(output.eq(dataset4))
+  }
+
+  test("pipeline with duplicate stages") {
+    val estimator = mock[Estimator[MyModel]]
+    val pipeline = new Pipeline()
+      .setStages(Array(estimator, estimator))
+    val dataset = mock[SchemaRDD]
+    intercept[IllegalArgumentException] {
+      pipeline.fit(dataset)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/38f9f2e1/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
new file mode 100644
index 0000000..625af29
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.sql.SchemaRDD
+
+class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
+
+  import sqlContext._
+
+  val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
+
+  test("logistic regression") {
+    val lr = new LogisticRegression
+    val model = lr.fit(dataset)
+    model.transform(dataset)
+      .select('label, 'prediction)
+      .collect()
+  }
+
+  test("logistic regression with setters") {
+    val lr = new LogisticRegression()
+      .setMaxIter(10)
+      .setRegParam(1.0)
+    val model = lr.fit(dataset)
+    model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
+      .select('label, 'score, 'prediction)
+      .collect()
+  }
+
+  test("logistic regression fit and transform with varargs") {
+    val lr = new LogisticRegression
+    val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
+    model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
+      .select('label, 'probability, 'prediction)
+      .collect()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/38f9f2e1/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
new file mode 100644
index 0000000..1ce2987
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param
+
+import org.scalatest.FunSuite
+
+class ParamsSuite extends FunSuite {
+
+  val solver = new TestParams()
+  import solver.{inputCol, maxIter}
+
+  test("param") {
+    assert(maxIter.name === "maxIter")
+    assert(maxIter.doc === "max number of iterations")
+    assert(maxIter.defaultValue.get === 100)
+    assert(maxIter.parent.eq(solver))
+    assert(maxIter.toString === "maxIter: max number of iterations (default: 100)")
+    assert(inputCol.defaultValue === None)
+  }
+
+  test("param pair") {
+    val pair0 = maxIter -> 5
+    val pair1 = maxIter.w(5)
+    val pair2 = ParamPair(maxIter, 5)
+    for (pair <- Seq(pair0, pair1, pair2)) {
+      assert(pair.param.eq(maxIter))
+      assert(pair.value === 5)
+    }
+  }
+
+  test("param map") {
+    val map0 = ParamMap.empty
+
+    assert(!map0.contains(maxIter))
+    assert(map0(maxIter) === maxIter.defaultValue.get)
+    map0.put(maxIter, 10)
+    assert(map0.contains(maxIter))
+    assert(map0(maxIter) === 10)
+
+    assert(!map0.contains(inputCol))
+    intercept[NoSuchElementException] {
+      map0(inputCol)
+    }
+    map0.put(inputCol -> "input")
+    assert(map0.contains(inputCol))
+    assert(map0(inputCol) === "input")
+
+    val map1 = map0.copy
+    val map2 = ParamMap(maxIter -> 10, inputCol -> "input")
+    val map3 = new ParamMap()
+      .put(maxIter, 10)
+      .put(inputCol, "input")
+    val map4 = ParamMap.empty ++ map0
+    val map5 = ParamMap.empty
+    map5 ++= map0
+
+    for (m <- Seq(map1, map2, map3, map4, map5)) {
+      assert(m.contains(maxIter))
+      assert(m(maxIter) === 10)
+      assert(m.contains(inputCol))
+      assert(m(inputCol) === "input")
+    }
+  }
+
+  test("params") {
+    val params = solver.params
+    assert(params.size === 2)
+    assert(params(0).eq(inputCol), "params must be ordered by name")
+    assert(params(1).eq(maxIter))
+    assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n"))
+    assert(solver.getParam("inputCol").eq(inputCol))
+    assert(solver.getParam("maxIter").eq(maxIter))
+    intercept[NoSuchMethodException] {
+      solver.getParam("abc")
+    }
+    assert(!solver.isSet(inputCol))
+    intercept[IllegalArgumentException] {
+      solver.validate()
+    }
+    solver.validate(ParamMap(inputCol -> "input"))
+    solver.setInputCol("input")
+    assert(solver.isSet(inputCol))
+    assert(solver.getInputCol === "input")
+    solver.validate()
+    intercept[IllegalArgumentException] {
+      solver.validate(ParamMap(maxIter -> -10))
+    }
+    solver.setMaxIter(-10)
+    intercept[IllegalArgumentException] {
+      solver.validate()
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/38f9f2e1/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
new file mode 100644
index 0000000..1a65883
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param
+
+/** A subclass of Params for testing. */
+class TestParams extends Params {
+
+  val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100))
+  def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
+  def getMaxIter: Int = get(maxIter)
+
+  val inputCol = new Param[String](this, "inputCol", "input column name")
+  def setInputCol(value: String): this.type = { set(inputCol, value); this }
+  def getInputCol: String = get(inputCol)
+
+  override def validate(paramMap: ParamMap) = {
+    val m = this.paramMap ++ paramMap
+    require(m(maxIter) >= 0)
+    require(m.contains(inputCol))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/38f9f2e1/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
new file mode 100644
index 0000000..72a334a
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
+import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.sql.SchemaRDD
+
+class CrossValidatorSuite extends FunSuite with LocalSparkContext {
+
+  import sqlContext._
+
+  val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
+
+  test("cross validation with logistic regression") {
+    val lr = new LogisticRegression
+    val lrParamMaps = new ParamGridBuilder()
+      .addGrid(lr.regParam, Array(0.001, 1000.0))
+      .addGrid(lr.maxIter, Array(0, 10))
+      .build()
+    val eval = new BinaryClassificationEvaluator
+    val cv = new CrossValidator()
+      .setEstimator(lr)
+      .setEstimatorParamMaps(lrParamMaps)
+      .setEvaluator(eval)
+      .setNumFolds(3)
+    val cvModel = cv.fit(dataset)
+    val bestParamMap = cvModel.bestModel.fittingParamMap
+    assert(bestParamMap(lr.regParam) === 0.001)
+    assert(bestParamMap(lr.maxIter) === 10)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/38f9f2e1/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
new file mode 100644
index 0000000..20aa100
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import scala.collection.mutable
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.param.{ParamMap, TestParams}
+
+class ParamGridBuilderSuite extends FunSuite {
+
+  val solver = new TestParams()
+  import solver.{inputCol, maxIter}
+
+  test("param grid builder") {
+    def validateGrid(maps: Array[ParamMap], expected: mutable.Set[(Int, String)]): Unit =
{
+      assert(maps.size === expected.size)
+      maps.foreach { m =>
+        val tuple = (m(maxIter), m(inputCol))
+        assert(expected.contains(tuple))
+        expected.remove(tuple)
+      }
+      assert(expected.isEmpty)
+    }
+
+    val maps0 = new ParamGridBuilder()
+      .baseOn(maxIter -> 10)
+      .addGrid(inputCol, Array("input0", "input1"))
+      .build()
+    val expected0 = mutable.Set(
+      (10, "input0"),
+      (10, "input1"))
+    validateGrid(maps0, expected0)
+
+    val maps1 = new ParamGridBuilder()
+      .baseOn(ParamMap(maxIter -> 5, inputCol -> "input")) // will be overwritten
+      .addGrid(maxIter, Array(10, 20))
+      .addGrid(inputCol, Array("input0", "input1"))
+      .build()
+    val expected1 = mutable.Set(
+      (10, "input0"),
+      (20, "input0"),
+      (10, "input1"),
+      (20, "input1"))
+    validateGrid(maps1, expected1)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/38f9f2e1/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
index 7857d9e..4417d66 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
@@ -17,26 +17,17 @@
 
 package org.apache.spark.mllib.util
 
-import org.scalatest.Suite
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.{BeforeAndAfterAll, Suite}
 
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.SQLContext
 
 trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
-  @transient var sc: SparkContext = _
-
-  override def beforeAll() {
-    val conf = new SparkConf()
-      .setMaster("local")
-      .setAppName("test")
-    sc = new SparkContext(conf)
-    super.beforeAll()
-  }
+  @transient val sc = new SparkContext("local", "test")
+  @transient lazy val sqlContext = new SQLContext(sc)
 
   override def afterAll() {
-    if (sc != null) {
-      sc.stop()
-    }
+    sc.stop()
     super.afterAll()
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message