spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-21174][SQL] Validate sampling fraction in logical operator level
Date Fri, 23 Jun 2017 01:27:40 GMT
Repository: spark
Updated Branches:
  refs/heads/master 5b5a69bea -> b8a743b6a


[SPARK-21174][SQL] Validate sampling fraction in logical operator level

## What changes were proposed in this pull request?

Currently the validation of sampling fraction in dataset is incomplete.
As an improvement, validate sampling fraction in logical operator level:
1) if with replacement: fraction should be nonnegative
2) else: fraction should be on interval [0, 1]
Also add test cases for the validation.

## How was this patch tested?
integration tests

gatorsmile cloud-fan
Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Wang Gengliang <ltnwgl@gmail.com>

Closes #18387 from gengliangwang/sample_ratio_validate.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b8a743b6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b8a743b6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b8a743b6

Branch: refs/heads/master
Commit: b8a743b6a531432e57eb50ecff06798ebc19483e
Parents: 5b5a69b
Author: Wang Gengliang <ltnwgl@gmail.com>
Authored: Fri Jun 23 09:27:35 2017 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Fri Jun 23 09:27:35 2017 +0800

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/parser/SqlBase.g4 |  2 +-
 .../spark/sql/catalyst/parser/AstBuilder.scala  |  3 +-
 .../plans/logical/basicLogicalOperators.scala   | 13 ++++
 .../scala/org/apache/spark/sql/Dataset.scala    |  3 -
 .../sql-tests/inputs/tablesample-negative.sql   | 14 +++++
 .../sql-tests/results/operators.sql.out         |  8 +--
 .../results/tablesample-negative.sql.out        | 62 ++++++++++++++++++++
 .../org/apache/spark/sql/DatasetSuite.scala     | 28 +++++++++
 8 files changed, 124 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b8a743b6/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index ef5648c..9456031 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -440,7 +440,7 @@ joinCriteria
 
 sample
     : TABLESAMPLE '('
-      ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT)
+      ( (negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT)
       | (expression sampleType=ROWS)
       | sampleType=BYTELENGTH_LITERAL
       | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON (identifier
| qualifiedName '(' ')'))?))

http://git-wip-us.apache.org/repos/asf/spark/blob/b8a743b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 500d999..315c672 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -636,7 +636,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with
Logging
 
       case SqlBaseParser.PERCENTLIT =>
         val fraction = ctx.percentage.getText.toDouble
-        sample(fraction / 100.0d)
+        val sign = if (ctx.negativeSign == null) 1 else -1
+        sample(sign * fraction / 100.0d)
 
       case SqlBaseParser.BYTELENGTH_LITERAL =>
         throw new ParseException(

http://git-wip-us.apache.org/repos/asf/spark/blob/b8a743b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 6878b6b..6e88b7a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
+import org.apache.spark.util.random.RandomSampler
 
 /**
  * When planning take() or collect() operations, this special node that is inserted at the
top of
@@ -817,6 +818,18 @@ case class Sample(
     child: LogicalPlan)(
     val isTableSample: java.lang.Boolean = false) extends UnaryNode {
 
+  val eps = RandomSampler.roundingEpsilon
+  val fraction = upperBound - lowerBound
+  if (withReplacement) {
+    require(
+      fraction >= 0.0 - eps,
+      s"Sampling fraction ($fraction) must be nonnegative with replacement")
+  } else {
+    require(
+      fraction >= 0.0 - eps && fraction <= 1.0 + eps,
+      s"Sampling fraction ($fraction) must be on interval [0, 1] without replacement")
+  }
+
   override def output: Seq[Attribute] = child.output
 
   override def computeStats(conf: SQLConf): Statistics = {

http://git-wip-us.apache.org/repos/asf/spark/blob/b8a743b6/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a2af9c2..767dad3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1806,9 +1806,6 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = {
-    require(fraction >= 0,
-      s"Fraction must be nonnegative, but got ${fraction}")
-
     withTypedPlan {
       Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/b8a743b6/sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql b/sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql
new file mode 100644
index 0000000..72508f5
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql
@@ -0,0 +1,14 @@
+-- Negative testcases for tablesample
+CREATE DATABASE mydb1;
+USE mydb1;
+CREATE TABLE t1 USING parquet AS SELECT 1 AS i1;
+
+-- Negative tests: negative percentage
+SELECT mydb1.t1 FROM t1 TABLESAMPLE (-1 PERCENT);
+
+-- Negative tests:  percentage over 100
+-- The TABLESAMPLE clause samples without replacement, so the value of PERCENT must not exceed
100
+SELECT mydb1.t1 FROM t1 TABLESAMPLE (101 PERCENT);
+
+-- reset
+DROP DATABASE mydb1 CASCADE;

http://git-wip-us.apache.org/repos/asf/spark/blob/b8a743b6/sql/core/src/test/resources/sql-tests/results/operators.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out
index 5cb6ed3..fec423f 100644
--- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 56
+-- Number of queries: 57
 
 
 -- !query 0
@@ -462,9 +462,9 @@ struct<abs(-3.13):decimal(3,2),abs(CAST(-2.19 AS DOUBLE)):double>
 3.13	2.19
 
 
--- !query 55
+-- !query 56
 select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11)
--- !query 55 schema
+-- !query 56 schema
 struct<(+ CAST(-1.11 AS DOUBLE)):double,(+ -1.11):decimal(3,2),(- CAST(-1.11 AS DOUBLE)):double,(-
-1.11):decimal(3,2)>
--- !query 55 output
+-- !query 56 output
 -1.11	-1.11	1.11	1.11

http://git-wip-us.apache.org/repos/asf/spark/blob/b8a743b6/sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out
new file mode 100644
index 0000000..35f3931
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out
@@ -0,0 +1,62 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 6
+
+
+-- !query 0
+CREATE DATABASE mydb1
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+USE mydb1
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+CREATE TABLE t1 USING parquet AS SELECT 1 AS i1
+-- !query 2 schema
+struct<>
+-- !query 2 output
+
+
+
+-- !query 3
+SELECT mydb1.t1 FROM t1 TABLESAMPLE (-1 PERCENT)
+-- !query 3 schema
+struct<>
+-- !query 3 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+Sampling fraction (-0.01) must be on interval [0, 1](line 1, pos 24)
+
+== SQL ==
+SELECT mydb1.t1 FROM t1 TABLESAMPLE (-1 PERCENT)
+------------------------^^^
+
+
+-- !query 4
+SELECT mydb1.t1 FROM t1 TABLESAMPLE (101 PERCENT)
+-- !query 4 schema
+struct<>
+-- !query 4 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+Sampling fraction (1.01) must be on interval [0, 1](line 1, pos 24)
+
+== SQL ==
+SELECT mydb1.t1 FROM t1 TABLESAMPLE (101 PERCENT)
+------------------------^^^
+
+
+-- !query 5
+DROP DATABASE mydb1 CASCADE
+-- !query 5 schema
+struct<>
+-- !query 5 output
+

http://git-wip-us.apache.org/repos/asf/spark/blob/b8a743b6/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 8eb381b..165176f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -457,6 +457,34 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       3, 17, 27, 58, 62)
   }
 
+  test("sample fraction should not be negative with replacement") {
+    val data = sparkContext.parallelize(1 to 2, 1).toDS()
+    val errMsg = intercept[IllegalArgumentException] {
+      data.sample(withReplacement = true, -0.1, 0)
+    }.getMessage
+    assert(errMsg.contains("Sampling fraction (-0.1) must be nonnegative with replacement"))
+
+    // Sampling fraction can be greater than 1 with replacement.
+    checkDataset(
+      data.sample(withReplacement = true, 1.05, seed = 13),
+      1, 2)
+  }
+
+  test("sample fraction should be on interval [0, 1] without replacement") {
+    val data = sparkContext.parallelize(1 to 2, 1).toDS()
+    val errMsg1 = intercept[IllegalArgumentException] {
+      data.sample(withReplacement = false, -0.1, 0)
+    }.getMessage()
+    assert(errMsg1.contains(
+      "Sampling fraction (-0.1) must be on interval [0, 1] without replacement"))
+
+    val errMsg2 = intercept[IllegalArgumentException] {
+      data.sample(withReplacement = false, 1.1, 0)
+    }.getMessage()
+    assert(errMsg2.contains(
+      "Sampling fraction (1.1) must be on interval [0, 1] without replacement"))
+  }
+
   test("SPARK-16686: Dataset.sample with seed results shouldn't depend on downstream usage")
{
     val simpleUdf = udf((n: Int) => {
       require(n != 1, "simpleUdf shouldn't see id=1!")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message