zeppelin-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From zjf...@apache.org
Subject [zeppelin] branch master updated: [ZEPPELIN-4522]. Support multiple sql statements for SparkSqlInterpreter
Date Tue, 31 Dec 2019 01:30:15 GMT
This is an automated email from the ASF dual-hosted git repository.

zjffdu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/zeppelin.git


The following commit(s) were added to refs/heads/master by this push:
     new 737d162  [ZEPPELIN-4522]. Support multiple sql statements for SparkSqlInterpreter
737d162 is described below

commit 737d1626d073351dca3c3cc508d0dbea773c4e43
Author: Jeff Zhang <zjffdu@apache.org>
AuthorDate: Sun Dec 29 23:31:19 2019 +0800

    [ZEPPELIN-4522]. Support multiple sql statements for SparkSqlInterpreter
    
    ### What is this PR for?
    
    Use the SqlSplitter in `zeppelin-interpreter` to split sql and execute in SparkSqlInterpreter.
Nothing changes for the previous single sql statement paragraph. But just multiple result
will be displayed for multiple sql statements.
    
    ### What type of PR is it?
    [Feature]
    
    ### Todos
    * [ ] - Task
    
    ### What is the Jira issue?
    * https://issues.apache.org/jira/browse/ZEPPELIN-4522
    
    ### How should this be tested?
    * CI pass
    
    ### Screenshots (if appropriate)
    
    ### Questions:
    * Does the licenses files need update? No
    * Is there breaking changes for older versions? No
    * Does this needs documentation? No
    
    Author: Jeff Zhang <zjffdu@apache.org>
    
    Closes #3579 from zjffdu/ZEPPELIN-4522 and squashes the following commits:
    
    eda573649 [Jeff Zhang] fix failed test
    68d5a30c8 [Jeff Zhang] Add test for no sql but just 2 comments
    4ff15e4fb [Jeff Zhang] address comment
    bc3c1feff [Jeff Zhang] [ZEPPELIN-4522]. Support multiple sql statements for SparkSqlInterpreter
---
 .../zeppelin/python/IPythonInterpreterTest.java    |  6 ++--
 .../apache/zeppelin/spark/SparkSqlInterpreter.java | 42 ++++++++++++++--------
 .../zeppelin/spark/SparkSqlInterpreterTest.java    | 41 +++++++++++++++++++++
 .../org/apache/zeppelin/spark/Spark1Shims.java     |  2 +-
 .../org/apache/zeppelin/spark/Spark2Shims.java     |  2 +-
 .../zeppelin/interpreter/InterpreterOutput.java    |  7 ++++
 .../interpreter/InterpreterResultTest.java         |  2 +-
 7 files changed, 82 insertions(+), 20 deletions(-)

diff --git a/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java
index b0a8ba6..0f302e9 100644
--- a/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java
+++ b/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java
@@ -298,13 +298,13 @@ public class IPythonInterpreterTest extends BasePythonInterpreterTest
{
         "df.hvplot()", context);
     assertEquals(InterpreterResult.Code.SUCCESS, result.code());
     interpreterResultMessages = context.out.toInterpreterResultMessage();
-    assertEquals(5, interpreterResultMessages.size());
+    assertEquals(4, interpreterResultMessages.size());
+    assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(0).getType());
     assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(1).getType());
     assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(2).getType());
     assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(3).getType());
-    assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(4).getType());
     // docs_json is the source data of plotting which bokeh would use to render the plotting.
-    assertTrue(interpreterResultMessages.get(4).getData().contains("docs_json"));
+    assertTrue(interpreterResultMessages.get(3).getData().contains("docs_json"));
   }
 
 
diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java
b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java
index 4e63760..f6372dd 100644
--- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java
+++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java
@@ -28,6 +28,7 @@ import org.apache.zeppelin.interpreter.InterpreterException;
 import org.apache.zeppelin.interpreter.InterpreterResult;
 import org.apache.zeppelin.interpreter.InterpreterResult.Code;
 import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
+import org.apache.zeppelin.interpreter.util.SqlSplitter;
 import org.apache.zeppelin.scheduler.Scheduler;
 import org.apache.zeppelin.scheduler.SchedulerFactory;
 import org.slf4j.Logger;
@@ -44,6 +45,7 @@ public class SparkSqlInterpreter extends AbstractInterpreter {
   private Logger logger = LoggerFactory.getLogger(SparkSqlInterpreter.class);
 
   private SparkInterpreter sparkInterpreter;
+  private SqlSplitter sqlSplitter;
 
   public SparkSqlInterpreter(Properties property) {
     super(property);
@@ -52,6 +54,7 @@ public class SparkSqlInterpreter extends AbstractInterpreter {
   @Override
   public void open() throws InterpreterException {
     this.sparkInterpreter = getInterpreterInTheSameSessionByClassName(SparkInterpreter.class);
+    this.sqlSplitter = new SqlSplitter();
   }
 
   public boolean concurrentSQL() {
@@ -82,26 +85,37 @@ public class SparkSqlInterpreter extends AbstractInterpreter {
     sparkInterpreter.getZeppelinContext().setInterpreterContext(context);
     SQLContext sqlc = sparkInterpreter.getSQLContext();
     SparkContext sc = sqlc.sparkContext();
+
+    StringBuilder builder = new StringBuilder();
+    List<String> sqls = sqlSplitter.splitSql(st);
+    int maxResult = Integer.parseInt(context.getLocalProperties().getOrDefault("limit",
+            "" + sparkInterpreter.getZeppelinContext().getMaxResult()));
+
     sc.setLocalProperty("spark.scheduler.pool", context.getLocalProperties().get("pool"));
     sc.setJobGroup(Utils.buildJobGroupId(context), Utils.buildJobDesc(context), false);
-
+    String curSql = null;
     try {
-      Method method = sqlc.getClass().getMethod("sql", String.class);
-      int maxResult = Integer.parseInt(context.getLocalProperties().getOrDefault("limit",
-              "" + sparkInterpreter.getZeppelinContext().getMaxResult()));
-      String msg = sparkInterpreter.getZeppelinContext().showData(
-          method.invoke(sqlc, st), maxResult);
-      sc.clearJobGroup();
-      return new InterpreterResult(Code.SUCCESS, msg);
+      for (String sql : sqls) {
+        curSql = sql;
+        String result = sparkInterpreter.getZeppelinContext().showData(sqlc.sql(sql), maxResult);
+        builder.append(result);
+      }
     } catch (Exception e) {
-      if (Boolean.parseBoolean(getProperty("zeppelin.spark.sql.stacktrace"))) {
-        return new InterpreterResult(Code.ERROR, ExceptionUtils.getStackTrace(e));
+      builder.append("\n%text Error happens in sql: " + curSql + "\n");
+      if (Boolean.parseBoolean(getProperty("zeppelin.spark.sql.stacktrace", "false"))) {
+        builder.append(ExceptionUtils.getStackTrace(e));
+      } else {
+        logger.error("Invocation target exception", e);
+        String msg = e.getCause().getMessage()
+                + "\nset zeppelin.spark.sql.stacktrace = true to see full stacktrace";
+        builder.append(msg);
       }
-      logger.error("Invocation target exception", e);
-      String msg = e.getCause().getMessage()
-              + "\nset zeppelin.spark.sql.stacktrace = true to see full stacktrace";
-      return new InterpreterResult(Code.ERROR, msg);
+      return new InterpreterResult(Code.ERROR, builder.toString());
+    } finally {
+      sc.clearJobGroup();
     }
+
+    return new InterpreterResult(Code.SUCCESS, builder.toString());
   }
 
   @Override
diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java
b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java
index cab5b1b..c3f245b 100644
--- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java
+++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java
@@ -171,6 +171,47 @@ public class SparkSqlInterpreterTest {
   }
 
   @Test
+  public void testMultipleStatements() throws InterpreterException {
+    sparkInterpreter.interpret("case class P(age:Int)", context);
+    sparkInterpreter.interpret(
+            "val gr = sc.parallelize(Seq(P(1),P(2),P(3),P(4)))",
+            context);
+    sparkInterpreter.interpret("gr.toDF.registerTempTable(\"gr\")", context);
+
+    // Two correct sql
+    InterpreterResult ret = sqlInterpreter.interpret(
+            "select * --comment_1\nfrom gr;select count(1) from gr", context);
+    assertEquals(InterpreterResult.Code.SUCCESS, ret.code());
+    assertEquals(ret.message().toString(), 2, ret.message().size());
+    assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(0).getType());
+    assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(1).getType());
+
+    // One correct sql + One invalid sql
+    ret = sqlInterpreter.interpret("select * from gr;invalid_sql", context);
+    assertEquals(InterpreterResult.Code.ERROR, ret.code());
+    assertEquals(ret.message().toString(), 2, ret.message().size());
+    assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(0).getType());
+    if (sparkInterpreter.getSparkVersion().isSpark2()) {
+      assertTrue(ret.message().toString(), ret.message().get(1).getData().contains("ParseException"));
+    }
+    
+    // One correct sql + One invalid sql + One valid sql (skipped)
+    ret = sqlInterpreter.interpret("select * from gr;invalid_sql; select count(1) from gr",
context);
+    assertEquals(InterpreterResult.Code.ERROR, ret.code());
+    assertEquals(ret.message().toString(), 2, ret.message().size());
+    assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(0).getType());
+    if (sparkInterpreter.getSparkVersion().isSpark2()) {
+      assertTrue(ret.message().toString(), ret.message().get(1).getData().contains("ParseException"));
+    }
+
+    // Two 2 comments
+    ret = sqlInterpreter.interpret(
+            "--comment_1\n--comment_2", context);
+    assertEquals(InterpreterResult.Code.SUCCESS, ret.code());
+    assertEquals(ret.message().toString(), 0, ret.message().size());
+  }
+
+  @Test
   public void testConcurrentSQL() throws InterpreterException, InterruptedException {
     if (sparkInterpreter.getSparkVersion().isSpark2()) {
       sparkInterpreter.interpret("spark.udf.register(\"sleep\", (e:Int) => {Thread.sleep(e*1000);
e})", context);
diff --git a/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java
b/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java
index 8e60ed0..6119647 100644
--- a/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java
+++ b/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java
@@ -70,7 +70,7 @@ public class Spark1Shims extends SparkShims {
       // fetch maxResult+1 rows so that we can check whether it is larger than zeppelin.spark.maxResult
       List<Row> rows = df.takeAsList(maxResult + 1);
       StringBuilder msg = new StringBuilder();
-      msg.append("%table ");
+      msg.append("\n%table ");
       msg.append(StringUtils.join(columns, "\t"));
       msg.append("\n");
       boolean isLargerThanMaxResult = rows.size() > maxResult;
diff --git a/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java
b/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java
index a7304c5..b7b1cf9 100644
--- a/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java
+++ b/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java
@@ -71,7 +71,7 @@ public class Spark2Shims extends SparkShims {
       // fetch maxResult+1 rows so that we can check whether it is larger than zeppelin.spark.maxResult
       List<Row> rows = df.takeAsList(maxResult + 1);
       StringBuilder msg = new StringBuilder();
-      msg.append("%table ");
+      msg.append("\n%table ");
       msg.append(StringUtils.join(columns, "\t"));
       msg.append("\n");
       boolean isLargerThanMaxResult = rows.size() > maxResult;
diff --git a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterOutput.java
b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterOutput.java
index 8853227..f85e535 100644
--- a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterOutput.java
+++ b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterOutput.java
@@ -17,6 +17,7 @@
 package org.apache.zeppelin.interpreter;
 
 
+import org.apache.commons.lang.StringUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -328,6 +329,12 @@ public class InterpreterOutput extends OutputStream {
     List<InterpreterResultMessage> list = new LinkedList<>();
     synchronized (resultMessageOutputs) {
       for (InterpreterResultMessageOutput out : resultMessageOutputs) {
+        if (out.toInterpreterResultMessage().getType() == InterpreterResult.Type.TEXT &&
+                StringUtils.isBlank(out.toInterpreterResultMessage().getData())) {
+          // skip blank text, because when print table data we usually need to print '%text
\n'
+          // first to separate it from previous other kind of data. e.g. z.show(df)
+          continue;
+        }
         list.add(out.toInterpreterResultMessage());
       }
     }
diff --git a/zeppelin-interpreter/src/test/java/org/apache/zeppelin/interpreter/InterpreterResultTest.java
b/zeppelin-interpreter/src/test/java/org/apache/zeppelin/interpreter/InterpreterResultTest.java
index a8ff1bf..84805ac 100644
--- a/zeppelin-interpreter/src/test/java/org/apache/zeppelin/interpreter/InterpreterResultTest.java
+++ b/zeppelin-interpreter/src/test/java/org/apache/zeppelin/interpreter/InterpreterResultTest.java
@@ -33,7 +33,7 @@ public class InterpreterResultTest {
     result = new InterpreterResult(InterpreterResult.Code.SUCCESS, "%this is a TEXT type");
     assertEquals("No magic", InterpreterResult.Type.TEXT, result.message().get(0).getType());
     result = new InterpreterResult(InterpreterResult.Code.SUCCESS, "%\n");
-    assertEquals("No magic", InterpreterResult.Type.TEXT, result.message().get(0).getType());
+    assertEquals(0, result.message().size());
   }
 
   @Test


Mime
View raw message