From commits-return-4302-archive-asf-public=cust-asf.ponee.io@zeppelin.apache.org Sun Apr 1 08:31:26 2018 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx-eu-01.ponee.io (Postfix) with SMTP id 2D883180634 for ; Sun, 1 Apr 2018 08:31:25 +0200 (CEST) Received: (qmail 3355 invoked by uid 500); 1 Apr 2018 06:31:24 -0000 Mailing-List: contact commits-help@zeppelin.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@zeppelin.apache.org Delivered-To: mailing list commits@zeppelin.apache.org Received: (qmail 3346 invoked by uid 99); 1 Apr 2018 06:31:24 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Sun, 01 Apr 2018 06:31:24 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 67AA6EF458; Sun, 1 Apr 2018 06:31:23 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: zjffdu@apache.org To: commits@zeppelin.apache.org Message-Id: <0cb088394fcf445d98f7366b6bc4e54a@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: zeppelin git commit: ZEPPELIN-3374. Improvement on PySparkInterpreter Date: Sun, 1 Apr 2018 06:31:23 +0000 (UTC) Repository: zeppelin Updated Branches: refs/heads/master 8238b711c -> aefc7ea39 ZEPPELIN-3374. Improvement on PySparkInterpreter ### What is this PR for? A few improvements on PySparkInterpreter. 1. Refactor PySparkInterpreter to make it more readable 2. Code completion features is totally broken, fix it in this PR 3. Reuse the same test case of IPySparkInterpreter. ### What type of PR is it? [Bug Fix | Improvement | Refactoring] ### Todos * [ ] - Task ### What is the Jira issue? * https://issues.apache.org/jira/browse/ZEPPELIN-3374 ### How should this be tested? * CI pass ### Screenshots (if appropriate) Code completion before ![completion_before](https://user-images.githubusercontent.com/164491/38160504-ee1ea3a8-34f1-11e8-9aab-baf98962aae3.gif) Code completion after ![completion_after](https://user-images.githubusercontent.com/164491/38160505-eff108b0-34f1-11e8-88eb-03c51cfb96de.gif) ### Questions: * Does the licenses files need update? No * Is there breaking changes for older versions? No * Does this needs documentation? No Author: Jeff Zhang Closes #2901 from zjffdu/ZEPPELIN-3374 and squashes the following commits: c22078d [Jeff Zhang] ZEPPELIN-3374. Improvement on PySparkInterpreter Project: http://git-wip-us.apache.org/repos/asf/zeppelin/repo Commit: http://git-wip-us.apache.org/repos/asf/zeppelin/commit/aefc7ea3 Tree: http://git-wip-us.apache.org/repos/asf/zeppelin/tree/aefc7ea3 Diff: http://git-wip-us.apache.org/repos/asf/zeppelin/diff/aefc7ea3 Branch: refs/heads/master Commit: aefc7ea395fed5a93e69942725a19b203a3574e2 Parents: 8238b71 Author: Jeff Zhang Authored: Fri Mar 30 11:12:08 2018 +0800 Committer: Jeff Zhang Committed: Sun Apr 1 14:31:12 2018 +0800 ---------------------------------------------------------------------- .../zeppelin/spark/PySparkInterpreter.java | 290 +++++++++---------- .../main/resources/python/zeppelin_pyspark.py | 86 +++--- .../zeppelin/spark/IPySparkInterpreterTest.java | 127 ++++---- .../zeppelin/spark/PySparkInterpreterTest.java | 39 +-- 4 files changed, 259 insertions(+), 283 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/zeppelin/blob/aefc7ea3/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java index d97bb51..809e883 100644 --- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java @@ -41,6 +41,7 @@ import org.apache.zeppelin.interpreter.InterpreterResultMessage; import org.apache.zeppelin.interpreter.InvalidHookException; import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.apache.zeppelin.interpreter.WrappedInterpreter; +import org.apache.zeppelin.interpreter.remote.RemoteInterpreterUtils; import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; import org.apache.zeppelin.interpreter.util.InterpreterOutputStream; import org.apache.zeppelin.spark.dep.SparkDependencyContext; @@ -53,11 +54,8 @@ import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; -import java.io.OutputStreamWriter; import java.io.PipedInputStream; -import java.io.PipedOutputStream; import java.net.MalformedURLException; -import java.net.ServerSocket; import java.net.URL; import java.net.URLClassLoader; import java.util.LinkedList; @@ -66,65 +64,26 @@ import java.util.Map; import java.util.Properties; /** - * + * Interpreter for PySpark, it is the first implementation of interpreter for PySpark, so with less + * features compared to IPySparkInterpreter, but requires less prerequisites than + * IPySparkInterpreter, only python is required. */ public class PySparkInterpreter extends Interpreter implements ExecuteResultHandler { private static final Logger LOGGER = LoggerFactory.getLogger(PySparkInterpreter.class); + private static final int MAX_TIMEOUT_SEC = 10; + private GatewayServer gatewayServer; private DefaultExecutor executor; - private int port; + // used to forward output from python process to InterpreterOutput private InterpreterOutputStream outputStream; - private BufferedWriter ins; - private PipedInputStream in; - private ByteArrayOutputStream input; private String scriptPath; - boolean pythonscriptRunning = false; - private static final int MAX_TIMEOUT_SEC = 10; - private long pythonPid; - + private boolean pythonscriptRunning = false; + private long pythonPid = -1; private IPySparkInterpreter iPySparkInterpreter; + private SparkInterpreter sparkInterpreter; public PySparkInterpreter(Properties property) { super(property); - - pythonPid = -1; - try { - File scriptFile = File.createTempFile("zeppelin_pyspark-", ".py"); - scriptPath = scriptFile.getAbsolutePath(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private void createPythonScript() throws InterpreterException { - ClassLoader classLoader = getClass().getClassLoader(); - File out = new File(scriptPath); - - if (out.exists() && out.isDirectory()) { - throw new InterpreterException("Can't create python script " + out.getAbsolutePath()); - } - - try { - FileOutputStream outStream = new FileOutputStream(out); - IOUtils.copy( - classLoader.getResourceAsStream("python/zeppelin_pyspark.py"), - outStream); - outStream.close(); - } catch (IOException e) { - throw new InterpreterException(e); - } - - try { - FileOutputStream outStream = new FileOutputStream(out.getParent() + "/zeppelin_context.py"); - IOUtils.copy( - classLoader.getResourceAsStream("python/zeppelin_context.py"), - outStream); - outStream.close(); - } catch (IOException e) { - throw new InterpreterException(e); - } - - LOGGER.info("File {} created", scriptPath); } @Override @@ -151,6 +110,7 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand InterpreterGroup intpGroup = getInterpreterGroup(); if (intpGroup != null && intpGroup.getInterpreterHookRegistry() != null) { try { + // just for unit test I believe (zjffdu) registerHook(HookType.POST_EXEC_DEV.getName(), "__zeppelin__._displayhook()"); } catch (InvalidHookException e) { throw new InterpreterException(e); @@ -199,48 +159,118 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand ClassLoader oldCl = Thread.currentThread().getContextClassLoader(); try { URLClassLoader newCl = new URLClassLoader(urls, oldCl); + LOGGER.info("urls:" + urls); + for (URL url : urls) { + LOGGER.info("url:" + url); + } Thread.currentThread().setContextClassLoader(newCl); + // must create spark interpreter after ClassLoader is set, otherwise the additional jars + // can not be loaded by spark repl. + this.sparkInterpreter = getSparkInterpreter(); createGatewayServerAndStartScript(); - } catch (Exception e) { - LOGGER.error("Error", e); - throw new InterpreterException(e); + } catch (IOException e) { + LOGGER.error("Fail to open PySparkInterpreter", e); + throw new InterpreterException("Fail to open PySparkInterpreter", e); } finally { Thread.currentThread().setContextClassLoader(oldCl); } } - private Map setupPySparkEnv() throws IOException, InterpreterException { - Map env = EnvironmentUtils.getProcEnvironment(); + private void createGatewayServerAndStartScript() throws IOException { + // start gateway server in JVM side + int port = RemoteInterpreterUtils.findRandomAvailablePortOnAllLocalInterfaces(); + gatewayServer = new GatewayServer(this, port); + gatewayServer.start(); + + // launch python process to connect to the gateway server in JVM side + createPythonScript(); + String pythonExec = getPythonExec(getProperties()); + LOGGER.info("PythonExec: " + pythonExec); + CommandLine cmd = CommandLine.parse(pythonExec); + cmd.addArgument(scriptPath, false); + cmd.addArgument(Integer.toString(port), false); + cmd.addArgument(Integer.toString(sparkInterpreter.getSparkVersion().toNumber()), false); + executor = new DefaultExecutor(); + outputStream = new InterpreterOutputStream(LOGGER); + PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream); + executor.setStreamHandler(streamHandler); + executor.setWatchdog(new ExecuteWatchdog(ExecuteWatchdog.INFINITE_TIMEOUT)); + + Map env = setupPySparkEnv(); + executor.execute(cmd, env, this); + pythonscriptRunning = true; + } + + private void createPythonScript() throws IOException { + FileOutputStream pysparkScriptOutput = null; + FileOutputStream zeppelinContextOutput = null; + try { + // copy zeppelin_pyspark.py + File scriptFile = File.createTempFile("zeppelin_pyspark-", ".py"); + this.scriptPath = scriptFile.getAbsolutePath(); + pysparkScriptOutput = new FileOutputStream(scriptFile); + IOUtils.copy( + getClass().getClassLoader().getResourceAsStream("python/zeppelin_pyspark.py"), + pysparkScriptOutput); + + // copy zeppelin_context.py to the same folder of zeppelin_pyspark.py + zeppelinContextOutput = new FileOutputStream(scriptFile.getParent() + "/zeppelin_context.py"); + IOUtils.copy( + getClass().getClassLoader().getResourceAsStream("python/zeppelin_context.py"), + zeppelinContextOutput); + LOGGER.info("PySpark script {} {} is created", + scriptPath, scriptFile.getParent() + "/zeppelin_context.py"); + } finally { + if (pysparkScriptOutput != null) { + try { + pysparkScriptOutput.close(); + } catch (IOException e) { + // ignore + } + } + if (zeppelinContextOutput != null) { + try { + zeppelinContextOutput.close(); + } catch (IOException e) { + // ignore + } + } + } + } + private Map setupPySparkEnv() throws IOException { + Map env = EnvironmentUtils.getProcEnvironment(); // only set PYTHONPATH in local or yarn-client mode. // yarn-cluster will setup PYTHONPATH automatically. - SparkConf conf = getSparkConf(); + SparkConf conf = null; + try { + conf = getSparkConf(); + } catch (InterpreterException e) { + throw new IOException(e); + } if (!conf.get("spark.submit.deployMode", "client").equals("cluster")) { if (!env.containsKey("PYTHONPATH")) { env.put("PYTHONPATH", PythonUtils.sparkPythonPath()); } else { - env.put("PYTHONPATH", PythonUtils.sparkPythonPath()); + env.put("PYTHONPATH", PythonUtils.sparkPythonPath() + ":" + env.get("PYTHONPATH")); } } // get additional class paths when using SPARK_SUBMIT and not using YARN-CLIENT // also, add all packages to PYTHONPATH since there might be transitive dependencies if (SparkInterpreter.useSparkSubmit() && - !getSparkInterpreter().isYarnMode()) { - - String sparkSubmitJars = getSparkConf().get("spark.jars").replace(",", ":"); - - if (!"".equals(sparkSubmitJars)) { - env.put("PYTHONPATH", env.get("PYTHONPATH") + sparkSubmitJars); + !sparkInterpreter.isYarnMode()) { + String sparkSubmitJars = conf.get("spark.jars").replace(",", ":"); + if (!StringUtils.isEmpty(sparkSubmitJars)) { + env.put("PYTHONPATH", env.get("PYTHONPATH") + ":" + sparkSubmitJars); } } - LOGGER.info("PYTHONPATH: " + env.get("PYTHONPATH")); - // set PYSPARK_PYTHON - if (getSparkConf().contains("spark.pyspark.python")) { - env.put("PYSPARK_PYTHON", getSparkConf().get("spark.pyspark.python")); + if (conf.contains("spark.pyspark.python")) { + env.put("PYSPARK_PYTHON", conf.get("spark.pyspark.python")); } + LOGGER.info("PYTHONPATH: " + env.get("PYTHONPATH")); return env; } @@ -258,66 +288,6 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand return pythonExec; } - private void createGatewayServerAndStartScript() throws InterpreterException { - // create python script - createPythonScript(); - - port = findRandomOpenPortOnAllLocalInterfaces(); - - gatewayServer = new GatewayServer(this, port); - gatewayServer.start(); - - String pythonExec = getPythonExec(getProperties()); - LOGGER.info("pythonExec: " + pythonExec); - CommandLine cmd = CommandLine.parse(pythonExec); - cmd.addArgument(scriptPath, false); - cmd.addArgument(Integer.toString(port), false); - cmd.addArgument(Integer.toString(getSparkInterpreter().getSparkVersion().toNumber()), false); - executor = new DefaultExecutor(); - outputStream = new InterpreterOutputStream(LOGGER); - PipedOutputStream ps = new PipedOutputStream(); - in = null; - try { - in = new PipedInputStream(ps); - } catch (IOException e1) { - throw new InterpreterException(e1); - } - ins = new BufferedWriter(new OutputStreamWriter(ps)); - - input = new ByteArrayOutputStream(); - - PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream, outputStream, in); - executor.setStreamHandler(streamHandler); - executor.setWatchdog(new ExecuteWatchdog(ExecuteWatchdog.INFINITE_TIMEOUT)); - - try { - Map env = setupPySparkEnv(); - executor.execute(cmd, env, this); - pythonscriptRunning = true; - } catch (IOException e) { - throw new InterpreterException(e); - } - - - try { - input.write("import sys, getopt\n".getBytes()); - ins.flush(); - } catch (IOException e) { - throw new InterpreterException(e); - } - } - - private int findRandomOpenPortOnAllLocalInterfaces() throws InterpreterException { - int port; - try (ServerSocket socket = new ServerSocket(0);) { - port = socket.getLocalPort(); - socket.close(); - } catch (IOException e) { - throw new InterpreterException(e); - } - return port; - } - @Override public void close() throws InterpreterException { if (iPySparkInterpreter != null) { @@ -325,25 +295,30 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand return; } executor.getWatchdog().destroyProcess(); - new File(scriptPath).delete(); gatewayServer.shutdown(); } - PythonInterpretRequest pythonInterpretRequest = null; + private PythonInterpretRequest pythonInterpretRequest = null; + private Integer statementSetNotifier = new Integer(0); + private String statementOutput = null; + private boolean statementError = false; + private Integer statementFinishedNotifier = new Integer(0); /** - * + * Request send to Python Daemon */ public class PythonInterpretRequest { public String statements; public String jobGroup; public String jobDescription; + public boolean isForCompletion; public PythonInterpretRequest(String statements, String jobGroup, - String jobDescription) { + String jobDescription, boolean isForCompletion) { this.statements = statements; this.jobGroup = jobGroup; this.jobDescription = jobDescription; + this.isForCompletion = isForCompletion; } public String statements() { @@ -357,10 +332,13 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand public String jobDescription() { return jobDescription; } - } - Integer statementSetNotifier = new Integer(0); + public boolean isForCompletion() { + return isForCompletion; + } + } + // called by Python Process public PythonInterpretRequest getStatements() { synchronized (statementSetNotifier) { while (pythonInterpretRequest == null) { @@ -375,10 +353,7 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand } } - String statementOutput = null; - boolean statementError = false; - Integer statementFinishedNotifier = new Integer(0); - + // called by Python Process public void setStatementsFinished(String out, boolean error) { synchronized (statementFinishedNotifier) { LOGGER.debug("Setting python statement output: " + out + ", error: " + error); @@ -388,9 +363,10 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand } } - boolean pythonScriptInitialized = false; - Integer pythonScriptInitializeNotifier = new Integer(0); + private boolean pythonScriptInitialized = false; + private Integer pythonScriptInitializeNotifier = new Integer(0); + // called by Python Process public void onPythonScriptInitialized(long pid) { pythonPid = pid; synchronized (pythonScriptInitializeNotifier) { @@ -400,6 +376,7 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand } } + // called by Python Process public void appendOutput(String message) throws IOException { LOGGER.debug("Output from python process: " + message); outputStream.getInterpreterOutput().write(message); @@ -412,7 +389,6 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand return iPySparkInterpreter.interpret(st, context); } - SparkInterpreter sparkInterpreter = getSparkInterpreter(); if (sparkInterpreter.isUnsupportedSparkVersion()) { return new InterpreterResult(Code.ERROR, "Spark " + sparkInterpreter.getSparkVersion().toString() + " is not supported"); @@ -420,7 +396,7 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand sparkInterpreter.populateSparkWebUrl(context); if (!pythonscriptRunning) { - return new InterpreterResult(Code.ERROR, "python process not running" + return new InterpreterResult(Code.ERROR, "python process not running " + outputStream.toString()); } @@ -432,6 +408,7 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand && pythonscriptRunning && System.currentTimeMillis() - startTime < MAX_TIMEOUT_SEC * 1000) { try { + LOGGER.info("Wait for PythonScript running"); pythonScriptInitializeNotifier.wait(1000); } catch (InterruptedException e) { e.printStackTrace(); @@ -451,32 +428,34 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand if (pythonscriptRunning == false) { // python script failed to initialize and terminated errorMessage.add(new InterpreterResultMessage( - InterpreterResult.Type.TEXT, "failed to start pyspark")); + InterpreterResult.Type.TEXT, "Failed to start PySpark")); return new InterpreterResult(Code.ERROR, errorMessage); } if (pythonScriptInitialized == false) { // timeout. didn't get initialized message errorMessage.add(new InterpreterResultMessage( - InterpreterResult.Type.TEXT, "pyspark is not responding")); + InterpreterResult.Type.TEXT, "Failed to initialize PySpark")); return new InterpreterResult(Code.ERROR, errorMessage); } + //TODO(zjffdu) remove this as PySpark is supported starting from spark 1.2s if (!sparkInterpreter.getSparkVersion().isPysparkSupported()) { errorMessage.add(new InterpreterResultMessage( InterpreterResult.Type.TEXT, "pyspark " + sparkInterpreter.getSparkContext().version() + " is not supported")); return new InterpreterResult(Code.ERROR, errorMessage); } + String jobGroup = Utils.buildJobGroupId(context); String jobDesc = "Started by: " + Utils.getUserName(context.getAuthenticationInfo()); - SparkZeppelinContext __zeppelin__ = sparkInterpreter.getZeppelinContext(); - __zeppelin__.setInterpreterContext(context); - __zeppelin__.setGui(context.getGui()); - __zeppelin__.setNoteGui(context.getNoteGui()); + SparkZeppelinContext z = sparkInterpreter.getZeppelinContext(); + z.setInterpreterContext(context); + z.setGui(context.getGui()); + z.setNoteGui(context.getNoteGui()); InterpreterContext.set(context); - pythonInterpretRequest = new PythonInterpretRequest(st, jobGroup, jobDesc); + pythonInterpretRequest = new PythonInterpretRequest(st, jobGroup, jobDesc, false); statementOutput = null; synchronized (statementSetNotifier) { @@ -495,13 +474,11 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand if (statementError) { return new InterpreterResult(Code.ERROR, statementOutput); } else { - try { context.out.flush(); } catch (IOException e) { throw new InterpreterException(e); } - return new InterpreterResult(Code.SUCCESS); } } @@ -558,14 +535,14 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand } String completionString = getCompletionTargetString(buf, cursor); String completionCommand = "completion.getCompletion('" + completionString + "')"; + LOGGER.debug("completionCommand: " + completionCommand); //start code for completion - SparkInterpreter sparkInterpreter = getSparkInterpreter(); if (sparkInterpreter.isUnsupportedSparkVersion() || pythonscriptRunning == false) { return new LinkedList<>(); } - pythonInterpretRequest = new PythonInterpretRequest(completionCommand, "", ""); + pythonInterpretRequest = new PythonInterpretRequest(completionCommand, "", "", true); statementOutput = null; synchronized (statementSetNotifier) { @@ -596,7 +573,6 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand completionList = gson.fromJson(statementOutput, String[].class); } //end code for completion - if (completionList == null) { return new LinkedList<>(); } @@ -604,6 +580,7 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand List results = new LinkedList<>(); for (String name: completionList) { results.add(new InterpreterCompletion(name, name, StringUtils.EMPTY)); + LOGGER.debug("completion: " + name); } return results; } @@ -753,4 +730,9 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand pythonscriptRunning = false; LOGGER.error("python process failed", e); } + + // Called by Python Process, used for debugging purpose + public void logPythonOutput(String message) { + LOGGER.debug("Python Process Output: " + message); + } } http://git-wip-us.apache.org/repos/asf/zeppelin/blob/aefc7ea3/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py b/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py index 614c516..1352318 100644 --- a/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py +++ b/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py @@ -62,46 +62,33 @@ class PySparkCompletion: def __init__(self, interpreterObject): self.interpreterObject = interpreterObject - def getGlobalCompletion(self): - objectDefList = [] + def getGlobalCompletion(self, text_value): + completions = [completion for completion in list(globals().keys()) if completion.startswith(text_value)] + return completions + + def getMethodCompletion(self, objName, methodName): + execResult = locals() try: - for completionItem in list(globals().keys()): - objectDefList.append(completionItem) + exec("{} = dir({})".format("objectDefList", objName), globals(), execResult) except: return None else: - return objectDefList + objectDefList = execResult['objectDefList'] + return [completion for completion in execResult['objectDefList'] if completion.startswith(methodName)] - def getMethodCompletion(self, text_value): - execResult = locals() + def getCompletion(self, text_value): if text_value == None: return None - completion_target = text_value - try: - if len(completion_target) <= 0: - return None - if text_value[-1] == ".": - completion_target = text_value[:-1] - exec("{} = dir({})".format("objectDefList", completion_target), globals(), execResult) - except: - return None - else: - return list(execResult['objectDefList']) + dotPos = text_value.find(".") + if dotPos == -1: + objName = text_value + completionList = self.getGlobalCompletion(objName) + else: + objName = text_value[:dotPos] + methodName = text_value[dotPos + 1:] + completionList = self.getMethodCompletion(objName, methodName) - def getCompletion(self, text_value): - completionList = set() - - globalCompletionList = self.getGlobalCompletion() - if globalCompletionList != None: - for completionItem in list(globalCompletionList): - completionList.add(completionItem) - - if text_value != None: - objectCompletionList = self.getMethodCompletion(text_value) - if objectCompletionList != None: - for completionItem in list(objectCompletionList): - completionList.add(completionItem) if len(completionList) <= 0: self.interpreterObject.setStatementsFinished("", False) else: @@ -130,7 +117,6 @@ intp = gateway.entry_point output = Logger() sys.stdout = output sys.stderr = output -intp.onPythonScriptInitialized(os.getpid()) jsc = intp.getJavaSparkContext() @@ -195,13 +181,16 @@ __zeppelin__._setup_matplotlib() _zcUserQueryNameSpace["z"] = z _zcUserQueryNameSpace["__zeppelin__"] = __zeppelin__ +intp.onPythonScriptInitialized(os.getpid()) + while True : req = intp.getStatements() try: stmts = req.statements().split("\n") jobGroup = req.jobGroup() jobDesc = req.jobDescription() - + isForCompletion = req.isForCompletion() + # Get post-execute hooks try: global_hook = intp.getHook('post_exec_dev') @@ -214,9 +203,10 @@ while True : user_hook = None nhooks = 0 - for hook in (global_hook, user_hook): - if hook: - nhooks += 1 + if not isForCompletion: + for hook in (global_hook, user_hook): + if hook: + nhooks += 1 if stmts: # use exec mode to compile the statements except the last statement, @@ -226,9 +216,9 @@ while True : to_run_hooks = [] if (nhooks > 0): to_run_hooks = code.body[-nhooks:] + to_run_exec, to_run_single = (code.body[:-(nhooks + 1)], [code.body[-(nhooks + 1)]]) - try: for node in to_run_exec: mod = ast.Module([node]) @@ -245,19 +235,23 @@ while True : code = compile(mod, '', 'exec') exec(code, _zcUserQueryNameSpace) - intp.setStatementsFinished("", False) + if not isForCompletion: + # only call it when it is not for code completion. code completion will call it in + # PySparkCompletion.getCompletion + intp.setStatementsFinished("", False) except Py4JJavaError: # raise it to outside try except raise except: - exception = traceback.format_exc() - m = re.search("File \"\", line (\d+).*", exception) - if m: - line_no = int(m.group(1)) - intp.setStatementsFinished( - "Fail to execute line {}: {}\n".format(line_no, stmts[line_no - 1]) + exception, True) - else: - intp.setStatementsFinished(exception, True) + if not isForCompletion: + exception = traceback.format_exc() + m = re.search("File \"\", line (\d+).*", exception) + if m: + line_no = int(m.group(1)) + intp.setStatementsFinished( + "Fail to execute line {}: {}\n".format(line_no, stmts[line_no - 1]) + exception, True) + else: + intp.setStatementsFinished(exception, True) else: intp.setStatementsFinished("", False) http://git-wip-us.apache.org/repos/asf/zeppelin/blob/aefc7ea3/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java index 8d08117..2cc11ac 100644 --- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java @@ -97,101 +97,124 @@ public class IPySparkInterpreterTest { public void testBasics() throws InterruptedException, IOException, InterpreterException { // all the ipython test should pass too. IPythonInterpreterTest.testInterpreter(iPySparkInterpreter); + testPySpark(iPySparkInterpreter, mockRemoteEventClient); + } + + public static void testPySpark(final Interpreter interpreter, RemoteEventClient mockRemoteEventClient) + throws InterpreterException, IOException, InterruptedException { // rdd - InterpreterContext context = getInterpreterContext(); - InterpreterResult result = iPySparkInterpreter.interpret("sc.version", context); + InterpreterContext context = getInterpreterContext(mockRemoteEventClient); + InterpreterResult result = interpreter.interpret("sc.version", context); Thread.sleep(100); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); String sparkVersion = context.out.toInterpreterResultMessage().get(0).getData(); // spark url is sent verify(mockRemoteEventClient).onMetaInfosReceived(any(Map.class)); - context = getInterpreterContext(); - result = iPySparkInterpreter.interpret("sc.range(1,10).sum()", context); + context = getInterpreterContext(mockRemoteEventClient); + result = interpreter.interpret("sc.range(1,10).sum()", context); Thread.sleep(100); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); List interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals("45", interpreterResultMessages.get(0).getData()); + assertEquals("45", interpreterResultMessages.get(0).getData().trim()); // spark job url is sent verify(mockRemoteEventClient).onParaInfosReceived(any(String.class), any(String.class), any(Map.class)); // spark sql - context = getInterpreterContext(); + context = getInterpreterContext(mockRemoteEventClient); if (!isSpark2(sparkVersion)) { - result = iPySparkInterpreter.interpret("df = sqlContext.createDataFrame([(1,'a'),(2,'b')])\ndf.show()", context); + result = interpreter.interpret("df = sqlContext.createDataFrame([(1,'a'),(2,'b')])\ndf.show()", context); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); interpreterResultMessages = context.out.toInterpreterResultMessage(); assertEquals( "+---+---+\n" + - "| _1| _2|\n" + - "+---+---+\n" + - "| 1| a|\n" + - "| 2| b|\n" + - "+---+---+\n\n", interpreterResultMessages.get(0).getData()); + "| _1| _2|\n" + + "+---+---+\n" + + "| 1| a|\n" + + "| 2| b|\n" + + "+---+---+", interpreterResultMessages.get(0).getData().trim()); - context = getInterpreterContext(); - result = iPySparkInterpreter.interpret("z.show(df)", context); + context = getInterpreterContext(mockRemoteEventClient); + result = interpreter.interpret("z.show(df)", context); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); interpreterResultMessages = context.out.toInterpreterResultMessage(); assertEquals( "_1 _2\n" + - "1 a\n" + - "2 b\n", interpreterResultMessages.get(0).getData()); + "1 a\n" + + "2 b", interpreterResultMessages.get(0).getData().trim()); } else { - result = iPySparkInterpreter.interpret("df = spark.createDataFrame([(1,'a'),(2,'b')])\ndf.show()", context); + result = interpreter.interpret("df = spark.createDataFrame([(1,'a'),(2,'b')])\ndf.show()", context); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); interpreterResultMessages = context.out.toInterpreterResultMessage(); assertEquals( "+---+---+\n" + - "| _1| _2|\n" + - "+---+---+\n" + - "| 1| a|\n" + - "| 2| b|\n" + - "+---+---+\n\n", interpreterResultMessages.get(0).getData()); + "| _1| _2|\n" + + "+---+---+\n" + + "| 1| a|\n" + + "| 2| b|\n" + + "+---+---+", interpreterResultMessages.get(0).getData().trim()); - context = getInterpreterContext(); - result = iPySparkInterpreter.interpret("z.show(df)", context); + context = getInterpreterContext(mockRemoteEventClient); + result = interpreter.interpret("z.show(df)", context); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); interpreterResultMessages = context.out.toInterpreterResultMessage(); assertEquals( "_1 _2\n" + - "1 a\n" + - "2 b\n", interpreterResultMessages.get(0).getData()); + "1 a\n" + + "2 b", interpreterResultMessages.get(0).getData().trim()); } // cancel - final InterpreterContext context2 = getInterpreterContext(); - - Thread thread = new Thread() { - @Override - public void run() { - InterpreterResult result = iPySparkInterpreter.interpret("import time\nsc.range(1,10).foreach(lambda x: time.sleep(1))", context2); - assertEquals(InterpreterResult.Code.ERROR, result.code()); - List interpreterResultMessages = null; - try { - interpreterResultMessages = context2.out.toInterpreterResultMessage(); - assertTrue(interpreterResultMessages.get(0).getData().contains("KeyboardInterrupt")); - } catch (IOException e) { - e.printStackTrace(); - } - } - }; - thread.start(); + if (interpreter instanceof IPySparkInterpreter) { + final InterpreterContext context2 = getInterpreterContext(mockRemoteEventClient); + Thread thread = new Thread() { + @Override + public void run() { + InterpreterResult result = null; + try { + result = interpreter.interpret("import time\nsc.range(1,10).foreach(lambda x: time.sleep(1))", context2); + } catch (InterpreterException e) { + e.printStackTrace(); + } + assertEquals(InterpreterResult.Code.ERROR, result.code()); + List interpreterResultMessages = null; + try { + interpreterResultMessages = context2.out.toInterpreterResultMessage(); + assertTrue(interpreterResultMessages.get(0).getData().contains("KeyboardInterrupt")); + } catch (IOException e) { + e.printStackTrace(); + } + } + }; + thread.start(); - // sleep 1 second to wait for the spark job starts - Thread.sleep(1000); - iPySparkInterpreter.cancel(context); - thread.join(); + // sleep 1 second to wait for the spark job starts + Thread.sleep(1000); + interpreter.cancel(context); + thread.join(); + } // completions - List completions = iPySparkInterpreter.completion("sc.ran", 6, getInterpreterContext()); + List completions = interpreter.completion("sc.ran", 6, getInterpreterContext(mockRemoteEventClient)); assertEquals(1, completions.size()); assertEquals("range", completions.get(0).getValue()); + completions = interpreter.completion("sc.", 3, getInterpreterContext(mockRemoteEventClient)); + assertTrue(completions.size() > 0); + completions.contains(new InterpreterCompletion("range", "range", "")); + + completions = interpreter.completion("1+1\nsc.", 7, getInterpreterContext(mockRemoteEventClient)); + assertTrue(completions.size() > 0); + completions.contains(new InterpreterCompletion("range", "range", "")); + + completions = interpreter.completion("s", 1, getInterpreterContext(mockRemoteEventClient)); + assertTrue(completions.size() > 0); + completions.contains(new InterpreterCompletion("sc", "sc", "")); + // pyspark streaming - context = getInterpreterContext(); - result = iPySparkInterpreter.interpret( + context = getInterpreterContext(mockRemoteEventClient); + result = interpreter.interpret( "from pyspark.streaming import StreamingContext\n" + "import time\n" + "ssc = StreamingContext(sc, 1)\n" + @@ -212,11 +235,11 @@ public class IPySparkInterpreterTest { assertTrue(interpreterResultMessages.get(0).getData().contains("(0, 100)")); } - private boolean isSpark2(String sparkVersion) { + private static boolean isSpark2(String sparkVersion) { return sparkVersion.startsWith("'2.") || sparkVersion.startsWith("u'2."); } - private InterpreterContext getInterpreterContext() { + private static InterpreterContext getInterpreterContext(RemoteEventClient mockRemoteEventClient) { InterpreterContext context = new InterpreterContext( "noteId", "paragraphId", http://git-wip-us.apache.org/repos/asf/zeppelin/blob/aefc7ea3/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java index 00972b4..e228c7e 100644 --- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java @@ -20,6 +20,7 @@ package org.apache.zeppelin.spark; import org.apache.zeppelin.display.AngularObjectRegistry; import org.apache.zeppelin.display.GUI; import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.interpreter.remote.RemoteEventClient; import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; import org.apache.zeppelin.resource.LocalResourcePool; import org.apache.zeppelin.user.AuthenticationInfo; @@ -31,11 +32,15 @@ import java.io.IOException; import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Properties; import java.util.regex.Matcher; import java.util.regex.Pattern; import static org.junit.Assert.*; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; @FixMethodOrder(MethodSorters.NAME_ASCENDING) public class PySparkInterpreterTest { @@ -47,6 +52,7 @@ public class PySparkInterpreterTest { static PySparkInterpreter pySparkInterpreter; static InterpreterGroup intpGroup; static InterpreterContext context; + private RemoteEventClient mockRemoteEventClient = mock(RemoteEventClient.class); private static Properties getPySparkTestProperties() throws IOException { Properties p = new Properties(); @@ -101,8 +107,6 @@ public class PySparkInterpreterTest { intpGroup.get("note").add(pySparkInterpreter); pySparkInterpreter.setInterpreterGroup(intpGroup); pySparkInterpreter.open(); - - } @AfterClass @@ -112,35 +116,8 @@ public class PySparkInterpreterTest { } @Test - public void testBasicIntp() throws InterpreterException { - if (getSparkVersionNumber() > 11) { - assertEquals(InterpreterResult.Code.SUCCESS, - pySparkInterpreter.interpret("a = 1\n", context).code()); - } - - InterpreterResult result = pySparkInterpreter.interpret( - "from pyspark.streaming import StreamingContext\n" + - "import time\n" + - "ssc = StreamingContext(sc, 1)\n" + - "rddQueue = []\n" + - "for i in range(5):\n" + - " rddQueue += [ssc.sparkContext.parallelize([j for j in range(1, 1001)], 10)]\n" + - "inputStream = ssc.queueStream(rddQueue)\n" + - "mappedStream = inputStream.map(lambda x: (x % 10, 1))\n" + - "reducedStream = mappedStream.reduceByKey(lambda a, b: a + b)\n" + - "reducedStream.pprint()\n" + - "ssc.start()\n" + - "time.sleep(6)\n" + - "ssc.stop(stopSparkContext=False, stopGraceFully=True)", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - } - - @Test - public void testCompletion() throws InterpreterException { - if (getSparkVersionNumber() > 11) { - List completions = pySparkInterpreter.completion("sc.", "sc.".length(), null); - assertTrue(completions.size() > 0); - } + public void testBasicIntp() throws InterpreterException, InterruptedException, IOException { + IPySparkInterpreterTest.testPySpark(pySparkInterpreter, mockRemoteEventClient); } @Test