Return-Path: X-Original-To: apmail-spark-commits-archive@minotaur.apache.org Delivered-To: apmail-spark-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 25DC410875 for ; Fri, 3 Jan 2014 07:17:55 +0000 (UTC) Received: (qmail 44025 invoked by uid 500); 3 Jan 2014 07:17:11 -0000 Delivered-To: apmail-spark-commits-archive@spark.apache.org Received: (qmail 43889 invoked by uid 500); 3 Jan 2014 07:17:07 -0000 Mailing-List: contact commits-help@spark.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@spark.incubator.apache.org Delivered-To: mailing list commits@spark.incubator.apache.org Received: (qmail 43538 invoked by uid 99); 3 Jan 2014 07:16:53 -0000 Received: from athena.apache.org (HELO athena.apache.org) (140.211.11.136) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 03 Jan 2014 07:16:53 +0000 X-ASF-Spam-Status: No, hits=-1998.4 required=5.0 tests=ALL_TRUSTED,FB_GET_MEDS,RP_MATCHES_RCVD X-Spam-Check-By: apache.org Received: from [140.211.11.3] (HELO mail.apache.org) (140.211.11.3) by apache.org (qpsmtpd/0.29) with SMTP; Fri, 03 Jan 2014 07:16:46 +0000 Received: (qmail 42058 invoked by uid 99); 3 Jan 2014 07:16:15 -0000 Received: from tyr.zones.apache.org (HELO tyr.zones.apache.org) (140.211.11.114) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 03 Jan 2014 07:16:15 +0000 Received: by tyr.zones.apache.org (Postfix, from userid 65534) id D3693915492; Fri, 3 Jan 2014 07:16:14 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: pwendell@apache.org To: commits@spark.incubator.apache.org Date: Fri, 03 Jan 2014 07:16:37 -0000 Message-Id: <5024368686724539a40974421b8bd48c@git.apache.org> In-Reply-To: References: X-Mailer: ASF-Git Admin Mailer Subject: [25/32] Using name yarn-alpha/yarn instead of yarn-2.0/yarn-2.2 X-Virus-Checked: Checked by ClamAV on apache.org http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/ebdfa6bb/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala ---------------------------------------------------------------------- diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala new file mode 100644 index 0000000..a750668 --- /dev/null +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -0,0 +1,523 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.{InetAddress, UnknownHostException, URI} +import java.nio.ByteBuffer + +import scala.collection.JavaConversions._ +import scala.collection.mutable.HashMap +import scala.collection.mutable.Map + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileContext, FileStatus, FileSystem, Path, FileUtil} +import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.io.DataOutputBuffer +import org.apache.hadoop.mapred.Master +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.client.api.impl.YarnClientImpl +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{Apps, Records} + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.Utils +import org.apache.spark.deploy.SparkHadoopUtil + + +/** + * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The + * Client submits an application to the global ResourceManager to launch Spark's ApplicationMaster, + * which will launch a Spark master process and negotiate resources throughout its duration. + */ +class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging { + + var rpc: YarnRPC = YarnRPC.create(conf) + val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + val credentials = UserGroupInformation.getCurrentUser().getCredentials() + private val SPARK_STAGING: String = ".sparkStaging" + private val distCacheMgr = new ClientDistributedCacheManager() + private val sparkConf = new SparkConf + + + // Staging directory is private! -> rwx-------- + val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700: Short) + // App files are world-wide readable and owner writable -> rw-r--r-- + val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644: Short) + + def this(args: ClientArguments) = this(new Configuration(), args) + + def runApp(): ApplicationId = { + validateArgs() + // Initialize and start the client service. + init(yarnConf) + start() + + // Log details about this YARN cluster (e.g, the number of slave machines/NodeManagers). + logClusterResourceDetails() + + // Prepare to submit a request to the ResourcManager (specifically its ApplicationsManager (ASM) + // interface). + + // Get a new client application. + val newApp = super.createApplication() + val newAppResponse = newApp.getNewApplicationResponse() + val appId = newAppResponse.getApplicationId() + + verifyClusterResources(newAppResponse) + + // Set up resource and environment variables. + val appStagingDir = getAppStagingDir(appId) + val localResources = prepareLocalResources(appStagingDir) + val launchEnv = setupLaunchEnv(localResources, appStagingDir) + val amContainer = createContainerLaunchContext(newAppResponse, localResources, launchEnv) + + // Set up an application submission context. + val appContext = newApp.getApplicationSubmissionContext() + appContext.setApplicationName(args.appName) + appContext.setQueue(args.amQueue) + appContext.setAMContainerSpec(amContainer) + + // Memory for the ApplicationMaster. + val memoryResource = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] + memoryResource.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + appContext.setResource(memoryResource) + + // Finally, submit and monitor the application. + submitApp(appContext) + appId + } + + def run() { + val appId = runApp() + monitorApplication(appId) + System.exit(0) + } + + // TODO(harvey): This could just go in ClientArguments. + def validateArgs() = { + Map( + (System.getenv("SPARK_JAR") == null) -> "Error: You must set SPARK_JAR environment variable!", + (args.userJar == null) -> "Error: You must specify a user jar!", + (args.userClass == null) -> "Error: You must specify a user class!", + (args.numWorkers <= 0) -> "Error: You must specify atleast 1 worker!", + (args.amMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: AM memory size must be" + + "greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD), + (args.workerMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: Worker memory size" + + "must be greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD.toString) + ).foreach { case(cond, errStr) => + if (cond) { + logError(errStr) + args.printUsageAndExit(1) + } + } + } + + def getAppStagingDir(appId: ApplicationId): String = { + SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR + } + + def logClusterResourceDetails() { + val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics + logInfo("Got Cluster metric info from ApplicationsManager (ASM), number of NodeManagers: " + + clusterMetrics.getNumNodeManagers) + + val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue) + logInfo("""Queue info ... queueName: %s, queueCurrentCapacity: %s, queueMaxCapacity: %s, + queueApplicationCount = %s, queueChildQueueCount = %s""".format( + queueInfo.getQueueName, + queueInfo.getCurrentCapacity, + queueInfo.getMaximumCapacity, + queueInfo.getApplications.size, + queueInfo.getChildQueues.size)) + } + + def verifyClusterResources(app: GetNewApplicationResponse) = { + val maxMem = app.getMaximumResourceCapability().getMemory() + logInfo("Max mem capabililty of a single resource in this cluster " + maxMem) + + // If we have requested more then the clusters max for a single resource then exit. + if (args.workerMemory > maxMem) { + logError("Required worker memory (%d MB), is above the max threshold (%d MB) of this cluster.". + format(args.workerMemory, maxMem)) + System.exit(1) + } + val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD + if (amMem > maxMem) { + logError("Required AM memory (%d) is above the max threshold (%d) of this cluster". + format(args.amMemory, maxMem)) + System.exit(1) + } + + // We could add checks to make sure the entire cluster has enough resources but that involves + // getting all the node reports and computing ourselves. + } + + /** See if two file systems are the same or not. */ + private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { + val srcUri = srcFs.getUri() + val dstUri = destFs.getUri() + if (srcUri.getScheme() == null) { + return false + } + if (!srcUri.getScheme().equals(dstUri.getScheme())) { + return false + } + var srcHost = srcUri.getHost() + var dstHost = dstUri.getHost() + if ((srcHost != null) && (dstHost != null)) { + try { + srcHost = InetAddress.getByName(srcHost).getCanonicalHostName() + dstHost = InetAddress.getByName(dstHost).getCanonicalHostName() + } catch { + case e: UnknownHostException => + return false + } + if (!srcHost.equals(dstHost)) { + return false + } + } else if (srcHost == null && dstHost != null) { + return false + } else if (srcHost != null && dstHost == null) { + return false + } + //check for ports + if (srcUri.getPort() != dstUri.getPort()) { + return false + } + return true + } + + /** Copy the file into HDFS if needed. */ + private def copyRemoteFile( + dstDir: Path, + originalPath: Path, + replication: Short, + setPerms: Boolean = false): Path = { + val fs = FileSystem.get(conf) + val remoteFs = originalPath.getFileSystem(conf) + var newPath = originalPath + if (! compareFs(remoteFs, fs)) { + newPath = new Path(dstDir, originalPath.getName()) + logInfo("Uploading " + originalPath + " to " + newPath) + FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf) + fs.setReplication(newPath, replication) + if (setPerms) fs.setPermission(newPath, new FsPermission(APP_FILE_PERMISSION)) + } + // Resolve any symlinks in the URI path so using a "current" symlink to point to a specific + // version shows the specific version in the distributed cache configuration + val qualPath = fs.makeQualified(newPath) + val fc = FileContext.getFileContext(qualPath.toUri(), conf) + val destPath = fc.resolvePath(qualPath) + destPath + } + + def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { + logInfo("Preparing Local resources") + // Upload Spark and the application JAR to the remote file system if necessary. Add them as + // local resources to the application master. + val fs = FileSystem.get(conf) + + val delegTokenRenewer = Master.getMasterPrincipal(conf) + if (UserGroupInformation.isSecurityEnabled()) { + if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { + logError("Can't get Master Kerberos principal for use as renewer") + System.exit(1) + } + } + val dst = new Path(fs.getHomeDirectory(), appStagingDir) + val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort + + if (UserGroupInformation.isSecurityEnabled()) { + val dstFs = dst.getFileSystem(conf) + dstFs.addDelegationTokens(delegTokenRenewer, credentials) + } + + val localResources = HashMap[String, LocalResource]() + FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION)) + + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + + Map( + Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar, + Client.LOG4J_PROP -> System.getenv("SPARK_LOG4J_CONF") + ).foreach { case(destName, _localPath) => + val localPath: String = if (_localPath != null) _localPath.trim() else "" + if (! localPath.isEmpty()) { + var localURI = new URI(localPath) + // If not specified assume these are in the local filesystem to keep behavior like Hadoop + if (localURI.getScheme() == null) { + localURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(localPath)).toString) + } + val setPermissions = if (destName.equals(Client.APP_JAR)) true else false + val destPath = copyRemoteFile(dst, new Path(localURI), replication, setPermissions) + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + destName, statCache) + } + } + + // Handle jars local to the ApplicationMaster. + if ((args.addJars != null) && (!args.addJars.isEmpty())){ + args.addJars.split(',').foreach { case file: String => + val localURI = new URI(file.trim()) + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyRemoteFile(dst, localPath, replication) + // Only add the resource to the Spark ApplicationMaster. + val appMasterOnly = true + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + linkname, statCache, appMasterOnly) + } + } + + // Handle any distributed cache files + if ((args.files != null) && (!args.files.isEmpty())){ + args.files.split(',').foreach { case file: String => + val localURI = new URI(file.trim()) + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyRemoteFile(dst, localPath, replication) + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + linkname, statCache) + } + } + + // Handle any distributed cache archives + if ((args.archives != null) && (!args.archives.isEmpty())) { + args.archives.split(',').foreach { case file:String => + val localURI = new URI(file.trim()) + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyRemoteFile(dst, localPath, replication) + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, + linkname, statCache) + } + } + + UserGroupInformation.getCurrentUser().addCredentials(credentials) + localResources + } + + def setupLaunchEnv( + localResources: HashMap[String, LocalResource], + stagingDir: String): HashMap[String, String] = { + logInfo("Setting up the launch environment") + val log4jConfLocalRes = localResources.getOrElse(Client.LOG4J_PROP, null) + + val env = new HashMap[String, String]() + + Client.populateClasspath(yarnConf, log4jConfLocalRes != null, env) + env("SPARK_YARN_MODE") = "true" + env("SPARK_YARN_STAGING_DIR") = stagingDir + + // Set the environment variables to be passed on to the Workers. + distCacheMgr.setDistFilesEnv(env) + distCacheMgr.setDistArchivesEnv(env) + + // Allow users to specify some environment variables. + Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV")) + + // Add each SPARK_* key to the environment. + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + + env + } + + def userArgsToString(clientArgs: ClientArguments): String = { + val prefix = " --args " + val args = clientArgs.userArgs + val retval = new StringBuilder() + for (arg <- args){ + retval.append(prefix).append(" '").append(arg).append("' ") + } + retval.toString + } + + def createContainerLaunchContext( + newApp: GetNewApplicationResponse, + localResources: HashMap[String, LocalResource], + env: HashMap[String, String]): ContainerLaunchContext = { + logInfo("Setting up container launch context") + val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) + amContainer.setLocalResources(localResources) + amContainer.setEnvironment(env) + + // TODO: Need a replacement for the following code to fix -Xmx? + // val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory() + // var amMemory = ((args.amMemory / minResMemory) * minResMemory) + + // ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) - + // YarnAllocationHandler.MEMORY_OVERHEAD) + + // Extra options for the JVM + var JAVA_OPTS = "" + + // Add Xmx for AM memory + JAVA_OPTS += "-Xmx" + args.amMemory + "m" + + val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + JAVA_OPTS += " -Djava.io.tmpdir=" + tmpDir + + // TODO: Remove once cpuset version is pushed out. + // The context is, default gc for server class machines ends up using all cores to do gc - + // hence if there are multiple containers in same node, Spark GC affects all other containers' + // performance (which can be that of other Spark containers) + // Instead of using this, rely on cpusets by YARN to enforce "proper" Spark behavior in + // multi-tenant environments. Not sure how default Java GC behaves if it is limited to subset + // of cores on a node. + val useConcurrentAndIncrementalGC = env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && + java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC")) + if (useConcurrentAndIncrementalGC) { + // In our expts, using (default) throughput collector has severe perf ramifications in + // multi-tenant machines + JAVA_OPTS += " -XX:+UseConcMarkSweepGC " + JAVA_OPTS += " -XX:+CMSIncrementalMode " + JAVA_OPTS += " -XX:+CMSIncrementalPacing " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 " + } + + if (env.isDefinedAt("SPARK_JAVA_OPTS")) { + JAVA_OPTS += " " + env("SPARK_JAVA_OPTS") + } + + // Command for the ApplicationMaster + var javaCommand = "java" + val javaHome = System.getenv("JAVA_HOME") + if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) { + javaCommand = Environment.JAVA_HOME.$() + "/bin/java" + } + + val commands = List[String]( + javaCommand + + " -server " + + JAVA_OPTS + + " " + args.amClass + + " --class " + args.userClass + + " --jar " + args.userJar + + userArgsToString(args) + + " --worker-memory " + args.workerMemory + + " --worker-cores " + args.workerCores + + " --num-workers " + args.numWorkers + + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + + logInfo("Command for starting the Spark ApplicationMaster: " + commands(0)) + amContainer.setCommands(commands) + + // Setup security tokens. + val dob = new DataOutputBuffer() + credentials.writeTokenStorageToStream(dob) + amContainer.setTokens(ByteBuffer.wrap(dob.getData())) + + amContainer + } + + def submitApp(appContext: ApplicationSubmissionContext) = { + // Submit the application to the applications manager. + logInfo("Submitting application to ASM") + super.submitApplication(appContext) + } + + def monitorApplication(appId: ApplicationId): Boolean = { + val interval = sparkConf.getLong("spark.yarn.report.interval", 1000) + + while (true) { + Thread.sleep(interval) + val report = super.getApplicationReport(appId) + + logInfo("Application report from ASM: \n" + + "\t application identifier: " + appId.toString() + "\n" + + "\t appId: " + appId.getId() + "\n" + + "\t clientToAMToken: " + report.getClientToAMToken() + "\n" + + "\t appDiagnostics: " + report.getDiagnostics() + "\n" + + "\t appMasterHost: " + report.getHost() + "\n" + + "\t appQueue: " + report.getQueue() + "\n" + + "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + + "\t appStartTime: " + report.getStartTime() + "\n" + + "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + + "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" + + "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" + + "\t appUser: " + report.getUser() + ) + + val state = report.getYarnApplicationState() + val dsStatus = report.getFinalApplicationStatus() + if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + return true + } + } + true + } +} + +object Client { + val SPARK_JAR: String = "spark.jar" + val APP_JAR: String = "app.jar" + val LOG4J_PROP: String = "log4j.properties" + + def main(argStrings: Array[String]) { + // Set an env variable indicating we are running in YARN mode. + // Note: anything env variable with SPARK_ prefix gets propagated to all (remote) processes - + // see Client#setupLaunchEnv(). + System.setProperty("SPARK_YARN_MODE", "true") + + val args = new ClientArguments(argStrings) + + (new Client(args)).run() + } + + // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps + def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) { + for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim) + } + } + + def populateClasspath(conf: Configuration, addLog4j: Boolean, env: HashMap[String, String]) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$()) + // If log4j present, ensure ours overrides all others + if (addLog4j) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Path.SEPARATOR + LOG4J_PROP) + } + // Normally the users app.jar is last in case conflicts with spark jars + val userClasspathFirst = new SparkConf().get("spark.yarn.user.classpath.first", "false") + .toBoolean + if (userClasspathFirst) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Path.SEPARATOR + APP_JAR) + } + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Path.SEPARATOR + SPARK_JAR) + Client.populateHadoopClasspath(conf, env) + + if (!userClasspathFirst) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Path.SEPARATOR + APP_JAR) + } + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Path.SEPARATOR + "*") + } +} http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/ebdfa6bb/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala ---------------------------------------------------------------------- diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala new file mode 100644 index 0000000..4d9cca0 --- /dev/null +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.Socket +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import akka.actor._ +import akka.remote._ +import akka.actor.Terminated +import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.SplitInfo +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest + +class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration) extends Logging { + + def this(args: ApplicationMasterArguments) = this(args, new Configuration()) + + private var appAttemptId: ApplicationAttemptId = _ + private var reporterThread: Thread = _ + private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + private var yarnAllocator: YarnAllocationHandler = _ + private var driverClosed:Boolean = false + + private var amClient: AMRMClient[ContainerRequest] = _ + private val sparkConf = new SparkConf + + val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, + conf = sparkConf)._1 + var actor: ActorRef = _ + + // This actor just working as a monitor to watch on Driver Actor. + class MonitorActor(driverUrl: String) extends Actor { + + var driver: ActorSelection = _ + + override def preStart() { + logInfo("Listen to driver: " + driverUrl) + driver = context.actorSelection(driverUrl) + // Send a hello message thus the connection is actually established, thus we can monitor Lifecycle Events. + driver ! "Hello" + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } + + override def receive = { + case x: DisassociatedEvent => + logInfo("Driver terminated or disconnected! Shutting down. $x") + driverClosed = true + } + } + + def run() { + + amClient = AMRMClient.createAMRMClient() + amClient.init(yarnConf) + amClient.start() + + appAttemptId = getApplicationAttemptId() + registerApplicationMaster() + + waitForSparkMaster() + + // Allocate all containers + allocateWorkers() + + // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout + // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. + + val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + // must be <= timeoutInterval/ 2. + // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM. + // so atleast 1 minute or timeoutInterval / 10 - whichever is higher. + val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L)) + reporterThread = launchReporterThread(interval) + + // Wait for the reporter thread to Finish. + reporterThread.join() + + finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) + actorSystem.shutdown() + + logInfo("Exited") + System.exit(0) + } + + private def getApplicationAttemptId(): ApplicationAttemptId = { + val envs = System.getenv() + val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name()) + val containerId = ConverterUtils.toContainerId(containerIdString) + val appAttemptId = containerId.getApplicationAttemptId() + logInfo("ApplicationAttemptId: " + appAttemptId) + appAttemptId + } + + private def registerApplicationMaster(): RegisterApplicationMasterResponse = { + logInfo("Registering the ApplicationMaster") + // TODO:(Raymond) Find out Spark UI address and fill in here? + amClient.registerApplicationMaster(Utils.localHostName(), 0, "") + } + + private def waitForSparkMaster() { + logInfo("Waiting for Spark driver to be reachable.") + var driverUp = false + val hostport = args.userArgs(0) + val (driverHost, driverPort) = Utils.parseHostPort(hostport) + while(!driverUp) { + try { + val socket = new Socket(driverHost, driverPort) + socket.close() + logInfo("Driver now available: %s:%s".format(driverHost, driverPort)) + driverUp = true + } catch { + case e: Exception => + logError("Failed to connect to driver at %s:%s, retrying ...". + format(driverHost, driverPort)) + Thread.sleep(100) + } + } + sparkConf.set("spark.driver.host", driverHost) + sparkConf.set("spark.driver.port", driverPort.toString) + + val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( + driverHost, driverPort.toString, CoarseGrainedSchedulerBackend.ACTOR_NAME) + + actor = actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM") + } + + + private def allocateWorkers() { + + // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now. + val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = + scala.collection.immutable.Map() + + yarnAllocator = YarnAllocationHandler.newAllocator( + yarnConf, + amClient, + appAttemptId, + args, + preferredNodeLocationData, + sparkConf) + + logInfo("Allocating " + args.numWorkers + " workers.") + // Wait until all containers have finished + // TODO: This is a bit ugly. Can we make it nicer? + // TODO: Handle container failure + + yarnAllocator.addResourceRequests(args.numWorkers) + while(yarnAllocator.getNumWorkersRunning < args.numWorkers) { + yarnAllocator.allocateResources() + Thread.sleep(100) + } + + logInfo("All workers have launched.") + + } + + // TODO: We might want to extend this to allocate more containers in case they die ! + private def launchReporterThread(_sleepTime: Long): Thread = { + val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime + + val t = new Thread { + override def run() { + while (!driverClosed) { + val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning - + yarnAllocator.getNumPendingAllocate + if (missingWorkerCount > 0) { + logInfo("Allocating %d containers to make up for (potentially) lost containers". + format(missingWorkerCount)) + yarnAllocator.addResourceRequests(missingWorkerCount) + } + sendProgress() + Thread.sleep(sleepTime) + } + } + } + // setting to daemon status, though this is usually not a good idea. + t.setDaemon(true) + t.start() + logInfo("Started progress reporter thread - sleep time : " + sleepTime) + t + } + + private def sendProgress() { + logDebug("Sending progress") + // simulated with an allocate request with no nodes requested ... + yarnAllocator.allocateResources() + } + + def finishApplicationMaster(status: FinalApplicationStatus) { + logInfo("finish ApplicationMaster with " + status) + amClient.unregisterApplicationMaster(status, "" /* appMessage */, "" /* appTrackingUrl */) + } + +} + + +object WorkerLauncher { + def main(argStrings: Array[String]) { + val args = new ApplicationMasterArguments(argStrings) + new WorkerLauncher(args).run() + } +} http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/ebdfa6bb/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala ---------------------------------------------------------------------- diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala new file mode 100644 index 0000000..9f5523c --- /dev/null +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.URI +import java.nio.ByteBuffer +import java.security.PrivilegedExceptionAction + +import scala.collection.JavaConversions._ +import scala.collection.mutable.HashMap + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.DataOutputBuffer +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.records.impl.pb.ProtoUtils +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.client.api.NMClient +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} + +import org.apache.spark.Logging + + +class WorkerRunnable( + container: Container, + conf: Configuration, + masterAddress: String, + slaveId: String, + hostname: String, + workerMemory: Int, + workerCores: Int) + extends Runnable with Logging { + + var rpc: YarnRPC = YarnRPC.create(conf) + var nmClient: NMClient = _ + val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + def run = { + logInfo("Starting Worker Container") + nmClient = NMClient.createNMClient() + nmClient.init(yarnConf) + nmClient.start() + startContainer + } + + def startContainer = { + logInfo("Setting up ContainerLaunchContext") + + val ctx = Records.newRecord(classOf[ContainerLaunchContext]) + .asInstanceOf[ContainerLaunchContext] + + val localResources = prepareLocalResources + ctx.setLocalResources(localResources) + + val env = prepareEnvironment + ctx.setEnvironment(env) + + // Extra options for the JVM + var JAVA_OPTS = "" + // Set the JVM memory + val workerMemoryString = workerMemory + "m" + JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " " + if (env.isDefinedAt("SPARK_JAVA_OPTS")) { + JAVA_OPTS += env("SPARK_JAVA_OPTS") + " " + } + + JAVA_OPTS += " -Djava.io.tmpdir=" + + new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " " + + // Commenting it out for now - so that people can refer to the properties if required. Remove + // it once cpuset version is pushed out. + // The context is, default gc for server class machines end up using all cores to do gc - hence + // if there are multiple containers in same node, spark gc effects all other containers + // performance (which can also be other spark containers) + // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in + // multi-tenant environments. Not sure how default java gc behaves if it is limited to subset + // of cores on a node. +/* + else { + // If no java_opts specified, default to using -XX:+CMSIncrementalMode + // It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont + // want to mess with it. + // In our expts, using (default) throughput collector has severe perf ramnifications in + // multi-tennent machines + // The options are based on + // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline + JAVA_OPTS += " -XX:+UseConcMarkSweepGC " + JAVA_OPTS += " -XX:+CMSIncrementalMode " + JAVA_OPTS += " -XX:+CMSIncrementalPacing " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 " + } +*/ + + val credentials = UserGroupInformation.getCurrentUser().getCredentials() + val dob = new DataOutputBuffer() + credentials.writeTokenStorageToStream(dob) + ctx.setTokens(ByteBuffer.wrap(dob.getData())) + + var javaCommand = "java" + val javaHome = System.getenv("JAVA_HOME") + if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) { + javaCommand = Environment.JAVA_HOME.$() + "/bin/java" + } + + val commands = List[String](javaCommand + + " -server " + + // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. + // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in + // an inconsistent state. + // TODO: If the OOM is not recoverable by rescheduling it on different node, then do + // 'something' to fail job ... akin to blacklisting trackers in mapred ? + " -XX:OnOutOfMemoryError='kill %p' " + + JAVA_OPTS + + " org.apache.spark.executor.CoarseGrainedExecutorBackend " + + masterAddress + " " + + slaveId + " " + + hostname + " " + + workerCores + + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + logInfo("Setting up worker with commands: " + commands) + ctx.setCommands(commands) + + // Send the start request to the ContainerManager + nmClient.startContainer(container, ctx) + } + + private def setupDistributedCache( + file: String, + rtype: LocalResourceType, + localResources: HashMap[String, LocalResource], + timestamp: String, + size: String, + vis: String) = { + val uri = new URI(file) + val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + amJarRsrc.setType(rtype) + amJarRsrc.setVisibility(LocalResourceVisibility.valueOf(vis)) + amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri)) + amJarRsrc.setTimestamp(timestamp.toLong) + amJarRsrc.setSize(size.toLong) + localResources(uri.getFragment()) = amJarRsrc + } + + def prepareLocalResources: HashMap[String, LocalResource] = { + logInfo("Preparing Local resources") + val localResources = HashMap[String, LocalResource]() + + if (System.getenv("SPARK_YARN_CACHE_FILES") != null) { + val timeStamps = System.getenv("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') + val fileSizes = System.getenv("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') + val distFiles = System.getenv("SPARK_YARN_CACHE_FILES").split(',') + val visibilities = System.getenv("SPARK_YARN_CACHE_FILES_VISIBILITIES").split(',') + for( i <- 0 to distFiles.length - 1) { + setupDistributedCache(distFiles(i), LocalResourceType.FILE, localResources, timeStamps(i), + fileSizes(i), visibilities(i)) + } + } + + if (System.getenv("SPARK_YARN_CACHE_ARCHIVES") != null) { + val timeStamps = System.getenv("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS").split(',') + val fileSizes = System.getenv("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES").split(',') + val distArchives = System.getenv("SPARK_YARN_CACHE_ARCHIVES").split(',') + val visibilities = System.getenv("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES").split(',') + for( i <- 0 to distArchives.length - 1) { + setupDistributedCache(distArchives(i), LocalResourceType.ARCHIVE, localResources, + timeStamps(i), fileSizes(i), visibilities(i)) + } + } + + logInfo("Prepared Local resources " + localResources) + localResources + } + + def prepareEnvironment: HashMap[String, String] = { + val env = new HashMap[String, String]() + + Client.populateClasspath(yarnConf, System.getenv("SPARK_YARN_LOG4J_PATH") != null, env) + + // Allow users to specify some environment variables + Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV")) + + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + env + } + +} http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/ebdfa6bb/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala ---------------------------------------------------------------------- diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala new file mode 100644 index 0000000..8a9a73f --- /dev/null +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -0,0 +1,694 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.lang.{Boolean => JBoolean} +import java.util.{Collections, Set => JSet} +import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap} +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.Utils + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.api.ApplicationMasterProtocol +import org.apache.hadoop.yarn.api.records.ApplicationAttemptId +import org.apache.hadoop.yarn.api.records.{Container, ContainerId, ContainerStatus} +import org.apache.hadoop.yarn.api.records.{Priority, Resource, ResourceRequest} +import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse} +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest +import org.apache.hadoop.yarn.util.{RackResolver, Records} + + +object AllocationType extends Enumeration { + type AllocationType = Value + val HOST, RACK, ANY = Value +} + +// TODO: +// Too many params. +// Needs to be mt-safe +// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should +// make it more proactive and decoupled. + +// Note that right now, we assume all node asks as uniform in terms of capabilities and priority +// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for +// more info on how we are requesting for containers. +private[yarn] class YarnAllocationHandler( + val conf: Configuration, + val amClient: AMRMClient[ContainerRequest], + val appAttemptId: ApplicationAttemptId, + val maxWorkers: Int, + val workerMemory: Int, + val workerCores: Int, + val preferredHostToCount: Map[String, Int], + val preferredRackToCount: Map[String, Int], + val sparkConf: SparkConf) + extends Logging { + // These three are locked on allocatedHostToContainersMap. Complementary data structures + // allocatedHostToContainersMap : containers which are running : host, Set + // allocatedContainerToHostMap: container to host mapping. + private val allocatedHostToContainersMap = + new HashMap[String, collection.mutable.Set[ContainerId]]() + + private val allocatedContainerToHostMap = new HashMap[ContainerId, String]() + + // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an + // allocated node) + // As with the two data structures above, tightly coupled with them, and to be locked on + // allocatedHostToContainersMap + private val allocatedRackCount = new HashMap[String, Int]() + + // Containers which have been released. + private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]() + // Containers to be released in next request to RM + private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] + + // Number of container requests that have been sent to, but not yet allocated by the + // ApplicationMaster. + private val numPendingAllocate = new AtomicInteger() + private val numWorkersRunning = new AtomicInteger() + // Used to generate a unique id per worker + private val workerIdCounter = new AtomicInteger() + private val lastResponseId = new AtomicInteger() + private val numWorkersFailed = new AtomicInteger() + + def getNumPendingAllocate: Int = numPendingAllocate.intValue + + def getNumWorkersRunning: Int = numWorkersRunning.intValue + + def getNumWorkersFailed: Int = numWorkersFailed.intValue + + def isResourceConstraintSatisfied(container: Container): Boolean = { + container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + } + + def releaseContainer(container: Container) { + val containerId = container.getId + pendingReleaseContainers.put(containerId, true) + amClient.releaseAssignedContainer(containerId) + } + + def allocateResources() { + // We have already set the container request. Poll the ResourceManager for a response. + // This doubles as a heartbeat if there are no pending container requests. + val progressIndicator = 0.1f + val allocateResponse = amClient.allocate(progressIndicator) + + val allocatedContainers = allocateResponse.getAllocatedContainers() + if (allocatedContainers.size > 0) { + var numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * allocatedContainers.size) + + if (numPendingAllocateNow < 0) { + numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * numPendingAllocateNow) + } + + logDebug(""" + Allocated containers: %d + Current worker count: %d + Containers released: %s + Containers to-be-released: %s + Cluster resources: %s + """.format( + allocatedContainers.size, + numWorkersRunning.get(), + releasedContainerList, + pendingReleaseContainers, + allocateResponse.getAvailableResources)) + + val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() + + for (container <- allocatedContainers) { + if (isResourceConstraintSatisfied(container)) { + // Add the accepted `container` to the host's list of already accepted, + // allocated containers + val host = container.getNodeId.getHost + val containersForHost = hostToContainers.getOrElseUpdate(host, + new ArrayBuffer[Container]()) + containersForHost += container + } else { + // Release container, since it doesn't satisfy resource constraints. + releaseContainer(container) + } + } + + // Find the appropriate containers to use. + // TODO: Cleanup this group-by... + val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() + + for (candidateHost <- hostToContainers.keySet) { + val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) + val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) + + val remainingContainersOpt = hostToContainers.get(candidateHost) + assert(remainingContainersOpt.isDefined) + var remainingContainers = remainingContainersOpt.get + + if (requiredHostCount >= remainingContainers.size) { + // Since we have <= required containers, add all remaining containers to + // `dataLocalContainers`. + dataLocalContainers.put(candidateHost, remainingContainers) + // There are no more free containers remaining. + remainingContainers = null + } else if (requiredHostCount > 0) { + // Container list has more containers than we need for data locality. + // Split the list into two: one based on the data local container count, + // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining + // containers. + val (dataLocal, remaining) = remainingContainers.splitAt( + remainingContainers.size - requiredHostCount) + dataLocalContainers.put(candidateHost, dataLocal) + + // Invariant: remainingContainers == remaining + + // YARN has a nasty habit of allocating a ton of containers on a host - discourage this. + // Add each container in `remaining` to list of containers to release. If we have an + // insufficient number of containers, then the next allocation cycle will reallocate + // (but won't treat it as data local). + // TODO(harvey): Rephrase this comment some more. + for (container <- remaining) releaseContainer(container) + remainingContainers = null + } + + // For rack local containers + if (remainingContainers != null) { + val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) + if (rack != null) { + val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) + val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - + rackLocalContainers.getOrElse(rack, List()).size + + if (requiredRackCount >= remainingContainers.size) { + // Add all remaining containers to to `dataLocalContainers`. + dataLocalContainers.put(rack, remainingContainers) + remainingContainers = null + } else if (requiredRackCount > 0) { + // Container list has more containers that we need for data locality. + // Split the list into two: one based on the data local container count, + // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining + // containers. + val (rackLocal, remaining) = remainingContainers.splitAt( + remainingContainers.size - requiredRackCount) + val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, + new ArrayBuffer[Container]()) + + existingRackLocal ++= rackLocal + + remainingContainers = remaining + } + } + } + + if (remainingContainers != null) { + // Not all containers have been consumed - add them to the list of off-rack containers. + offRackContainers.put(candidateHost, remainingContainers) + } + } + + // Now that we have split the containers into various groups, go through them in order: + // first host-local, then rack-local, and finally off-rack. + // Note that the list we create below tries to ensure that not all containers end up within + // a host if there is a sufficiently large number of hosts/containers. + val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers) + + // Run each of the allocated containers. + for (container <- allocatedContainersToProcess) { + val numWorkersRunningNow = numWorkersRunning.incrementAndGet() + val workerHostname = container.getNodeId.getHost + val containerId = container.getId + + val workerMemoryOverhead = (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + assert(container.getResource.getMemory >= workerMemoryOverhead) + + if (numWorkersRunningNow > maxWorkers) { + logInfo("""Ignoring container %s at host %s, since we already have the required number of + containers for it.""".format(containerId, workerHostname)) + releaseContainer(container) + numWorkersRunning.decrementAndGet() + } else { + val workerId = workerIdCounter.incrementAndGet().toString + val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( + sparkConf.get("spark.driver.host"), + sparkConf.get("spark.driver.port"), + CoarseGrainedSchedulerBackend.ACTOR_NAME) + + logInfo("Launching container %s for on host %s".format(containerId, workerHostname)) + + // To be safe, remove the container from `pendingReleaseContainers`. + pendingReleaseContainers.remove(containerId) + + val rack = YarnAllocationHandler.lookupRack(conf, workerHostname) + allocatedHostToContainersMap.synchronized { + val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, + new HashSet[ContainerId]()) + + containerSet += containerId + allocatedContainerToHostMap.put(containerId, workerHostname) + + if (rack != null) { + allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) + } + } + logInfo("Launching WorkerRunnable. driverUrl: %s, workerHostname: %s".format(driverUrl, workerHostname)) + val workerRunnable = new WorkerRunnable( + container, + conf, + driverUrl, + workerId, + workerHostname, + workerMemory, + workerCores) + new Thread(workerRunnable).start() + } + } + logDebug(""" + Finished allocating %s containers (from %s originally). + Current number of workers running: %d, + releasedContainerList: %s, + pendingReleaseContainers: %s + """.format( + allocatedContainersToProcess, + allocatedContainers, + numWorkersRunning.get(), + releasedContainerList, + pendingReleaseContainers)) + } + + val completedContainers = allocateResponse.getCompletedContainersStatuses() + if (completedContainers.size > 0) { + logDebug("Completed %d containers".format(completedContainers.size)) + + for (completedContainer <- completedContainers) { + val containerId = completedContainer.getContainerId + + if (pendingReleaseContainers.containsKey(containerId)) { + // YarnAllocationHandler already marked the container for release, so remove it from + // `pendingReleaseContainers`. + pendingReleaseContainers.remove(containerId) + } else { + // Decrement the number of workers running. The next iteration of the ApplicationMaster's + // reporting thread will take care of allocating. + numWorkersRunning.decrementAndGet() + logInfo("Completed container %s (state: %s, exit status: %s)".format( + containerId, + completedContainer.getState, + completedContainer.getExitStatus())) + // Hadoop 2.2.X added a ContainerExitStatus we should switch to use + // there are some exit status' we shouldn't necessarily count against us, but for + // now I think its ok as none of the containers are expected to exit + if (completedContainer.getExitStatus() != 0) { + logInfo("Container marked as failed: " + containerId) + numWorkersFailed.incrementAndGet() + } + } + + allocatedHostToContainersMap.synchronized { + if (allocatedContainerToHostMap.containsKey(containerId)) { + val hostOpt = allocatedContainerToHostMap.get(containerId) + assert(hostOpt.isDefined) + val host = hostOpt.get + + val containerSetOpt = allocatedHostToContainersMap.get(host) + assert(containerSetOpt.isDefined) + val containerSet = containerSetOpt.get + + containerSet.remove(containerId) + if (containerSet.isEmpty) { + allocatedHostToContainersMap.remove(host) + } else { + allocatedHostToContainersMap.update(host, containerSet) + } + + allocatedContainerToHostMap.remove(containerId) + + // TODO: Move this part outside the synchronized block? + val rack = YarnAllocationHandler.lookupRack(conf, host) + if (rack != null) { + val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 + if (rackCount > 0) { + allocatedRackCount.put(rack, rackCount) + } else { + allocatedRackCount.remove(rack) + } + } + } + } + } + logDebug(""" + Finished processing %d completed containers. + Current number of workers running: %d, + releasedContainerList: %s, + pendingReleaseContainers: %s + """.format( + completedContainers.size, + numWorkersRunning.get(), + releasedContainerList, + pendingReleaseContainers)) + } + } + + def createRackResourceRequests( + hostContainers: ArrayBuffer[ContainerRequest] + ): ArrayBuffer[ContainerRequest] = { + // Generate modified racks and new set of hosts under it before issuing requests. + val rackToCounts = new HashMap[String, Int]() + + for (container <- hostContainers) { + val candidateHost = container.getNodes.last + assert(YarnAllocationHandler.ANY_HOST != candidateHost) + + val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) + if (rack != null) { + var count = rackToCounts.getOrElse(rack, 0) + count += 1 + rackToCounts.put(rack, count) + } + } + + val requestedContainers = new ArrayBuffer[ContainerRequest](rackToCounts.size) + for ((rack, count) <- rackToCounts) { + requestedContainers ++= createResourceRequests( + AllocationType.RACK, + rack, + count, + YarnAllocationHandler.PRIORITY) + } + + requestedContainers + } + + def allocatedContainersOnHost(host: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedHostToContainersMap.getOrElse(host, Set()).size + } + retval + } + + def allocatedContainersOnRack(rack: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedRackCount.getOrElse(rack, 0) + } + retval + } + + def addResourceRequests(numWorkers: Int) { + val containerRequests: List[ContainerRequest] = + if (numWorkers <= 0 || preferredHostToCount.isEmpty) { + logDebug("numWorkers: " + numWorkers + ", host preferences: " + + preferredHostToCount.isEmpty) + createResourceRequests( + AllocationType.ANY, + resource = null, + numWorkers, + YarnAllocationHandler.PRIORITY).toList + } else { + // Request for all hosts in preferred nodes and for numWorkers - + // candidates.size, request by default allocation policy. + val hostContainerRequests = new ArrayBuffer[ContainerRequest](preferredHostToCount.size) + for ((candidateHost, candidateCount) <- preferredHostToCount) { + val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost) + + if (requiredCount > 0) { + hostContainerRequests ++= createResourceRequests( + AllocationType.HOST, + candidateHost, + requiredCount, + YarnAllocationHandler.PRIORITY) + } + } + val rackContainerRequests: List[ContainerRequest] = createRackResourceRequests( + hostContainerRequests).toList + + val anyContainerRequests = createResourceRequests( + AllocationType.ANY, + resource = null, + numWorkers, + YarnAllocationHandler.PRIORITY) + + val containerRequestBuffer = new ArrayBuffer[ContainerRequest]( + hostContainerRequests.size + rackContainerRequests.size() + anyContainerRequests.size) + + containerRequestBuffer ++= hostContainerRequests + containerRequestBuffer ++= rackContainerRequests + containerRequestBuffer ++= anyContainerRequests + containerRequestBuffer.toList + } + + for (request <- containerRequests) { + amClient.addContainerRequest(request) + } + + if (numWorkers > 0) { + numPendingAllocate.addAndGet(numWorkers) + logInfo("Will Allocate %d worker containers, each with %d memory".format( + numWorkers, + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))) + } else { + logDebug("Empty allocation request ...") + } + + for (request <- containerRequests) { + val nodes = request.getNodes + var hostStr = if (nodes == null || nodes.isEmpty) { + "Any" + } else { + nodes.last + } + logInfo("Container request (host: %s, priority: %s, capability: %s".format( + hostStr, + request.getPriority().getPriority, + request.getCapability)) + } + } + + private def createResourceRequests( + requestType: AllocationType.AllocationType, + resource: String, + numWorkers: Int, + priority: Int + ): ArrayBuffer[ContainerRequest] = { + + // If hostname is specified, then we need at least two requests - node local and rack local. + // There must be a third request, which is ANY. That will be specially handled. + requestType match { + case AllocationType.HOST => { + assert(YarnAllocationHandler.ANY_HOST != resource) + val hostname = resource + val nodeLocal = constructContainerRequests( + Array(hostname), + racks = null, + numWorkers, + priority) + + // Add `hostname` to the global (singleton) host->rack mapping in YarnAllocationHandler. + YarnAllocationHandler.populateRackInfo(conf, hostname) + nodeLocal + } + case AllocationType.RACK => { + val rack = resource + constructContainerRequests(hosts = null, Array(rack), numWorkers, priority) + } + case AllocationType.ANY => constructContainerRequests( + hosts = null, racks = null, numWorkers, priority) + case _ => throw new IllegalArgumentException( + "Unexpected/unsupported request type: " + requestType) + } + } + + private def constructContainerRequests( + hosts: Array[String], + racks: Array[String], + numWorkers: Int, + priority: Int + ): ArrayBuffer[ContainerRequest] = { + + val memoryResource = Records.newRecord(classOf[Resource]) + memoryResource.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + + val prioritySetting = Records.newRecord(classOf[Priority]) + prioritySetting.setPriority(priority) + + val requests = new ArrayBuffer[ContainerRequest]() + for (i <- 0 until numWorkers) { + requests += new ContainerRequest(memoryResource, hosts, racks, prioritySetting) + } + requests + } +} + +object YarnAllocationHandler { + + val ANY_HOST = "*" + // All requests are issued with same priority : we do not (yet) have any distinction between + // request types (like map/reduce in hadoop for example) + val PRIORITY = 1 + + // Additional memory overhead - in mb. + val MEMORY_OVERHEAD = 384 + + // Host to rack map - saved from allocation requests. We are expecting this not to change. + // Note that it is possible for this to change : and ResurceManager will indicate that to us via + // update response to allocate. But we are punting on handling that for now. + private val hostToRack = new ConcurrentHashMap[String, String]() + private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]() + + + def newAllocator( + conf: Configuration, + amClient: AMRMClient[ContainerRequest], + appAttemptId: ApplicationAttemptId, + args: ApplicationMasterArguments, + sparkConf: SparkConf + ): YarnAllocationHandler = { + new YarnAllocationHandler( + conf, + amClient, + appAttemptId, + args.numWorkers, + args.workerMemory, + args.workerCores, + Map[String, Int](), + Map[String, Int](), + sparkConf) + } + + def newAllocator( + conf: Configuration, + amClient: AMRMClient[ContainerRequest], + appAttemptId: ApplicationAttemptId, + args: ApplicationMasterArguments, + map: collection.Map[String, + collection.Set[SplitInfo]], + sparkConf: SparkConf + ): YarnAllocationHandler = { + val (hostToSplitCount, rackToSplitCount) = generateNodeToWeight(conf, map) + new YarnAllocationHandler( + conf, + amClient, + appAttemptId, + args.numWorkers, + args.workerMemory, + args.workerCores, + hostToSplitCount, + rackToSplitCount, + sparkConf) + } + + def newAllocator( + conf: Configuration, + amClient: AMRMClient[ContainerRequest], + appAttemptId: ApplicationAttemptId, + maxWorkers: Int, + workerMemory: Int, + workerCores: Int, + map: collection.Map[String, collection.Set[SplitInfo]], + sparkConf: SparkConf + ): YarnAllocationHandler = { + val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) + new YarnAllocationHandler( + conf, + amClient, + appAttemptId, + maxWorkers, + workerMemory, + workerCores, + hostToCount, + rackToCount, + sparkConf) + } + + // A simple method to copy the split info map. + private def generateNodeToWeight( + conf: Configuration, + input: collection.Map[String, collection.Set[SplitInfo]] + ): (Map[String, Int], Map[String, Int]) = { + + if (input == null) { + return (Map[String, Int](), Map[String, Int]()) + } + + val hostToCount = new HashMap[String, Int] + val rackToCount = new HashMap[String, Int] + + for ((host, splits) <- input) { + val hostCount = hostToCount.getOrElse(host, 0) + hostToCount.put(host, hostCount + splits.size) + + val rack = lookupRack(conf, host) + if (rack != null){ + val rackCount = rackToCount.getOrElse(host, 0) + rackToCount.put(host, rackCount + splits.size) + } + } + + (hostToCount.toMap, rackToCount.toMap) + } + + def lookupRack(conf: Configuration, host: String): String = { + if (!hostToRack.contains(host)) { + populateRackInfo(conf, host) + } + hostToRack.get(host) + } + + def fetchCachedHostsForRack(rack: String): Option[Set[String]] = { + Option(rackToHostSet.get(rack)).map { set => + val convertedSet: collection.mutable.Set[String] = set + // TODO: Better way to get a Set[String] from JSet. + convertedSet.toSet + } + } + + def populateRackInfo(conf: Configuration, hostname: String) { + Utils.checkHost(hostname) + + if (!hostToRack.containsKey(hostname)) { + // If there are repeated failures to resolve, all to an ignore list. + val rackInfo = RackResolver.resolve(conf, hostname) + if (rackInfo != null && rackInfo.getNetworkLocation != null) { + val rack = rackInfo.getNetworkLocation + hostToRack.put(hostname, rack) + if (! rackToHostSet.containsKey(rack)) { + rackToHostSet.putIfAbsent(rack, + Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]())) + } + rackToHostSet.get(rack).add(hostname) + + // TODO(harvey): Figure out what this comment means... + // Since RackResolver caches, we are disabling this for now ... + } /* else { + // right ? Else we will keep calling rack resolver in case we cant resolve rack info ... + hostToRack.put(hostname, null) + } */ + } + } +}