spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-5472][SQL] A JDBC data source for Spark SQL.
Date Tue, 03 Feb 2015 03:50:41 GMT
Repository: spark
Updated Branches:
  refs/heads/master 1bcd46574 -> 8f471a66d


[SPARK-5472][SQL] A JDBC data source for Spark SQL.

This pull request contains a Spark SQL data source that can pull data from, and can put data into, a JDBC database.

I have tested both read and write support with H2, MySQL, and Postgres.  It would surprise me if both read and write support worked flawlessly out-of-the-box for any other database; different databases have different names for different JDBC data types and different meanings for SQL types with the same name.  However, this code is designed (see `DriverQuirks.scala`) to make it *relatively* painless to add support for another database by augmenting the type mapping contained in this PR.

Author: Tor Myklebust <tmyklebu@gmail.com>

Closes #4261 from tmyklebu/master and squashes the following commits:

cf167ce [Tor Myklebust] Work around other Java tests ruining TestSQLContext.
67893bf [Tor Myklebust] Move the jdbcRDD methods into SQLContext itself.
585f95b [Tor Myklebust] Dependencies go into the project's pom.xml.
829d5ba [Tor Myklebust] Merge branch 'master' of https://github.com/apache/spark
41647ef [Tor Myklebust] Hide a couple things that don't need to be public.
7318aea [Tor Myklebust] Fix scalastyle warnings.
a09eeac [Tor Myklebust] JDBC data source for Spark SQL.
176bb98 [Tor Myklebust] Add test deps for JDBC support.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8f471a66
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8f471a66
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8f471a66

Branch: refs/heads/master
Commit: 8f471a66db0571a76a21c0d93312197fee16174a
Parents: 1bcd465
Author: Tor Myklebust <tmyklebu@gmail.com>
Authored: Mon Feb 2 19:50:14 2015 -0800
Committer: Michael Armbrust <michael@databricks.com>
Committed: Mon Feb 2 19:50:14 2015 -0800

----------------------------------------------------------------------
 sql/core/pom.xml                                |  24 ++
 .../org/apache/spark/sql/jdbc/JDBCUtils.java    |  59 +++
 .../scala/org/apache/spark/sql/SQLContext.scala |  49 ++-
 .../apache/spark/sql/jdbc/DriverQuirks.scala    |  99 +++++
 .../org/apache/spark/sql/jdbc/JDBCRDD.scala     | 417 +++++++++++++++++++
 .../apache/spark/sql/jdbc/JDBCRelation.scala    | 133 ++++++
 .../spark/sql/jdbc/JavaJDBCTrampoline.scala     |  30 ++
 .../scala/org/apache/spark/sql/jdbc/jdbc.scala  | 235 +++++++++++
 .../org/apache/spark/sql/jdbc/JavaJDBCTest.java | 102 +++++
 .../org/apache/spark/sql/jdbc/DockerHacks.scala |  51 +++
 .../org/apache/spark/sql/jdbc/JDBCSuite.scala   | 248 +++++++++++
 .../apache/spark/sql/jdbc/JDBCWriteSuite.scala  | 107 +++++
 .../spark/sql/jdbc/MySQLIntegration.scala       | 235 +++++++++++
 .../spark/sql/jdbc/PostgresIntegration.scala    | 149 +++++++
 14 files changed, 1937 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8f471a66/sql/core/pom.xml
----------------------------------------------------------------------
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 3e9ef07..1a0c77d 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -76,6 +76,30 @@
       <artifactId>scalacheck_${scala.binary.version}</artifactId>
       <scope>test</scope>
     </dependency>
+    <dependency>
+      <groupId>com.h2database</groupId>
+      <artifactId>h2</artifactId>
+      <version>1.4.183</version>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>mysql</groupId>
+      <artifactId>mysql-connector-java</artifactId>
+      <version>5.1.34</version>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.postgresql</groupId>
+      <artifactId>postgresql</artifactId>
+      <version>9.3-1102-jdbc41</version>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>com.spotify</groupId>
+      <artifactId>docker-client</artifactId>
+      <version>2.7.5</version>
+      <scope>test</scope>
+    </dependency>
   </dependencies>
   <build>
     <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>

http://git-wip-us.apache.org/repos/asf/spark/blob/8f471a66/sql/core/src/main/java/org/apache/spark/sql/jdbc/JDBCUtils.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/jdbc/JDBCUtils.java b/sql/core/src/main/java/org/apache/spark/sql/jdbc/JDBCUtils.java
new file mode 100644
index 0000000..aa441b2
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/jdbc/JDBCUtils.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc;
+
+import org.apache.spark.Partition;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.DataFrame;
+
+public class JDBCUtils {
+  /**
+   * Construct a DataFrame representing the JDBC table at the database
+   * specified by url with table name table.
+   */
+  public static DataFrame jdbcRDD(SQLContext sql, String url, String table) {
+    Partition[] parts = new Partition[1];
+    parts[0] = new JDBCPartition(null, 0);
+    return sql.baseRelationToDataFrame(
+        new JDBCRelation(url, table, parts, sql));
+  }
+
+  /**
+   * Construct a DataFrame representing the JDBC table at the database
+   * specified by url with table name table partitioned by parts.
+   * Here, parts is an array of expressions suitable for insertion into a WHERE
+   * clause; each one defines one partition.
+   */
+  public static DataFrame jdbcRDD(SQLContext sql, String url, String table, String[] parts) {
+    Partition[] partitions = new Partition[parts.length];
+    for (int i = 0; i < parts.length; i++)
+      partitions[i] = new JDBCPartition(parts[i], i);
+    return sql.baseRelationToDataFrame(
+        new JDBCRelation(url, table, partitions, sql));
+  }
+
+  private static JavaJDBCTrampoline trampoline = new JavaJDBCTrampoline();
+
+  public static void createJDBCTable(DataFrame rdd, String url, String table, boolean allowExisting) {
+    trampoline.createJDBCTable(rdd, url, table, allowExisting);
+  }
+
+  public static void insertIntoJDBC(DataFrame rdd, String url, String table, boolean overwrite) {
+    trampoline.insertIntoJDBC(rdd, url, table, overwrite);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8f471a66/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index d0bbb5f..f4692b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -24,7 +24,7 @@ import scala.collection.immutable
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.TypeTag
 
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, Partition}
 import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental}
 import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
 import org.apache.spark.rdd.RDD
@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.json._
+import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
 import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation, DDLParser, DataSourceStrategy}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
@@ -335,6 +336,52 @@ class SQLContext(@transient val sparkContext: SparkContext)
   }
 
   /**
+   * :: Experimental ::
+   * Construct an RDD representing the database table accessible via JDBC URL
+   * url named table.
+   */
+  @Experimental
+  def jdbcRDD(url: String, table: String): DataFrame = {
+    jdbcRDD(url, table, null.asInstanceOf[JDBCPartitioningInfo])
+  }
+
+  /**
+   * :: Experimental ::
+   * Construct an RDD representing the database table accessible via JDBC URL
+   * url named table.  The PartitioningInfo parameter
+   * gives the name of a column of integral type, a number of partitions, and
+   * advisory minimum and maximum values for the column.  The RDD is
+   * partitioned according to said column.
+   */
+  @Experimental
+  def jdbcRDD(url: String, table: String, partitioning: JDBCPartitioningInfo):
+      DataFrame = {
+    val parts = JDBCRelation.columnPartition(partitioning)
+    jdbcRDD(url, table, parts)
+  }
+
+  /**
+   * :: Experimental ::
+   * Construct an RDD representing the database table accessible via JDBC URL
+   * url named table.  The theParts parameter gives a list expressions
+   * suitable for inclusion in WHERE clauses; each one defines one partition
+   * of the RDD.
+   */
+  @Experimental
+  def jdbcRDD(url: String, table: String, theParts: Array[String]):
+      DataFrame = {
+    val parts: Array[Partition] = theParts.zipWithIndex.map(
+        x => JDBCPartition(x._1, x._2).asInstanceOf[Partition])
+    jdbcRDD(url, table, parts)
+  }
+
+  private def jdbcRDD(url: String, table: String, parts: Array[Partition]):
+      DataFrame = {
+    val relation = JDBCRelation(url, table, parts)(this)
+    baseRelationToDataFrame(relation)
+  }
+
+  /**
    * Registers the given RDD as a temporary table in the catalog.  Temporary tables exist only
    * during the lifetime of this instance of SQLContext.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/8f471a66/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala
new file mode 100644
index 0000000..1704be7
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import org.apache.spark.sql.types._
+
+import java.sql.Types
+
+
+/**
+ * Encapsulates workarounds for the extensions, quirks, and bugs in various
+ * databases.  Lots of databases define types that aren't explicitly supported
+ * by the JDBC spec.  Some JDBC drivers also report inaccurate
+ * information---for instance, BIT(n>1) being reported as a BIT type is quite
+ * common, even though BIT in JDBC is meant for single-bit values.  Also, there
+ * does not appear to be a standard name for an unbounded string or binary
+ * type; we use BLOB and CLOB by default but override with database-specific
+ * alternatives when these are absent or do not behave correctly.
+ *
+ * Currently, the only thing DriverQuirks does is handle type mapping.
+ * `getCatalystType` is used when reading from a JDBC table and `getJDBCType`
+ * is used when writing to a JDBC table.  If `getCatalystType` returns `null`,
+ * the default type handling is used for the given JDBC type.  Similarly,
+ * if `getJDBCType` returns `(null, None)`, the default type handling is used
+ * for the given Catalyst type.
+ */
+private[sql] abstract class DriverQuirks {
+  def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType
+  def getJDBCType(dt: DataType): (String, Option[Int])
+}
+
+private[sql] object DriverQuirks {
+  /**
+   * Fetch the DriverQuirks class corresponding to a given database url.
+   */
+  def get(url: String): DriverQuirks = {
+    if (url.substring(0, 10).equals("jdbc:mysql")) {
+      new MySQLQuirks()
+    } else if (url.substring(0, 15).equals("jdbc:postgresql")) {
+      new PostgresQuirks()
+    } else {
+      new NoQuirks()
+    }
+  }
+}
+
+private[sql] class NoQuirks extends DriverQuirks {
+  def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType =
+    null
+  def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None)
+}
+
+private[sql] class PostgresQuirks extends DriverQuirks {
+  def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = {
+    if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
+      BinaryType
+    } else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
+      StringType
+    } else if (sqlType == Types.OTHER && typeName.equals("inet")) {
+      StringType
+    } else null
+  }
+
+  def getJDBCType(dt: DataType): (String, Option[Int]) = dt match {
+    case StringType => ("TEXT", Some(java.sql.Types.CHAR))
+    case BinaryType => ("BYTEA", Some(java.sql.Types.BINARY))
+    case BooleanType => ("BOOLEAN", Some(java.sql.Types.BOOLEAN))
+    case _ => (null, None)
+  }
+}
+
+private[sql] class MySQLQuirks extends DriverQuirks {
+  def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = {
+    if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
+      // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as
+      // byte arrays instead of longs.
+      md.putLong("binarylong", 1)
+      LongType
+    } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) {
+      BooleanType
+    } else null
+  }
+  def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8f471a66/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
new file mode 100644
index 0000000..a2f9467
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -0,0 +1,417 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.sql.{Connection, DatabaseMetaData, DriverManager, ResultSet, ResultSetMetaData, SQLException}
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.NextIterator
+import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
+import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.sources._
+
+private[sql] object JDBCRDD extends Logging {
+  /**
+   * Maps a JDBC type to a Catalyst type.  This function is called only when
+   * the DriverQuirks class corresponding to your database driver returns null.
+   *
+   * @param sqlType - A field of java.sql.Types
+   * @return The Catalyst type corresponding to sqlType.
+   */
+  private def getCatalystType(sqlType: Int): DataType = {
+    val answer = sqlType match {
+      case java.sql.Types.ARRAY         => null
+      case java.sql.Types.BIGINT        => LongType
+      case java.sql.Types.BINARY        => BinaryType
+      case java.sql.Types.BIT           => BooleanType // Per JDBC; Quirks handles quirky drivers.
+      case java.sql.Types.BLOB          => BinaryType
+      case java.sql.Types.BOOLEAN       => BooleanType
+      case java.sql.Types.CHAR          => StringType
+      case java.sql.Types.CLOB          => StringType
+      case java.sql.Types.DATALINK      => null
+      case java.sql.Types.DATE          => DateType
+      case java.sql.Types.DECIMAL       => DecimalType.Unlimited
+      case java.sql.Types.DISTINCT      => null
+      case java.sql.Types.DOUBLE        => DoubleType
+      case java.sql.Types.FLOAT         => FloatType
+      case java.sql.Types.INTEGER       => IntegerType
+      case java.sql.Types.JAVA_OBJECT   => null
+      case java.sql.Types.LONGNVARCHAR  => StringType
+      case java.sql.Types.LONGVARBINARY => BinaryType
+      case java.sql.Types.LONGVARCHAR   => StringType
+      case java.sql.Types.NCHAR         => StringType
+      case java.sql.Types.NCLOB         => StringType
+      case java.sql.Types.NULL          => null
+      case java.sql.Types.NUMERIC       => DecimalType.Unlimited
+      case java.sql.Types.OTHER         => null
+      case java.sql.Types.REAL          => DoubleType
+      case java.sql.Types.REF           => StringType
+      case java.sql.Types.ROWID         => LongType
+      case java.sql.Types.SMALLINT      => IntegerType
+      case java.sql.Types.SQLXML        => StringType
+      case java.sql.Types.STRUCT        => StringType
+      case java.sql.Types.TIME          => TimestampType
+      case java.sql.Types.TIMESTAMP     => TimestampType
+      case java.sql.Types.TINYINT       => IntegerType
+      case java.sql.Types.VARBINARY     => BinaryType
+      case java.sql.Types.VARCHAR       => StringType
+      case _ => null
+    }
+
+    if (answer == null) throw new SQLException("Unsupported type " + sqlType)
+    answer
+  }
+
+  /**
+   * Takes a (schema, table) specification and returns the table's Catalyst
+   * schema.
+   *
+   * @param url - The JDBC url to fetch information from.
+   * @param table - The table name of the desired table.  This may also be a
+   *   SQL query wrapped in parentheses.
+   *
+   * @return A StructType giving the table's Catalyst schema.
+   * @throws SQLException if the table specification is garbage.
+   * @throws SQLException if the table contains an unsupported type.
+   */
+  def resolveTable(url: String, table: String): StructType = {
+    val quirks = DriverQuirks.get(url)
+    val conn: Connection = DriverManager.getConnection(url)
+    try {
+      val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery()
+      try {
+        val rsmd = rs.getMetaData
+        val ncols = rsmd.getColumnCount
+        var fields = new Array[StructField](ncols);
+        var i = 0
+        while (i < ncols) {
+          val columnName = rsmd.getColumnName(i + 1)
+          val dataType = rsmd.getColumnType(i + 1)
+          val typeName = rsmd.getColumnTypeName(i + 1)
+          val fieldSize = rsmd.getPrecision(i + 1)
+          val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
+          val metadata = new MetadataBuilder().putString("name", columnName)
+          var columnType = quirks.getCatalystType(dataType, typeName, fieldSize, metadata)
+          if (columnType == null) columnType = getCatalystType(dataType)
+          fields(i) = StructField(columnName, columnType, nullable, metadata.build())
+          i = i + 1
+        }
+        return new StructType(fields)
+      } finally {
+        rs.close()
+      }
+    } finally {
+      conn.close()
+    }
+
+    throw new RuntimeException("This line is unreachable.")
+  }
+
+  /**
+   * Prune all but the specified columns from the specified Catalyst schema.
+   *
+   * @param schema - The Catalyst schema of the master table
+   * @param columns - The list of desired columns
+   *
+   * @return A Catalyst schema corresponding to columns in the given order.
+   */
+  private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
+    val fieldMap = Map(schema.fields map { x => x.metadata.getString("name") -> x }: _*)
+    new StructType(columns map { name => fieldMap(name) })
+  }
+
+  /**
+   * Given a driver string and an url, return a function that loads the
+   * specified driver string then returns a connection to the JDBC url.
+   * getConnector is run on the driver code, while the function it returns
+   * is run on the executor.
+   *
+   * @param driver - The class name of the JDBC driver for the given url.
+   * @param url - The JDBC url to connect to.
+   *
+   * @return A function that loads the driver and connects to the url.
+   */
+  def getConnector(driver: String, url: String): () => Connection = {
+    () => {
+      try {
+        if (driver != null) Class.forName(driver)
+      } catch {
+        case e: ClassNotFoundException => {
+          logWarning(s"Couldn't find class $driver", e);
+        }
+      }
+      DriverManager.getConnection(url)
+    }
+  }
+  /**
+   * Build and return JDBCRDD from the given information.
+   *
+   * @param sc - Your SparkContext.
+   * @param schema - The Catalyst schema of the underlying database table.
+   * @param driver - The class name of the JDBC driver for the given url.
+   * @param url - The JDBC url to connect to.
+   * @param fqTable - The fully-qualified table name (or paren'd SQL query) to use.
+   * @param requiredColumns - The names of the columns to SELECT.
+   * @param filters - The filters to include in all WHERE clauses.
+   * @param parts - An array of JDBCPartitions specifying partition ids and
+   *    per-partition WHERE clauses.
+   *
+   * @return An RDD representing "SELECT requiredColumns FROM fqTable".
+   */
+  def scanTable(sc: SparkContext,
+                schema: StructType,
+                driver: String,
+                url: String,
+                fqTable: String,
+                requiredColumns: Array[String],
+                filters: Array[Filter],
+                parts: Array[Partition]): RDD[Row] = {
+    val prunedSchema = pruneSchema(schema, requiredColumns)
+
+    return new JDBCRDD(sc,
+        getConnector(driver, url),
+        prunedSchema,
+        fqTable,
+        requiredColumns,
+        filters,
+        parts)
+  }
+}
+
+/**
+ * An RDD representing a table in a database accessed via JDBC.  Both the
+ * driver code and the workers must be able to access the database; the driver
+ * needs to fetch the schema while the workers need to fetch the data.
+ */
+private[sql] class JDBCRDD(
+    sc: SparkContext,
+    getConnection: () => Connection,
+    schema: StructType,
+    fqTable: String,
+    columns: Array[String],
+    filters: Array[Filter],
+    partitions: Array[Partition])
+  extends RDD[Row](sc, Nil) {
+
+  /**
+   * Retrieve the list of partitions corresponding to this RDD.
+   */
+  override def getPartitions: Array[Partition] = partitions
+
+  /**
+   * `columns`, but as a String suitable for injection into a SQL query.
+   */
+  private val columnList: String = {
+    val sb = new StringBuilder()
+    columns.foreach(x => sb.append(",").append(x))
+    if (sb.length == 0) "1" else sb.substring(1)
+  }
+
+  /**
+   * Turns a single Filter into a String representing a SQL expression.
+   * Returns null for an unhandled filter.
+   */
+  private def compileFilter(f: Filter): String = f match {
+    case EqualTo(attr, value) => s"$attr = $value"
+    case LessThan(attr, value) => s"$attr < $value"
+    case GreaterThan(attr, value) => s"$attr > $value"
+    case LessThanOrEqual(attr, value) => s"$attr <= $value"
+    case GreaterThanOrEqual(attr, value) => s"$attr >= $value"
+    case _ => null
+  }
+
+  /**
+   * `filters`, but as a WHERE clause suitable for injection into a SQL query.
+   */
+  private val filterWhereClause: String = {
+    val filterStrings = filters map compileFilter filter (_ != null)
+    if (filterStrings.size > 0) {
+      val sb = new StringBuilder("WHERE ")
+      filterStrings.foreach(x => sb.append(x).append(" AND "))
+      sb.substring(0, sb.length - 5)
+    } else ""
+  }
+
+  /**
+   * A WHERE clause representing both `filters`, if any, and the current partition.
+   */
+  private def getWhereClause(part: JDBCPartition): String = {
+    if (part.whereClause != null && filterWhereClause.length > 0) {
+      filterWhereClause + " AND " + part.whereClause
+    } else if (part.whereClause != null) {
+      "WHERE " + part.whereClause
+    } else {
+      filterWhereClause
+    }
+  }
+
+  // Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that
+  // we don't have to potentially poke around in the Metadata once for every
+  // row.  
+  // Is there a better way to do this?  I'd rather be using a type that
+  // contains only the tags I define.
+  abstract class JDBCConversion
+  case object BooleanConversion extends JDBCConversion
+  case object DateConversion extends JDBCConversion
+  case object DecimalConversion extends JDBCConversion
+  case object DoubleConversion extends JDBCConversion
+  case object FloatConversion extends JDBCConversion
+  case object IntegerConversion extends JDBCConversion
+  case object LongConversion extends JDBCConversion
+  case object BinaryLongConversion extends JDBCConversion
+  case object StringConversion extends JDBCConversion
+  case object TimestampConversion extends JDBCConversion
+  case object BinaryConversion extends JDBCConversion
+
+  /**
+   * Maps a StructType to a type tag list.
+   */
+  def getConversions(schema: StructType): Array[JDBCConversion] = {
+    schema.fields.map(sf => sf.dataType match {
+      case BooleanType           => BooleanConversion
+      case DateType              => DateConversion
+      case DecimalType.Unlimited => DecimalConversion
+      case DoubleType            => DoubleConversion
+      case FloatType             => FloatConversion
+      case IntegerType           => IntegerConversion
+      case LongType              =>
+        if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion
+      case StringType            => StringConversion
+      case TimestampType         => TimestampConversion
+      case BinaryType            => BinaryConversion
+      case _                     => throw new IllegalArgumentException(s"Unsupported field $sf")
+    }).toArray
+  }
+
+
+  /**
+   * Runs the SQL query against the JDBC driver.
+   */
+  override def compute(thePart: Partition, context: TaskContext) = new Iterator[Row] {
+    var closed = false
+    var finished = false
+    var gotNext = false
+    var nextValue: Row = null
+
+    context.addTaskCompletionListener{ context => close() }
+    val part = thePart.asInstanceOf[JDBCPartition]
+    val conn = getConnection()
+
+    // H2's JDBC driver does not support the setSchema() method.  We pass a
+    // fully-qualified table name in the SELECT statement.  I don't know how to
+    // talk about a table in a completely portable way.
+
+    val myWhereClause = getWhereClause(part)
+
+    val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause"
+    val stmt = conn.prepareStatement(sqlText,
+        ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
+    val rs = stmt.executeQuery()
+
+    val conversions = getConversions(schema)
+    val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
+
+    def getNext(): Row = {
+      if (rs.next()) {
+        var i = 0
+        while (i < conversions.length) {
+          val pos = i + 1
+          conversions(i) match {
+            case BooleanConversion    => mutableRow.setBoolean(i, rs.getBoolean(pos))
+            case DateConversion       => mutableRow.update(i, rs.getDate(pos))
+            case DecimalConversion    => mutableRow.update(i, rs.getBigDecimal(pos))
+            case DoubleConversion     => mutableRow.setDouble(i, rs.getDouble(pos))
+            case FloatConversion      => mutableRow.setFloat(i, rs.getFloat(pos))
+            case IntegerConversion    => mutableRow.setInt(i, rs.getInt(pos))
+            case LongConversion       => mutableRow.setLong(i, rs.getLong(pos))
+            case StringConversion     => mutableRow.setString(i, rs.getString(pos))
+            case TimestampConversion  => mutableRow.update(i, rs.getTimestamp(pos))
+            case BinaryConversion     => mutableRow.update(i, rs.getBytes(pos))
+            case BinaryLongConversion => {
+              val bytes = rs.getBytes(pos)
+              var ans = 0L
+              var j = 0
+              while (j < bytes.size) {
+                ans = 256*ans + (255 & bytes(j))
+                j = j + 1;
+              }
+              mutableRow.setLong(i, ans)
+            }
+          }
+          if (rs.wasNull) mutableRow.setNullAt(i)
+          i = i + 1
+        }
+        mutableRow
+      } else {
+        finished = true
+        null.asInstanceOf[Row]
+      }
+    }
+
+    def close() {
+      if (closed) return
+      try {
+        if (null != rs && ! rs.isClosed()) {
+          rs.close()
+        }
+      } catch {
+        case e: Exception => logWarning("Exception closing resultset", e)
+      }
+      try {
+        if (null != stmt && ! stmt.isClosed()) {
+          stmt.close()
+        }
+      } catch {
+        case e: Exception => logWarning("Exception closing statement", e)
+      }
+      try {
+        if (null != conn && ! conn.isClosed()) {
+          conn.close()
+        }
+        logInfo("closed connection")
+      } catch {
+        case e: Exception => logWarning("Exception closing connection", e)
+      }
+    }
+
+    override def hasNext: Boolean = {
+      if (!finished) {
+        if (!gotNext) {
+          nextValue = getNext()
+          if (finished) {
+            close()
+          }
+          gotNext = true
+        }
+      }
+      !finished
+    }
+
+    override def next(): Row = {
+      if (!hasNext) {
+        throw new NoSuchElementException("End of stream")
+      }
+      gotNext = false
+      nextValue
+    }
+
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8f471a66/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
new file mode 100644
index 0000000..e09125e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import scala.collection.mutable.ArrayBuffer
+import java.sql.DriverManager
+
+import org.apache.spark.Partition
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.sources._
+
+/**
+ * Data corresponding to one partition of a JDBCRDD.
+ */
+private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Partition {
+  override def index: Int = idx
+}
+
+/**
+ * Instructions on how to partition the table among workers.
+ */
+private[sql] case class JDBCPartitioningInfo(
+    column: String,
+    lowerBound: Long,
+    upperBound: Long,
+    numPartitions: Int)
+
+private[sql] object JDBCRelation {
+  /**
+   * Given a partitioning schematic (a column of integral type, a number of
+   * partitions, and upper and lower bounds on the column's value), generate
+   * WHERE clauses for each partition so that each row in the table appears
+   * exactly once.  The parameters minValue and maxValue are advisory in that
+   * incorrect values may cause the partitioning to be poor, but no data
+   * will fail to be represented.
+   *
+   * @param column - Column name.  Must refer to a column of integral type.
+   * @param numPartitions - Number of partitions
+   * @param minValue - Smallest value of column.  Advisory.
+   * @param maxValue - Largest value of column.  Advisory.
+   */
+  def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
+    if (partitioning == null) return Array[Partition](JDBCPartition(null, 0))
+
+    val numPartitions = partitioning.numPartitions
+    val column = partitioning.column
+    if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0))
+    // Overflow and silliness can happen if you subtract then divide.
+    // Here we get a little roundoff, but that's (hopefully) OK.
+    val stride: Long = (partitioning.upperBound / numPartitions 
+                      - partitioning.lowerBound / numPartitions)
+    var i: Int = 0
+    var currentValue: Long = partitioning.lowerBound
+    var ans = new ArrayBuffer[Partition]()
+    while (i < numPartitions) {
+      val lowerBound = (if (i != 0) s"$column >= $currentValue" else null)
+      currentValue += stride
+      val upperBound = (if (i != numPartitions - 1) s"$column < $currentValue" else null)
+      val whereClause = (if (upperBound == null) lowerBound
+                    else if (lowerBound == null) upperBound
+                    else s"$lowerBound AND $upperBound")
+      ans += JDBCPartition(whereClause, i)
+      i = i + 1
+    }
+    ans.toArray
+  }
+}
+
+private[sql] class DefaultSource extends RelationProvider {
+  /** Returns a new base relation with the given parameters. */
+  override def createRelation(
+      sqlContext: SQLContext,
+      parameters: Map[String, String]): BaseRelation = {
+    val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
+    val driver = parameters.getOrElse("driver", null)
+    val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
+    val partitionColumn = parameters.getOrElse("partitionColumn", null)
+    val lowerBound = parameters.getOrElse("lowerBound", null)
+    val upperBound = parameters.getOrElse("upperBound", null)
+    val numPartitions = parameters.getOrElse("numPartitions", null)
+
+    if (driver != null) Class.forName(driver)
+
+    if (   partitionColumn != null
+        && (lowerBound == null || upperBound == null || numPartitions == null)) {
+      sys.error("Partitioning incompletely specified")
+    }
+
+    val partitionInfo = if (partitionColumn == null) {
+      null
+    } else {
+      JDBCPartitioningInfo(partitionColumn,
+                           lowerBound.toLong, upperBound.toLong,
+                           numPartitions.toInt)
+    }
+    val parts = JDBCRelation.columnPartition(partitionInfo)
+    JDBCRelation(url, table, parts)(sqlContext)
+  }
+}
+
+private[sql] case class JDBCRelation(url: String,
+                                     table: String,
+                                     parts: Array[Partition])(
+    @transient val sqlContext: SQLContext)
+  extends PrunedFilteredScan {
+
+  override val schema = JDBCRDD.resolveTable(url, table)
+
+  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = {
+    val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName
+    JDBCRDD.scanTable(sqlContext.sparkContext,
+                      schema,
+                      driver, url,
+                      table,
+                      requiredColumns, filters,
+                      parts)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8f471a66/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JavaJDBCTrampoline.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JavaJDBCTrampoline.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JavaJDBCTrampoline.scala
new file mode 100644
index 0000000..86bb67e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JavaJDBCTrampoline.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import org.apache.spark.sql.DataFrame
+
+private[jdbc] class JavaJDBCTrampoline {
+  def createJDBCTable(rdd: DataFrame, url: String, table: String, allowExisting: Boolean) {
+    rdd.createJDBCTable(url, table, allowExisting);
+  }
+
+  def insertIntoJDBC(rdd: DataFrame, url: String, table: String, overwrite: Boolean) {
+    rdd.insertIntoJDBC(url, table, overwrite);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8f471a66/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
new file mode 100644
index 0000000..34a83f0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.sql.{Connection, DriverManager, PreparedStatement}
+import org.apache.spark.{Logging, Partition}
+import org.apache.spark.sql._
+import org.apache.spark.sql.sources.LogicalRelation
+
+import org.apache.spark.sql.jdbc.{JDBCPartitioningInfo, JDBCRelation, JDBCPartition}
+import org.apache.spark.sql.types._
+
+package object jdbc {
+  object JDBCWriteDetails extends Logging {
+    /**
+     * Returns a PreparedStatement that inserts a row into table via conn.
+     */
+    private def insertStatement(conn: Connection, table: String, rddSchema: StructType):
+        PreparedStatement = {
+      val sql = new StringBuilder(s"INSERT INTO $table VALUES (")
+      var fieldsLeft = rddSchema.fields.length
+      while (fieldsLeft > 0) {
+        sql.append("?")
+        if (fieldsLeft > 1) sql.append(", ") else sql.append(")")
+        fieldsLeft = fieldsLeft - 1
+      }
+      conn.prepareStatement(sql.toString)
+    }
+
+    /**
+     * Saves a partition of a DataFrame to the JDBC database.  This is done in
+     * a single database transaction in order to avoid repeatedly inserting
+     * data as much as possible.
+     *
+     * It is still theoretically possible for rows in a DataFrame to be
+     * inserted into the database more than once if a stage somehow fails after
+     * the commit occurs but before the stage can return successfully.
+     *
+     * This is not a closure inside saveTable() because apparently cosmetic
+     * implementation changes elsewhere might easily render such a closure
+     * non-Serializable.  Instead, we explicitly close over all variables that
+     * are used.
+     */
+    private[jdbc] def savePartition(url: String, table: String, iterator: Iterator[Row],
+        rddSchema: StructType, nullTypes: Array[Int]): Iterator[Byte] = {
+      val conn = DriverManager.getConnection(url)
+      var committed = false
+      try {
+        conn.setAutoCommit(false) // Everything in the same db transaction.
+        val stmt = insertStatement(conn, table, rddSchema)
+        try {
+          while (iterator.hasNext) {
+            val row = iterator.next()
+            val numFields = rddSchema.fields.length
+            var i = 0
+            while (i < numFields) {
+              if (row.isNullAt(i)) {
+                stmt.setNull(i + 1, nullTypes(i))
+              } else {
+                rddSchema.fields(i).dataType match {
+                  case IntegerType => stmt.setInt(i + 1, row.getInt(i))
+                  case LongType => stmt.setLong(i + 1, row.getLong(i))
+                  case DoubleType => stmt.setDouble(i + 1, row.getDouble(i))
+                  case FloatType => stmt.setFloat(i + 1, row.getFloat(i))
+                  case ShortType => stmt.setInt(i + 1, row.getShort(i))
+                  case ByteType => stmt.setInt(i + 1, row.getByte(i))
+                  case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i))
+                  case StringType => stmt.setString(i + 1, row.getString(i))
+                  case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i))
+                  case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
+                  case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
+                  case DecimalType.Unlimited => stmt.setBigDecimal(i + 1,
+                      row.getAs[java.math.BigDecimal](i))
+                  case _ => throw new IllegalArgumentException(
+                      s"Can't translate non-null value for field $i")
+                }
+              }
+              i = i + 1
+            }
+            stmt.executeUpdate()
+          }
+        } finally {
+          stmt.close()
+        }
+        conn.commit()
+        committed = true
+      } finally {
+        if (!committed) {
+          // The stage must fail.  We got here through an exception path, so
+          // let the exception through unless rollback() or close() want to
+          // tell the user about another problem.
+          conn.rollback()
+          conn.close()
+        } else {
+          // The stage must succeed.  We cannot propagate any exception close() might throw.
+          try {
+            conn.close()
+          } catch {
+            case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
+          }
+        }
+      }
+      Array[Byte]().iterator
+    }
+  }
+
+  /**
+   * Make it so that you can call createJDBCTable and insertIntoJDBC on a DataFrame.
+   */
+  implicit class JDBCDataFrame(rdd: DataFrame) {
+    /**
+     * Compute the schema string for this RDD.
+     */
+    private def schemaString(url: String): String = {
+      val sb = new StringBuilder()
+      val quirks = DriverQuirks.get(url)
+      rdd.schema.fields foreach { field => {
+        val name = field.name
+        var typ: String = quirks.getJDBCType(field.dataType)._1
+        if (typ == null) typ = field.dataType match {
+          case IntegerType => "INTEGER"
+          case LongType => "BIGINT"
+          case DoubleType => "DOUBLE PRECISION"
+          case FloatType => "REAL"
+          case ShortType => "INTEGER"
+          case ByteType => "BYTE"
+          case BooleanType => "BIT(1)"
+          case StringType => "TEXT"
+          case BinaryType => "BLOB"
+          case TimestampType => "TIMESTAMP"
+          case DateType => "DATE"
+          case DecimalType.Unlimited => "DECIMAL(40,20)"
+          case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
+        }
+        val nullable = if (field.nullable) "" else "NOT NULL"
+        sb.append(s", $name $typ $nullable")
+      }}
+      if (sb.length < 2) "" else sb.substring(2)
+    }
+
+    /**
+     * Saves the RDD to the database in a single transaction.
+     */
+    private def saveTable(url: String, table: String) {
+      val quirks = DriverQuirks.get(url)
+      var nullTypes: Array[Int] = rdd.schema.fields.map(field => {
+        var nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2
+        if (nullType.isEmpty) {
+          field.dataType match {
+            case IntegerType => java.sql.Types.INTEGER
+            case LongType => java.sql.Types.BIGINT
+            case DoubleType => java.sql.Types.DOUBLE
+            case FloatType => java.sql.Types.REAL
+            case ShortType => java.sql.Types.INTEGER
+            case ByteType => java.sql.Types.INTEGER
+            case BooleanType => java.sql.Types.BIT
+            case StringType => java.sql.Types.CLOB
+            case BinaryType => java.sql.Types.BLOB
+            case TimestampType => java.sql.Types.TIMESTAMP
+            case DateType => java.sql.Types.DATE
+            case DecimalType.Unlimited => java.sql.Types.DECIMAL
+            case _ => throw new IllegalArgumentException(
+                s"Can't translate null value for field $field")
+          }
+        } else nullType.get
+      }).toArray
+
+      val rddSchema = rdd.schema
+      rdd.mapPartitions(iterator => JDBCWriteDetails.savePartition(
+          url, table, iterator, rddSchema, nullTypes)).collect()
+    }
+
+    /**
+     * Save this RDD to a JDBC database at `url` under the table name `table`.
+     * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements.
+     * If you pass `true` for `allowExisting`, it will drop any table with the
+     * given name; if you pass `false`, it will throw if the table already
+     * exists.
+     */
+    def createJDBCTable(url: String, table: String, allowExisting: Boolean) {
+      val conn = DriverManager.getConnection(url)
+      try {
+        if (allowExisting) {
+          val sql = s"DROP TABLE IF EXISTS $table"
+          conn.prepareStatement(sql).executeUpdate()
+        }
+        val schema = schemaString(url)
+        val sql = s"CREATE TABLE $table ($schema)"
+        conn.prepareStatement(sql).executeUpdate()
+      } finally {
+        conn.close()
+      }
+      saveTable(url, table)
+    }
+
+    /**
+     * Save this RDD to a JDBC database at `url` under the table name `table`.
+     * Assumes the table already exists and has a compatible schema.  If you
+     * pass `true` for `overwrite`, it will `TRUNCATE` the table before
+     * performing the `INSERT`s.
+     *
+     * The table must already exist on the database.  It must have a schema
+     * that is compatible with the schema of this RDD; inserting the rows of
+     * the RDD in order via the simple statement
+     * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail.
+     */
+    def insertIntoJDBC(url: String, table: String, overwrite: Boolean) {
+      if (overwrite) {
+        val conn = DriverManager.getConnection(url)
+        try {
+          val sql = s"TRUNCATE TABLE $table"
+          conn.prepareStatement(sql).executeUpdate()
+        } finally {
+          conn.close()
+        }
+      }
+      saveTable(url, table)
+    }
+  } // implicit class JDBCDataFrame
+} // package object jdbc

http://git-wip-us.apache.org/repos/asf/spark/blob/8f471a66/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java b/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java
new file mode 100644
index 0000000..80bd74f
--- /dev/null
+++ b/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc;
+
+import org.junit.*;
+import static org.junit.Assert.*;
+import java.sql.Connection;
+import java.sql.DriverManager;
+
+import org.apache.spark.SparkEnv;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.api.java.*;
+import org.apache.spark.sql.test.TestSQLContext$;
+
+public class JavaJDBCTest {
+  static String url = "jdbc:h2:mem:testdb1";
+
+  static Connection conn = null;
+
+  // This variable will always be null if TestSQLContext is intact when running
+  // these tests.  Some Java tests do not play nicely with others, however;
+  // they create a SparkContext of their own at startup and stop it at exit.
+  // This renders TestSQLContext inoperable, meaning we have to do the same
+  // thing.  If this variable is nonnull, that means we allocated a
+  // SparkContext of our own and that we need to stop it at teardown.
+  static JavaSparkContext localSparkContext = null;
+
+  static SQLContext sql = TestSQLContext$.MODULE$;
+
+  @Before
+  public void beforeTest() throws Exception {
+    if (SparkEnv.get() == null) { // A previous test destroyed TestSQLContext.
+      localSparkContext = new JavaSparkContext("local", "JavaAPISuite");
+      sql = new SQLContext(localSparkContext);
+    }
+    Class.forName("org.h2.Driver");
+    conn = DriverManager.getConnection(url);
+    conn.prepareStatement("create schema test").executeUpdate();
+    conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate();
+    conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate();
+    conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate();
+    conn.prepareStatement("insert into test.people values ('joe', 3)").executeUpdate();
+    conn.commit();
+  }
+
+  @After
+  public void afterTest() throws Exception {
+    if (localSparkContext != null) {
+      localSparkContext.stop();
+      localSparkContext = null;
+    }
+    try {
+      conn.close();
+    } finally {
+      conn = null;
+    }
+  }
+
+  @Test
+  public void basicTest() throws Exception {
+    DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE");
+    Row[] rows = rdd.collect();
+    assertEquals(rows.length, 3);
+  }
+
+  @Test
+  public void partitioningTest() throws Exception {
+    String[] parts = new String[2];
+    parts[0] = "THEID < 2";
+    parts[1] = "THEID = 2"; // Deliberately forget about one of them.
+    DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE", parts);
+    Row[] rows = rdd.collect();
+    assertEquals(rows.length, 2);
+  }
+
+  @Test
+  public void writeTest() throws Exception {
+    DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE");
+    JDBCUtils.createJDBCTable(rdd, url, "TEST.PEOPLECOPY", false);
+    DataFrame rdd2 = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLECOPY");
+    Row[] rows = rdd2.collect();
+    assertEquals(rows.length, 3);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8f471a66/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala
new file mode 100644
index 0000000..f332cb3
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import scala.collection.mutable.MutableList
+
+import com.spotify.docker.client._
+
+/**
+ * A factory and morgue for DockerClient objects.  In the DockerClient we use,
+ * calling close() closes the desired DockerClient but also renders all other
+ * DockerClients inoperable.  This is inconvenient if we have more than one
+ * open, such as during tests.
+ */
+object DockerClientFactory {
+  var numClients: Int = 0
+  val zombies = new MutableList[DockerClient]()
+
+  def get(): DockerClient = {
+    this.synchronized {
+      numClients = numClients + 1
+      DefaultDockerClient.fromEnv.build()
+    }
+  }
+
+  def close(dc: DockerClient) {
+    this.synchronized {
+      numClients = numClients - 1
+      zombies += dc
+      if (numClients == 0) {
+        zombies.foreach(_.close())
+        zombies.clear()
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8f471a66/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
new file mode 100644
index 0000000..d25c139
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.math.BigDecimal
+import org.apache.spark.sql.test._
+import org.scalatest.{FunSuite, BeforeAndAfter}
+import java.sql.DriverManager
+import TestSQLContext._
+
+class JDBCSuite extends FunSuite with BeforeAndAfter {
+  val url = "jdbc:h2:mem:testdb0"
+  var conn: java.sql.Connection = null
+
+  val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte)
+
+  before {
+    Class.forName("org.h2.Driver")
+    conn = DriverManager.getConnection(url)
+    conn.prepareStatement("create schema test").executeUpdate()
+    conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
+    conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate()
+    conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate()
+    conn.prepareStatement("insert into test.people values ('joe', 3)").executeUpdate()
+    conn.commit()
+
+    sql(
+      s"""
+        |CREATE TEMPORARY TABLE foobar
+        |USING org.apache.spark.sql.jdbc
+        |OPTIONS (url '$url', dbtable 'TEST.PEOPLE')
+      """.stripMargin.replaceAll("\n", " "))
+
+    sql(
+      s"""
+        |CREATE TEMPORARY TABLE parts
+        |USING org.apache.spark.sql.jdbc
+        |OPTIONS (url '$url', dbtable 'TEST.PEOPLE',
+        |partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3')
+      """.stripMargin.replaceAll("\n", " "))
+
+    conn.prepareStatement("create table test.inttypes (a INT, b BOOLEAN, c TINYINT, "
+      + "d SMALLINT, e BIGINT)").executeUpdate()
+    conn.prepareStatement("insert into test.inttypes values (1, false, 3, 4, 1234567890123)"
+        ).executeUpdate()
+    conn.prepareStatement("insert into test.inttypes values (null, null, null, null, null)"
+        ).executeUpdate()
+    conn.commit()
+    sql(
+      s"""
+        |CREATE TEMPORARY TABLE inttypes
+        |USING org.apache.spark.sql.jdbc
+        |OPTIONS (url '$url', dbtable 'TEST.INTTYPES')
+      """.stripMargin.replaceAll("\n", " "))
+
+    conn.prepareStatement("create table test.strtypes (a BINARY(20), b VARCHAR(20), "
+      + "c VARCHAR_IGNORECASE(20), d CHAR(20), e BLOB, f CLOB)").executeUpdate()
+    var stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)")
+    stmt.setBytes(1, testBytes)
+    stmt.setString(2, "Sensitive")
+    stmt.setString(3, "Insensitive")
+    stmt.setString(4, "Twenty-byte CHAR")
+    stmt.setBytes(5, testBytes)
+    stmt.setString(6, "I am a clob!")
+    stmt.executeUpdate()
+    sql(
+      s"""
+        |CREATE TEMPORARY TABLE strtypes
+        |USING org.apache.spark.sql.jdbc
+        |OPTIONS (url '$url', dbtable 'TEST.STRTYPES')
+      """.stripMargin.replaceAll("\n", " "))
+
+    conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP)"
+        ).executeUpdate()
+    conn.prepareStatement("insert into test.timetypes values ('12:34:56', "
+      + "'1996-01-01', '2002-02-20 11:22:33.543543543')").executeUpdate()
+    conn.commit()
+    sql(
+      s"""
+        |CREATE TEMPORARY TABLE timetypes
+        |USING org.apache.spark.sql.jdbc
+        |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES')
+      """.stripMargin.replaceAll("\n", " "))
+
+
+    conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(40, 20))"
+        ).executeUpdate()
+    conn.prepareStatement("insert into test.flttypes values ("
+      + "1.0000000000000002220446049250313080847263336181640625, "
+      + "1.00000011920928955078125, "
+      + "123456789012345.543215432154321)").executeUpdate()
+    conn.commit()
+    sql(
+      s"""
+        |CREATE TEMPORARY TABLE flttypes
+        |USING org.apache.spark.sql.jdbc
+        |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES')
+      """.stripMargin.replaceAll("\n", " "))
+
+    // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types.
+  }
+
+  after {
+    conn.close()
+  }
+
+  test("SELECT *") {
+    assert(sql("SELECT * FROM foobar").collect().size == 3)
+  }
+
+  test("SELECT * WHERE (simple predicates)") {
+    assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size == 0)
+    assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size == 2)
+    assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size == 1)
+  }
+
+  test("SELECT first field") {
+    val names = sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _)
+    assert(names.size == 3)
+    assert(names(0).equals("fred"))
+    assert(names(1).equals("joe"))
+    assert(names(2).equals("mary"))
+  }
+
+  test("SELECT second field") {
+    val ids = sql("SELECT THEID FROM foobar").collect().map(x => x.getInt(0)).sortWith(_ < _)
+    assert(ids.size == 3)
+    assert(ids(0) == 1)
+    assert(ids(1) == 2)
+    assert(ids(2) == 3)
+  }
+
+  test("SELECT * partitioned") {
+    assert(sql("SELECT * FROM parts").collect().size == 3)
+  }
+
+  test("SELECT WHERE (simple predicates) partitioned") {
+    assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size == 0)
+    assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size == 2)
+    assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size == 1)
+  }
+
+  test("SELECT second field partitioned") {
+    val ids = sql("SELECT THEID FROM parts").collect().map(x => x.getInt(0)).sortWith(_ < _)
+    assert(ids.size == 3)
+    assert(ids(0) == 1)
+    assert(ids(1) == 2)
+    assert(ids(2) == 3)
+  }
+
+  test("Basic API") {
+    assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE").collect.size == 3)
+  }
+
+  test("Partitioning via JDBCPartitioningInfo API") {
+    val parts = JDBCPartitioningInfo("THEID", 0, 4, 3)
+    assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE", parts).collect.size == 3)
+  }
+
+  test("Partitioning via list-of-where-clauses API") {
+    val parts = Array[String]("THEID < 2", "THEID >= 2")
+    assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE", parts).collect.size == 3)
+  }
+
+  test("H2 integral types") {
+    val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect()
+    assert(rows.size == 1)
+    assert(rows(0).getInt(0) == 1)
+    assert(rows(0).getBoolean(1) == false)
+    assert(rows(0).getInt(2) == 3)
+    assert(rows(0).getInt(3) == 4)
+    assert(rows(0).getLong(4) == 1234567890123L)
+  }
+
+  test("H2 null entries") {
+    val rows = sql("SELECT * FROM inttypes WHERE A IS NULL").collect()
+    assert(rows.size == 1)
+    assert(rows(0).isNullAt(0))
+    assert(rows(0).isNullAt(1))
+    assert(rows(0).isNullAt(2))
+    assert(rows(0).isNullAt(3))
+    assert(rows(0).isNullAt(4))
+  }
+
+  test("H2 string types") {
+    val rows = sql("SELECT * FROM strtypes").collect()
+    assert(rows(0).getAs[Array[Byte]](0).sameElements(testBytes))
+    assert(rows(0).getString(1).equals("Sensitive"))
+    assert(rows(0).getString(2).equals("Insensitive"))
+    assert(rows(0).getString(3).equals("Twenty-byte CHAR"))
+    assert(rows(0).getAs[Array[Byte]](4).sameElements(testBytes))
+    assert(rows(0).getString(5).equals("I am a clob!"))
+  }
+
+  test("H2 time types") {
+    val rows = sql("SELECT * FROM timetypes").collect()
+    assert(rows(0).getAs[java.sql.Timestamp](0).getHours == 12)
+    assert(rows(0).getAs[java.sql.Timestamp](0).getMinutes == 34)
+    assert(rows(0).getAs[java.sql.Timestamp](0).getSeconds == 56)
+    assert(rows(0).getAs[java.sql.Date](1).getYear == 96)
+    assert(rows(0).getAs[java.sql.Date](1).getMonth == 0)
+    assert(rows(0).getAs[java.sql.Date](1).getDate == 1)
+    assert(rows(0).getAs[java.sql.Timestamp](2).getYear == 102)
+    assert(rows(0).getAs[java.sql.Timestamp](2).getMonth == 1)
+    assert(rows(0).getAs[java.sql.Timestamp](2).getDate == 20)
+    assert(rows(0).getAs[java.sql.Timestamp](2).getHours == 11)
+    assert(rows(0).getAs[java.sql.Timestamp](2).getMinutes == 22)
+    assert(rows(0).getAs[java.sql.Timestamp](2).getSeconds == 33)
+    assert(rows(0).getAs[java.sql.Timestamp](2).getNanos == 543543543)
+  }
+
+  test("H2 floating-point types") {
+    val rows = sql("SELECT * FROM flttypes").collect()
+    assert(rows(0).getDouble(0) == 1.00000000000000022) // Yes, I meant ==.
+    assert(rows(0).getDouble(1) == 1.00000011920928955) // Yes, I meant ==.
+    assert(rows(0).getAs[BigDecimal](2)
+        .equals(new BigDecimal("123456789012345.54321543215432100000")))
+  }
+
+
+  test("SQL query as table name") {
+    sql(
+      s"""
+        |CREATE TEMPORARY TABLE hack
+        |USING org.apache.spark.sql.jdbc
+        |OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)')
+      """.stripMargin.replaceAll("\n", " "))
+    val rows = sql("SELECT * FROM hack").collect()
+    assert(rows(0).getDouble(0) == 1.00000011920928955) // Yes, I meant ==.
+    // For some reason, H2 computes this square incorrectly...
+    assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8f471a66/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
new file mode 100644
index 0000000..e581ac9
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.math.BigDecimal
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.test._
+import org.scalatest.{FunSuite, BeforeAndAfter}
+import java.sql.DriverManager
+import TestSQLContext._
+
+class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
+  val url = "jdbc:h2:mem:testdb2"
+  var conn: java.sql.Connection = null
+
+  before {
+    Class.forName("org.h2.Driver")
+    conn = DriverManager.getConnection(url)
+    conn.prepareStatement("create schema test").executeUpdate()
+  }
+
+  after {
+    conn.close()
+  }
+
+  val sc = TestSQLContext.sparkContext
+
+  val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222))
+  val arr1x2 = Array[Row](Row.apply("fred", 3))
+  val schema2 = StructType(
+      StructField("name", StringType) ::
+      StructField("id", IntegerType) :: Nil)
+
+  val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2))
+  val schema3 = StructType(
+      StructField("name", StringType) ::
+      StructField("id", IntegerType) ::
+      StructField("seq", IntegerType) :: Nil)
+
+  test("Basic CREATE") {
+    val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
+
+    srdd.createJDBCTable(url, "TEST.BASICCREATETEST", false)
+    assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").count)
+    assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").collect()(0).length)
+  }
+
+  test("CREATE with overwrite") {
+    val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3)
+    val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+
+    srdd.createJDBCTable(url, "TEST.DROPTEST", false)
+    assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count)
+    assert(3 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").collect()(0).length)
+
+    srdd2.createJDBCTable(url, "TEST.DROPTEST", true)
+    assert(1 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count)
+    assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").collect()(0).length)
+  }
+
+  test("CREATE then INSERT to append") {
+    val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
+    val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+
+    srdd.createJDBCTable(url, "TEST.APPENDTEST", false)
+    srdd2.insertIntoJDBC(url, "TEST.APPENDTEST", false)
+    assert(3 == TestSQLContext.jdbcRDD(url, "TEST.APPENDTEST").count)
+    assert(2 == TestSQLContext.jdbcRDD(url, "TEST.APPENDTEST").collect()(0).length)
+  }
+
+  test("CREATE then INSERT to truncate") {
+    val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
+    val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+
+    srdd.createJDBCTable(url, "TEST.TRUNCATETEST", false)
+    srdd2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true)
+    assert(1 == TestSQLContext.jdbcRDD(url, "TEST.TRUNCATETEST").count)
+    assert(2 == TestSQLContext.jdbcRDD(url, "TEST.TRUNCATETEST").collect()(0).length)
+  }
+
+  test("Incompatible INSERT to append") {
+    val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
+    val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3)
+
+    srdd.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false)
+    intercept[org.apache.spark.SparkException] {
+      srdd2.insertIntoJDBC(url, "TEST.INCOMPATIBLETEST", true)
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8f471a66/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala
new file mode 100644
index 0000000..89920f2
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.math.BigDecimal
+import java.sql.{Date, DriverManager, Timestamp}
+import com.spotify.docker.client.{DefaultDockerClient, DockerClient}
+import com.spotify.docker.client.messages.ContainerConfig
+import org.scalatest.{FunSuite, BeforeAndAfterAll, Ignore}
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkContext._
+import org.apache.spark.sql._
+import org.apache.spark.sql.test._
+import TestSQLContext._
+
+import org.apache.spark.sql.jdbc._
+
+class MySQLDatabase {
+  val docker: DockerClient = DockerClientFactory.get()
+  val containerId = {
+    println("Pulling mysql")
+    docker.pull("mysql")
+    println("Configuring container")
+    val config = (ContainerConfig.builder().image("mysql")
+        .env("MYSQL_ROOT_PASSWORD=rootpass")
+        .build())
+    println("Creating container")
+    val id = docker.createContainer(config).id
+    println("Starting container " + id)
+    docker.startContainer(id)
+    id
+  }
+  val ip = docker.inspectContainer(containerId).networkSettings.ipAddress
+
+  def close() {
+    try {
+      println("Killing container " + containerId)
+      docker.killContainer(containerId)
+      println("Removing container " + containerId)
+      docker.removeContainer(containerId)
+      println("Closing docker client")
+      DockerClientFactory.close(docker)
+    } catch {
+      case e: Exception => {
+        println(e)
+        println("You may need to clean this up manually.")
+        throw e
+      }
+    }
+  }
+}
+
+@Ignore class MySQLIntegration extends FunSuite with BeforeAndAfterAll {
+  var ip: String = null
+
+  def url(ip: String): String = url(ip, "mysql")
+  def url(ip: String, db: String): String = s"jdbc:mysql://$ip:3306/$db?user=root&password=rootpass"
+
+  def waitForDatabase(ip: String, maxMillis: Long) {
+    println("Waiting for database to start up.")
+    val before = System.currentTimeMillis()
+    var lastException: java.sql.SQLException = null
+    while (true) {
+      if (System.currentTimeMillis() > before + maxMillis) {
+        throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", lastException)
+      }
+      try {
+        val conn = java.sql.DriverManager.getConnection(url(ip))
+        conn.close()
+        println("Database is up.")
+        return;
+      } catch {
+        case e: java.sql.SQLException => {
+          lastException = e
+          java.lang.Thread.sleep(250)
+        }
+      }
+    }
+  }
+
+  def setupDatabase(ip: String) {
+    val conn = java.sql.DriverManager.getConnection(url(ip))
+    try {
+      conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
+      conn.prepareStatement("CREATE TABLE foo.tbl (x INTEGER, y TEXT(8))").executeUpdate()
+      conn.prepareStatement("INSERT INTO foo.tbl VALUES (42,'fred')").executeUpdate()
+      conn.prepareStatement("INSERT INTO foo.tbl VALUES (17,'dave')").executeUpdate()
+
+      conn.prepareStatement("CREATE TABLE foo.numbers (onebit BIT(1), tenbits BIT(10), "
+          + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, "
+          + "dbl DOUBLE)").executeUpdate()
+      conn.prepareStatement("INSERT INTO foo.numbers VALUES (b'0', b'1000100101', "
+          + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, "
+          + "42.75, 1.0000000000000002)").executeUpdate()
+
+      conn.prepareStatement("CREATE TABLE foo.dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, "
+          + "yr YEAR)").executeUpdate()
+      conn.prepareStatement("INSERT INTO foo.dates VALUES ('1991-11-09', '13:31:24', "
+          + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate()
+
+      // TODO: Test locale conversion for strings.
+      conn.prepareStatement("CREATE TABLE foo.strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, "
+          + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)"
+          ).executeUpdate()
+      conn.prepareStatement("INSERT INTO foo.strings VALUES ('the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate()
+    } finally {
+      conn.close()
+    }
+  }
+
+  var db: MySQLDatabase = null
+
+  override def beforeAll() {
+    // If you load the MySQL driver here, DriverManager will deadlock.  The
+    // MySQL driver gets loaded when its jar gets loaded, unlike the Postgres
+    // and H2 drivers.
+    //Class.forName("com.mysql.jdbc.Driver")
+
+    db = new MySQLDatabase()
+    waitForDatabase(db.ip, 60000)
+    setupDatabase(db.ip)
+    ip = db.ip
+  }
+
+  override def afterAll() {
+    db.close()
+  }
+
+  test("Basic test") {
+    val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "tbl")
+    val rows = rdd.collect
+    assert(rows.length == 2)
+    val types = rows(0).toSeq.map(x => x.getClass.toString)
+    assert(types.length == 2)
+    assert(types(0).equals("class java.lang.Integer"))
+    assert(types(1).equals("class java.lang.String"))
+  }
+
+  test("Numeric types") {
+    val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "numbers")
+    val rows = rdd.collect
+    assert(rows.length == 1)
+    val types = rows(0).toSeq.map(x => x.getClass.toString)
+    assert(types.length == 9)
+    println(types(1))
+    assert(types(0).equals("class java.lang.Boolean"))
+    assert(types(1).equals("class java.lang.Long"))
+    assert(types(2).equals("class java.lang.Integer"))
+    assert(types(3).equals("class java.lang.Integer"))
+    assert(types(4).equals("class java.lang.Integer"))
+    assert(types(5).equals("class java.lang.Long"))
+    assert(types(6).equals("class java.math.BigDecimal"))
+    assert(types(7).equals("class java.lang.Double"))
+    assert(types(8).equals("class java.lang.Double"))
+    assert(rows(0).getBoolean(0) == false)
+    assert(rows(0).getLong(1) == 0x225)
+    assert(rows(0).getInt(2) == 17)
+    assert(rows(0).getInt(3) == 77777)
+    assert(rows(0).getInt(4) == 123456789)
+    assert(rows(0).getLong(5) == 123456789012345L)
+    val bd = new BigDecimal("123456789012345.12345678901234500000")
+    assert(rows(0).getAs[BigDecimal](6).equals(bd))
+    assert(rows(0).getDouble(7) == 42.75)
+    assert(rows(0).getDouble(8) == 1.0000000000000002)
+  }
+
+  test("Date types") {
+    val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "dates")
+    val rows = rdd.collect
+    assert(rows.length == 1)
+    val types = rows(0).toSeq.map(x => x.getClass.toString)
+    assert(types.length == 5)
+    assert(types(0).equals("class java.sql.Date"))
+    assert(types(1).equals("class java.sql.Timestamp"))
+    assert(types(2).equals("class java.sql.Timestamp"))
+    assert(types(3).equals("class java.sql.Timestamp"))
+    assert(types(4).equals("class java.sql.Date"))
+    assert(rows(0).getAs[Date](0).equals(new Date(91, 10, 9)))
+    assert(rows(0).getAs[Timestamp](1).equals(new Timestamp(70, 0, 1, 13, 31, 24, 0)))
+    assert(rows(0).getAs[Timestamp](2).equals(new Timestamp(96, 0, 1, 1, 23, 45, 0)))
+    assert(rows(0).getAs[Timestamp](3).equals(new Timestamp(109, 1, 13, 23, 31, 30, 0)))
+    assert(rows(0).getAs[Date](4).equals(new Date(101, 0, 1)))
+  }
+
+  test("String types") {
+    val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "strings")
+    val rows = rdd.collect
+    assert(rows.length == 1)
+    val types = rows(0).toSeq.map(x => x.getClass.toString)
+    assert(types.length == 9)
+    assert(types(0).equals("class java.lang.String"))
+    assert(types(1).equals("class java.lang.String"))
+    assert(types(2).equals("class java.lang.String"))
+    assert(types(3).equals("class java.lang.String"))
+    assert(types(4).equals("class java.lang.String"))
+    assert(types(5).equals("class java.lang.String"))
+    assert(types(6).equals("class [B"))
+    assert(types(7).equals("class [B"))
+    assert(types(8).equals("class [B"))
+    assert(rows(0).getString(0).equals("the"))
+    assert(rows(0).getString(1).equals("quick"))
+    assert(rows(0).getString(2).equals("brown"))
+    assert(rows(0).getString(3).equals("fox"))
+    assert(rows(0).getString(4).equals("jumps"))
+    assert(rows(0).getString(5).equals("over"))
+    assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0)))
+    assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121)))
+    assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103)))
+  }
+
+  test("Basic write test") {
+    val rdd1 = TestSQLContext.jdbcRDD(url(ip, "foo"), "numbers")
+    val rdd2 = TestSQLContext.jdbcRDD(url(ip, "foo"), "dates")
+    val rdd3 = TestSQLContext.jdbcRDD(url(ip, "foo"), "strings")
+    rdd1.createJDBCTable(url(ip, "foo"), "numberscopy", false)
+    rdd2.createJDBCTable(url(ip, "foo"), "datescopy", false)
+    rdd3.createJDBCTable(url(ip, "foo"), "stringscopy", false)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8f471a66/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala
new file mode 100644
index 0000000..c174d7a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.math.BigDecimal
+import org.apache.spark.sql.test._
+import org.scalatest.{FunSuite, BeforeAndAfterAll, Ignore}
+import java.sql.DriverManager
+import TestSQLContext._
+import com.spotify.docker.client.{DefaultDockerClient, DockerClient}
+import com.spotify.docker.client.messages.ContainerConfig
+
+class PostgresDatabase {
+  val docker: DockerClient = DockerClientFactory.get()
+  val containerId = {
+    println("Pulling postgres")
+    docker.pull("postgres")
+    println("Configuring container")
+    val config = (ContainerConfig.builder().image("postgres")
+        .env("POSTGRES_PASSWORD=rootpass")
+        .build())
+    println("Creating container")
+    val id = docker.createContainer(config).id
+    println("Starting container " + id)
+    docker.startContainer(id)
+    id
+  }
+  val ip = docker.inspectContainer(containerId).networkSettings.ipAddress
+
+  def close() {
+    try {
+      println("Killing container " + containerId)
+      docker.killContainer(containerId)
+      println("Removing container " + containerId)
+      docker.removeContainer(containerId)
+      println("Closing docker client")
+      DockerClientFactory.close(docker)
+    } catch {
+      case e: Exception => {
+        println(e)
+        println("You may need to clean this up manually.")
+        throw e
+      }
+    }
+  }
+}
+
+@Ignore class PostgresIntegration extends FunSuite with BeforeAndAfterAll {
+  lazy val db = new PostgresDatabase()
+
+  def url(ip: String) = s"jdbc:postgresql://$ip:5432/postgres?user=postgres&password=rootpass"
+
+  def waitForDatabase(ip: String, maxMillis: Long) {
+    val before = System.currentTimeMillis()
+    var lastException: java.sql.SQLException = null
+    while (true) {
+      if (System.currentTimeMillis() > before + maxMillis) {
+        throw new java.sql.SQLException(s"Database not up after $maxMillis ms.",
+ lastException)
+      }
+      try {
+        val conn = java.sql.DriverManager.getConnection(url(ip))
+        conn.close()
+        println("Database is up.")
+        return;
+      } catch {
+        case e: java.sql.SQLException => {
+          lastException = e
+          java.lang.Thread.sleep(250)
+        }
+      }
+    }
+  }
+
+  def setupDatabase(ip: String) {
+    val conn = DriverManager.getConnection(url(ip))
+    try {
+      conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
+      conn.setCatalog("foo")
+      conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, "
+          + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate()
+      conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
+          + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate()
+    } finally {
+      conn.close()
+    }
+  }
+
+  override def beforeAll() {
+    println("Waiting for database to start up.")
+    waitForDatabase(db.ip, 60000)
+    println("Setting up database.")
+    setupDatabase(db.ip)
+  }
+
+  override def afterAll() {
+    db.close()
+  }
+
+  test("Type mapping for various types") {
+    val rdd = TestSQLContext.jdbcRDD(url(db.ip), "public.bar")
+    val rows = rdd.collect
+    assert(rows.length == 1)
+    val types = rows(0).toSeq.map(x => x.getClass.toString)
+    assert(types.length == 10)
+    assert(types(0).equals("class java.lang.String"))
+    assert(types(1).equals("class java.lang.Integer"))
+    assert(types(2).equals("class java.lang.Double"))
+    assert(types(3).equals("class java.lang.Long"))
+    assert(types(4).equals("class java.lang.Boolean"))
+    assert(types(5).equals("class [B"))
+    assert(types(6).equals("class [B"))
+    assert(types(7).equals("class java.lang.Boolean"))
+    assert(types(8).equals("class java.lang.String"))
+    assert(types(9).equals("class java.lang.String"))
+    assert(rows(0).getString(0).equals("hello"))
+    assert(rows(0).getInt(1) == 42)
+    assert(rows(0).getDouble(2) == 1.25)
+    assert(rows(0).getLong(3) == 123456789012345L)
+    assert(rows(0).getBoolean(4) == false)
+    // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's...
+    assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), Array[Byte](49,48,48,48,49,48,48,49,48,49)))
+    assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte)))
+    assert(rows(0).getBoolean(7) == true)
+    assert(rows(0).getString(8) == "172.16.0.42")
+    assert(rows(0).getString(9) == "192.168.0.0/16")
+  }
+
+  test("Basic write test") {
+    val rdd = TestSQLContext.jdbcRDD(url(db.ip), "public.bar")
+    rdd.createJDBCTable(url(db.ip), "public.barcopy", false)
+    // Test only that it doesn't bomb out.
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message