flink-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sachingoel0101 <...@git.apache.org>
Subject [GitHub] flink pull request: [FLINK-1719] [ml] Add Multinomial Naive Bayes ...
Date Fri, 25 Sep 2015 15:42:44 GMT
Github user sachingoel0101 commented on a diff in the pull request:

    https://github.com/apache/flink/pull/1156#discussion_r40443743
  
    --- Diff: flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/MultinomialNaiveBayes.scala
---
    @@ -0,0 +1,900 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.flink.ml.classification
    +
    +import java.{lang, util}
    +
    +import org.apache.flink.api.common.functions._
    +import org.apache.flink.api.scala._
    +import org.apache.flink.configuration.Configuration
    +import org.apache.flink.core.fs.FileSystem.WriteMode
    +import org.apache.flink.ml.common.{ParameterMap, Parameter}
    +import org.apache.flink.ml.pipeline.{PredictDataSetOperation, FitOperation, Predictor}
    +import org.apache.flink.util.Collector
    +
    +import scala.collection.JavaConverters._
    +import scala.collection.mutable
    +import scala.collection.mutable.ListBuffer
    +import scala.collection.mutable.Map
    +
    +/**
    + * While building the model different approaches need to be compared.
    + * For that purpose the fitParameters are used. Every possibility that might enhance
    + * the implementation can be chosen separately by using the following list of parameters:
    + *
    + * Possibility 1: way of calculating document count
    + *  P1 = 0 -> use .count() to get count of all documents
    + *  P1 = 1 -> use a reducer and a mapper to create a broadcast data set containing
the count of
    + *    all documents
    + *
    + * Possibility 2: all words in class (order of operators)
    + *    If p2 = 1 improves the speed, many other calculations must switch their operators,
too.
    + *  P2 = 0 -> first the reducer, than the mapper
    + *  P2 = 1 -> first the mapper, than the reducer
    + *
    + * Possibility 3: way of calculating pwc
    + *  P2 = 0 -> join singleWordsInClass and allWordsInClass to wordsInClass data set
    + *  P3 = 1 -> work on singleWordsInClass data set and broadcast allWordsInClass data
set
    + *
    + * Schneider/Rennie 1: ignore/reduce word frequency information
    + *  SR1 = 0 -> word frequency information is not ignored
    + *  SR1 = 1 -> word frequency information is ignored (Schneiders approach)
    + *  SR1 = 2 -> word frequency information is reduced (Rennies approach)
    + *
    + * Schneider1: ignore P(c_j) in cMAP formula
    + *  S1 = 0 -> normal cMAP formula
    + *  S2 = 1 -> cMAP without P(c_j)
    + *
    + * Rennie1: transform document frequency
    + *  R1 = 0 -> normal formula
    + *  R1 = 1 -> apply inverse document frequecy
    + * Note: if R1 = 1 and SR1 = 2, both approaches get applied.
    + *
    + */
    +class MultinomialNaiveBayes extends Predictor[MultinomialNaiveBayes] {
    +
    +  import MultinomialNaiveBayes._
    +
    +  //The model, that stores all needed information that are related to one specific word
    +  var wordRelatedModelData: Option[DataSet[(String, String, Double)]] =
    +    None // (class name -> word -> log P(w|c))
    +
    +  //The model, that stores all needed information that are related to one specifc class+
    +  var classRelatedModelData: Option[DataSet[(String, Double, Double)]] =
    +    None // (class name -> p(c) -> log p(w|c) not in class)
    +
    +  //A data set that stores additional needed information for some of the improvements
    +  var improvementData: Option[DataSet[(String, Double)]] =
    +    None // (word -> log number of documents in all classes / word frequency in all
classes
    +
    +  // ============================== Parameter configuration ========================================
    +
    +  def setP1(value: Int): MultinomialNaiveBayes = {
    +    parameters.add(P1, value)
    +    this
    +  }
    +
    +  def setP2(value: Int): MultinomialNaiveBayes = {
    +    parameters.add(P2, value)
    +    this
    +  }
    +
    +  def setP3(value: Int): MultinomialNaiveBayes = {
    +    parameters.add(P3, value)
    +    this
    +  }
    +
    +  def setSR1(value: Int): MultinomialNaiveBayes = {
    +    parameters.add(SR1, value)
    +    this
    +  }
    +
    +  def setS1(value: Int): MultinomialNaiveBayes = {
    +    parameters.add(S1, value)
    +    this
    +  }
    +
    +  def setR1(value: Int): MultinomialNaiveBayes = {
    +    parameters.add(R1, value)
    +    this
    +  }
    +
    +  // =============================================== Methods =======================================
    +
    +  /**
    +   * Save already existing model data created by the NaiveBayes algorithm. Requires the
designated
    +   * locations. The saved data is a representation of the [[wordRelatedModelData]] and
    +   * [[classRelatedModelData]].
    +   * @param wordRelated, the save location for the wordRelated data
    +   * @param classRelated, the save location for the classRelated data
    +   */
    +  def saveModelDataSet(wordRelated: String, classRelated: String) : Unit = {
    +    wordRelatedModelData.get.writeAsCsv(wordRelated, "\n", "|", WriteMode.OVERWRITE)
    +    classRelatedModelData.get.writeAsCsv(classRelated, "\n", "|", WriteMode.OVERWRITE)
    +  }
    +
    +  /**
    +   * Save the improvment data set. Requires the designated save location. The saved data
is a
    +   * representation of the [[improvementData]] data set.
    +   * @param path, the save location for the improvment data
    +   */
    +  def saveImprovementDataSet(path: String) : Unit = {
    +    improvementData.get.writeAsCsv(path, "\n", "|", WriteMode.OVERWRITE)
    +  }
    +
    +  /**
    +   * Sets the [[wordRelatedModelData]] and the [[classRelatedModelData]] to the given
data sets.
    +   * @param wordRelated, the data set representing the wordRelated model
    +   * @param classRelated, the data set representing the classRelated model
    +   */
    +  def setModelDataSet(wordRelated : DataSet[(String, String, Double)],
    +                      classRelated: DataSet[(String, Double, Double)]) : Unit = {
    +    this.wordRelatedModelData = Some(wordRelated)
    +    this.classRelatedModelData = Some(classRelated)
    +  }
    +
    +  def setImprovementDataSet(impSet : DataSet[(String, Double)]) : Unit = {
    +    this.improvementData = Some(impSet)
    +  }
    +
    +}
    +
    +object MultinomialNaiveBayes {
    +
    +  // ========================================== Parameters =========================================
    +  case object P1 extends Parameter[Int] {
    +    override val defaultValue: Option[Int] = Some(0)
    +  }
    +
    +  case object P2 extends Parameter[Int] {
    +    override val defaultValue: Option[Int] = Some(0)
    +  }
    +
    +  case object P3 extends Parameter[Int] {
    +    override val defaultValue: Option[Int] = Some(0)
    +  }
    +
    +  case object SR1 extends Parameter[Int] {
    +    override val defaultValue: Option[Int] = Some(0)
    +  }
    +
    +  case object S1 extends Parameter[Int] {
    +    override val defaultValue: Option[Int] = Some(0)
    +  }
    +
    +  case object R1 extends Parameter[Int] {
    +    override val defaultValue: Option[Int] = Some(0)
    +  }
    +
    +  // ======================================== Factory Methods ======================================
    +  def apply(): MultinomialNaiveBayes = {
    +    new MultinomialNaiveBayes()
    +  }
    +
    +  // ====================================== Operations =============================================
    +  /**
    +   * Trains the models to fit the training data. The resulting
    +   * [[MultinomialNaiveBayes.wordRelatedModelData]] and
    +   * [[MultinomialNaiveBayes.classRelatedModelData]] are stored in the [[MultinomialNaiveBayes]]
    +   * instance.
    +   */
    +
    +  implicit val fitNNB = new FitOperation[MultinomialNaiveBayes, (String, String)] {
    +    /**
    +     * The [[FitOperation]] used to create the model. Requires an instance of
    +     * [[MultinomialNaiveBayes]], a [[ParameterMap]] and the input data set. This data
set
    +     * maps (string -> string) containing (label -> text, words separated by ",")
    +     * @param instance of [[MultinomialNaiveBayes]]
    +     * @param fitParameters, additional parameters
    +     * @param input, the to processed data set
    +     */
    +    override def fit(instance: MultinomialNaiveBayes,
    +                     fitParameters: ParameterMap,
    +                     input: DataSet[(String, String)]): Unit = {
    +
    +      val resultingParameters = instance.parameters ++ fitParameters
    +
    +      //Count the amount of documents for each class.
    +      // 1. Map: replace the document text by a 1
    +      // 2. Group-Reduce: sum the 1s by class
    +      val documentsPerClass: DataSet[(String, Int)] = input.map { input => (input._1,
1)}
    +        .groupBy(0).sum(1) // (class name -> count of documents)
    +
    +      //Count the amount of occurrences of each word for each class.
    +      // 1. FlatMap: split the document into its words and add a 1 to each tuple
    +      // 2. Group-Reduce: sum the 1s by class, word
    +      var singleWordsInClass: DataSet[(String, String, Int)] = input
    +        .flatMap(new SingleWordSplitter())
    +        .groupBy(0, 1).sum(2) // (class name -> word -> count of that word)
    +
    +      //POSSIBILITY 2: all words in class (order of operators)
    +      //SCHNEIDER/RENNIE 1: ignore/reduce word frequency information
    +        //the allWordsInClass data set does only contain distinct
    +        //words for schneiders approach: ndw(cj), nothing changes for rennies approach
    +
    +      val p2 = resultingParameters(P2)
    +
    +      val sr1 = resultingParameters(SR1)
    +
    +      var allWordsInClass: DataSet[(String, Int)] =
    +        null // (class name -> count of all words in that class)
    +
    +      if (p2 == 0) {
    +        if (sr1 == 0 || sr1 == 2) {
    +          //Count all the words for each class.
    +          // 1. Reduce: add the count for each word in a class together
    +          // 2. Map: remove the field that contains the word
    +          allWordsInClass = singleWordsInClass.groupBy(0).reduce {
    +            (singleWords1, singleWords2) =>
    +              (singleWords1._1, singleWords1._2, singleWords1._3 + singleWords2._3)
    +          }.map(singleWords =>
    +            (singleWords._1, singleWords._3)) // (class name -> count of all words
in that class)
    +        } else if (sr1 == 1) {
    +          //Count all distinct words for each class.
    +          // 1. Map: set the word count to 1
    +          // 2. Reduce: add the count for each word in a class together
    +          // 3. Map: remove the field that contains the word
    +          allWordsInClass = singleWordsInClass
    +            .map(singleWords => (singleWords._1, singleWords._2, 1))
    +            .groupBy(0).reduce {
    +            (singleWords1, singleWords2) =>
    +              (singleWords1._1, singleWords1._2, singleWords1._3 + singleWords2._3)
    +          }.map(singleWords =>
    +            (singleWords._1, singleWords._3))//(class name -> count of distinct words
in that class)
    +        }
    +      } else if (p2 == 1) {
    +        if (sr1 == 0 || sr1 == 2) {
    +          //Count all the words for each class.
    +          // 1. Map: remove the field that contains the word
    +          // 2. Reduce: add the count for each word in a class together
    +          allWordsInClass = singleWordsInClass.map(singleWords => (singleWords._1,
singleWords._3))
    +            .groupBy(0).reduce {
    +            (singleWords1, singleWords2) => (singleWords1._1, singleWords1._2 + singleWords2._2)
    +          } // (class name -> count of all words in that class)
    +        } else if (sr1 == 1) {
    +          //Count all distinct words for each class.
    +          // 1. Map: remove the field that contains the word, set the word count to 1
    +          // 2. Reduce: add the count for each word in a class together
    +          allWordsInClass = singleWordsInClass.map(singleWords => (singleWords._1,
1))
    +            .groupBy(0).reduce {
    +            (singleWords1, singleWords2) => (singleWords1._1, singleWords1._2 + singleWords2._2)
    +          } // (class name -> count of distinct words in that class)
    +        }
    +
    +      }
    +
    +      //END SCHNEIDER/RENNIE 1
    +      //END POSSIBILITY 2
    +
    +      //POSSIBILITY 1: way of calculating document count
    +      val p1 = resultingParameters(P1)
    +
    +      var pc: DataSet[(String, Double)] = null // (class name -> P(c) in class)
    +
    +      if (p1 == 0) {
    +        val documentsCount: Double = input.count() //count of all documents
    +        //Calculate P(c)
    +        // 1. Map: divide count of documents for a class through total count of documents
    +        pc = documentsPerClass.map(line => (line._1, line._2 / documentsCount))
    +
    +      } else if (p1 == 1) {
    +        //Create a data set that contains only one double value: the count of all documents
    +        // 1. Reduce: At the count of documents together
    +        // 2. Map: Remove field that contains document identifier
    +        val documentCount: DataSet[(Double)] = documentsPerClass
    +          .reduce((line1, line2) => (line1._1, line1._2 + line2._2))
    +          .map(line => line._2) //(count of all documents)
    +
    +        //calculate P(c)
    +        // 1. Map: divide count of documents for a class through total count of documents
    +        //    (only element in documentCount data set)
    +        pc = documentsPerClass.map(new RichMapFunction[(String, Int), (String, Double)]
{
    +
    +            var broadcastSet: util.List[Double] = null
    +
    +            override def open(config: Configuration): Unit = {
    +              broadcastSet = getRuntimeContext.getBroadcastVariable[Double]("documentCount")
    +              if (broadcastSet.size() != 1) {
    +                throw new RuntimeException("The document count data set used by p1 =
1 has the " +
    +                  "wrong size! Please use p1 = 0 if the problem can not be solved.")
    +              }
    +            }
    +
    +            override def map(value: (String, Int)): (String, Double) = {
    +              (value._1, value._2 / broadcastSet.get(0))
    +            }
    +          }).withBroadcastSet(documentCount, "documentCount")
    +      }
    +      //END POSSIBILITY 1
    +
    +      // (list of all words, but distinct)
    +      val vocabulary = singleWordsInClass.map(tuple => (tuple._2, 1)).distinct(0)
    +      // (count of items in vocabulary list)
    +      val vocabularyCount: Double = vocabulary.count()
    +
    +      //calculate the P(w|c) value for words, that are not part of a class, needed for
smoothing
    +      // 1. Map: use P(w|c) formula with smoothing with n(c_j, w_t) = 0
    +      val pwcNotInClass: DataSet[(String, Double)] = allWordsInClass
    +        .map(line =>
    +          (line._1, 1 / (line._2 + vocabularyCount))) // (class name -> P(w|c) word
not in class)
    +
    +      //SCHNEIDER/RENNIE 1: ignore/reduce word frequency information
    +        //The singleWordsInClass data set must be changed before, the calculation of
pwc starts for
    +        //schneider, it needs this form classname -> word -> number of documents
containing wt in cj
    +
    +      if (sr1 == 1) {
    +        //Calculate the required data set (see above)
    +        // 1. FlatMap: class -> word -> 1 (one tuple for each document in which
this word occurs)
    +        // 2. Group-Reduce: sum all 1s where the first two fields equal
    +        // 3. Map: Remove unesseccary count of word and replace with 1
    +        singleWordsInClass = input
    +          .flatMap(new SingleDistinctWordSplitter())
    +          .groupBy(0, 1)
    +          .reduce((line1, line2) => (line1._1, line1._2, line1._3 + line2._3))
    +      }
    +
    +      //END SCHNEIDER/RENNIE 1
    +
    +      //POSSIBILITY 3: way of calculating pwc
    +
    +      val p3 = resultingParameters(P3)
    +
    +      var pwc: DataSet[(String, String, Double)] = null // (class name -> word ->
P(w|c))
    +
    +      if (p3 == 0) {
    +
    +          //Join the singleWordsInClass data set with the allWordsInClass data set to
use the
    +          //information for the calculation of p(w|c).
    +          val wordsInClass = singleWordsInClass
    +            .join(allWordsInClass).where(0).equalTo(0) {
    +            (single, all) => (single._1, single._2, single._3, all._2)
    +          } // (class name -> word -> count of that word -> count of all words
in that class)
    +
    +          //calculate the P(w|c) value for each word in each class
    +          // 1. Map: use normal P(w|c) formula
    +          pwc = wordsInClass.map(line => (line._1, line._2, (line._3 + 1) /
    +            (line._4 + vocabularyCount)))
    +
    +      } else if (p3 == 1) {
    +
    +        //calculate the P(w|c) value for each word in class
    +        //  1. Map: use normal P(w|c) formula / use the
    +        pwc = singleWordsInClass.map(new RichMapFunction[(String, String, Int),
    +          (String, String, Double)] {
    +
    +          var broadcastMap: mutable.Map[String, Int] = mutable.Map[String, Int]()
    +
    +
    +          override def open(config: Configuration): Unit = {
    +            val collection = getRuntimeContext
    +              .getBroadcastVariable[(String, Int)]("allWordsInClass").asScala
    +            for (record <- collection) {
    +              broadcastMap.put(record._1, record._2)
    +            }
    +          }
    +
    +          override def map(value: (String, String, Int)): (String, String, Double) =
{
    +            (value._1, value._2, (value._3 + 1) / (broadcastMap(value._1) + vocabularyCount))
    +          }
    +        }).withBroadcastSet(allWordsInClass, "allWordsInClass")
    +
    +      }
    +
    +      //END POSSIBILITY 3
    +
    +      //stores all the word related information in one data set
    +      // 1. Map: Caluclate logarithms
    +      val wordRelatedModelData = pwc.map(line => (line._1, line._2, Math.log(line._3)))
    +
    +      //store all class related information in one data set
    +      // 1. Join: P(c) data set and P(w|c) data set not in class and calculate logarithms
    +      val classRelatedModelData = pc.join(pwcNotInClass)
    +        .where(0).equalTo(0) {
    +        (line1, line2) => (line1._1, Math.log(line1._2), Math.log(line2._2))
    +      } // (class name -> log(P(c)) -> log(P(w|c) not in class))
    +
    +      instance.wordRelatedModelData = Some(wordRelatedModelData)
    +      instance.classRelatedModelData = Some(classRelatedModelData)
    +
    +      //RENNIE 1: transform document frequency
    +        //for this, the improvementData set must be set
    +        //calculate (word -> log number of documents in all classes / docs with that
word)
    +
    +      val r1 = resultingParameters(R1)
    +
    +      if (r1 == 1) {
    +        val totalDocumentCount: DataSet[(Double)] = documentsPerClass
    +          .reduce((line1, line2) => (line1._1, line1._2 + line2._2))
    +          .map(line => line._2) // (count of all documents)
    +
    +        //number of occurences over all documents of all classes
    +        val wordCountTotal = input
    +          .flatMap(new SingleDistinctWordSplitter())
    +          .map(line => (line._2, 1))
    +          .groupBy(0)
    +          .reduce((line1, line2) => (line1._1, line1._2 + line2._2))
    +           // (word -> count of documents with that word)
    +
    +        val improvementData = wordCountTotal.map(new RichMapFunction[(String, Int),
    +          (String, Double)] {
    +
    +          var broadcastSet: util.List[Double] = null
    +
    +          override def open(config: Configuration): Unit = {
    +            broadcastSet = getRuntimeContext.getBroadcastVariable[Double]("totalDocumentCount")
    +            if (broadcastSet.size() != 1) {
    +              throw new RuntimeException("The total document count data set used by 11
= 1 has " +
    +                "the wrong size! Please use r1 = 0 if the problem can not be solved.")
    +            }
    +          }
    +
    +          override def map(value: (String, Int)): (String, Double) = {
    +            (value._1, Math.log(broadcastSet.get(0) / value._2))
    +          }
    +        }).withBroadcastSet(totalDocumentCount, "totalDocumentCount")
    +
    +        instance.improvementData = Some(improvementData)
    +      }
    +
    +    }
    +  }
    +
    +  // Model (String, String, Double, Double, Double)
    +  implicit def predictNNB = new PredictDataSetOperation[
    +    MultinomialNaiveBayes,
    +    (Int, String),
    +    (Int, String)]() {
    +
    +    override def predictDataSet(instance: MultinomialNaiveBayes,
    --- End diff --
    
    indentations.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

Mime
View raw message