spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From hhbyyh <...@git.apache.org>
Subject [GitHub] spark pull request #16158: [SPARK-18724][ML] Add TuningSummary for TrainVali...
Date Mon, 11 Sep 2017 17:12:59 GMT
Github user hhbyyh commented on a diff in the pull request:

    https://github.com/apache/spark/pull/16158#discussion_r138133273
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala ---
    @@ -85,6 +86,32 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
         instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName)
         instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length)
       }
    +
    +
    +  /**
    +   * Summary of grid search tuning in the format of DataFrame. Each row contains one
candidate
    +   * paramMap and the corresponding metric of trained model.
    +   */
    +  protected def getTuningSummaryDF(metrics: Array[Double]): DataFrame = {
    +    val params = $(estimatorParamMaps)
    +    require(params.nonEmpty, "estimator param maps should not be empty")
    +    require(params.length == metrics.length, "estimator param maps number should match
metrics")
    +    val metricName = $(evaluator) match {
    +      case b: BinaryClassificationEvaluator => b.getMetricName
    +      case m: MulticlassClassificationEvaluator => m.getMetricName
    +      case r: RegressionEvaluator => r.getMetricName
    +      case _ => "metrics"
    +    }
    +    val spark = SparkSession.builder().getOrCreate()
    +    val sc = spark.sparkContext
    +    val fields = params(0).toSeq.sortBy(_.param.name).map(_.param.name) ++ Seq(metricName)
    +    val schema = new StructType(fields.map(name => StructField(name, StringType)).toArray)
    +    val rows = sc.parallelize(params.zip(metrics)).map { case (param, metric) =>
    +      val values = param.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric.toString)
    +      Row.fromSeq(values)
    +    }
    --- End diff --
    
    OK



---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Mime
View raw message