spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-11217][ML] save/load for non-meta estimators and transformers
Date Fri, 06 Nov 2015 22:51:20 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 52e921c7c -> e7e3bfda3


[SPARK-11217][ML] save/load for non-meta estimators and transformers

This PR implements the default save/load for non-meta estimators and transformers using the
JSON serialization of param values. The saved metadata includes:

* class name
* uid
* timestamp
* paramMap

The save/load interface is similar to DataFrames. We use the current active context by default,
which should be sufficient for most use cases.

~~~scala
instance.save("path")
instance.write.context(sqlContext).overwrite().save("path")

Instance.load("path")
~~~

The param handling is different from the design doc. We didn't save default and user-set params
separately, and when we load it back, all parameters are user-set. This does cause issues.
But it also cause other issues if we modify the default params.

TODOs:

* [x] Java test
* [ ] a follow-up PR to implement default save/load for all non-meta estimators and transformers

cc jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #9454 from mengxr/SPARK-11217.

(cherry picked from commit c447c9d54603890db7399fb80adc9fae40b71f64)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e7e3bfda
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e7e3bfda
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e7e3bfda

Branch: refs/heads/branch-1.6
Commit: e7e3bfda315eb7809c8bdffa61a30b70a680611c
Parents: 52e921c
Author: Xiangrui Meng <meng@databricks.com>
Authored: Fri Nov 6 14:51:03 2015 -0800
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Fri Nov 6 14:51:15 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/Binarizer.scala |  11 +-
 .../org/apache/spark/ml/param/params.scala      |   2 +-
 .../org/apache/spark/ml/util/ReadWrite.scala    | 220 +++++++++++++++++++
 .../ml/util/JavaDefaultReadWriteSuite.java      |  74 +++++++
 .../spark/ml/feature/BinarizerSuite.scala       |  11 +-
 .../spark/ml/util/DefaultReadWriteTest.scala    | 110 ++++++++++
 .../apache/spark/ml/util/TempDirectory.scala    |  45 ++++
 7 files changed, 469 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e7e3bfda/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index edad754..e5c2557 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.attribute.BinaryAttribute
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
 import org.apache.spark.sql._
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DoubleType, StructType}
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
  */
 @Experimental
 final class Binarizer(override val uid: String)
-  extends Transformer with HasInputCol with HasOutputCol {
+  extends Transformer with Writable with HasInputCol with HasOutputCol {
 
   def this() = this(Identifiable.randomUID("binarizer"))
 
@@ -86,4 +86,11 @@ final class Binarizer(override val uid: String)
   }
 
   override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
+
+  override def write: Writer = new DefaultParamsWriter(this)
+}
+
+object Binarizer extends Readable[Binarizer] {
+
+  override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer]
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e7e3bfda/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 8361406..c932570 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -592,7 +592,7 @@ trait Params extends Identifiable with Serializable {
   /**
    * Sets a parameter in the embedded param map.
    */
-  protected final def set[T](param: Param[T], value: T): this.type = {
+  final def set[T](param: Param[T], value: T): this.type = {
     set(param -> value)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e7e3bfda/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
new file mode 100644
index 0000000..ea790e0
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.io.IOException
+
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.param.{ParamPair, Params}
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.util.Utils
+
+/**
+ * Trait for [[Writer]] and [[Reader]].
+ */
+private[util] sealed trait BaseReadWrite {
+  private var optionSQLContext: Option[SQLContext] = None
+
+  /**
+   * Sets the SQL context to use for saving/loading.
+   */
+  @Since("1.6.0")
+  def context(sqlContext: SQLContext): this.type = {
+    optionSQLContext = Option(sqlContext)
+    this
+  }
+
+  /**
+   * Returns the user-specified SQL context or the default.
+   */
+  protected final def sqlContext: SQLContext = optionSQLContext.getOrElse {
+    SQLContext.getOrCreate(SparkContext.getOrCreate())
+  }
+}
+
+/**
+ * Abstract class for utility classes that can save ML instances.
+ */
+@Experimental
+@Since("1.6.0")
+abstract class Writer extends BaseReadWrite {
+
+  protected var shouldOverwrite: Boolean = false
+
+  /**
+   * Saves the ML instances to the input path.
+   */
+  @Since("1.6.0")
+  @throws[IOException]("If the input path already exists but overwrite is not enabled.")
+  def save(path: String): Unit
+
+  /**
+   * Overwrites if the output path already exists.
+   */
+  @Since("1.6.0")
+  def overwrite(): this.type = {
+    shouldOverwrite = true
+    this
+  }
+
+  // override for Java compatibility
+  override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
+}
+
+/**
+ * Trait for classes that provide [[Writer]].
+ */
+@Since("1.6.0")
+trait Writable {
+
+  /**
+   * Returns a [[Writer]] instance for this ML instance.
+   */
+  @Since("1.6.0")
+  def write: Writer
+
+  /**
+   * Saves this ML instance to the input path, a shortcut of `write.save(path)`.
+   */
+  @Since("1.6.0")
+  @throws[IOException]("If the input path already exists but overwrite is not enabled.")
+  def save(path: String): Unit = write.save(path)
+}
+
+/**
+ * Abstract class for utility classes that can load ML instances.
+ * @tparam T ML instance type
+ */
+@Experimental
+@Since("1.6.0")
+abstract class Reader[T] extends BaseReadWrite {
+
+  /**
+   * Loads the ML component from the input path.
+   */
+  @Since("1.6.0")
+  def load(path: String): T
+
+  // override for Java compatibility
+  override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
+}
+
+/**
+ * Trait for objects that provide [[Reader]].
+ * @tparam T ML instance type
+ */
+@Experimental
+@Since("1.6.0")
+trait Readable[T] {
+
+  /**
+   * Returns a [[Reader]] instance for this class.
+   */
+  @Since("1.6.0")
+  def read: Reader[T]
+
+  /**
+   * Reads an ML instance from the input path, a shortcut of `read.load(path)`.
+   */
+  @Since("1.6.0")
+  def load(path: String): T = read.load(path)
+}
+
+/**
+ * Default [[Writer]] implementation for transformers and estimators that contain basic
+ * (json4s-serializable) params and no data. This will not handle more complex params or
types with
+ * data (e.g., models with coefficients).
+ * @param instance object to save
+ */
+private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging {
+
+  /**
+   * Saves the ML component to the input path.
+   */
+  override def save(path: String): Unit = {
+    val sc = sqlContext.sparkContext
+
+    val hadoopConf = sc.hadoopConfiguration
+    val fs = FileSystem.get(hadoopConf)
+    val p = new Path(path)
+    if (fs.exists(p)) {
+      if (shouldOverwrite) {
+        logInfo(s"Path $path already exists. It will be overwritten.")
+        // TODO: Revert back to the original content if save is not successful.
+        fs.delete(p, true)
+      } else {
+        throw new IOException(
+          s"Path $path already exists. Please use write.overwrite().save(path) to overwrite
it.")
+      }
+    }
+
+    val uid = instance.uid
+    val cls = instance.getClass.getName
+    val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
+    val jsonParams = params.map { case ParamPair(p, v) =>
+      p.name -> parse(p.jsonEncode(v))
+    }.toList
+    val metadata = ("class" -> cls) ~
+      ("timestamp" -> System.currentTimeMillis()) ~
+      ("uid" -> uid) ~
+      ("paramMap" -> jsonParams)
+    val metadataPath = new Path(path, "metadata").toString
+    val metadataJson = compact(render(metadata))
+    sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+  }
+}
+
+/**
+ * Default [[Reader]] implementation for transformers and estimators that contain basic
+ * (json4s-serializable) params and no data. This will not handle more complex params or
types with
+ * data (e.g., models with coefficients).
+ * @tparam T ML instance type
+ */
+private[ml] class DefaultParamsReader[T] extends Reader[T] {
+
+  /**
+   * Loads the ML component from the input path.
+   */
+  override def load(path: String): T = {
+    implicit val format = DefaultFormats
+    val sc = sqlContext.sparkContext
+    val metadataPath = new Path(path, "metadata").toString
+    val metadataStr = sc.textFile(metadataPath, 1).first()
+    val metadata = parse(metadataStr)
+    val cls = Utils.classForName((metadata \ "class").extract[String])
+    val uid = (metadata \ "uid").extract[String]
+    val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params]
+    (metadata \ "paramMap") match {
+      case JObject(pairs) =>
+        pairs.foreach { case (paramName, jsonValue) =>
+          val param = instance.getParam(paramName)
+          val value = param.jsonDecode(compact(render(jsonValue)))
+          instance.set(param, value)
+        }
+      case _ =>
+        throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.")
+    }
+    instance.asInstanceOf[T]
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e7e3bfda/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
new file mode 100644
index 0000000..c395380
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util;
+
+import java.io.File;
+import java.io.IOException;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.util.Utils;
+
+public class JavaDefaultReadWriteSuite {
+
+  JavaSparkContext jsc = null;
+  File tempDir = null;
+
+  @Before
+  public void setUp() {
+    jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite");
+    tempDir = Utils.createTempDir(
+      System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite");
+  }
+
+  @After
+  public void tearDown() {
+    if (jsc != null) {
+      jsc.stop();
+      jsc = null;
+    }
+    Utils.deleteRecursively(tempDir);
+  }
+
+  @Test
+  public void testDefaultReadWrite() throws IOException {
+    String uid = "my_params";
+    MyParams instance = new MyParams(uid);
+    instance.set(instance.intParam(), 2);
+    String outputPath = new File(tempDir, uid).getPath();
+    instance.save(outputPath);
+    try {
+      instance.save(outputPath);
+      Assert.fail(
+        "Write without overwrite enabled should fail if the output directory already exists.");
+    } catch (IOException e) {
+      // expected
+    }
+    SQLContext sqlContext = new SQLContext(jsc);
+    instance.write().context(sqlContext).overwrite().save(outputPath);
+    MyParams newInstance = MyParams.load(outputPath);
+    Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
+    Assert.assertEquals("Params should be preserved.",
+      2, newInstance.getOrDefault(newInstance.intParam()));
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e7e3bfda/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 2086043..9dfa143 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -19,10 +19,11 @@ package org.apache.spark.ml.feature
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Row}
 
-class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
{
 
   @transient var data: Array[Double] = _
 
@@ -66,4 +67,12 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
         assert(x === y, "The feature value is not correct after binarization.")
     }
   }
+
+  test("read/write") {
+    val binarizer = new Binarizer()
+      .setInputCol("feature")
+      .setOutputCol("binarized_feature")
+      .setThreshold(0.1)
+    testDefaultReadWrite(binarizer)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e7e3bfda/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
new file mode 100644
index 0000000..4545b0f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.io.{File, IOException}
+
+import org.scalatest.Suite
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
+
+  /**
+   * Checks "overwrite" option and params.
+   * @param instance ML instance to test saving/loading
+   * @tparam T ML instance type
+   */
+  def testDefaultReadWrite[T <: Params with Writable](instance: T): Unit = {
+    val uid = instance.uid
+    val path = new File(tempDir, uid).getPath
+
+    instance.save(path)
+    intercept[IOException] {
+      instance.save(path)
+    }
+    instance.write.overwrite().save(path)
+    val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]]
+    val newInstance = loader.load(path)
+
+    assert(newInstance.uid === instance.uid)
+    instance.params.foreach { p =>
+      if (instance.isDefined(p)) {
+        (instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
+          case (Array(values), Array(newValues)) =>
+            assert(values === newValues, s"Values do not match on param ${p.name}.")
+          case (value, newValue) =>
+            assert(value === newValue, s"Values do not match on param ${p.name}.")
+        }
+      } else {
+        assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
+      }
+    }
+
+    val load = instance.getClass.getMethod("load", classOf[String])
+    val another = load.invoke(instance, path).asInstanceOf[T]
+    assert(another.uid === instance.uid)
+  }
+}
+
+class MyParams(override val uid: String) extends Params with Writable {
+
+  final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc")
+  final val intParam: IntParam = new IntParam(this, "intParam", "doc")
+  final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc")
+  final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc")
+  final val longParam: LongParam = new LongParam(this, "longParam", "doc")
+  final val stringParam: Param[String] = new Param[String](this, "stringParam", "doc")
+  final val intArrayParam: IntArrayParam = new IntArrayParam(this, "intArrayParam", "doc")
+  final val doubleArrayParam: DoubleArrayParam =
+    new DoubleArrayParam(this, "doubleArrayParam", "doc")
+  final val stringArrayParam: StringArrayParam =
+    new StringArrayParam(this, "stringArrayParam", "doc")
+
+  setDefault(intParamWithDefault -> 0)
+  set(intParam -> 1)
+  set(floatParam -> 2.0f)
+  set(doubleParam -> 3.0)
+  set(longParam -> 4L)
+  set(stringParam -> "5")
+  set(intArrayParam -> Array(6, 7))
+  set(doubleArrayParam -> Array(8.0, 9.0))
+  set(stringArrayParam -> Array("10", "11"))
+
+  override def copy(extra: ParamMap): Params = defaultCopy(extra)
+
+  override def write: Writer = new DefaultParamsWriter(this)
+}
+
+object MyParams extends Readable[MyParams] {
+
+  override def read: Reader[MyParams] = new DefaultParamsReader[MyParams]
+
+  override def load(path: String): MyParams = read.load(path)
+}
+
+class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext
+  with DefaultReadWriteTest {
+
+  test("default read/write") {
+    val myParams = new MyParams("my_params")
+    testDefaultReadWrite(myParams)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e7e3bfda/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala
new file mode 100644
index 0000000..2742026
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.io.File
+
+import org.scalatest.{BeforeAndAfterAll, Suite}
+
+import org.apache.spark.util.Utils
+
+/**
+ * Trait that creates a temporary directory before all tests and deletes it after all.
+ */
+trait TempDirectory extends BeforeAndAfterAll { self: Suite =>
+
+  private var _tempDir: File = _
+
+  /** Returns the temporary directory as a [[File]] instance. */
+  protected def tempDir: File = _tempDir
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    _tempDir = Utils.createTempDir(this.getClass.getName)
+  }
+
+  override def afterAll(): Unit = {
+    Utils.deleteRecursively(_tempDir)
+    super.afterAll()
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message