mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dlyubi...@apache.org
Subject [08/32] mahout git commit: MAHOUT-1570: Flink: calculating ncol, nrow; colSum, colMean, norm methods
Date Tue, 20 Oct 2015 05:36:51 GMT
MAHOUT-1570: Flink: calculating ncol, nrow; colSum, colMean, norm methods


Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/df1db7cc
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/df1db7cc
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/df1db7cc

Branch: refs/heads/flink-binding
Commit: df1db7cc775ed5e10c6416e033e25e430ffdd171
Parents: 522f3d5
Author: Alexey Grigorev <alexey.s.grigoriev@gmail.com>
Authored: Tue May 26 16:17:14 2015 +0200
Committer: Alexey Grigorev <alexey.s.grigoriev@gmail.com>
Committed: Fri Sep 25 17:41:45 2015 +0200

----------------------------------------------------------------------
 .../mahout/flinkbindings/FlinkEngine.scala      | 30 +++++++++--
 .../drm/CheckpointedFlinkDrm.scala              | 52 +++++++++++++-------
 2 files changed, 62 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/df1db7cc/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
index 03d1a9c..6696152 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
@@ -44,6 +44,8 @@ import org.apache.mahout.math.drm.logical.OpRbind
 import org.apache.mahout.math.drm.logical.OpMapBlock
 import org.apache.mahout.math.drm.logical.OpRowRange
 import org.apache.mahout.math.drm.logical.OpTimesRightMatrix
+import org.apache.flink.api.common.functions.MapFunction
+import org.apache.flink.api.common.functions.ReduceFunction
 
 object FlinkEngine extends DistributedEngine {
 
@@ -119,15 +121,37 @@ object FlinkEngine extends DistributedEngine {
   def translate[K: ClassTag](oper: DrmLike[K]): DataSet[K] = ???
 
   /** Engine-specific colSums implementation based on a checkpoint. */
-  override def colSums[K: ClassTag](drm: CheckpointedDrm[K]): Vector = ???
+  override def colSums[K: ClassTag](drm: CheckpointedDrm[K]): Vector = {
+    val sum = drm.ds.map(new MapFunction[(K, Vector), Vector] {
+      def map(tuple: (K, Vector)): Vector = tuple._2
+    }).reduce(new ReduceFunction[Vector] {
+      def reduce(v1: Vector, v2: Vector) = v1 + v2
+    })
+
+    val list = CheckpointedFlinkDrm.flinkCollect(sum, "FlinkEngine colSums()")
+    list.head
+  }
 
   /** Engine-specific numNonZeroElementsPerColumn implementation based on a checkpoint. */
   override def numNonZeroElementsPerColumn[K: ClassTag](drm: CheckpointedDrm[K]): Vector
= ???
 
   /** Engine-specific colMeans implementation based on a checkpoint. */
-  override def colMeans[K: ClassTag](drm: CheckpointedDrm[K]): Vector = ???
+  override def colMeans[K: ClassTag](drm: CheckpointedDrm[K]): Vector = {
+    drm.colSums() / drm.nrow
+  }
 
-  override def norm[K: ClassTag](drm: CheckpointedDrm[K]): Double = ???
+  override def norm[K: ClassTag](drm: CheckpointedDrm[K]): Double = {
+    val sumOfSquares = drm.ds.map(new MapFunction[(K, Vector), Double] {
+      def map(tuple: (K, Vector)): Double = tuple match {
+        case (idx, vec) => vec dot vec
+      }
+    }).reduce(new ReduceFunction[Double] {
+      def reduce(v1: Double, v2: Double) = v1 + v2
+    })
+
+    val list = CheckpointedFlinkDrm.flinkCollect(sumOfSquares, "FlinkEngine norm()")
+    list.head
+  }
 
   /** Broadcast support */
   override def drmBroadcast(v: Vector)(implicit dc: DistributedContext): BCast[Vector] =
???

http://git-wip-us.apache.org/repos/asf/mahout/blob/df1db7cc/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
index c19920f..e7d9dcd 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
@@ -5,7 +5,6 @@ import org.apache.mahout.math.drm._
 import org.apache.mahout.math.scalabindings._
 import RLikeOps._
 import org.apache.mahout.flinkbindings._
-
 import org.apache.mahout.math.drm.CheckpointedDrm
 import org.apache.mahout.math.Matrix
 import org.apache.mahout.flinkbindings.FlinkDistributedContext
@@ -17,8 +16,10 @@ import org.apache.mahout.math.DenseMatrix
 import org.apache.mahout.math.SparseMatrix
 import org.apache.flink.api.java.io.LocalCollectionOutputFormat
 import java.util.ArrayList
-
 import scala.collection.JavaConverters._
+import org.apache.flink.api.common.functions.MapFunction
+import org.apache.flink.api.common.functions.ReduceFunction
+import org.apache.flink.api.java.DataSet
 
 class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K],
   private var _nrow: Long = CheckpointedFlinkDrm.UNKNOWN,
@@ -27,20 +28,31 @@ class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K],
   override protected[mahout] val partitioningTag: Long = Random.nextLong(),
   private var _canHaveMissingRows: Boolean = false) extends CheckpointedDrm[K] {
 
-  lazy val nrow = if (_nrow >= 0) _nrow else computeNRow
-  lazy val ncol = if (_ncol >= 0) _ncol else computeNCol
+  lazy val nrow: Long = if (_nrow >= 0) _nrow else computeNRow
+  lazy val ncol: Int = if (_ncol >= 0) _ncol else computeNCol
+
+  protected def computeNRow: Long = { 
+    val count = ds.map(new MapFunction[DrmTuple[K], Long] {
+      def map(value: DrmTuple[K]): Long = 1L
+    }).reduce(new ReduceFunction[Long] {
+      def reduce(a1: Long, a2: Long) = a1 + a2
+    })
+
+    val list = CheckpointedFlinkDrm.flinkCollect(count, "CheckpointedFlinkDrm computeNRow()")
+    list.head
+  }
 
-  protected def computeNRow = ???
-  protected def computeNCol = ??? /*{
-  TODO: find out how to get one value
+  protected def computeNCol: Int = {
     val max = ds.map(new MapFunction[DrmTuple[K], Int] {
       def map(value: DrmTuple[K]): Int = value._2.length
     }).reduce(new ReduceFunction[Int] {
       def reduce(a1: Int, a2: Int) = Math.max(a1, a2)
     })
-    
-    max
-  }*/
+
+    val list = CheckpointedFlinkDrm.flinkCollect(max, "CheckpointedFlinkDrm computeNCol()")
+    list.head
+  }
+
   def keyClassTag: ClassTag[K] = implicitly[ClassTag[K]]
 
   def cache() = {
@@ -57,12 +69,7 @@ class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K],
   def checkpoint(cacheHint: CacheHint.CacheHint): CheckpointedDrm[K] = this
 
   def collect: Matrix = {
-    val dataJavaList = new ArrayList[DrmTuple[K]]
-    val outputFormat = new LocalCollectionOutputFormat[DrmTuple[K]](dataJavaList)
-    ds.output(outputFormat)
-    val data = dataJavaList.asScala
-    ds.getExecutionEnvironment.execute("Checkpointed Flink Drm collect()")
-
+    val data = CheckpointedFlinkDrm.flinkCollect(ds, "Checkpointed Flink Drm collect()")
     val isDense = data.forall(_._2.isDense)
 
     val m = if (isDense) {
@@ -99,5 +106,16 @@ class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K],
 }
 
 object CheckpointedFlinkDrm {
-  val UNKNOWN = -1;
+  val UNKNOWN = -1
+
+  // needed for backwards compatibility with flink 0.8.1
+  def flinkCollect[K](dataset: DataSet[K], jobName: String = "flinkCollect()"): List[K] =
{
+    val dataJavaList = new ArrayList[K]
+    val outputFormat = new LocalCollectionOutputFormat[K](dataJavaList)
+    dataset.output(outputFormat)
+    val data = dataJavaList.asScala
+    dataset.getExecutionEnvironment.execute(jobName)
+    data.toList
+  }
+
 }
\ No newline at end of file


Mime
View raw message