Return-Path: X-Original-To: apmail-flink-issues-archive@minotaur.apache.org Delivered-To: apmail-flink-issues-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id D81441724E for ; Fri, 25 Sep 2015 15:28:13 +0000 (UTC) Received: (qmail 90262 invoked by uid 500); 25 Sep 2015 15:28:13 -0000 Delivered-To: apmail-flink-issues-archive@flink.apache.org Received: (qmail 90203 invoked by uid 500); 25 Sep 2015 15:28:13 -0000 Mailing-List: contact issues-help@flink.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@flink.apache.org Delivered-To: mailing list issues@flink.apache.org Received: (qmail 90187 invoked by uid 99); 25 Sep 2015 15:28:13 -0000 Received: from Unknown (HELO spamd3-us-west.apache.org) (209.188.14.142) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 25 Sep 2015 15:28:13 +0000 Received: from localhost (localhost [127.0.0.1]) by spamd3-us-west.apache.org (ASF Mail Server at spamd3-us-west.apache.org) with ESMTP id 498A2180A28 for ; Fri, 25 Sep 2015 15:28:13 +0000 (UTC) X-Virus-Scanned: Debian amavisd-new at spamd3-us-west.apache.org X-Spam-Flag: NO X-Spam-Score: 1.775 X-Spam-Level: * X-Spam-Status: No, score=1.775 tagged_above=-999 required=6.31 tests=[KAM_ASCII_DIVIDERS=0.8, KAM_LAZY_DOMAIN_SECURITY=1, RCVD_IN_MSPIKE_H3=-0.01, RCVD_IN_MSPIKE_WL=-0.01, RP_MATCHES_RCVD=-0.006, URIBL_BLOCKED=0.001] autolearn=disabled Received: from mx1-eu-west.apache.org ([10.40.0.8]) by localhost (spamd3-us-west.apache.org [10.40.0.10]) (amavisd-new, port 10024) with ESMTP id luiAD9MFE_F1 for ; Fri, 25 Sep 2015 15:28:05 +0000 (UTC) Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx1-eu-west.apache.org (ASF Mail Server at mx1-eu-west.apache.org) with SMTP id 042B2206F0 for ; Fri, 25 Sep 2015 15:28:03 +0000 (UTC) Received: (qmail 89674 invoked by uid 99); 25 Sep 2015 15:28:03 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 25 Sep 2015 15:28:03 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 2DCF5DFE5F; Fri, 25 Sep 2015 15:28:03 +0000 (UTC) From: sachingoel0101 To: issues@flink.incubator.apache.org Reply-To: issues@flink.incubator.apache.org References: In-Reply-To: Subject: [GitHub] flink pull request: [FLINK-1719] [ml] Add Multinomial Naive Bayes ... Content-Type: text/plain Message-Id: <20150925152803.2DCF5DFE5F@git1-us-west.apache.org> Date: Fri, 25 Sep 2015 15:28:03 +0000 (UTC) Github user sachingoel0101 commented on a diff in the pull request: https://github.com/apache/flink/pull/1156#discussion_r40442050 --- Diff: flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/MultinomialNaiveBayes.scala --- @@ -0,0 +1,900 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification + +import java.{lang, util} + +import org.apache.flink.api.common.functions._ +import org.apache.flink.api.scala._ +import org.apache.flink.configuration.Configuration +import org.apache.flink.core.fs.FileSystem.WriteMode +import org.apache.flink.ml.common.{ParameterMap, Parameter} +import org.apache.flink.ml.pipeline.{PredictDataSetOperation, FitOperation, Predictor} +import org.apache.flink.util.Collector + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ListBuffer +import scala.collection.mutable.Map + +/** + * While building the model different approaches need to be compared. + * For that purpose the fitParameters are used. Every possibility that might enhance + * the implementation can be chosen separately by using the following list of parameters: + * + * Possibility 1: way of calculating document count + * P1 = 0 -> use .count() to get count of all documents + * P1 = 1 -> use a reducer and a mapper to create a broadcast data set containing the count of + * all documents + * + * Possibility 2: all words in class (order of operators) + * If p2 = 1 improves the speed, many other calculations must switch their operators, too. + * P2 = 0 -> first the reducer, than the mapper + * P2 = 1 -> first the mapper, than the reducer + * + * Possibility 3: way of calculating pwc + * P2 = 0 -> join singleWordsInClass and allWordsInClass to wordsInClass data set + * P3 = 1 -> work on singleWordsInClass data set and broadcast allWordsInClass data set + * + * Schneider/Rennie 1: ignore/reduce word frequency information + * SR1 = 0 -> word frequency information is not ignored + * SR1 = 1 -> word frequency information is ignored (Schneiders approach) + * SR1 = 2 -> word frequency information is reduced (Rennies approach) + * + * Schneider1: ignore P(c_j) in cMAP formula + * S1 = 0 -> normal cMAP formula + * S2 = 1 -> cMAP without P(c_j) + * + * Rennie1: transform document frequency + * R1 = 0 -> normal formula + * R1 = 1 -> apply inverse document frequecy + * Note: if R1 = 1 and SR1 = 2, both approaches get applied. + * + */ +class MultinomialNaiveBayes extends Predictor[MultinomialNaiveBayes] { + + import MultinomialNaiveBayes._ + + //The model, that stores all needed information that are related to one specific word + var wordRelatedModelData: Option[DataSet[(String, String, Double)]] = + None // (class name -> word -> log P(w|c)) + + //The model, that stores all needed information that are related to one specifc class+ + var classRelatedModelData: Option[DataSet[(String, Double, Double)]] = + None // (class name -> p(c) -> log p(w|c) not in class) + + //A data set that stores additional needed information for some of the improvements + var improvementData: Option[DataSet[(String, Double)]] = + None // (word -> log number of documents in all classes / word frequency in all classes + + // ============================== Parameter configuration ======================================== + + def setP1(value: Int): MultinomialNaiveBayes = { + parameters.add(P1, value) + this + } + + def setP2(value: Int): MultinomialNaiveBayes = { + parameters.add(P2, value) + this + } + + def setP3(value: Int): MultinomialNaiveBayes = { + parameters.add(P3, value) + this + } + + def setSR1(value: Int): MultinomialNaiveBayes = { + parameters.add(SR1, value) + this + } + + def setS1(value: Int): MultinomialNaiveBayes = { + parameters.add(S1, value) + this + } + + def setR1(value: Int): MultinomialNaiveBayes = { + parameters.add(R1, value) + this + } + + // =============================================== Methods ======================================= + + /** + * Save already existing model data created by the NaiveBayes algorithm. Requires the designated + * locations. The saved data is a representation of the [[wordRelatedModelData]] and + * [[classRelatedModelData]]. + * @param wordRelated, the save location for the wordRelated data + * @param classRelated, the save location for the classRelated data + */ + def saveModelDataSet(wordRelated: String, classRelated: String) : Unit = { + wordRelatedModelData.get.writeAsCsv(wordRelated, "\n", "|", WriteMode.OVERWRITE) + classRelatedModelData.get.writeAsCsv(classRelated, "\n", "|", WriteMode.OVERWRITE) + } + + /** + * Save the improvment data set. Requires the designated save location. The saved data is a + * representation of the [[improvementData]] data set. + * @param path, the save location for the improvment data + */ + def saveImprovementDataSet(path: String) : Unit = { + improvementData.get.writeAsCsv(path, "\n", "|", WriteMode.OVERWRITE) + } + + /** + * Sets the [[wordRelatedModelData]] and the [[classRelatedModelData]] to the given data sets. + * @param wordRelated, the data set representing the wordRelated model + * @param classRelated, the data set representing the classRelated model + */ + def setModelDataSet(wordRelated : DataSet[(String, String, Double)], + classRelated: DataSet[(String, Double, Double)]) : Unit = { + this.wordRelatedModelData = Some(wordRelated) + this.classRelatedModelData = Some(classRelated) + } + + def setImprovementDataSet(impSet : DataSet[(String, Double)]) : Unit = { + this.improvementData = Some(impSet) + } + +} + +object MultinomialNaiveBayes { + + // ========================================== Parameters ========================================= + case object P1 extends Parameter[Int] { + override val defaultValue: Option[Int] = Some(0) + } + + case object P2 extends Parameter[Int] { + override val defaultValue: Option[Int] = Some(0) + } + + case object P3 extends Parameter[Int] { + override val defaultValue: Option[Int] = Some(0) + } + + case object SR1 extends Parameter[Int] { + override val defaultValue: Option[Int] = Some(0) + } + + case object S1 extends Parameter[Int] { + override val defaultValue: Option[Int] = Some(0) + } + + case object R1 extends Parameter[Int] { + override val defaultValue: Option[Int] = Some(0) + } + + // ======================================== Factory Methods ====================================== + def apply(): MultinomialNaiveBayes = { + new MultinomialNaiveBayes() + } + + // ====================================== Operations ============================================= + /** + * Trains the models to fit the training data. The resulting + * [[MultinomialNaiveBayes.wordRelatedModelData]] and + * [[MultinomialNaiveBayes.classRelatedModelData]] are stored in the [[MultinomialNaiveBayes]] + * instance. + */ + + implicit val fitNNB = new FitOperation[MultinomialNaiveBayes, (String, String)] { + /** + * The [[FitOperation]] used to create the model. Requires an instance of + * [[MultinomialNaiveBayes]], a [[ParameterMap]] and the input data set. This data set + * maps (string -> string) containing (label -> text, words separated by ",") + * @param instance of [[MultinomialNaiveBayes]] + * @param fitParameters, additional parameters + * @param input, the to processed data set + */ + override def fit(instance: MultinomialNaiveBayes, + fitParameters: ParameterMap, + input: DataSet[(String, String)]): Unit = { + + val resultingParameters = instance.parameters ++ fitParameters + + //Count the amount of documents for each class. + // 1. Map: replace the document text by a 1 + // 2. Group-Reduce: sum the 1s by class + val documentsPerClass: DataSet[(String, Int)] = input.map { input => (input._1, 1)} + .groupBy(0).sum(1) // (class name -> count of documents) + + //Count the amount of occurrences of each word for each class. + // 1. FlatMap: split the document into its words and add a 1 to each tuple + // 2. Group-Reduce: sum the 1s by class, word + var singleWordsInClass: DataSet[(String, String, Int)] = input + .flatMap(new SingleWordSplitter()) + .groupBy(0, 1).sum(2) // (class name -> word -> count of that word) + + //POSSIBILITY 2: all words in class (order of operators) + //SCHNEIDER/RENNIE 1: ignore/reduce word frequency information + //the allWordsInClass data set does only contain distinct + //words for schneiders approach: ndw(cj), nothing changes for rennies approach + + val p2 = resultingParameters(P2) + --- End diff -- line break is not needed. --- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastructure@apache.org or file a JIRA ticket with INFRA. ---