spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sethah <...@git.apache.org>
Subject [GitHub] spark pull request #17673: [SPARK-20372] [ML] Word2Vec Continuous Bag of Wor...
Date Thu, 05 Oct 2017 21:39:46 GMT
Github user sethah commented on a diff in the pull request:

    https://github.com/apache/spark/pull/17673#discussion_r143011936
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/Word2VecCBOWSolver.scala ---
    @@ -0,0 +1,344 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import com.github.fommil.netlib.BLAS.{getInstance => blas}
    +
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.mllib.feature
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.util.random.XORShiftRandom
    +
    +object Word2VecCBOWSolver extends Logging {
    +  // learning rate is updated for every batch of size batchSize
    +  private val batchSize = 10000
    +
    +  // power to raise the unigram distribution with
    +  private val power = 0.75
    +
    +  private val MAX_EXP = 6
    +
    +  case class Vocabulary(
    +    totalWordCount: Long,
    +    vocabMap: Map[String, Int],
    +    unigramTable: Array[Int],
    +    samplingTable: Array[Float])
    +
    +  /**
    +   * This method implements Word2Vec Continuous Bag Of Words based implementation using
    +   * negative sampling optimization, using BLAS for vectorizing operations where applicable.
    +   * The algorithm is parallelized in the same way as the skip-gram based estimation.
    +   * We divide input data into N equally sized random partitions.
    +   * We then generate initial weights and broadcast them to the N partitions. This way
    +   * all the partitions start with the same initial weights. We then run N independent
    +   * estimations that each estimate a model on a partition. The weights learned
    +   * from each of the N models are averaged and rebroadcast the weights.
    +   * This process is repeated `maxIter` number of times.
    +   *
    +   * @param input A RDD of strings. Each string would be considered a sentence.
    +   * @return Estimated word2vec model
    +   */
    +  def fitCBOW[S <: Iterable[String]](
    +      word2Vec: Word2Vec,
    +      input: RDD[S]): feature.Word2VecModel = {
    +
    +    val negativeSamples = word2Vec.getNegativeSamples
    +    val sample = word2Vec.getSample
    +
    +    val Vocabulary(totalWordCount, vocabMap, uniTable, sampleTable) =
    +      generateVocab(input, word2Vec.getMinCount, sample, word2Vec.getUnigramTableSize)
    +    val vocabSize = vocabMap.size
    +
    +    assert(negativeSamples < vocabSize, s"Vocab size ($vocabSize) cannot be smaller"
+
    +      s" than negative samples($negativeSamples)")
    +
    +    val seed = word2Vec.getSeed
    +    val initRandom = new XORShiftRandom(seed)
    +
    +    val vectorSize = word2Vec.getVectorSize
    +    val syn0Global = Array.fill(vocabSize * vectorSize)(initRandom.nextFloat - 0.5f)
    +    val syn1Global = Array.fill(vocabSize * vectorSize)(0.0f)
    +
    +    val sc = input.context
    +
    +    val vocabMapBroadcast = sc.broadcast(vocabMap)
    +    val unigramTableBroadcast = sc.broadcast(uniTable)
    +    val sampleTableBroadcast = sc.broadcast(sampleTable)
    +
    +    val windowSize = word2Vec.getWindowSize
    +    val maxSentenceLength = word2Vec.getMaxSentenceLength
    +    val numPartitions = word2Vec.getNumPartitions
    +
    +    val digitSentences = input.flatMap { sentence =>
    +      val wordIndexes = sentence.flatMap(vocabMapBroadcast.value.get)
    +      wordIndexes.grouped(maxSentenceLength).map(_.toArray)
    +    }.repartition(numPartitions).cache()
    +
    +    val learningRate = word2Vec.getStepSize
    +
    +    val wordsPerPartition = totalWordCount / numPartitions
    +
    +    logInfo(s"VocabSize: ${vocabMap.size}, TotalWordCount: $totalWordCount")
    +
    +    val maxIter = word2Vec.getMaxIter
    +    for {iteration <- 1 to maxIter} {
    --- End diff --
    
    parentheses


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Mime
View raw message