zeppelin-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From zjf...@apache.org
Subject zeppelin git commit: ZEPPELIN-3374. Improvement on PySparkInterpreter
Date Sun, 01 Apr 2018 06:31:23 GMT
Repository: zeppelin
Updated Branches:
  refs/heads/master 8238b711c -> aefc7ea39


ZEPPELIN-3374. Improvement on PySparkInterpreter

### What is this PR for?
A few improvements on PySparkInterpreter.
1. Refactor PySparkInterpreter to make it more readable
2. Code completion features is totally broken, fix it in this PR
3. Reuse the same test case of IPySparkInterpreter.

### What type of PR is it?
[Bug Fix | Improvement | Refactoring]

### Todos
* [ ] - Task

### What is the Jira issue?
* https://issues.apache.org/jira/browse/ZEPPELIN-3374

### How should this be tested?
* CI pass

### Screenshots (if appropriate)

Code completion before

![completion_before](https://user-images.githubusercontent.com/164491/38160504-ee1ea3a8-34f1-11e8-9aab-baf98962aae3.gif)

Code completion after
![completion_after](https://user-images.githubusercontent.com/164491/38160505-eff108b0-34f1-11e8-88eb-03c51cfb96de.gif)

### Questions:
* Does the licenses files need update? No
* Is there breaking changes for older versions? No
* Does this needs documentation? No

Author: Jeff Zhang <zjffdu@apache.org>

Closes #2901 from zjffdu/ZEPPELIN-3374 and squashes the following commits:

c22078d [Jeff Zhang] ZEPPELIN-3374. Improvement on PySparkInterpreter


Project: http://git-wip-us.apache.org/repos/asf/zeppelin/repo
Commit: http://git-wip-us.apache.org/repos/asf/zeppelin/commit/aefc7ea3
Tree: http://git-wip-us.apache.org/repos/asf/zeppelin/tree/aefc7ea3
Diff: http://git-wip-us.apache.org/repos/asf/zeppelin/diff/aefc7ea3

Branch: refs/heads/master
Commit: aefc7ea395fed5a93e69942725a19b203a3574e2
Parents: 8238b71
Author: Jeff Zhang <zjffdu@apache.org>
Authored: Fri Mar 30 11:12:08 2018 +0800
Committer: Jeff Zhang <zjffdu@apache.org>
Committed: Sun Apr 1 14:31:12 2018 +0800

----------------------------------------------------------------------
 .../zeppelin/spark/PySparkInterpreter.java      | 290 +++++++++----------
 .../main/resources/python/zeppelin_pyspark.py   |  86 +++---
 .../zeppelin/spark/IPySparkInterpreterTest.java | 127 ++++----
 .../zeppelin/spark/PySparkInterpreterTest.java  |  39 +--
 4 files changed, 259 insertions(+), 283 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/zeppelin/blob/aefc7ea3/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java
b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java
index d97bb51..809e883 100644
--- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java
+++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java
@@ -41,6 +41,7 @@ import org.apache.zeppelin.interpreter.InterpreterResultMessage;
 import org.apache.zeppelin.interpreter.InvalidHookException;
 import org.apache.zeppelin.interpreter.LazyOpenInterpreter;
 import org.apache.zeppelin.interpreter.WrappedInterpreter;
+import org.apache.zeppelin.interpreter.remote.RemoteInterpreterUtils;
 import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
 import org.apache.zeppelin.interpreter.util.InterpreterOutputStream;
 import org.apache.zeppelin.spark.dep.SparkDependencyContext;
@@ -53,11 +54,8 @@ import java.io.ByteArrayOutputStream;
 import java.io.File;
 import java.io.FileOutputStream;
 import java.io.IOException;
-import java.io.OutputStreamWriter;
 import java.io.PipedInputStream;
-import java.io.PipedOutputStream;
 import java.net.MalformedURLException;
-import java.net.ServerSocket;
 import java.net.URL;
 import java.net.URLClassLoader;
 import java.util.LinkedList;
@@ -66,65 +64,26 @@ import java.util.Map;
 import java.util.Properties;
 
 /**
- *
+ *  Interpreter for PySpark, it is the first implementation of interpreter for PySpark, so
with less
+ *  features compared to IPySparkInterpreter, but requires less prerequisites than
+ *  IPySparkInterpreter, only python is required.
  */
 public class PySparkInterpreter extends Interpreter implements ExecuteResultHandler {
   private static final Logger LOGGER = LoggerFactory.getLogger(PySparkInterpreter.class);
+  private static final int MAX_TIMEOUT_SEC = 10;
+
   private GatewayServer gatewayServer;
   private DefaultExecutor executor;
-  private int port;
+  // used to forward output from python process to InterpreterOutput
   private InterpreterOutputStream outputStream;
-  private BufferedWriter ins;
-  private PipedInputStream in;
-  private ByteArrayOutputStream input;
   private String scriptPath;
-  boolean pythonscriptRunning = false;
-  private static final int MAX_TIMEOUT_SEC = 10;
-  private long pythonPid;
-
+  private boolean pythonscriptRunning = false;
+  private long pythonPid = -1;
   private IPySparkInterpreter iPySparkInterpreter;
+  private SparkInterpreter sparkInterpreter;
 
   public PySparkInterpreter(Properties property) {
     super(property);
-
-    pythonPid = -1;
-    try {
-      File scriptFile = File.createTempFile("zeppelin_pyspark-", ".py");
-      scriptPath = scriptFile.getAbsolutePath();
-    } catch (IOException e) {
-      throw new RuntimeException(e);
-    }
-  }
-
-  private void createPythonScript() throws InterpreterException {
-    ClassLoader classLoader = getClass().getClassLoader();
-    File out = new File(scriptPath);
-
-    if (out.exists() && out.isDirectory()) {
-      throw new InterpreterException("Can't create python script " + out.getAbsolutePath());
-    }
-
-    try {
-      FileOutputStream outStream = new FileOutputStream(out);
-      IOUtils.copy(
-          classLoader.getResourceAsStream("python/zeppelin_pyspark.py"),
-          outStream);
-      outStream.close();
-    } catch (IOException e) {
-      throw new InterpreterException(e);
-    }
-
-    try {
-      FileOutputStream outStream = new FileOutputStream(out.getParent() + "/zeppelin_context.py");
-      IOUtils.copy(
-          classLoader.getResourceAsStream("python/zeppelin_context.py"),
-          outStream);
-      outStream.close();
-    } catch (IOException e) {
-      throw new InterpreterException(e);
-    }
-
-    LOGGER.info("File {} created", scriptPath);
   }
 
   @Override
@@ -151,6 +110,7 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
     InterpreterGroup intpGroup = getInterpreterGroup();
     if (intpGroup != null && intpGroup.getInterpreterHookRegistry() != null) {
       try {
+        // just for unit test I believe (zjffdu)
         registerHook(HookType.POST_EXEC_DEV.getName(), "__zeppelin__._displayhook()");
       } catch (InvalidHookException e) {
         throw new InterpreterException(e);
@@ -199,48 +159,118 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
     ClassLoader oldCl = Thread.currentThread().getContextClassLoader();
     try {
       URLClassLoader newCl = new URLClassLoader(urls, oldCl);
+      LOGGER.info("urls:" + urls);
+      for (URL url : urls) {
+        LOGGER.info("url:" + url);
+      }
       Thread.currentThread().setContextClassLoader(newCl);
+      // must create spark interpreter after ClassLoader is set, otherwise the additional
jars
+      // can not be loaded by spark repl.
+      this.sparkInterpreter = getSparkInterpreter();
       createGatewayServerAndStartScript();
-    } catch (Exception e) {
-      LOGGER.error("Error", e);
-      throw new InterpreterException(e);
+    } catch (IOException e) {
+      LOGGER.error("Fail to open PySparkInterpreter", e);
+      throw new InterpreterException("Fail to open PySparkInterpreter", e);
     } finally {
       Thread.currentThread().setContextClassLoader(oldCl);
     }
   }
 
-  private Map setupPySparkEnv() throws IOException, InterpreterException {
-    Map env = EnvironmentUtils.getProcEnvironment();
+  private void createGatewayServerAndStartScript() throws IOException {
+    // start gateway server in JVM side
+    int port = RemoteInterpreterUtils.findRandomAvailablePortOnAllLocalInterfaces();
+    gatewayServer = new GatewayServer(this, port);
+    gatewayServer.start();
+
+    // launch python process to connect to the gateway server in JVM side
+    createPythonScript();
+    String pythonExec = getPythonExec(getProperties());
+    LOGGER.info("PythonExec: " + pythonExec);
+    CommandLine cmd = CommandLine.parse(pythonExec);
+    cmd.addArgument(scriptPath, false);
+    cmd.addArgument(Integer.toString(port), false);
+    cmd.addArgument(Integer.toString(sparkInterpreter.getSparkVersion().toNumber()), false);
+    executor = new DefaultExecutor();
+    outputStream = new InterpreterOutputStream(LOGGER);
+    PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream);
+    executor.setStreamHandler(streamHandler);
+    executor.setWatchdog(new ExecuteWatchdog(ExecuteWatchdog.INFINITE_TIMEOUT));
+
+    Map<String, String> env = setupPySparkEnv();
+    executor.execute(cmd, env, this);
+    pythonscriptRunning = true;
+  }
+
+  private void createPythonScript() throws IOException {
+    FileOutputStream pysparkScriptOutput = null;
+    FileOutputStream zeppelinContextOutput = null;
+    try {
+      // copy zeppelin_pyspark.py
+      File scriptFile = File.createTempFile("zeppelin_pyspark-", ".py");
+      this.scriptPath = scriptFile.getAbsolutePath();
+      pysparkScriptOutput = new FileOutputStream(scriptFile);
+      IOUtils.copy(
+          getClass().getClassLoader().getResourceAsStream("python/zeppelin_pyspark.py"),
+          pysparkScriptOutput);
+
+      // copy zeppelin_context.py to the same folder of zeppelin_pyspark.py
+      zeppelinContextOutput = new FileOutputStream(scriptFile.getParent() + "/zeppelin_context.py");
+      IOUtils.copy(
+          getClass().getClassLoader().getResourceAsStream("python/zeppelin_context.py"),
+          zeppelinContextOutput);
+      LOGGER.info("PySpark script {} {} is created",
+          scriptPath, scriptFile.getParent() + "/zeppelin_context.py");
+    } finally {
+      if (pysparkScriptOutput != null) {
+        try {
+          pysparkScriptOutput.close();
+        } catch (IOException e) {
+          // ignore
+        }
+      }
+      if (zeppelinContextOutput != null) {
+        try {
+          zeppelinContextOutput.close();
+        } catch (IOException e) {
+          // ignore
+        }
+      }
+    }
+  }
 
+  private Map<String, String> setupPySparkEnv() throws IOException {
+    Map<String, String> env = EnvironmentUtils.getProcEnvironment();
     // only set PYTHONPATH in local or yarn-client mode.
     // yarn-cluster will setup PYTHONPATH automatically.
-    SparkConf conf = getSparkConf();
+    SparkConf conf = null;
+    try {
+      conf = getSparkConf();
+    } catch (InterpreterException e) {
+      throw new IOException(e);
+    }
     if (!conf.get("spark.submit.deployMode", "client").equals("cluster")) {
       if (!env.containsKey("PYTHONPATH")) {
         env.put("PYTHONPATH", PythonUtils.sparkPythonPath());
       } else {
-        env.put("PYTHONPATH", PythonUtils.sparkPythonPath());
+        env.put("PYTHONPATH", PythonUtils.sparkPythonPath() + ":" + env.get("PYTHONPATH"));
       }
     }
 
     // get additional class paths when using SPARK_SUBMIT and not using YARN-CLIENT
     // also, add all packages to PYTHONPATH since there might be transitive dependencies
     if (SparkInterpreter.useSparkSubmit() &&
-        !getSparkInterpreter().isYarnMode()) {
-
-      String sparkSubmitJars = getSparkConf().get("spark.jars").replace(",", ":");
-
-      if (!"".equals(sparkSubmitJars)) {
-        env.put("PYTHONPATH", env.get("PYTHONPATH") + sparkSubmitJars);
+        !sparkInterpreter.isYarnMode()) {
+      String sparkSubmitJars = conf.get("spark.jars").replace(",", ":");
+      if (!StringUtils.isEmpty(sparkSubmitJars)) {
+        env.put("PYTHONPATH", env.get("PYTHONPATH") + ":" + sparkSubmitJars);
       }
     }
 
-    LOGGER.info("PYTHONPATH: " + env.get("PYTHONPATH"));
-
     // set PYSPARK_PYTHON
-    if (getSparkConf().contains("spark.pyspark.python")) {
-      env.put("PYSPARK_PYTHON", getSparkConf().get("spark.pyspark.python"));
+    if (conf.contains("spark.pyspark.python")) {
+      env.put("PYSPARK_PYTHON", conf.get("spark.pyspark.python"));
     }
+    LOGGER.info("PYTHONPATH: " + env.get("PYTHONPATH"));
     return env;
   }
 
@@ -258,66 +288,6 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
     return pythonExec;
   }
 
-  private void createGatewayServerAndStartScript() throws InterpreterException {
-    // create python script
-    createPythonScript();
-
-    port = findRandomOpenPortOnAllLocalInterfaces();
-
-    gatewayServer = new GatewayServer(this, port);
-    gatewayServer.start();
-
-    String pythonExec = getPythonExec(getProperties());
-    LOGGER.info("pythonExec: " + pythonExec);
-    CommandLine cmd = CommandLine.parse(pythonExec);
-    cmd.addArgument(scriptPath, false);
-    cmd.addArgument(Integer.toString(port), false);
-    cmd.addArgument(Integer.toString(getSparkInterpreter().getSparkVersion().toNumber()),
false);
-    executor = new DefaultExecutor();
-    outputStream = new InterpreterOutputStream(LOGGER);
-    PipedOutputStream ps = new PipedOutputStream();
-    in = null;
-    try {
-      in = new PipedInputStream(ps);
-    } catch (IOException e1) {
-      throw new InterpreterException(e1);
-    }
-    ins = new BufferedWriter(new OutputStreamWriter(ps));
-
-    input = new ByteArrayOutputStream();
-
-    PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream, outputStream, in);
-    executor.setStreamHandler(streamHandler);
-    executor.setWatchdog(new ExecuteWatchdog(ExecuteWatchdog.INFINITE_TIMEOUT));
-
-    try {
-      Map env = setupPySparkEnv();
-      executor.execute(cmd, env, this);
-      pythonscriptRunning = true;
-    } catch (IOException e) {
-      throw new InterpreterException(e);
-    }
-
-
-    try {
-      input.write("import sys, getopt\n".getBytes());
-      ins.flush();
-    } catch (IOException e) {
-      throw new InterpreterException(e);
-    }
-  }
-
-  private int findRandomOpenPortOnAllLocalInterfaces() throws InterpreterException {
-    int port;
-    try (ServerSocket socket = new ServerSocket(0);) {
-      port = socket.getLocalPort();
-      socket.close();
-    } catch (IOException e) {
-      throw new InterpreterException(e);
-    }
-    return port;
-  }
-
   @Override
   public void close() throws InterpreterException {
     if (iPySparkInterpreter != null) {
@@ -325,25 +295,30 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
       return;
     }
     executor.getWatchdog().destroyProcess();
-    new File(scriptPath).delete();
     gatewayServer.shutdown();
   }
 
-  PythonInterpretRequest pythonInterpretRequest = null;
+  private PythonInterpretRequest pythonInterpretRequest = null;
+  private Integer statementSetNotifier = new Integer(0);
+  private String statementOutput = null;
+  private boolean statementError = false;
+  private Integer statementFinishedNotifier = new Integer(0);
 
   /**
-   *
+   * Request send to Python Daemon
    */
   public class PythonInterpretRequest {
     public String statements;
     public String jobGroup;
     public String jobDescription;
+    public boolean isForCompletion;
 
     public PythonInterpretRequest(String statements, String jobGroup,
-        String jobDescription) {
+        String jobDescription, boolean isForCompletion) {
       this.statements = statements;
       this.jobGroup = jobGroup;
       this.jobDescription = jobDescription;
+      this.isForCompletion = isForCompletion;
     }
 
     public String statements() {
@@ -357,10 +332,13 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
     public String jobDescription() {
       return jobDescription;
     }
-  }
 
-  Integer statementSetNotifier = new Integer(0);
+    public boolean isForCompletion() {
+      return isForCompletion;
+    }
+  }
 
+  // called by Python Process
   public PythonInterpretRequest getStatements() {
     synchronized (statementSetNotifier) {
       while (pythonInterpretRequest == null) {
@@ -375,10 +353,7 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
     }
   }
 
-  String statementOutput = null;
-  boolean statementError = false;
-  Integer statementFinishedNotifier = new Integer(0);
-
+  // called by Python Process
   public void setStatementsFinished(String out, boolean error) {
     synchronized (statementFinishedNotifier) {
       LOGGER.debug("Setting python statement output: " + out + ", error: " + error);
@@ -388,9 +363,10 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
     }
   }
 
-  boolean pythonScriptInitialized = false;
-  Integer pythonScriptInitializeNotifier = new Integer(0);
+  private boolean pythonScriptInitialized = false;
+  private Integer pythonScriptInitializeNotifier = new Integer(0);
 
+  // called by Python Process
   public void onPythonScriptInitialized(long pid) {
     pythonPid = pid;
     synchronized (pythonScriptInitializeNotifier) {
@@ -400,6 +376,7 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
     }
   }
 
+  // called by Python Process
   public void appendOutput(String message) throws IOException {
     LOGGER.debug("Output from python process: " + message);
     outputStream.getInterpreterOutput().write(message);
@@ -412,7 +389,6 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
       return iPySparkInterpreter.interpret(st, context);
     }
 
-    SparkInterpreter sparkInterpreter = getSparkInterpreter();
     if (sparkInterpreter.isUnsupportedSparkVersion()) {
       return new InterpreterResult(Code.ERROR, "Spark "
           + sparkInterpreter.getSparkVersion().toString() + " is not supported");
@@ -420,7 +396,7 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
     sparkInterpreter.populateSparkWebUrl(context);
 
     if (!pythonscriptRunning) {
-      return new InterpreterResult(Code.ERROR, "python process not running"
+      return new InterpreterResult(Code.ERROR, "python process not running "
           + outputStream.toString());
     }
 
@@ -432,6 +408,7 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
           && pythonscriptRunning
           && System.currentTimeMillis() - startTime < MAX_TIMEOUT_SEC * 1000)
{
         try {
+          LOGGER.info("Wait for PythonScript running");
           pythonScriptInitializeNotifier.wait(1000);
         } catch (InterruptedException e) {
           e.printStackTrace();
@@ -451,32 +428,34 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
     if (pythonscriptRunning == false) {
       // python script failed to initialize and terminated
       errorMessage.add(new InterpreterResultMessage(
-          InterpreterResult.Type.TEXT, "failed to start pyspark"));
+          InterpreterResult.Type.TEXT, "Failed to start PySpark"));
       return new InterpreterResult(Code.ERROR, errorMessage);
     }
     if (pythonScriptInitialized == false) {
       // timeout. didn't get initialized message
       errorMessage.add(new InterpreterResultMessage(
-          InterpreterResult.Type.TEXT, "pyspark is not responding"));
+          InterpreterResult.Type.TEXT, "Failed to initialize PySpark"));
       return new InterpreterResult(Code.ERROR, errorMessage);
     }
 
+    //TODO(zjffdu) remove this as PySpark is supported starting from spark 1.2s
     if (!sparkInterpreter.getSparkVersion().isPysparkSupported()) {
       errorMessage.add(new InterpreterResultMessage(
           InterpreterResult.Type.TEXT,
           "pyspark " + sparkInterpreter.getSparkContext().version() + " is not supported"));
       return new InterpreterResult(Code.ERROR, errorMessage);
     }
+
     String jobGroup = Utils.buildJobGroupId(context);
     String jobDesc = "Started by: " + Utils.getUserName(context.getAuthenticationInfo());
 
-    SparkZeppelinContext __zeppelin__ = sparkInterpreter.getZeppelinContext();
-    __zeppelin__.setInterpreterContext(context);
-    __zeppelin__.setGui(context.getGui());
-    __zeppelin__.setNoteGui(context.getNoteGui());
+    SparkZeppelinContext z = sparkInterpreter.getZeppelinContext();
+    z.setInterpreterContext(context);
+    z.setGui(context.getGui());
+    z.setNoteGui(context.getNoteGui());
     InterpreterContext.set(context);
 
-    pythonInterpretRequest = new PythonInterpretRequest(st, jobGroup, jobDesc);
+    pythonInterpretRequest = new PythonInterpretRequest(st, jobGroup, jobDesc, false);
     statementOutput = null;
 
     synchronized (statementSetNotifier) {
@@ -495,13 +474,11 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
     if (statementError) {
       return new InterpreterResult(Code.ERROR, statementOutput);
     } else {
-
       try {
         context.out.flush();
       } catch (IOException e) {
         throw new InterpreterException(e);
       }
-
       return new InterpreterResult(Code.SUCCESS);
     }
   }
@@ -558,14 +535,14 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
     }
     String completionString = getCompletionTargetString(buf, cursor);
     String completionCommand = "completion.getCompletion('" + completionString + "')";
+    LOGGER.debug("completionCommand: " + completionCommand);
 
     //start code for completion
-    SparkInterpreter sparkInterpreter = getSparkInterpreter();
     if (sparkInterpreter.isUnsupportedSparkVersion() || pythonscriptRunning == false) {
       return new LinkedList<>();
     }
 
-    pythonInterpretRequest = new PythonInterpretRequest(completionCommand, "", "");
+    pythonInterpretRequest = new PythonInterpretRequest(completionCommand, "", "", true);
     statementOutput = null;
 
     synchronized (statementSetNotifier) {
@@ -596,7 +573,6 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
       completionList = gson.fromJson(statementOutput, String[].class);
     }
     //end code for completion
-
     if (completionList == null) {
       return new LinkedList<>();
     }
@@ -604,6 +580,7 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
     List<InterpreterCompletion> results = new LinkedList<>();
     for (String name: completionList) {
       results.add(new InterpreterCompletion(name, name, StringUtils.EMPTY));
+      LOGGER.debug("completion: " + name);
     }
     return results;
   }
@@ -753,4 +730,9 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
     pythonscriptRunning = false;
     LOGGER.error("python process failed", e);
   }
+
+  // Called by Python Process, used for debugging purpose
+  public void logPythonOutput(String message) {
+    LOGGER.debug("Python Process Output: " + message);
+  }
 }

http://git-wip-us.apache.org/repos/asf/zeppelin/blob/aefc7ea3/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py b/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py
index 614c516..1352318 100644
--- a/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py
+++ b/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py
@@ -62,46 +62,33 @@ class PySparkCompletion:
   def __init__(self, interpreterObject):
     self.interpreterObject = interpreterObject
 
-  def getGlobalCompletion(self):
-    objectDefList = []
+  def getGlobalCompletion(self, text_value):
+    completions = [completion for completion in list(globals().keys()) if completion.startswith(text_value)]
+    return completions
+
+  def getMethodCompletion(self, objName, methodName):
+    execResult = locals()
     try:
-      for completionItem in list(globals().keys()):
-        objectDefList.append(completionItem)
+      exec("{} = dir({})".format("objectDefList", objName), globals(), execResult)
     except:
       return None
     else:
-      return objectDefList
+      objectDefList = execResult['objectDefList']
+      return [completion for completion in execResult['objectDefList'] if completion.startswith(methodName)]
 
-  def getMethodCompletion(self, text_value):
-    execResult = locals()
+  def getCompletion(self, text_value):
     if text_value == None:
       return None
-    completion_target = text_value
-    try:
-      if len(completion_target) <= 0:
-        return None
-      if text_value[-1] == ".":
-        completion_target = text_value[:-1]
-      exec("{} = dir({})".format("objectDefList", completion_target), globals(), execResult)
-    except:
-      return None
-    else:
-      return list(execResult['objectDefList'])
 
+    dotPos = text_value.find(".")
+    if dotPos == -1:
+      objName = text_value
+      completionList = self.getGlobalCompletion(objName)
+    else:
+      objName = text_value[:dotPos]
+      methodName = text_value[dotPos + 1:]
+      completionList = self.getMethodCompletion(objName, methodName)
 
-  def getCompletion(self, text_value):
-    completionList = set()
-
-    globalCompletionList = self.getGlobalCompletion()
-    if globalCompletionList != None:
-      for completionItem in list(globalCompletionList):
-        completionList.add(completionItem)
-
-    if text_value != None:
-      objectCompletionList = self.getMethodCompletion(text_value)
-      if objectCompletionList != None:
-        for completionItem in list(objectCompletionList):
-          completionList.add(completionItem)
     if len(completionList) <= 0:
       self.interpreterObject.setStatementsFinished("", False)
     else:
@@ -130,7 +117,6 @@ intp = gateway.entry_point
 output = Logger()
 sys.stdout = output
 sys.stderr = output
-intp.onPythonScriptInitialized(os.getpid())
 
 jsc = intp.getJavaSparkContext()
 
@@ -195,13 +181,16 @@ __zeppelin__._setup_matplotlib()
 _zcUserQueryNameSpace["z"] = z
 _zcUserQueryNameSpace["__zeppelin__"] = __zeppelin__
 
+intp.onPythonScriptInitialized(os.getpid())
+
 while True :
   req = intp.getStatements()
   try:
     stmts = req.statements().split("\n")
     jobGroup = req.jobGroup()
     jobDesc = req.jobDescription()
-    
+    isForCompletion = req.isForCompletion()
+
     # Get post-execute hooks
     try:
       global_hook = intp.getHook('post_exec_dev')
@@ -214,9 +203,10 @@ while True :
       user_hook = None
       
     nhooks = 0
-    for hook in (global_hook, user_hook):
-      if hook:
-        nhooks += 1
+    if not isForCompletion:
+      for hook in (global_hook, user_hook):
+        if hook:
+          nhooks += 1
 
     if stmts:
       # use exec mode to compile the statements except the last statement,
@@ -226,9 +216,9 @@ while True :
       to_run_hooks = []
       if (nhooks > 0):
         to_run_hooks = code.body[-nhooks:]
+
       to_run_exec, to_run_single = (code.body[:-(nhooks + 1)],
                                     [code.body[-(nhooks + 1)]])
-
       try:
         for node in to_run_exec:
           mod = ast.Module([node])
@@ -245,19 +235,23 @@ while True :
           code = compile(mod, '<stdin>', 'exec')
           exec(code, _zcUserQueryNameSpace)
 
-        intp.setStatementsFinished("", False)
+        if not isForCompletion:
+          # only call it when it is not for code completion. code completion will call it
in
+          # PySparkCompletion.getCompletion
+          intp.setStatementsFinished("", False)
       except Py4JJavaError:
         # raise it to outside try except
         raise
       except:
-        exception = traceback.format_exc()
-        m = re.search("File \"<stdin>\", line (\d+).*", exception)
-        if m:
-          line_no = int(m.group(1))
-          intp.setStatementsFinished(
-            "Fail to execute line {}: {}\n".format(line_no, stmts[line_no - 1]) + exception,
True)
-        else:
-          intp.setStatementsFinished(exception, True)
+        if not isForCompletion:
+          exception = traceback.format_exc()
+          m = re.search("File \"<stdin>\", line (\d+).*", exception)
+          if m:
+            line_no = int(m.group(1))
+            intp.setStatementsFinished(
+              "Fail to execute line {}: {}\n".format(line_no, stmts[line_no - 1]) + exception,
True)
+          else:
+            intp.setStatementsFinished(exception, True)
     else:
       intp.setStatementsFinished("", False)
 

http://git-wip-us.apache.org/repos/asf/zeppelin/blob/aefc7ea3/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java
b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java
index 8d08117..2cc11ac 100644
--- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java
+++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java
@@ -97,101 +97,124 @@ public class IPySparkInterpreterTest {
   public void testBasics() throws InterruptedException, IOException, InterpreterException
{
     // all the ipython test should pass too.
     IPythonInterpreterTest.testInterpreter(iPySparkInterpreter);
+    testPySpark(iPySparkInterpreter, mockRemoteEventClient);
 
+  }
+
+  public static void testPySpark(final Interpreter interpreter, RemoteEventClient mockRemoteEventClient)
+      throws InterpreterException, IOException, InterruptedException {
     // rdd
-    InterpreterContext context = getInterpreterContext();
-    InterpreterResult result = iPySparkInterpreter.interpret("sc.version", context);
+    InterpreterContext context = getInterpreterContext(mockRemoteEventClient);
+    InterpreterResult result = interpreter.interpret("sc.version", context);
     Thread.sleep(100);
     assertEquals(InterpreterResult.Code.SUCCESS, result.code());
     String sparkVersion = context.out.toInterpreterResultMessage().get(0).getData();
     // spark url is sent
     verify(mockRemoteEventClient).onMetaInfosReceived(any(Map.class));
 
-    context = getInterpreterContext();
-    result = iPySparkInterpreter.interpret("sc.range(1,10).sum()", context);
+    context = getInterpreterContext(mockRemoteEventClient);
+    result = interpreter.interpret("sc.range(1,10).sum()", context);
     Thread.sleep(100);
     assertEquals(InterpreterResult.Code.SUCCESS, result.code());
     List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage();
-    assertEquals("45", interpreterResultMessages.get(0).getData());
+    assertEquals("45", interpreterResultMessages.get(0).getData().trim());
     // spark job url is sent
     verify(mockRemoteEventClient).onParaInfosReceived(any(String.class), any(String.class),
any(Map.class));
 
     // spark sql
-    context = getInterpreterContext();
+    context = getInterpreterContext(mockRemoteEventClient);
     if (!isSpark2(sparkVersion)) {
-      result = iPySparkInterpreter.interpret("df = sqlContext.createDataFrame([(1,'a'),(2,'b')])\ndf.show()",
context);
+      result = interpreter.interpret("df = sqlContext.createDataFrame([(1,'a'),(2,'b')])\ndf.show()",
context);
       assertEquals(InterpreterResult.Code.SUCCESS, result.code());
       interpreterResultMessages = context.out.toInterpreterResultMessage();
       assertEquals(
           "+---+---+\n" +
-          "| _1| _2|\n" +
-          "+---+---+\n" +
-          "|  1|  a|\n" +
-          "|  2|  b|\n" +
-          "+---+---+\n\n", interpreterResultMessages.get(0).getData());
+              "| _1| _2|\n" +
+              "+---+---+\n" +
+              "|  1|  a|\n" +
+              "|  2|  b|\n" +
+              "+---+---+", interpreterResultMessages.get(0).getData().trim());
 
-      context = getInterpreterContext();
-      result = iPySparkInterpreter.interpret("z.show(df)", context);
+      context = getInterpreterContext(mockRemoteEventClient);
+      result = interpreter.interpret("z.show(df)", context);
       assertEquals(InterpreterResult.Code.SUCCESS, result.code());
       interpreterResultMessages = context.out.toInterpreterResultMessage();
       assertEquals(
           "_1	_2\n" +
-          "1	a\n" +
-          "2	b\n", interpreterResultMessages.get(0).getData());
+              "1	a\n" +
+              "2	b", interpreterResultMessages.get(0).getData().trim());
     } else {
-      result = iPySparkInterpreter.interpret("df = spark.createDataFrame([(1,'a'),(2,'b')])\ndf.show()",
context);
+      result = interpreter.interpret("df = spark.createDataFrame([(1,'a'),(2,'b')])\ndf.show()",
context);
       assertEquals(InterpreterResult.Code.SUCCESS, result.code());
       interpreterResultMessages = context.out.toInterpreterResultMessage();
       assertEquals(
           "+---+---+\n" +
-          "| _1| _2|\n" +
-          "+---+---+\n" +
-          "|  1|  a|\n" +
-          "|  2|  b|\n" +
-          "+---+---+\n\n", interpreterResultMessages.get(0).getData());
+              "| _1| _2|\n" +
+              "+---+---+\n" +
+              "|  1|  a|\n" +
+              "|  2|  b|\n" +
+              "+---+---+", interpreterResultMessages.get(0).getData().trim());
 
-      context = getInterpreterContext();
-      result = iPySparkInterpreter.interpret("z.show(df)", context);
+      context = getInterpreterContext(mockRemoteEventClient);
+      result = interpreter.interpret("z.show(df)", context);
       assertEquals(InterpreterResult.Code.SUCCESS, result.code());
       interpreterResultMessages = context.out.toInterpreterResultMessage();
       assertEquals(
           "_1	_2\n" +
-          "1	a\n" +
-          "2	b\n", interpreterResultMessages.get(0).getData());
+              "1	a\n" +
+              "2	b", interpreterResultMessages.get(0).getData().trim());
     }
     // cancel
-    final InterpreterContext context2 = getInterpreterContext();
-
-    Thread thread = new Thread() {
-      @Override
-      public void run() {
-        InterpreterResult result = iPySparkInterpreter.interpret("import time\nsc.range(1,10).foreach(lambda
x: time.sleep(1))", context2);
-        assertEquals(InterpreterResult.Code.ERROR, result.code());
-        List<InterpreterResultMessage> interpreterResultMessages = null;
-        try {
-          interpreterResultMessages = context2.out.toInterpreterResultMessage();
-          assertTrue(interpreterResultMessages.get(0).getData().contains("KeyboardInterrupt"));
-        } catch (IOException e) {
-          e.printStackTrace();
-        }
-      }
-    };
-    thread.start();
+    if (interpreter instanceof IPySparkInterpreter) {
+      final InterpreterContext context2 = getInterpreterContext(mockRemoteEventClient);
 
+      Thread thread = new Thread() {
+        @Override
+        public void run() {
+          InterpreterResult result = null;
+          try {
+            result = interpreter.interpret("import time\nsc.range(1,10).foreach(lambda x:
time.sleep(1))", context2);
+          } catch (InterpreterException e) {
+            e.printStackTrace();
+          }
+          assertEquals(InterpreterResult.Code.ERROR, result.code());
+          List<InterpreterResultMessage> interpreterResultMessages = null;
+          try {
+            interpreterResultMessages = context2.out.toInterpreterResultMessage();
+            assertTrue(interpreterResultMessages.get(0).getData().contains("KeyboardInterrupt"));
+          } catch (IOException e) {
+            e.printStackTrace();
+          }
+        }
+      };
+      thread.start();
 
-    // sleep 1 second to wait for the spark job starts
-    Thread.sleep(1000);
-    iPySparkInterpreter.cancel(context);
-    thread.join();
+      // sleep 1 second to wait for the spark job starts
+      Thread.sleep(1000);
+      interpreter.cancel(context);
+      thread.join();
+    }
 
     // completions
-    List<InterpreterCompletion> completions = iPySparkInterpreter.completion("sc.ran",
6, getInterpreterContext());
+    List<InterpreterCompletion> completions = interpreter.completion("sc.ran", 6, getInterpreterContext(mockRemoteEventClient));
     assertEquals(1, completions.size());
     assertEquals("range", completions.get(0).getValue());
 
+    completions = interpreter.completion("sc.", 3, getInterpreterContext(mockRemoteEventClient));
+    assertTrue(completions.size() > 0);
+    completions.contains(new InterpreterCompletion("range", "range", ""));
+
+    completions = interpreter.completion("1+1\nsc.", 7, getInterpreterContext(mockRemoteEventClient));
+    assertTrue(completions.size() > 0);
+    completions.contains(new InterpreterCompletion("range", "range", ""));
+
+    completions = interpreter.completion("s", 1, getInterpreterContext(mockRemoteEventClient));
+    assertTrue(completions.size() > 0);
+    completions.contains(new InterpreterCompletion("sc", "sc", ""));
+
     // pyspark streaming
-    context = getInterpreterContext();
-    result = iPySparkInterpreter.interpret(
+    context = getInterpreterContext(mockRemoteEventClient);
+    result = interpreter.interpret(
         "from pyspark.streaming import StreamingContext\n" +
             "import time\n" +
             "ssc = StreamingContext(sc, 1)\n" +
@@ -212,11 +235,11 @@ public class IPySparkInterpreterTest {
     assertTrue(interpreterResultMessages.get(0).getData().contains("(0, 100)"));
   }
 
-  private boolean isSpark2(String sparkVersion) {
+  private static boolean isSpark2(String sparkVersion) {
     return sparkVersion.startsWith("'2.") || sparkVersion.startsWith("u'2.");
   }
 
-  private InterpreterContext getInterpreterContext() {
+  private static InterpreterContext getInterpreterContext(RemoteEventClient mockRemoteEventClient)
{
     InterpreterContext context = new InterpreterContext(
         "noteId",
         "paragraphId",

http://git-wip-us.apache.org/repos/asf/zeppelin/blob/aefc7ea3/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java
b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java
index 00972b4..e228c7e 100644
--- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java
+++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java
@@ -20,6 +20,7 @@ package org.apache.zeppelin.spark;
 import org.apache.zeppelin.display.AngularObjectRegistry;
 import org.apache.zeppelin.display.GUI;
 import org.apache.zeppelin.interpreter.*;
+import org.apache.zeppelin.interpreter.remote.RemoteEventClient;
 import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
 import org.apache.zeppelin.resource.LocalResourcePool;
 import org.apache.zeppelin.user.AuthenticationInfo;
@@ -31,11 +32,15 @@ import java.io.IOException;
 import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.Map;
 import java.util.Properties;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 
 import static org.junit.Assert.*;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 
 @FixMethodOrder(MethodSorters.NAME_ASCENDING)
 public class PySparkInterpreterTest {
@@ -47,6 +52,7 @@ public class PySparkInterpreterTest {
   static PySparkInterpreter pySparkInterpreter;
   static InterpreterGroup intpGroup;
   static InterpreterContext context;
+  private RemoteEventClient mockRemoteEventClient = mock(RemoteEventClient.class);
 
   private static Properties getPySparkTestProperties() throws IOException {
     Properties p = new Properties();
@@ -101,8 +107,6 @@ public class PySparkInterpreterTest {
     intpGroup.get("note").add(pySparkInterpreter);
     pySparkInterpreter.setInterpreterGroup(intpGroup);
     pySparkInterpreter.open();
-
-
   }
 
   @AfterClass
@@ -112,35 +116,8 @@ public class PySparkInterpreterTest {
   }
 
   @Test
-  public void testBasicIntp() throws InterpreterException {
-    if (getSparkVersionNumber() > 11) {
-      assertEquals(InterpreterResult.Code.SUCCESS,
-        pySparkInterpreter.interpret("a = 1\n", context).code());
-    }
-
-    InterpreterResult result = pySparkInterpreter.interpret(
-        "from pyspark.streaming import StreamingContext\n" +
-            "import time\n" +
-            "ssc = StreamingContext(sc, 1)\n" +
-            "rddQueue = []\n" +
-            "for i in range(5):\n" +
-            "    rddQueue += [ssc.sparkContext.parallelize([j for j in range(1, 1001)], 10)]\n"
+
-            "inputStream = ssc.queueStream(rddQueue)\n" +
-            "mappedStream = inputStream.map(lambda x: (x % 10, 1))\n" +
-            "reducedStream = mappedStream.reduceByKey(lambda a, b: a + b)\n" +
-            "reducedStream.pprint()\n" +
-            "ssc.start()\n" +
-            "time.sleep(6)\n" +
-            "ssc.stop(stopSparkContext=False, stopGraceFully=True)", context);
-    assertEquals(InterpreterResult.Code.SUCCESS, result.code());
-  }
-
-  @Test
-  public void testCompletion() throws InterpreterException {
-    if (getSparkVersionNumber() > 11) {
-      List<InterpreterCompletion> completions = pySparkInterpreter.completion("sc.",
"sc.".length(), null);
-      assertTrue(completions.size() > 0);
-    }
+  public void testBasicIntp() throws InterpreterException, InterruptedException, IOException
{
+    IPySparkInterpreterTest.testPySpark(pySparkInterpreter, mockRemoteEventClient);
   }
 
   @Test


Mime
View raw message