spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From pwend...@apache.org
Subject [25/32] Using name yarn-alpha/yarn instead of yarn-2.0/yarn-2.2
Date Fri, 03 Jan 2014 07:16:37 GMT
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/ebdfa6bb/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
----------------------------------------------------------------------
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
new file mode 100644
index 0000000..a750668
--- /dev/null
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -0,0 +1,523 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.{InetAddress, UnknownHostException, URI}
+import java.nio.ByteBuffer
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Map
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileContext, FileStatus, FileSystem, Path, FileUtil}
+import org.apache.hadoop.fs.permission.FsPermission;
+import org.apache.hadoop.io.DataOutputBuffer
+import org.apache.hadoop.mapred.Master
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.impl.YarnClientImpl
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{Apps, Records}
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.util.Utils
+import org.apache.spark.deploy.SparkHadoopUtil
+
+
+/**
+ * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The
+ * Client submits an application to the global ResourceManager to launch Spark's ApplicationMaster,
+ * which will launch a Spark master process and negotiate resources throughout its duration.
+ */
+class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging {
+
+  var rpc: YarnRPC = YarnRPC.create(conf)
+  val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+  val credentials = UserGroupInformation.getCurrentUser().getCredentials()
+  private val SPARK_STAGING: String = ".sparkStaging"
+  private val distCacheMgr = new ClientDistributedCacheManager()
+  private val sparkConf = new SparkConf
+
+
+  // Staging directory is private! -> rwx--------
+  val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700: Short)
+  // App files are world-wide readable and owner writable -> rw-r--r--
+  val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644: Short)
+
+  def this(args: ClientArguments) = this(new Configuration(), args)
+
+  def runApp(): ApplicationId = {
+    validateArgs()
+    // Initialize and start the client service.
+    init(yarnConf)
+    start()
+
+    // Log details about this YARN cluster (e.g, the number of slave machines/NodeManagers).
+    logClusterResourceDetails()
+
+    // Prepare to submit a request to the ResourcManager (specifically its ApplicationsManager (ASM)
+    // interface).
+
+    // Get a new client application.
+    val newApp = super.createApplication()
+    val newAppResponse = newApp.getNewApplicationResponse()
+    val appId = newAppResponse.getApplicationId()
+
+    verifyClusterResources(newAppResponse)
+
+    // Set up resource and environment variables.
+    val appStagingDir = getAppStagingDir(appId)
+    val localResources = prepareLocalResources(appStagingDir)
+    val launchEnv = setupLaunchEnv(localResources, appStagingDir)
+    val amContainer = createContainerLaunchContext(newAppResponse, localResources, launchEnv)
+
+    // Set up an application submission context.
+    val appContext = newApp.getApplicationSubmissionContext()
+    appContext.setApplicationName(args.appName)
+    appContext.setQueue(args.amQueue)
+    appContext.setAMContainerSpec(amContainer)
+
+    // Memory for the ApplicationMaster.
+    val memoryResource = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
+    memoryResource.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+    appContext.setResource(memoryResource)
+
+    // Finally, submit and monitor the application.
+    submitApp(appContext)
+    appId
+  }
+
+  def run() {
+    val appId = runApp()
+    monitorApplication(appId)
+    System.exit(0)
+  }
+
+  // TODO(harvey): This could just go in ClientArguments.
+  def validateArgs() = {
+    Map(
+      (System.getenv("SPARK_JAR") == null) -> "Error: You must set SPARK_JAR environment variable!",
+      (args.userJar == null) -> "Error: You must specify a user jar!",
+      (args.userClass == null) -> "Error: You must specify a user class!",
+      (args.numWorkers <= 0) -> "Error: You must specify atleast 1 worker!",
+      (args.amMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: AM memory size must be" +
+        "greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD),
+      (args.workerMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: Worker memory size" +
+        "must be greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD.toString)
+    ).foreach { case(cond, errStr) =>
+      if (cond) {
+        logError(errStr)
+        args.printUsageAndExit(1)
+      }
+    }
+  }
+
+  def getAppStagingDir(appId: ApplicationId): String = {
+    SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR
+  }
+
+  def logClusterResourceDetails() {
+    val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics
+    logInfo("Got Cluster metric info from ApplicationsManager (ASM), number of NodeManagers: " +
+      clusterMetrics.getNumNodeManagers)
+
+    val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue)
+    logInfo("""Queue info ... queueName: %s, queueCurrentCapacity: %s, queueMaxCapacity: %s,
+      queueApplicationCount = %s, queueChildQueueCount = %s""".format(
+        queueInfo.getQueueName,
+        queueInfo.getCurrentCapacity,
+        queueInfo.getMaximumCapacity,
+        queueInfo.getApplications.size,
+        queueInfo.getChildQueues.size))
+  }
+
+  def verifyClusterResources(app: GetNewApplicationResponse) = {
+    val maxMem = app.getMaximumResourceCapability().getMemory()
+    logInfo("Max mem capabililty of a single resource in this cluster " + maxMem)
+
+    // If we have requested more then the clusters max for a single resource then exit.
+    if (args.workerMemory > maxMem) {
+      logError("Required worker memory (%d MB), is above the max threshold (%d MB) of this cluster.".
+        format(args.workerMemory, maxMem))
+      System.exit(1)
+    }
+    val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD
+    if (amMem > maxMem) {
+      logError("Required AM memory (%d) is above the max threshold (%d) of this cluster".
+        format(args.amMemory, maxMem))
+      System.exit(1)
+    }
+
+    // We could add checks to make sure the entire cluster has enough resources but that involves
+    // getting all the node reports and computing ourselves.
+  }
+
+  /** See if two file systems are the same or not. */
+  private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = {
+    val srcUri = srcFs.getUri()
+    val dstUri = destFs.getUri()
+    if (srcUri.getScheme() == null) {
+      return false
+    }
+    if (!srcUri.getScheme().equals(dstUri.getScheme())) {
+      return false
+    }
+    var srcHost = srcUri.getHost()
+    var dstHost = dstUri.getHost()
+    if ((srcHost != null) && (dstHost != null)) {
+      try {
+        srcHost = InetAddress.getByName(srcHost).getCanonicalHostName()
+        dstHost = InetAddress.getByName(dstHost).getCanonicalHostName()
+      } catch {
+        case e: UnknownHostException =>
+          return false
+      }
+      if (!srcHost.equals(dstHost)) {
+        return false
+      }
+    } else if (srcHost == null && dstHost != null) {
+      return false
+    } else if (srcHost != null && dstHost == null) {
+      return false
+    }
+    //check for ports
+    if (srcUri.getPort() != dstUri.getPort()) {
+      return false
+    }
+    return true
+  }
+
+  /** Copy the file into HDFS if needed. */
+  private def copyRemoteFile(
+      dstDir: Path,
+      originalPath: Path,
+      replication: Short,
+      setPerms: Boolean = false): Path = {
+    val fs = FileSystem.get(conf)
+    val remoteFs = originalPath.getFileSystem(conf)
+    var newPath = originalPath
+    if (! compareFs(remoteFs, fs)) {
+      newPath = new Path(dstDir, originalPath.getName())
+      logInfo("Uploading " + originalPath + " to " + newPath)
+      FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf)
+      fs.setReplication(newPath, replication)
+      if (setPerms) fs.setPermission(newPath, new FsPermission(APP_FILE_PERMISSION))
+    }
+    // Resolve any symlinks in the URI path so using a "current" symlink to point to a specific
+    // version shows the specific version in the distributed cache configuration
+    val qualPath = fs.makeQualified(newPath)
+    val fc = FileContext.getFileContext(qualPath.toUri(), conf)
+    val destPath = fc.resolvePath(qualPath)
+    destPath
+  }
+
+  def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = {
+    logInfo("Preparing Local resources")
+    // Upload Spark and the application JAR to the remote file system if necessary. Add them as
+    // local resources to the application master.
+    val fs = FileSystem.get(conf)
+
+    val delegTokenRenewer = Master.getMasterPrincipal(conf)
+    if (UserGroupInformation.isSecurityEnabled()) {
+      if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) {
+        logError("Can't get Master Kerberos principal for use as renewer")
+        System.exit(1)
+      }
+    }
+    val dst = new Path(fs.getHomeDirectory(), appStagingDir)
+    val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort
+
+    if (UserGroupInformation.isSecurityEnabled()) {
+      val dstFs = dst.getFileSystem(conf)
+      dstFs.addDelegationTokens(delegTokenRenewer, credentials)
+    }
+
+    val localResources = HashMap[String, LocalResource]()
+    FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION))
+
+    val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+
+    Map(
+      Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar,
+      Client.LOG4J_PROP -> System.getenv("SPARK_LOG4J_CONF")
+    ).foreach { case(destName, _localPath) =>
+      val localPath: String = if (_localPath != null) _localPath.trim() else ""
+      if (! localPath.isEmpty()) {
+        var localURI = new URI(localPath)
+        // If not specified assume these are in the local filesystem to keep behavior like Hadoop
+        if (localURI.getScheme() == null) {
+          localURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(localPath)).toString)
+        }
+        val setPermissions = if (destName.equals(Client.APP_JAR)) true else false
+        val destPath = copyRemoteFile(dst, new Path(localURI), replication, setPermissions)
+        distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE,
+          destName, statCache)
+      }
+    }
+
+    // Handle jars local to the ApplicationMaster.
+    if ((args.addJars != null) && (!args.addJars.isEmpty())){
+      args.addJars.split(',').foreach { case file: String =>
+        val localURI = new URI(file.trim())
+        val localPath = new Path(localURI)
+        val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
+        val destPath = copyRemoteFile(dst, localPath, replication)
+        // Only add the resource to the Spark ApplicationMaster.
+        val appMasterOnly = true
+        distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE,
+          linkname, statCache, appMasterOnly)
+      }
+    }
+
+    // Handle any distributed cache files
+    if ((args.files != null) && (!args.files.isEmpty())){
+      args.files.split(',').foreach { case file: String =>
+        val localURI = new URI(file.trim())
+        val localPath = new Path(localURI)
+        val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
+        val destPath = copyRemoteFile(dst, localPath, replication)
+        distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE,
+          linkname, statCache)
+      }
+    }
+
+    // Handle any distributed cache archives
+    if ((args.archives != null) && (!args.archives.isEmpty())) {
+      args.archives.split(',').foreach { case file:String =>
+        val localURI = new URI(file.trim())
+        val localPath = new Path(localURI)
+        val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
+        val destPath = copyRemoteFile(dst, localPath, replication)
+        distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE,
+          linkname, statCache)
+      }
+    }
+
+    UserGroupInformation.getCurrentUser().addCredentials(credentials)
+    localResources
+  }
+
+  def setupLaunchEnv(
+      localResources: HashMap[String, LocalResource],
+      stagingDir: String): HashMap[String, String] = {
+    logInfo("Setting up the launch environment")
+    val log4jConfLocalRes = localResources.getOrElse(Client.LOG4J_PROP, null)
+
+    val env = new HashMap[String, String]()
+
+    Client.populateClasspath(yarnConf, log4jConfLocalRes != null, env)
+    env("SPARK_YARN_MODE") = "true"
+    env("SPARK_YARN_STAGING_DIR") = stagingDir
+
+    // Set the environment variables to be passed on to the Workers.
+    distCacheMgr.setDistFilesEnv(env)
+    distCacheMgr.setDistArchivesEnv(env)
+
+    // Allow users to specify some environment variables.
+    Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))
+
+    // Add each SPARK_* key to the environment.
+    System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
+
+    env
+  }
+
+  def userArgsToString(clientArgs: ClientArguments): String = {
+    val prefix = " --args "
+    val args = clientArgs.userArgs
+    val retval = new StringBuilder()
+    for (arg <- args){
+      retval.append(prefix).append(" '").append(arg).append("' ")
+    }
+    retval.toString
+  }
+
+  def createContainerLaunchContext(
+      newApp: GetNewApplicationResponse,
+      localResources: HashMap[String, LocalResource],
+      env: HashMap[String, String]): ContainerLaunchContext = {
+    logInfo("Setting up container launch context")
+    val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
+    amContainer.setLocalResources(localResources)
+    amContainer.setEnvironment(env)
+
+    // TODO: Need a replacement for the following code to fix -Xmx?
+    // val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory()
+    // var amMemory = ((args.amMemory / minResMemory) * minResMemory) +
+    //  ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) -
+    //    YarnAllocationHandler.MEMORY_OVERHEAD)
+
+    // Extra options for the JVM
+    var JAVA_OPTS = ""
+
+    // Add Xmx for AM memory
+    JAVA_OPTS += "-Xmx" + args.amMemory + "m"
+
+    val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR)
+    JAVA_OPTS += " -Djava.io.tmpdir=" + tmpDir
+
+    // TODO: Remove once cpuset version is pushed out.
+    // The context is, default gc for server class machines ends up using all cores to do gc -
+    // hence if there are multiple containers in same node, Spark GC affects all other containers'
+    // performance (which can be that of other Spark containers)
+    // Instead of using this, rely on cpusets by YARN to enforce "proper" Spark behavior in
+    // multi-tenant environments. Not sure how default Java GC behaves if it is limited to subset
+    // of cores on a node.
+    val useConcurrentAndIncrementalGC = env.isDefinedAt("SPARK_USE_CONC_INCR_GC") &&
+      java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))
+    if (useConcurrentAndIncrementalGC) {
+      // In our expts, using (default) throughput collector has severe perf ramifications in
+      // multi-tenant machines
+      JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
+      JAVA_OPTS += " -XX:+CMSIncrementalMode "
+      JAVA_OPTS += " -XX:+CMSIncrementalPacing "
+      JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
+      JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
+    }
+
+    if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
+      JAVA_OPTS += " " + env("SPARK_JAVA_OPTS")
+    }
+
+    // Command for the ApplicationMaster
+    var javaCommand = "java"
+    val javaHome = System.getenv("JAVA_HOME")
+    if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) {
+      javaCommand = Environment.JAVA_HOME.$() + "/bin/java"
+    }
+
+    val commands = List[String](
+      javaCommand +
+      " -server " +
+      JAVA_OPTS +
+      " " + args.amClass +
+      " --class " + args.userClass +
+      " --jar " + args.userJar +
+      userArgsToString(args) +
+      " --worker-memory " + args.workerMemory +
+      " --worker-cores " + args.workerCores +
+      " --num-workers " + args.numWorkers +
+      " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
+      " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+
+    logInfo("Command for starting the Spark ApplicationMaster: " + commands(0))
+    amContainer.setCommands(commands)
+
+    // Setup security tokens.
+    val dob = new DataOutputBuffer()
+    credentials.writeTokenStorageToStream(dob)
+    amContainer.setTokens(ByteBuffer.wrap(dob.getData()))
+
+    amContainer
+  }
+
+  def submitApp(appContext: ApplicationSubmissionContext) = {
+    // Submit the application to the applications manager.
+    logInfo("Submitting application to ASM")
+    super.submitApplication(appContext)
+  }
+
+  def monitorApplication(appId: ApplicationId): Boolean = {
+    val interval = sparkConf.getLong("spark.yarn.report.interval", 1000)
+
+    while (true) {
+      Thread.sleep(interval)
+      val report = super.getApplicationReport(appId)
+
+      logInfo("Application report from ASM: \n" +
+        "\t application identifier: " + appId.toString() + "\n" +
+        "\t appId: " + appId.getId() + "\n" +
+        "\t clientToAMToken: " + report.getClientToAMToken() + "\n" +
+        "\t appDiagnostics: " + report.getDiagnostics() + "\n" +
+        "\t appMasterHost: " + report.getHost() + "\n" +
+        "\t appQueue: " + report.getQueue() + "\n" +
+        "\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
+        "\t appStartTime: " + report.getStartTime() + "\n" +
+        "\t yarnAppState: " + report.getYarnApplicationState() + "\n" +
+        "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" +
+        "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" +
+        "\t appUser: " + report.getUser()
+      )
+
+      val state = report.getYarnApplicationState()
+      val dsStatus = report.getFinalApplicationStatus()
+      if (state == YarnApplicationState.FINISHED ||
+        state == YarnApplicationState.FAILED ||
+        state == YarnApplicationState.KILLED) {
+        return true
+      }
+    }
+    true
+  }
+}
+
+object Client {
+  val SPARK_JAR: String = "spark.jar"
+  val APP_JAR: String = "app.jar"
+  val LOG4J_PROP: String = "log4j.properties"
+
+  def main(argStrings: Array[String]) {
+    // Set an env variable indicating we are running in YARN mode.
+    // Note: anything env variable with SPARK_ prefix gets propagated to all (remote) processes -
+    // see Client#setupLaunchEnv().
+    System.setProperty("SPARK_YARN_MODE", "true")
+
+    val args = new ClientArguments(argStrings)
+
+    (new Client(args)).run()
+  }
+
+  // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps
+  def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) {
+    for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) {
+      Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim)
+    }
+  }
+
+  def populateClasspath(conf: Configuration, addLog4j: Boolean, env: HashMap[String, String]) {
+    Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$())
+    // If log4j present, ensure ours overrides all others
+    if (addLog4j) {
+      Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+        Path.SEPARATOR + LOG4J_PROP)
+    }
+    // Normally the users app.jar is last in case conflicts with spark jars
+    val userClasspathFirst = new SparkConf().get("spark.yarn.user.classpath.first", "false")
+      .toBoolean
+    if (userClasspathFirst) {
+      Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+        Path.SEPARATOR + APP_JAR)
+    }
+    Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+      Path.SEPARATOR + SPARK_JAR)
+    Client.populateHadoopClasspath(conf, env)
+
+    if (!userClasspathFirst) {
+      Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+        Path.SEPARATOR + APP_JAR)
+    }
+    Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+      Path.SEPARATOR + "*")
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/ebdfa6bb/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
----------------------------------------------------------------------
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
new file mode 100644
index 0000000..4d9cca0
--- /dev/null
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.Socket
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+import akka.actor._
+import akka.remote._
+import akka.actor.Terminated
+import org.apache.spark.{SparkConf, SparkContext, Logging}
+import org.apache.spark.util.{Utils, AkkaUtils}
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.scheduler.SplitInfo
+import org.apache.hadoop.yarn.client.api.AMRMClient
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+
+class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
+
+  def this(args: ApplicationMasterArguments) = this(args, new Configuration())
+
+  private var appAttemptId: ApplicationAttemptId = _
+  private var reporterThread: Thread = _
+  private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+  private var yarnAllocator: YarnAllocationHandler = _
+  private var driverClosed:Boolean = false
+
+  private var amClient: AMRMClient[ContainerRequest] = _
+  private val sparkConf = new SparkConf
+
+  val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0,
+    conf = sparkConf)._1
+  var actor: ActorRef = _
+
+  // This actor just working as a monitor to watch on Driver Actor.
+  class MonitorActor(driverUrl: String) extends Actor {
+
+    var driver: ActorSelection = _
+
+    override def preStart() {
+      logInfo("Listen to driver: " + driverUrl)
+      driver = context.actorSelection(driverUrl)
+      // Send a hello message thus the connection is actually established, thus we can monitor Lifecycle Events.
+      driver ! "Hello"
+      context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+    }
+
+    override def receive = {
+      case x: DisassociatedEvent =>
+        logInfo("Driver terminated or disconnected! Shutting down. $x")
+        driverClosed = true
+    }
+  }
+
+  def run() {
+
+    amClient = AMRMClient.createAMRMClient()
+    amClient.init(yarnConf)
+    amClient.start()
+
+    appAttemptId = getApplicationAttemptId()
+    registerApplicationMaster()
+
+    waitForSparkMaster()
+
+    // Allocate all containers
+    allocateWorkers()
+
+    // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
+    // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
+
+    val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+    // must be <= timeoutInterval/ 2.
+    // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
+    // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
+    val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
+    reporterThread = launchReporterThread(interval)
+
+    // Wait for the reporter thread to Finish.
+    reporterThread.join()
+
+    finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
+    actorSystem.shutdown()
+
+    logInfo("Exited")
+    System.exit(0)
+  }
+
+  private def getApplicationAttemptId(): ApplicationAttemptId = {
+    val envs = System.getenv()
+    val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name())
+    val containerId = ConverterUtils.toContainerId(containerIdString)
+    val appAttemptId = containerId.getApplicationAttemptId()
+    logInfo("ApplicationAttemptId: " + appAttemptId)
+    appAttemptId
+  }
+
+  private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
+    logInfo("Registering the ApplicationMaster")
+    // TODO:(Raymond) Find out Spark UI address and fill in here?
+    amClient.registerApplicationMaster(Utils.localHostName(), 0, "")
+  }
+
+  private def waitForSparkMaster() {
+    logInfo("Waiting for Spark driver to be reachable.")
+    var driverUp = false
+    val hostport = args.userArgs(0)
+    val (driverHost, driverPort) = Utils.parseHostPort(hostport)
+    while(!driverUp) {
+      try {
+        val socket = new Socket(driverHost, driverPort)
+        socket.close()
+        logInfo("Driver now available: %s:%s".format(driverHost, driverPort))
+        driverUp = true
+      } catch {
+        case e: Exception =>
+          logError("Failed to connect to driver at %s:%s, retrying ...".
+            format(driverHost, driverPort))
+        Thread.sleep(100)
+      }
+    }
+    sparkConf.set("spark.driver.host",  driverHost)
+    sparkConf.set("spark.driver.port",  driverPort.toString)
+
+    val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
+      driverHost, driverPort.toString, CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
+    actor = actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM")
+  }
+
+
+  private def allocateWorkers() {
+
+    // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now.
+    val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] =
+      scala.collection.immutable.Map()
+
+    yarnAllocator = YarnAllocationHandler.newAllocator(
+      yarnConf,
+      amClient,
+      appAttemptId,
+      args,
+      preferredNodeLocationData,
+      sparkConf)
+
+    logInfo("Allocating " + args.numWorkers + " workers.")
+    // Wait until all containers have finished
+    // TODO: This is a bit ugly. Can we make it nicer?
+    // TODO: Handle container failure
+
+    yarnAllocator.addResourceRequests(args.numWorkers)
+    while(yarnAllocator.getNumWorkersRunning < args.numWorkers) {
+      yarnAllocator.allocateResources()
+      Thread.sleep(100)
+    }
+
+    logInfo("All workers have launched.")
+
+  }
+
+  // TODO: We might want to extend this to allocate more containers in case they die !
+  private def launchReporterThread(_sleepTime: Long): Thread = {
+    val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
+
+    val t = new Thread {
+      override def run() {
+        while (!driverClosed) {
+          val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning -
+            yarnAllocator.getNumPendingAllocate
+          if (missingWorkerCount > 0) {
+            logInfo("Allocating %d containers to make up for (potentially) lost containers".
+              format(missingWorkerCount))
+            yarnAllocator.addResourceRequests(missingWorkerCount)
+          }
+          sendProgress()
+          Thread.sleep(sleepTime)
+        }
+      }
+    }
+    // setting to daemon status, though this is usually not a good idea.
+    t.setDaemon(true)
+    t.start()
+    logInfo("Started progress reporter thread - sleep time : " + sleepTime)
+    t
+  }
+
+  private def sendProgress() {
+    logDebug("Sending progress")
+    // simulated with an allocate request with no nodes requested ...
+    yarnAllocator.allocateResources()
+  }
+
+  def finishApplicationMaster(status: FinalApplicationStatus) {
+    logInfo("finish ApplicationMaster with " + status)
+    amClient.unregisterApplicationMaster(status, "" /* appMessage */, "" /* appTrackingUrl */)
+  }
+
+}
+
+
+object WorkerLauncher {
+  def main(argStrings: Array[String]) {
+    val args = new ApplicationMasterArguments(argStrings)
+    new WorkerLauncher(args).run()
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/ebdfa6bb/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
----------------------------------------------------------------------
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
new file mode 100644
index 0000000..9f5523c
--- /dev/null
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+import java.nio.ByteBuffer
+import java.security.PrivilegedExceptionAction
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.HashMap
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.DataOutputBuffer
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.records.impl.pb.ProtoUtils
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.client.api.NMClient
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records}
+
+import org.apache.spark.Logging
+
+
+class WorkerRunnable(
+    container: Container,
+    conf: Configuration,
+    masterAddress: String,
+    slaveId: String,
+    hostname: String,
+    workerMemory: Int,
+    workerCores: Int) 
+  extends Runnable with Logging {
+
+  var rpc: YarnRPC = YarnRPC.create(conf)
+  var nmClient: NMClient = _
+  val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+  def run = {
+    logInfo("Starting Worker Container")
+    nmClient = NMClient.createNMClient()
+    nmClient.init(yarnConf)
+    nmClient.start()
+    startContainer
+  }
+
+  def startContainer = {
+    logInfo("Setting up ContainerLaunchContext")
+
+    val ctx = Records.newRecord(classOf[ContainerLaunchContext])
+      .asInstanceOf[ContainerLaunchContext]
+
+    val localResources = prepareLocalResources
+    ctx.setLocalResources(localResources)
+
+    val env = prepareEnvironment
+    ctx.setEnvironment(env)
+
+    // Extra options for the JVM
+    var JAVA_OPTS = ""
+    // Set the JVM memory
+    val workerMemoryString = workerMemory + "m"
+    JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " "
+    if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
+      JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
+    }
+
+    JAVA_OPTS += " -Djava.io.tmpdir=" + 
+      new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " "
+
+    // Commenting it out for now - so that people can refer to the properties if required. Remove
+    // it once cpuset version is pushed out.
+    // The context is, default gc for server class machines end up using all cores to do gc - hence
+    // if there are multiple containers in same node, spark gc effects all other containers
+    // performance (which can also be other spark containers)
+    // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in
+    // multi-tenant environments. Not sure how default java gc behaves if it is limited to subset
+    // of cores on a node.
+/*
+    else {
+      // If no java_opts specified, default to using -XX:+CMSIncrementalMode
+      // It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont
+      // want to mess with it.
+      // In our expts, using (default) throughput collector has severe perf ramnifications in
+      // multi-tennent machines
+      // The options are based on
+      // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline
+      JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
+      JAVA_OPTS += " -XX:+CMSIncrementalMode "
+      JAVA_OPTS += " -XX:+CMSIncrementalPacing "
+      JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
+      JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
+    }
+*/
+
+    val credentials = UserGroupInformation.getCurrentUser().getCredentials()
+    val dob = new DataOutputBuffer()
+    credentials.writeTokenStorageToStream(dob)
+    ctx.setTokens(ByteBuffer.wrap(dob.getData()))
+
+    var javaCommand = "java"
+    val javaHome = System.getenv("JAVA_HOME")
+    if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) {
+      javaCommand = Environment.JAVA_HOME.$() + "/bin/java"
+    }
+
+    val commands = List[String](javaCommand +
+      " -server " +
+      // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling.
+      // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in
+      // an inconsistent state.
+      // TODO: If the OOM is not recoverable by rescheduling it on different node, then do
+      // 'something' to fail job ... akin to blacklisting trackers in mapred ?
+      " -XX:OnOutOfMemoryError='kill %p' " +
+      JAVA_OPTS +
+      " org.apache.spark.executor.CoarseGrainedExecutorBackend " +
+      masterAddress + " " +
+      slaveId + " " +
+      hostname + " " +
+      workerCores +
+      " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
+      " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+    logInfo("Setting up worker with commands: " + commands)
+    ctx.setCommands(commands)
+
+    // Send the start request to the ContainerManager
+    nmClient.startContainer(container, ctx)
+  }
+
+  private def setupDistributedCache(
+      file: String,
+      rtype: LocalResourceType,
+      localResources: HashMap[String, LocalResource],
+      timestamp: String,
+      size: String, 
+      vis: String) = {
+    val uri = new URI(file)
+    val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+    amJarRsrc.setType(rtype)
+    amJarRsrc.setVisibility(LocalResourceVisibility.valueOf(vis))
+    amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri))
+    amJarRsrc.setTimestamp(timestamp.toLong)
+    amJarRsrc.setSize(size.toLong)
+    localResources(uri.getFragment()) = amJarRsrc
+  }
+
+  def prepareLocalResources: HashMap[String, LocalResource] = {
+    logInfo("Preparing Local resources")
+    val localResources = HashMap[String, LocalResource]()
+
+    if (System.getenv("SPARK_YARN_CACHE_FILES") != null) {
+      val timeStamps = System.getenv("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',')
+      val fileSizes = System.getenv("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',')
+      val distFiles = System.getenv("SPARK_YARN_CACHE_FILES").split(',')
+      val visibilities = System.getenv("SPARK_YARN_CACHE_FILES_VISIBILITIES").split(',')
+      for( i <- 0 to distFiles.length - 1) {
+        setupDistributedCache(distFiles(i), LocalResourceType.FILE, localResources, timeStamps(i),
+          fileSizes(i), visibilities(i))
+      }
+    }
+
+    if (System.getenv("SPARK_YARN_CACHE_ARCHIVES") != null) {
+      val timeStamps = System.getenv("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS").split(',')
+      val fileSizes = System.getenv("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES").split(',')
+      val distArchives = System.getenv("SPARK_YARN_CACHE_ARCHIVES").split(',')
+      val visibilities = System.getenv("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES").split(',')
+      for( i <- 0 to distArchives.length - 1) {
+        setupDistributedCache(distArchives(i), LocalResourceType.ARCHIVE, localResources, 
+          timeStamps(i), fileSizes(i), visibilities(i))
+      }
+    }
+
+    logInfo("Prepared Local resources " + localResources)
+    localResources
+  }
+
+  def prepareEnvironment: HashMap[String, String] = {
+    val env = new HashMap[String, String]()
+
+    Client.populateClasspath(yarnConf, System.getenv("SPARK_YARN_LOG4J_PATH") != null, env)
+
+    // Allow users to specify some environment variables
+    Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))
+
+    System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
+    env
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/ebdfa6bb/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
----------------------------------------------------------------------
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
new file mode 100644
index 0000000..8a9a73f
--- /dev/null
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -0,0 +1,694 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.lang.{Boolean => JBoolean}
+import java.util.{Collections, Set => JSet}
+import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection
+import scala.collection.JavaConversions._
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl}
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.util.Utils
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.yarn.api.ApplicationMasterProtocol
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId
+import org.apache.hadoop.yarn.api.records.{Container, ContainerId, ContainerStatus}
+import org.apache.hadoop.yarn.api.records.{Priority, Resource, ResourceRequest}
+import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
+import org.apache.hadoop.yarn.client.api.AMRMClient
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+import org.apache.hadoop.yarn.util.{RackResolver, Records}
+
+
+object AllocationType extends Enumeration {
+  type AllocationType = Value
+  val HOST, RACK, ANY = Value
+}
+
+// TODO:
+// Too many params.
+// Needs to be mt-safe
+// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should
+// make it more proactive and decoupled.
+
+// Note that right now, we assume all node asks as uniform in terms of capabilities and priority
+// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for
+// more info on how we are requesting for containers.
+private[yarn] class YarnAllocationHandler(
+    val conf: Configuration,
+    val amClient: AMRMClient[ContainerRequest],
+    val appAttemptId: ApplicationAttemptId,
+    val maxWorkers: Int,
+    val workerMemory: Int,
+    val workerCores: Int,
+    val preferredHostToCount: Map[String, Int], 
+    val preferredRackToCount: Map[String, Int],
+    val sparkConf: SparkConf)
+  extends Logging {
+  // These three are locked on allocatedHostToContainersMap. Complementary data structures
+  // allocatedHostToContainersMap : containers which are running : host, Set<containerid>
+  // allocatedContainerToHostMap: container to host mapping.
+  private val allocatedHostToContainersMap =
+    new HashMap[String, collection.mutable.Set[ContainerId]]()
+
+  private val allocatedContainerToHostMap = new HashMap[ContainerId, String]()
+
+  // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an
+  // allocated node)
+  // As with the two data structures above, tightly coupled with them, and to be locked on
+  // allocatedHostToContainersMap
+  private val allocatedRackCount = new HashMap[String, Int]()
+
+  // Containers which have been released.
+  private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]()
+  // Containers to be released in next request to RM
+  private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean]
+
+  // Number of container requests that have been sent to, but not yet allocated by the
+  // ApplicationMaster.
+  private val numPendingAllocate = new AtomicInteger()
+  private val numWorkersRunning = new AtomicInteger()
+  // Used to generate a unique id per worker
+  private val workerIdCounter = new AtomicInteger()
+  private val lastResponseId = new AtomicInteger()
+  private val numWorkersFailed = new AtomicInteger()
+
+  def getNumPendingAllocate: Int = numPendingAllocate.intValue
+
+  def getNumWorkersRunning: Int = numWorkersRunning.intValue
+
+  def getNumWorkersFailed: Int = numWorkersFailed.intValue
+
+  def isResourceConstraintSatisfied(container: Container): Boolean = {
+    container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+  }
+
+  def releaseContainer(container: Container) {
+    val containerId = container.getId
+    pendingReleaseContainers.put(containerId, true)
+    amClient.releaseAssignedContainer(containerId)
+  }
+
+  def allocateResources() {
+    // We have already set the container request. Poll the ResourceManager for a response.
+    // This doubles as a heartbeat if there are no pending container requests.
+    val progressIndicator = 0.1f
+    val allocateResponse = amClient.allocate(progressIndicator)
+
+    val allocatedContainers = allocateResponse.getAllocatedContainers()
+    if (allocatedContainers.size > 0) {
+      var numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * allocatedContainers.size)
+
+      if (numPendingAllocateNow < 0) {
+        numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * numPendingAllocateNow)
+      }
+
+      logDebug("""
+        Allocated containers: %d
+        Current worker count: %d
+        Containers released: %s
+        Containers to-be-released: %s
+        Cluster resources: %s
+        """.format(
+          allocatedContainers.size,
+          numWorkersRunning.get(),
+          releasedContainerList,
+          pendingReleaseContainers,
+          allocateResponse.getAvailableResources))
+
+      val hostToContainers = new HashMap[String, ArrayBuffer[Container]]()
+
+      for (container <- allocatedContainers) {
+        if (isResourceConstraintSatisfied(container)) {
+          // Add the accepted `container` to the host's list of already accepted,
+          // allocated containers
+          val host = container.getNodeId.getHost
+          val containersForHost = hostToContainers.getOrElseUpdate(host,
+            new ArrayBuffer[Container]())
+          containersForHost += container
+        } else {
+          // Release container, since it doesn't satisfy resource constraints.
+          releaseContainer(container)
+        }
+      }
+
+       // Find the appropriate containers to use.
+      // TODO: Cleanup this group-by...
+      val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
+      val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
+      val offRackContainers = new HashMap[String, ArrayBuffer[Container]]()
+
+      for (candidateHost <- hostToContainers.keySet) {
+        val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0)
+        val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost)
+
+        val remainingContainersOpt = hostToContainers.get(candidateHost)
+        assert(remainingContainersOpt.isDefined)
+        var remainingContainers = remainingContainersOpt.get
+
+        if (requiredHostCount >= remainingContainers.size) {
+          // Since we have <= required containers, add all remaining containers to
+          // `dataLocalContainers`.
+          dataLocalContainers.put(candidateHost, remainingContainers)
+          // There are no more free containers remaining.
+          remainingContainers = null
+        } else if (requiredHostCount > 0) {
+          // Container list has more containers than we need for data locality.
+          // Split the list into two: one based on the data local container count,
+          // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining
+          // containers.
+          val (dataLocal, remaining) = remainingContainers.splitAt(
+            remainingContainers.size - requiredHostCount)
+          dataLocalContainers.put(candidateHost, dataLocal)
+
+          // Invariant: remainingContainers == remaining
+
+          // YARN has a nasty habit of allocating a ton of containers on a host - discourage this.
+          // Add each container in `remaining` to list of containers to release. If we have an
+          // insufficient number of containers, then the next allocation cycle will reallocate
+          // (but won't treat it as data local).
+          // TODO(harvey): Rephrase this comment some more.
+          for (container <- remaining) releaseContainer(container)
+          remainingContainers = null
+        }
+
+        // For rack local containers
+        if (remainingContainers != null) {
+          val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
+          if (rack != null) {
+            val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0)
+            val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) -
+              rackLocalContainers.getOrElse(rack, List()).size
+
+            if (requiredRackCount >= remainingContainers.size) {
+              // Add all remaining containers to to `dataLocalContainers`.
+              dataLocalContainers.put(rack, remainingContainers)
+              remainingContainers = null
+            } else if (requiredRackCount > 0) {
+              // Container list has more containers that we need for data locality.
+              // Split the list into two: one based on the data local container count,
+              // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining
+              // containers.
+              val (rackLocal, remaining) = remainingContainers.splitAt(
+                remainingContainers.size - requiredRackCount)
+              val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack,
+                new ArrayBuffer[Container]())
+
+              existingRackLocal ++= rackLocal
+
+              remainingContainers = remaining
+            }
+          }
+        }
+
+        if (remainingContainers != null) {
+          // Not all containers have been consumed - add them to the list of off-rack containers.
+          offRackContainers.put(candidateHost, remainingContainers)
+        }
+      }
+
+      // Now that we have split the containers into various groups, go through them in order:
+      // first host-local, then rack-local, and finally off-rack.
+      // Note that the list we create below tries to ensure that not all containers end up within
+      // a host if there is a sufficiently large number of hosts/containers.
+      val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size)
+      allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers)
+      allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers)
+      allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers)
+
+      // Run each of the allocated containers.
+      for (container <- allocatedContainersToProcess) {
+        val numWorkersRunningNow = numWorkersRunning.incrementAndGet()
+        val workerHostname = container.getNodeId.getHost
+        val containerId = container.getId
+
+        val workerMemoryOverhead = (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+        assert(container.getResource.getMemory >= workerMemoryOverhead)
+
+        if (numWorkersRunningNow > maxWorkers) {
+          logInfo("""Ignoring container %s at host %s, since we already have the required number of
+            containers for it.""".format(containerId, workerHostname))
+          releaseContainer(container)
+          numWorkersRunning.decrementAndGet()
+        } else {
+          val workerId = workerIdCounter.incrementAndGet().toString
+          val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
+            sparkConf.get("spark.driver.host"),
+            sparkConf.get("spark.driver.port"),
+            CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
+          logInfo("Launching container %s for on host %s".format(containerId, workerHostname))
+
+          // To be safe, remove the container from `pendingReleaseContainers`.
+          pendingReleaseContainers.remove(containerId)
+
+          val rack = YarnAllocationHandler.lookupRack(conf, workerHostname)
+          allocatedHostToContainersMap.synchronized {
+            val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname,
+              new HashSet[ContainerId]())
+
+            containerSet += containerId
+            allocatedContainerToHostMap.put(containerId, workerHostname)
+
+            if (rack != null) {
+              allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
+            }
+          }
+          logInfo("Launching WorkerRunnable. driverUrl: %s,  workerHostname: %s".format(driverUrl, workerHostname))
+          val workerRunnable = new WorkerRunnable(
+            container,
+            conf,
+            driverUrl,
+            workerId,
+            workerHostname,
+            workerMemory,
+            workerCores)
+          new Thread(workerRunnable).start()
+        }
+      }
+      logDebug("""
+        Finished allocating %s containers (from %s originally).
+        Current number of workers running: %d,
+        releasedContainerList: %s,
+        pendingReleaseContainers: %s
+        """.format(
+          allocatedContainersToProcess,
+          allocatedContainers,
+          numWorkersRunning.get(),
+          releasedContainerList,
+          pendingReleaseContainers))
+    }
+
+    val completedContainers = allocateResponse.getCompletedContainersStatuses()
+    if (completedContainers.size > 0) {
+      logDebug("Completed %d containers".format(completedContainers.size))
+
+      for (completedContainer <- completedContainers) {
+        val containerId = completedContainer.getContainerId
+
+        if (pendingReleaseContainers.containsKey(containerId)) {
+          // YarnAllocationHandler already marked the container for release, so remove it from
+          // `pendingReleaseContainers`.
+          pendingReleaseContainers.remove(containerId)
+        } else {
+          // Decrement the number of workers running. The next iteration of the ApplicationMaster's
+          // reporting thread will take care of allocating.
+          numWorkersRunning.decrementAndGet()
+          logInfo("Completed container %s (state: %s, exit status: %s)".format(
+            containerId,
+            completedContainer.getState,
+            completedContainer.getExitStatus()))
+          // Hadoop 2.2.X added a ContainerExitStatus we should switch to use
+          // there are some exit status' we shouldn't necessarily count against us, but for
+          // now I think its ok as none of the containers are expected to exit
+          if (completedContainer.getExitStatus() != 0) {
+            logInfo("Container marked as failed: " + containerId)
+            numWorkersFailed.incrementAndGet()
+          }
+        }
+
+        allocatedHostToContainersMap.synchronized {
+          if (allocatedContainerToHostMap.containsKey(containerId)) {
+            val hostOpt = allocatedContainerToHostMap.get(containerId)
+            assert(hostOpt.isDefined)
+            val host = hostOpt.get
+
+            val containerSetOpt = allocatedHostToContainersMap.get(host)
+            assert(containerSetOpt.isDefined)
+            val containerSet = containerSetOpt.get
+
+            containerSet.remove(containerId)
+            if (containerSet.isEmpty) {
+              allocatedHostToContainersMap.remove(host)
+            } else {
+              allocatedHostToContainersMap.update(host, containerSet)
+            }
+
+            allocatedContainerToHostMap.remove(containerId)
+
+            // TODO: Move this part outside the synchronized block?
+            val rack = YarnAllocationHandler.lookupRack(conf, host)
+            if (rack != null) {
+              val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1
+              if (rackCount > 0) {
+                allocatedRackCount.put(rack, rackCount)
+              } else {
+                allocatedRackCount.remove(rack)
+              }
+            }
+          }
+        }
+      }
+      logDebug("""
+        Finished processing %d completed containers.
+        Current number of workers running: %d,
+        releasedContainerList: %s,
+        pendingReleaseContainers: %s
+        """.format(
+          completedContainers.size,
+          numWorkersRunning.get(),
+          releasedContainerList,
+          pendingReleaseContainers))
+    }
+  }
+
+  def createRackResourceRequests(
+      hostContainers: ArrayBuffer[ContainerRequest]
+    ): ArrayBuffer[ContainerRequest] = {
+    // Generate modified racks and new set of hosts under it before issuing requests.
+    val rackToCounts = new HashMap[String, Int]()
+
+    for (container <- hostContainers) {
+      val candidateHost = container.getNodes.last
+      assert(YarnAllocationHandler.ANY_HOST != candidateHost)
+
+      val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
+      if (rack != null) {
+        var count = rackToCounts.getOrElse(rack, 0)
+        count += 1
+        rackToCounts.put(rack, count)
+      }
+    }
+
+    val requestedContainers = new ArrayBuffer[ContainerRequest](rackToCounts.size)
+    for ((rack, count) <- rackToCounts) {
+      requestedContainers ++= createResourceRequests(
+        AllocationType.RACK,
+        rack,
+        count,
+        YarnAllocationHandler.PRIORITY)
+    }
+
+    requestedContainers
+  }
+
+  def allocatedContainersOnHost(host: String): Int = {
+    var retval = 0
+    allocatedHostToContainersMap.synchronized {
+      retval = allocatedHostToContainersMap.getOrElse(host, Set()).size
+    }
+    retval
+  }
+
+  def allocatedContainersOnRack(rack: String): Int = {
+    var retval = 0
+    allocatedHostToContainersMap.synchronized {
+      retval = allocatedRackCount.getOrElse(rack, 0)
+    }
+    retval
+  }
+
+  def addResourceRequests(numWorkers: Int) {
+    val containerRequests: List[ContainerRequest] =
+      if (numWorkers <= 0 || preferredHostToCount.isEmpty) {
+        logDebug("numWorkers: " + numWorkers + ", host preferences: " +
+          preferredHostToCount.isEmpty)
+        createResourceRequests(
+          AllocationType.ANY,
+          resource = null,
+          numWorkers,
+          YarnAllocationHandler.PRIORITY).toList
+      } else {
+        // Request for all hosts in preferred nodes and for numWorkers - 
+        // candidates.size, request by default allocation policy.
+        val hostContainerRequests = new ArrayBuffer[ContainerRequest](preferredHostToCount.size)
+        for ((candidateHost, candidateCount) <- preferredHostToCount) {
+          val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
+
+          if (requiredCount > 0) {
+            hostContainerRequests ++= createResourceRequests(
+              AllocationType.HOST,
+              candidateHost,
+              requiredCount,
+              YarnAllocationHandler.PRIORITY)
+          }
+        }
+        val rackContainerRequests: List[ContainerRequest] = createRackResourceRequests(
+          hostContainerRequests).toList
+
+        val anyContainerRequests = createResourceRequests(
+          AllocationType.ANY,
+          resource = null,
+          numWorkers,
+          YarnAllocationHandler.PRIORITY)
+
+        val containerRequestBuffer = new ArrayBuffer[ContainerRequest](
+          hostContainerRequests.size + rackContainerRequests.size() + anyContainerRequests.size)
+
+        containerRequestBuffer ++= hostContainerRequests
+        containerRequestBuffer ++= rackContainerRequests
+        containerRequestBuffer ++= anyContainerRequests
+        containerRequestBuffer.toList
+      }
+
+    for (request <- containerRequests) {
+      amClient.addContainerRequest(request)
+    }
+
+    if (numWorkers > 0) {
+      numPendingAllocate.addAndGet(numWorkers)
+      logInfo("Will Allocate %d worker containers, each with %d memory".format(
+        numWorkers,
+        (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)))
+    } else {
+      logDebug("Empty allocation request ...")
+    }
+
+    for (request <- containerRequests) {
+      val nodes = request.getNodes
+      var hostStr = if (nodes == null || nodes.isEmpty) {
+        "Any"
+      } else {
+        nodes.last
+      }
+      logInfo("Container request (host: %s, priority: %s, capability: %s".format(
+        hostStr,
+        request.getPriority().getPriority,
+        request.getCapability))
+    }
+  }
+
+  private def createResourceRequests(
+      requestType: AllocationType.AllocationType,
+      resource: String,
+      numWorkers: Int,
+      priority: Int
+    ): ArrayBuffer[ContainerRequest] = {
+
+    // If hostname is specified, then we need at least two requests - node local and rack local.
+    // There must be a third request, which is ANY. That will be specially handled.
+    requestType match {
+      case AllocationType.HOST => {
+        assert(YarnAllocationHandler.ANY_HOST != resource)
+        val hostname = resource
+        val nodeLocal = constructContainerRequests(
+          Array(hostname),
+          racks = null,
+          numWorkers,
+          priority)
+
+        // Add `hostname` to the global (singleton) host->rack mapping in YarnAllocationHandler.
+        YarnAllocationHandler.populateRackInfo(conf, hostname)
+        nodeLocal
+      }
+      case AllocationType.RACK => {
+        val rack = resource
+        constructContainerRequests(hosts = null, Array(rack), numWorkers, priority)
+      }
+      case AllocationType.ANY => constructContainerRequests(
+        hosts = null, racks = null, numWorkers, priority)
+      case _ => throw new IllegalArgumentException(
+        "Unexpected/unsupported request type: " + requestType)
+    }
+  }
+
+  private def constructContainerRequests(
+      hosts: Array[String],
+      racks: Array[String],
+      numWorkers: Int,
+      priority: Int
+    ): ArrayBuffer[ContainerRequest] = {
+
+    val memoryResource = Records.newRecord(classOf[Resource])
+    memoryResource.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+
+    val prioritySetting = Records.newRecord(classOf[Priority])
+    prioritySetting.setPriority(priority)
+
+    val requests = new ArrayBuffer[ContainerRequest]()
+    for (i <- 0 until numWorkers) {
+      requests += new ContainerRequest(memoryResource, hosts, racks, prioritySetting)
+    }
+    requests
+  }
+}
+
+object YarnAllocationHandler {
+
+  val ANY_HOST = "*"
+  // All requests are issued with same priority : we do not (yet) have any distinction between 
+  // request types (like map/reduce in hadoop for example)
+  val PRIORITY = 1
+
+  // Additional memory overhead - in mb.
+  val MEMORY_OVERHEAD = 384
+
+  // Host to rack map - saved from allocation requests. We are expecting this not to change.
+  // Note that it is possible for this to change : and ResurceManager will indicate that to us via
+  // update response to allocate. But we are punting on handling that for now.
+  private val hostToRack = new ConcurrentHashMap[String, String]()
+  private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]()
+
+
+  def newAllocator(
+      conf: Configuration,
+      amClient: AMRMClient[ContainerRequest],
+      appAttemptId: ApplicationAttemptId,
+      args: ApplicationMasterArguments,
+      sparkConf: SparkConf
+    ): YarnAllocationHandler = {
+    new YarnAllocationHandler(
+      conf,
+      amClient,
+      appAttemptId,
+      args.numWorkers, 
+      args.workerMemory,
+      args.workerCores,
+      Map[String, Int](),
+      Map[String, Int](),
+      sparkConf)
+  }
+
+  def newAllocator(
+      conf: Configuration,
+      amClient: AMRMClient[ContainerRequest],
+      appAttemptId: ApplicationAttemptId,
+      args: ApplicationMasterArguments,
+      map: collection.Map[String,
+      collection.Set[SplitInfo]],
+      sparkConf: SparkConf
+    ): YarnAllocationHandler = {
+    val (hostToSplitCount, rackToSplitCount) = generateNodeToWeight(conf, map)
+    new YarnAllocationHandler(
+      conf,
+      amClient,
+      appAttemptId,
+      args.numWorkers, 
+      args.workerMemory,
+      args.workerCores,
+      hostToSplitCount,
+      rackToSplitCount,
+      sparkConf)
+  }
+
+  def newAllocator(
+      conf: Configuration,
+      amClient: AMRMClient[ContainerRequest],
+      appAttemptId: ApplicationAttemptId,
+      maxWorkers: Int,
+      workerMemory: Int,
+      workerCores: Int,
+      map: collection.Map[String, collection.Set[SplitInfo]],
+      sparkConf: SparkConf
+    ): YarnAllocationHandler = {
+    val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
+    new YarnAllocationHandler(
+      conf,
+      amClient,
+      appAttemptId,
+      maxWorkers,
+      workerMemory,
+      workerCores,
+      hostToCount,
+      rackToCount,
+      sparkConf)
+  }
+
+  // A simple method to copy the split info map.
+  private def generateNodeToWeight(
+      conf: Configuration,
+      input: collection.Map[String, collection.Set[SplitInfo]]
+    ): (Map[String, Int], Map[String, Int]) = {
+
+    if (input == null) {
+      return (Map[String, Int](), Map[String, Int]())
+    }
+
+    val hostToCount = new HashMap[String, Int]
+    val rackToCount = new HashMap[String, Int]
+
+    for ((host, splits) <- input) {
+      val hostCount = hostToCount.getOrElse(host, 0)
+      hostToCount.put(host, hostCount + splits.size)
+
+      val rack = lookupRack(conf, host)
+      if (rack != null){
+        val rackCount = rackToCount.getOrElse(host, 0)
+        rackToCount.put(host, rackCount + splits.size)
+      }
+    }
+
+    (hostToCount.toMap, rackToCount.toMap)
+  }
+
+  def lookupRack(conf: Configuration, host: String): String = {
+    if (!hostToRack.contains(host)) {
+      populateRackInfo(conf, host)
+    }
+    hostToRack.get(host)
+  }
+
+  def fetchCachedHostsForRack(rack: String): Option[Set[String]] = {
+    Option(rackToHostSet.get(rack)).map { set =>
+      val convertedSet: collection.mutable.Set[String] = set
+      // TODO: Better way to get a Set[String] from JSet.
+      convertedSet.toSet
+    }
+  }
+
+  def populateRackInfo(conf: Configuration, hostname: String) {
+    Utils.checkHost(hostname)
+
+    if (!hostToRack.containsKey(hostname)) {
+      // If there are repeated failures to resolve, all to an ignore list.
+      val rackInfo = RackResolver.resolve(conf, hostname)
+      if (rackInfo != null && rackInfo.getNetworkLocation != null) {
+        val rack = rackInfo.getNetworkLocation
+        hostToRack.put(hostname, rack)
+        if (! rackToHostSet.containsKey(rack)) {
+          rackToHostSet.putIfAbsent(rack,
+            Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]()))
+        }
+        rackToHostSet.get(rack).add(hostname)
+
+        // TODO(harvey): Figure out what this comment means...
+        // Since RackResolver caches, we are disabling this for now ...
+      } /* else {
+        // right ? Else we will keep calling rack resolver in case we cant resolve rack info ...
+        hostToRack.put(hostname, null)
+      } */
+    }
+  }
+}


Mime
View raw message