spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mengxr <...@git.apache.org>
Subject [GitHub] spark pull request: [MLlib] [SPARK-2510]Word2Vec: Distributed Repr...
Date Sun, 03 Aug 2014 18:27:32 GMT
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/1719#discussion_r15735820
  
    --- Diff: mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala ---
    @@ -0,0 +1,414 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.feature
    +
    +import scala.util.Random
    +import scala.collection.mutable.ArrayBuffer
    +import scala.collection.mutable
    +
    +import com.github.fommil.netlib.BLAS.{getInstance => blas}
    +
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.Logging
    +import org.apache.spark.rdd._
    +import org.apache.spark.SparkContext._
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
    +import org.apache.spark.HashPartitioner
    +
    +/**
    + *  Entry in vocabulary 
    + */
    +private case class VocabWord(
    +  var word: String,
    +  var cn: Int,
    +  var point: Array[Int],
    +  var code: Array[Int],
    +  var codeLen:Int
    +)
    +
    +/**
    + * :: Experimental ::
    + * Word2Vec creates vector representation of words in a text corpus.
    + * The algorithm first constructs a vocabulary from the corpus
    + * and then learns vector representation of words in the vocabulary. 
    + * The vector representation can be used as features in 
    + * natural language processing and machine learning algorithms.
    + * 
    + * We used skip-gram model in our implementation and hierarchical softmax 
    + * method to train the model. The variable names in the implementation
    + * mathes the original C implementation.
    + *
    + * For original C implementation, see https://code.google.com/p/word2vec/ 
    + * For research papers, see 
    + * Efficient Estimation of Word Representations in Vector Space
    + * and 
    + * Distributed Representations of Words and Phrases and their Compositionality.
    + * @param size vector dimension
    + * @param startingAlpha initial learning rate
    + * @param window context words from [-window, window]
    + * @param minCount minimum frequncy to consider a vocabulary word
    + * @param parallelisum number of partitions to run Word2Vec
    + */
    +@Experimental
    +class Word2Vec(
    +    val size: Int,
    +    val startingAlpha: Double,
    +    val window: Int,
    +    val minCount: Int,
    +    val parallelism:Int = 1,
    +    val numIterations:Int = 1) 
    +  extends Serializable with Logging {
    +  
    +  private val EXP_TABLE_SIZE = 1000
    +  private val MAX_EXP = 6
    +  private val MAX_CODE_LENGTH = 40
    +  private val MAX_SENTENCE_LENGTH = 1000
    +  private val layer1Size = size 
    +  private val modelPartitionNum = 100
    +  
    +  private var trainWordsCount = 0
    +  private var vocabSize = 0
    +  private var vocab: Array[VocabWord] = null
    +  private var vocabHash = mutable.HashMap.empty[String, Int]
    +  private var alpha = startingAlpha
    +
    +  private def learnVocab(words:RDD[String]) {
    +    vocab = words.map(w => (w, 1))
    +      .reduceByKey(_ + _)
    +      .map(x => VocabWord(
    +        x._1, 
    +        x._2, 
    +        new Array[Int](MAX_CODE_LENGTH), 
    +        new Array[Int](MAX_CODE_LENGTH), 
    +        0))
    +      .filter(_.cn >= minCount)
    +      .collect()
    +      .sortWith((a, b)=> a.cn > b.cn)
    +    
    +    vocabSize = vocab.length
    +    var a = 0
    +    while (a < vocabSize) {
    +      vocabHash += vocab(a).word -> a
    +      trainWordsCount += vocab(a).cn
    +      a += 1
    +    }
    +    logInfo("trainWordsCount = " + trainWordsCount)
    +  }
    +
    +  private def createExpTable(): Array[Double] = {
    +    val expTable = new Array[Double](EXP_TABLE_SIZE)
    +    var i = 0
    +    while (i < EXP_TABLE_SIZE) {
    +      val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
    +      expTable(i) = tmp / (tmp + 1)
    +      i += 1
    +    }
    +    expTable
    +  }
    +
    +  private def createBinaryTree() {
    +    val count = new Array[Long](vocabSize * 2 + 1)
    +    val binary = new Array[Int](vocabSize * 2 + 1)
    +    val parentNode = new Array[Int](vocabSize * 2 + 1)
    +    val code = new Array[Int](MAX_CODE_LENGTH)
    +    val point = new Array[Int](MAX_CODE_LENGTH)
    +    var a = 0
    +    while (a < vocabSize) {
    +      count(a) = vocab(a).cn
    +      a += 1
    +    }
    +    while (a < 2 * vocabSize) {
    +      count(a) = 1e9.toInt
    +      a += 1
    +    }
    +    var pos1 = vocabSize - 1
    +    var pos2 = vocabSize
    +    
    +    var min1i = 0 
    +    var min2i = 0
    +
    +    a = 0
    +    while (a < vocabSize - 1) {
    +      if (pos1 >= 0) {
    +        if (count(pos1) < count(pos2)) {
    +          min1i = pos1
    +          pos1 -= 1
    +        } else {
    +          min1i = pos2
    +          pos2 += 1
    +        }
    +      } else {
    +        min1i = pos2
    +        pos2 += 1
    +      }
    +      if (pos1 >= 0) {
    +        if (count(pos1) < count(pos2)) {
    +          min2i = pos1
    +          pos1 -= 1
    +        } else {
    +          min2i = pos2
    +          pos2 += 1
    +        }
    +      } else {
    +        min2i = pos2
    +        pos2 += 1
    +      }
    +      count(vocabSize + a) = count(min1i) + count(min2i)
    +      parentNode(min1i) = vocabSize + a
    +      parentNode(min2i) = vocabSize + a
    +      binary(min2i) = 1
    +      a += 1
    +    }
    +    // Now assign binary code to each vocabulary word
    +    var i = 0
    +    a = 0
    +    while (a < vocabSize) {
    +      var b = a
    +      i = 0
    +      while (b != vocabSize * 2 - 2) {
    +        code(i) = binary(b)
    +        point(i) = b
    +        i += 1
    +        b = parentNode(b)
    +      }
    +      vocab(a).codeLen = i
    +      vocab(a).point(0) = vocabSize - 2
    +      b = 0
    +      while (b < i) {
    +        vocab(a).code(i - b - 1) = code(b)
    +        vocab(a).point(i - b) = point(b) - vocabSize
    +        b += 1
    +      }
    +      a += 1
    +    }
    +  }
    +  
    +  /**
    +   * Computes the vector representation of each word in vocabulary.
    +   * @param dataset an RDD of words
    +   * @return a Word2VecModel
    +   */
    +
    +  def fit[S <: Iterable[String]](dataset:RDD[S]): Word2VecModel = {
    +
    +    val words = dataset.flatMap(x => x)
    +
    +    learnVocab(words)
    +    
    +    createBinaryTree()
    +    
    +    val sc = dataset.context
    +
    +    val expTable = sc.broadcast(createExpTable())
    +    val V = sc.broadcast(vocab)
    +    val VHash = sc.broadcast(vocabHash)
    --- End diff --
    
    same issue here. `bcVocabHash`


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Mime
View raw message