zeppelin-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject incubator-zeppelin git commit: ZEPPELIN-297 Dependency should be loaded in pypsark
Date Mon, 14 Sep 2015 06:13:59 GMT
Repository: incubator-zeppelin
Updated Branches:
  refs/heads/master d169284fd -> ca67374c2


ZEPPELIN-297 Dependency should be loaded in pypsark

This PR fixes https://issues.apache.org/jira/browse/ZEPPELIN-297

* [x] fix, by set context classloader with classes from dependency loader when initializing
py4j gatewayserver
* [x] test

Author: Lee moon soo <moon@apache.org>

Closes #298 from Leemoonsoo/ZEPPELIN-297 and squashes the following commits:

0de89fe [Lee moon soo] Add logging
1e8f52a [Lee moon soo] Add test
163acfa [Lee moon soo] Set classloader for gatewayserver with classes from dependency loader


Project: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/commit/ca67374c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/tree/ca67374c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/diff/ca67374c

Branch: refs/heads/master
Commit: ca67374c2b0fe26f6bc6992910bb11d5658258e0
Parents: d169284
Author: Lee moon soo <moon@apache.org>
Authored: Mon Sep 14 14:04:40 2015 +0900
Committer: Lee moon soo <moon@apache.org>
Committed: Mon Sep 14 15:13:53 2015 +0900

----------------------------------------------------------------------
 .../zeppelin/spark/PySparkInterpreter.java      | 59 ++++++++++++++++++++
 .../zeppelin/rest/ZeppelinSparkClusterTest.java | 45 ++++++++++++++-
 2 files changed, 103 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/ca67374c/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java
----------------------------------------------------------------------
diff --git a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java
index 0e58729..d0e5fec 100644
--- a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java
+++ b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java
@@ -25,7 +25,10 @@ import java.io.IOException;
 import java.io.OutputStreamWriter;
 import java.io.PipedInputStream;
 import java.io.PipedOutputStream;
+import java.net.MalformedURLException;
 import java.net.ServerSocket;
+import java.net.URL;
+import java.net.URLClassLoader;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
@@ -51,6 +54,7 @@ import org.apache.zeppelin.interpreter.InterpreterResult;
 import org.apache.zeppelin.interpreter.InterpreterResult.Code;
 import org.apache.zeppelin.interpreter.LazyOpenInterpreter;
 import org.apache.zeppelin.interpreter.WrappedInterpreter;
+import org.apache.zeppelin.spark.dep.DependencyContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -125,6 +129,44 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
 
   @Override
   public void open() {
+    DepInterpreter depInterpreter = getDepInterpreter();
+
+    // load libraries from Dependency Interpreter
+    URL [] urls = new URL[0];
+
+    if (depInterpreter != null) {
+      DependencyContext depc = depInterpreter.getDependencyContext();
+      if (depc != null) {
+        List<File> files = depc.getFiles();
+        List<URL> urlList = new LinkedList<URL>();
+        if (files != null) {
+          for (File f : files) {
+            try {
+              urlList.add(f.toURI().toURL());
+            } catch (MalformedURLException e) {
+              logger.error("Error", e);
+            }
+          }
+
+          urls = urlList.toArray(urls);
+        }
+      }
+    }
+
+    ClassLoader oldCl = Thread.currentThread().getContextClassLoader();
+    try {
+      URLClassLoader newCl = new URLClassLoader(urls, oldCl);
+      Thread.currentThread().setContextClassLoader(newCl);
+      createGatewayServerAndStartScript();
+    } catch (Exception e) {
+      logger.error("Error", e);
+      throw new InterpreterException(e);
+    } finally {
+      Thread.currentThread().setContextClassLoader(oldCl);
+    }
+  }
+
+  private void createGatewayServerAndStartScript() {
     // create python script
     createPythonScript();
 
@@ -400,6 +442,23 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand
     }
   }
 
+  private DepInterpreter getDepInterpreter() {
+    InterpreterGroup intpGroup = getInterpreterGroup();
+    if (intpGroup == null) return null;
+    synchronized (intpGroup) {
+      for (Interpreter intp : intpGroup) {
+        if (intp.getClassName().equals(DepInterpreter.class.getName())) {
+          Interpreter p = intp;
+          while (p instanceof WrappedInterpreter) {
+            p = ((WrappedInterpreter) p).getInnerInterpreter();
+          }
+          return (DepInterpreter) p;
+        }
+      }
+    }
+    return null;
+  }
+
 
   @Override
   public void onProcessComplete(int exitValue) {

http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/ca67374c/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java
----------------------------------------------------------------------
diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java
b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java
index aa2476a..d5006ee 100644
--- a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java
+++ b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java
@@ -18,8 +18,12 @@ package org.apache.zeppelin.rest;
 
 import static org.junit.Assert.assertEquals;
 
+import java.io.File;
 import java.io.IOException;
+import java.util.List;
 
+import org.apache.commons.io.FileUtils;
+import org.apache.zeppelin.interpreter.InterpreterSetting;
 import org.apache.zeppelin.notebook.Note;
 import org.apache.zeppelin.notebook.Paragraph;
 import org.apache.zeppelin.scheduler.Job.Status;
@@ -75,7 +79,6 @@ public class ZeppelinSparkClusterTest extends AbstractTestRestApi {
   public void pySparkTest() throws IOException {
     // create new note
     Note note = ZeppelinServer.notebook.createNote();
-
     int sparkVersion = getSparkVersionNumber(note);
 
     if (isPyspark() && sparkVersion >= 12) {   // pyspark supported from 1.2.1
@@ -129,6 +132,46 @@ public class ZeppelinSparkClusterTest extends AbstractTestRestApi {
     ZeppelinServer.notebook.removeNote(note.id());
   }
 
+  @Test
+  public void pySparkDepLoaderTest() throws IOException {
+    // create new note
+    Note note = ZeppelinServer.notebook.createNote();
+
+    if (isPyspark() && getSparkVersionNumber(note) >= 14) {
+      // restart spark interpreter
+      List<InterpreterSetting> settings =
+          ZeppelinServer.notebook.getBindedInterpreterSettings(note.id());
+
+      for (InterpreterSetting setting : settings) {
+        if (setting.getGroup().equals("spark")) {
+          ZeppelinServer.notebook.getInterpreterFactory().restart(setting.id());
+          break;
+        }
+      }
+
+      // load dep
+      Paragraph p0 = note.addParagraph();
+      p0.setText("%dep z.load(\"com.databricks:spark-csv_2.11:1.2.0\")");
+      note.run(p0.getId());
+      waitForFinish(p0);
+
+      // write test csv file
+      File tmpFile = File.createTempFile("test", "csv");
+      FileUtils.write(tmpFile, "a,b\n1,2");
+
+      // load data using libraries from dep loader
+      Paragraph p1 = note.addParagraph();
+      p1.setText("%pyspark\n" +
+        "from pyspark.sql import SQLContext\n" +
+        "print(sqlContext.read.format('com.databricks.spark.csv')" +
+        ".load('"+ tmpFile.getAbsolutePath() +"').count())");
+      note.run(p1.getId());
+
+      waitForFinish(p1);
+      assertEquals("2\n", p1.getResult().message());
+    }
+  }
+
   /**
    * Get spark version number as a numerical value.
    * eg. 1.1.x => 11, 1.2.x => 12, 1.3.x => 13 ...


Mime
View raw message