spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From j...@apache.org
Subject git commit: [SPARK-3578] Fix upper bound in GraphGenerators.sampleLogNormal
Date Mon, 22 Sep 2014 20:47:54 GMT
Repository: spark
Updated Branches:
  refs/heads/master 56dae30ca -> f9d6220c7


[SPARK-3578] Fix upper bound in GraphGenerators.sampleLogNormal

GraphGenerators.sampleLogNormal is supposed to return an integer strictly less than maxVal.
However, it violates this guarantee. It generates its return value as follows:

```scala
var X: Double = maxVal

while (X >= maxVal) {
  val Z = rand.nextGaussian()
  X = math.exp(mu + sigma*Z)
}
math.round(X.toFloat)
```

When X is sampled to be close to (but less than) maxVal, then it will pass the while loop
condition, but the rounded result will be equal to maxVal, which will violate the guarantee.
For example, if maxVal is 5 and X is 4.9, then X < maxVal, but `math.round(X.toFloat)`
is 5.

This PR instead rounds X before checking the loop condition, guaranteeing that the condition
will hold for the return value.

Author: Ankur Dave <ankurdave@gmail.com>

Closes #2439 from ankurdave/SPARK-3578 and squashes the following commits:

f6655e5 [Ankur Dave] Go back to math.floor
5900c22 [Ankur Dave] Round X in loop condition
6fd5fb1 [Ankur Dave] Run sampleLogNormal bounds check 1000 times
1638598 [Ankur Dave] Round down in sampleLogNormal to guarantee upper bound


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f9d6220c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f9d6220c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f9d6220c

Branch: refs/heads/master
Commit: f9d6220c792b779be385f3022d146911a22c2130
Parents: 56dae30
Author: Ankur Dave <ankurdave@gmail.com>
Authored: Mon Sep 22 13:47:43 2014 -0700
Committer: Joseph E. Gonzalez <joseph.e.gonzalez@gmail.com>
Committed: Mon Sep 22 13:47:43 2014 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/graphx/util/GraphGenerators.scala  | 2 +-
 .../org/apache/spark/graphx/util/GraphGeneratorsSuite.scala   | 7 +++++--
 2 files changed, 6 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f9d6220c/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
----------------------------------------------------------------------
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
index b830928..8a13c74 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
@@ -118,7 +118,7 @@ object GraphGenerators {
       val Z = rand.nextGaussian()
       X = math.exp(mu + sigma*Z)
     }
-    math.round(X.toFloat)
+    math.floor(X).toInt
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/f9d6220c/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
----------------------------------------------------------------------
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
index b346d4d..3abefbe 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
@@ -64,8 +64,11 @@ class GraphGeneratorsSuite extends FunSuite with LocalSparkContext {
     val sigma = 1.3
     val maxVal = 100
 
-    val dstId = GraphGenerators.sampleLogNormal(mu, sigma, maxVal)
-    assert(dstId < maxVal)
+    val trials = 1000
+    for (i <- 1 to trials) {
+      val dstId = GraphGenerators.sampleLogNormal(mu, sigma, maxVal)
+      assert(dstId < maxVal)
+    }
 
     val dstId_round1 = GraphGenerators.sampleLogNormal(mu, sigma, maxVal, 12345)
     val dstId_round2 = GraphGenerators.sampleLogNormal(mu, sigma, maxVal, 12345)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message