spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From van...@apache.org
Subject spark git commit: [SPARK-20922][CORE] Add whitelist of classes that can be deserialized by the launcher.
Date Thu, 01 Jun 2017 21:44:47 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.2 b81a702f4 -> 4cba3b5a3


[SPARK-20922][CORE] Add whitelist of classes that can be deserialized by the launcher.

Blindly deserializing classes using Java serialization opens the code up to
issues in other libraries, since just deserializing data from a stream may
end up execution code (think readObject()).

Since the launcher protocol is pretty self-contained, there's just a handful
of classes it legitimately needs to deserialize, and they're in just two
packages, so add a filter that throws errors if classes from any other
package show up in the stream.

This also maintains backwards compatibility (the updated launcher code can
still communicate with the backend code in older Spark releases).

Tested with new and existing unit tests.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #18166 from vanzin/SPARK-20922.

(cherry picked from commit 8efc6e986554ae66eab93cd64a9035d716adbab0)
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4cba3b5a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4cba3b5a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4cba3b5a

Branch: refs/heads/branch-2.2
Commit: 4cba3b5a350f4d477466fc73b32cbd653eee8405
Parents: b81a702
Author: Marcelo Vanzin <vanzin@cloudera.com>
Authored: Thu Jun 1 14:44:34 2017 -0700
Committer: Marcelo Vanzin <vanzin@cloudera.com>
Committed: Thu Jun 1 14:44:42 2017 -0700

----------------------------------------------------------------------
 .../launcher/FilteredObjectInputStream.java     | 53 +++++++++++
 .../spark/launcher/LauncherConnection.java      |  3 +-
 .../spark/launcher/LauncherServerSuite.java     | 92 ++++++++++++++------
 3 files changed, 121 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4cba3b5a/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java
----------------------------------------------------------------------
diff --git a/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java
b/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java
new file mode 100644
index 0000000..4d254a0
--- /dev/null
+++ b/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher;
+
+import java.io.InputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectStreamClass;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * An object input stream that only allows classes used by the launcher protocol to be in
the
+ * serialized stream. See SPARK-20922.
+ */
+class FilteredObjectInputStream extends ObjectInputStream {
+
+  private static final List<String> ALLOWED_PACKAGES = Arrays.asList(
+    "org.apache.spark.launcher.",
+    "java.lang.");
+
+  FilteredObjectInputStream(InputStream is) throws IOException {
+    super(is);
+  }
+
+  @Override
+  protected Class<?> resolveClass(ObjectStreamClass desc)
+      throws IOException, ClassNotFoundException {
+
+    boolean isValid = ALLOWED_PACKAGES.stream().anyMatch(p -> desc.getName().startsWith(p));
+    if (!isValid) {
+      throw new IllegalArgumentException(
+        String.format("Unexpected class in stream: %s", desc.getName()));
+    }
+    return super.resolveClass(desc);
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4cba3b5a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java
----------------------------------------------------------------------
diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java
index eec2649..b4a8719 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java
@@ -20,7 +20,6 @@ package org.apache.spark.launcher;
 import java.io.Closeable;
 import java.io.EOFException;
 import java.io.IOException;
-import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.net.Socket;
 import java.util.logging.Level;
@@ -53,7 +52,7 @@ abstract class LauncherConnection implements Closeable, Runnable {
   @Override
   public void run() {
     try {
-      ObjectInputStream in = new ObjectInputStream(socket.getInputStream());
+      FilteredObjectInputStream in = new FilteredObjectInputStream(socket.getInputStream());
       while (!closed) {
         Message msg = (Message) in.readObject();
         handle(msg);

http://git-wip-us.apache.org/repos/asf/spark/blob/4cba3b5a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
----------------------------------------------------------------------
diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
index 12f1a0c..03c2934 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
@@ -19,8 +19,11 @@ package org.apache.spark.launcher;
 
 import java.io.Closeable;
 import java.io.IOException;
+import java.io.ObjectInputStream;
 import java.net.InetAddress;
 import java.net.Socket;
+import java.util.Arrays;
+import java.util.List;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.Semaphore;
@@ -120,31 +123,7 @@ public class LauncherServerSuite extends BaseSuite {
       Socket s = new Socket(InetAddress.getLoopbackAddress(),
         LauncherServer.getServerInstance().getPort());
       client = new TestClient(s);
-
-      // Try a few times since the client-side socket may not reflect the server-side close
-      // immediately.
-      boolean helloSent = false;
-      int maxTries = 10;
-      for (int i = 0; i < maxTries; i++) {
-        try {
-          if (!helloSent) {
-            client.send(new Hello(handle.getSecret(), "1.4.0"));
-            helloSent = true;
-          } else {
-            client.send(new SetAppId("appId"));
-          }
-          fail("Expected exception caused by connection timeout.");
-        } catch (IllegalStateException | IOException e) {
-          // Expected.
-          break;
-        } catch (AssertionError e) {
-          if (i < maxTries - 1) {
-            Thread.sleep(100);
-          } else {
-            throw new AssertionError("Test failed after " + maxTries + " attempts.", e);
-          }
-        }
-      }
+      waitForError(client, handle.getSecret());
     } finally {
       SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT);
       kill(handle);
@@ -183,6 +162,25 @@ public class LauncherServerSuite extends BaseSuite {
     }
   }
 
+  @Test
+  public void testStreamFiltering() throws Exception {
+    ChildProcAppHandle handle = LauncherServer.newAppHandle();
+    TestClient client = null;
+    try {
+      Socket s = new Socket(InetAddress.getLoopbackAddress(),
+        LauncherServer.getServerInstance().getPort());
+
+      client = new TestClient(s);
+      client.send(new EvilPayload());
+      waitForError(client, handle.getSecret());
+      assertEquals(0, EvilPayload.EVIL_BIT);
+    } finally {
+      kill(handle);
+      close(client);
+      client.clientThread.join();
+    }
+  }
+
   private void kill(SparkAppHandle handle) {
     if (handle != null) {
       handle.kill();
@@ -199,6 +197,35 @@ public class LauncherServerSuite extends BaseSuite {
     }
   }
 
+  /**
+   * Try a few times to get a client-side error, since the client-side socket may not reflect
the
+   * server-side close immediately.
+   */
+  private void waitForError(TestClient client, String secret) throws Exception {
+    boolean helloSent = false;
+    int maxTries = 10;
+    for (int i = 0; i < maxTries; i++) {
+      try {
+        if (!helloSent) {
+          client.send(new Hello(secret, "1.4.0"));
+          helloSent = true;
+        } else {
+          client.send(new SetAppId("appId"));
+        }
+        fail("Expected error but message went through.");
+      } catch (IllegalStateException | IOException e) {
+        // Expected.
+        break;
+      } catch (AssertionError e) {
+        if (i < maxTries - 1) {
+          Thread.sleep(100);
+        } else {
+          throw new AssertionError("Test failed after " + maxTries + " attempts.", e);
+        }
+      }
+    }
+  }
+
   private static class TestClient extends LauncherConnection {
 
     final BlockingQueue<Message> inbound;
@@ -220,4 +247,19 @@ public class LauncherServerSuite extends BaseSuite {
 
   }
 
+  private static class EvilPayload extends LauncherProtocol.Message {
+
+    static int EVIL_BIT = 0;
+
+    // This field should cause the launcher server to throw an error and not deserialize
the
+    // message.
+    private List<String> notAllowedField = Arrays.asList("disallowed");
+
+    private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException
{
+      stream.defaultReadObject();
+      EVIL_BIT = 1;
+    }
+
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message