spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From viirya <...@git.apache.org>
Subject [GitHub] spark pull request #20146: [SPARK-11215][ML] Add multiple columns support to...
Date Mon, 23 Apr 2018 13:26:02 GMT
Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20146#discussion_r183393215
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---
    @@ -130,21 +161,53 @@ class StringIndexer @Since("1.4.0") (
       @Since("1.4.0")
       def setOutputCol(value: String): this.type = set(outputCol, value)
     
    +  /** @group setParam */
    +  @Since("2.4.0")
    +  def setInputCols(value: Array[String]): this.type = set(inputCols, value)
    +
    +  /** @group setParam */
    +  @Since("2.4.0")
    +  def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
    +
       @Since("2.0.0")
       override def fit(dataset: Dataset[_]): StringIndexerModel = {
         transformSchema(dataset.schema, logging = true)
    -    val values = dataset.na.drop(Array($(inputCol)))
    -      .select(col($(inputCol)).cast(StringType))
    -      .rdd.map(_.getString(0))
    -    val labels = $(stringOrderType) match {
    -      case StringIndexer.frequencyDesc => values.countByValue().toSeq.sortBy(-_._2)
    -        .map(_._1).toArray
    -      case StringIndexer.frequencyAsc => values.countByValue().toSeq.sortBy(_._2)
    -        .map(_._1).toArray
    -      case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > _)
    -      case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _)
    -    }
    -    copyValues(new StringIndexerModel(uid, labels).setParent(this))
    +
    +    val (inputCols, _) = getInOutCols()
    +    val zeroState = Array.fill(inputCols.length)(new OpenHashMap[String, Long]())
    +
    +    // Counts by the string values in the dataset.
    +    val countByValueArray = dataset.na.drop(inputCols)
    +      .select(inputCols.map(col(_).cast(StringType)): _*)
    +      .rdd.treeAggregate(zeroState)(
    --- End diff --
    
    I think it should be doable with SQL `Aggregator`.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Mime
View raw message