From commits-return-5775-archive-asf-public=cust-asf.ponee.io@tez.apache.org Mon Mar 5 16:55:05 2018 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx-eu-01.ponee.io (Postfix) with SMTP id 70600180676 for ; Mon, 5 Mar 2018 16:55:04 +0100 (CET) Received: (qmail 52954 invoked by uid 500); 5 Mar 2018 15:55:03 -0000 Mailing-List: contact commits-help@tez.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@tez.apache.org Delivered-To: mailing list commits@tez.apache.org Received: (qmail 52945 invoked by uid 99); 5 Mar 2018 15:55:03 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Mon, 05 Mar 2018 15:55:03 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 9A9BCEEE26; Mon, 5 Mar 2018 15:55:02 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: jlowe@apache.org To: commits@tez.apache.org Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: tez git commit: TEZ-3897. Tez Local Mode hang for vertices with broadcast input. (Jonathan Eagles via jlowe) Date: Mon, 5 Mar 2018 15:55:02 +0000 (UTC) Repository: tez Updated Branches: refs/heads/master bb40cf5b8 -> c34e46c73 TEZ-3897. Tez Local Mode hang for vertices with broadcast input. (Jonathan Eagles via jlowe) Project: http://git-wip-us.apache.org/repos/asf/tez/repo Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/c34e46c7 Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/c34e46c7 Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/c34e46c7 Branch: refs/heads/master Commit: c34e46c73218bf21a0219f3004e20cbedaad92f4 Parents: bb40cf5 Author: Jason Lowe Authored: Mon Mar 5 09:53:11 2018 -0600 Committer: Jason Lowe Committed: Mon Mar 5 09:53:11 2018 -0600 ---------------------------------------------------------------------- .../app/launcher/LocalContainerLauncher.java | 19 +- .../dag/app/rm/LocalTaskSchedulerService.java | 185 ++++++++++++++----- .../tez/dag/app/rm/TestLocalTaskScheduler.java | 8 +- .../app/rm/TestLocalTaskSchedulerService.java | 94 +++++++++- 4 files changed, 243 insertions(+), 63 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/tez/blob/c34e46c7/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java index 9764daa..13e4115 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java @@ -94,9 +94,9 @@ public class LocalContainerLauncher extends DagContainerLauncher { int shufflePort = TezRuntimeUtils.INVALID_PORT; private DeletionTracker deletionTracker; - private final ConcurrentHashMap + private final ConcurrentHashMap> runningContainers = - new ConcurrentHashMap(); + new ConcurrentHashMap<>(); private final ConcurrentHashMap cacheManagers = new ConcurrentHashMap<>(); @@ -281,7 +281,7 @@ public class LocalContainerLauncher extends DagContainerLauncher { ListenableFuture runningTaskFuture = taskExecutorService.submit(createSubTask(tezChild, event.getContainerId())); RunningTaskCallback callback = new RunningTaskCallback(event.getContainerId()); - runningContainers.put(event.getContainerId(), callback); + runningContainers.put(event.getContainerId(), runningTaskFuture); Futures.addCallback(runningTaskFuture, callback, callbackExecutor); if (deletionTracker != null) { deletionTracker.addNodeShufflePort(event.getNodeId(), shufflePort); @@ -293,19 +293,16 @@ public class LocalContainerLauncher extends DagContainerLauncher { private void stop(ContainerStopRequest event) { // A stop_request will come in when a task completes and reports back or a preemption decision - // is made. Currently the LocalTaskScheduler does not support preemption. Also preemption - // will not work in local mode till Tez supports task preemption instead of container preemption. - RunningTaskCallback callback = + // is made. + ListenableFuture future = runningContainers.get(event.getContainerId()); - if (callback == null) { + if (future == null) { LOG.info("Ignoring stop request for containerId: " + event.getContainerId()); } else { LOG.info( - "Ignoring stop request for containerId {}. Relying on regular task shutdown for it to end", + "Stopping containerId: {}", event.getContainerId()); - // Allow the tezChild thread to run it's course. It'll receive a shutdown request from the - // AM eventually since the task and container will be unregistered. - // This will need to be fixed once interrupting tasks is supported. + future.cancel(true); } // Send this event to maintain regular control flow. This isn't of much use though. getContext().containerStopRequested(event.getContainerId()); http://git-wip-us.apache.org/repos/asf/tez/blob/c34e46c7/tez-dag/src/main/java/org/apache/tez/dag/app/rm/LocalTaskSchedulerService.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/LocalTaskSchedulerService.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/LocalTaskSchedulerService.java index 04e79a8..cc213cb 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/LocalTaskSchedulerService.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/LocalTaskSchedulerService.java @@ -19,6 +19,9 @@ package org.apache.tez.dag.app.rm; import java.io.IOException; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.Map; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.PriorityBlockingQueue; @@ -29,6 +32,7 @@ import java.util.LinkedHashMap; import com.google.common.primitives.Ints; import org.apache.tez.common.TezUtils; +import org.apache.tez.serviceplugins.api.DagInfo; import org.apache.tez.serviceplugins.api.TaskScheduler; import org.apache.tez.serviceplugins.api.TaskSchedulerContext; import org.slf4j.Logger; @@ -51,19 +55,19 @@ public class LocalTaskSchedulerService extends TaskScheduler { private static final Logger LOG = LoggerFactory.getLogger(LocalTaskSchedulerService.class); final ContainerSignatureMatcher containerSignatureMatcher; - final LinkedBlockingQueue taskRequestQueue; + final LinkedBlockingQueue taskRequestQueue; final Configuration conf; AsyncDelegateRequestHandler taskRequestHandler; Thread asyncDelegateRequestThread; - final HashMap taskAllocations; + final HashMap taskAllocations; final String appTrackingUrl; final long customContainerAppId; public LocalTaskSchedulerService(TaskSchedulerContext taskSchedulerContext) { super(taskSchedulerContext); taskRequestQueue = new LinkedBlockingQueue<>(); - taskAllocations = new LinkedHashMap(); + taskAllocations = new LinkedHashMap<>(); this.appTrackingUrl = taskSchedulerContext.getAppTrackingUrl(); this.containerSignatureMatcher = taskSchedulerContext.getContainerSignatureMatcher(); this.customContainerAppId = taskSchedulerContext.getCustomClusterIdentifier(); @@ -98,6 +102,7 @@ public class LocalTaskSchedulerService extends TaskScheduler { @Override public void dagComplete() { + taskRequestHandler.dagComplete(); } @Override @@ -129,7 +134,7 @@ public class LocalTaskSchedulerService extends TaskScheduler { // in local mode every task is already container level local taskRequestHandler.addAllocateTaskRequest(task, capability, priority, clientCookie); } - + @Override public boolean deallocateTask(Object task, boolean taskSucceeded, TaskAttemptEndReason endReason, String diagnostics) { return taskRequestHandler.addDeallocateTaskRequest(task); @@ -137,6 +142,7 @@ public class LocalTaskSchedulerService extends TaskScheduler { @Override public Object deallocateContainer(ContainerId containerId) { + taskRequestHandler.addDeallocateContainerRequest(containerId); return null; } @@ -212,20 +218,14 @@ public class LocalTaskSchedulerService extends TaskScheduler { } } - static class TaskRequest implements Comparable { - // Higher prority than Priority.UNDEFINED - static final int HIGHEST_PRIORITY = -2; - Object task; - Priority priority; + static class SchedulerRequest { + } - public TaskRequest(Object task, Priority priority) { - this.task = task; - this.priority = priority; - } + static class TaskRequest extends SchedulerRequest { + final Object task; - @Override - public int compareTo(TaskRequest request) { - return request.priority.compareTo(this.priority); + public TaskRequest(Object task) { + this.task = task; } @Override @@ -239,9 +239,6 @@ public class LocalTaskSchedulerService extends TaskScheduler { TaskRequest that = (TaskRequest) o; - if (priority != null ? !priority.equals(that.priority) : that.priority != null) { - return false; - } if (task != null ? !task.equals(that.task) : that.task != null) { return false; } @@ -251,23 +248,29 @@ public class LocalTaskSchedulerService extends TaskScheduler { @Override public int hashCode() { - int result = 1; - result = 7841 * result + (task != null ? task.hashCode() : 0); - result = 7841 * result + (priority != null ? priority.hashCode() : 0); - return result; + return 7841 + (task != null ? task.hashCode() : 0); } } - static class AllocateTaskRequest extends TaskRequest { - Resource capability; - Object clientCookie; + static class AllocateTaskRequest extends TaskRequest implements Comparable { + final Priority priority; + final Resource capability; + final Object clientCookie; + final int vertexIndex; - public AllocateTaskRequest(Object task, Resource capability, Priority priority, - Object clientCookie) { - super(task, priority); + public AllocateTaskRequest(Object task, int vertexIndex, Resource capability, Priority priority, + Object clientCookie) { + super(task); + this.priority = priority; this.capability = capability; this.clientCookie = clientCookie; + this.vertexIndex = vertexIndex; + } + + @Override + public int compareTo(AllocateTaskRequest request) { + return request.priority.compareTo(this.priority); } @Override @@ -284,6 +287,10 @@ public class LocalTaskSchedulerService extends TaskScheduler { AllocateTaskRequest that = (AllocateTaskRequest) o; + if (priority != null ? !priority.equals(that.priority) : that.priority != null) { + return false; + } + if (capability != null ? !capability.equals(that.capability) : that.capability != null) { return false; } @@ -298,6 +305,7 @@ public class LocalTaskSchedulerService extends TaskScheduler { @Override public int hashCode() { int result = super.hashCode(); + result = 12329 * result + (priority != null ? priority.hashCode() : 0); result = 12329 * result + (capability != null ? capability.hashCode() : 0); result = 12329 * result + (clientCookie != null ? clientCookie.hashCode() : 0); return result; @@ -305,24 +313,43 @@ public class LocalTaskSchedulerService extends TaskScheduler { } static class DeallocateTaskRequest extends TaskRequest { - static final Priority DEALLOCATE_PRIORITY = Priority.newInstance(HIGHEST_PRIORITY); public DeallocateTaskRequest(Object task) { - super(task, DEALLOCATE_PRIORITY); + super(task); + } + } + + static class DeallocateContainerRequest extends SchedulerRequest { + final ContainerId containerId; + + public DeallocateContainerRequest(ContainerId containerId) { + this.containerId = containerId; + } + } + + static class AllocatedTask { + final AllocateTaskRequest request; + final Container container; + + AllocatedTask(AllocateTaskRequest request, Container container) { + this.request = request; + this.container = container; } } static class AsyncDelegateRequestHandler implements Runnable { - final LinkedBlockingQueue clientRequestQueue; + final LinkedBlockingQueue clientRequestQueue; final PriorityBlockingQueue taskRequestQueue; final LocalContainerFactory localContainerFactory; - final HashMap taskAllocations; + final HashMap taskAllocations; final TaskSchedulerContext taskSchedulerContext; + private final Object descendantsLock = new Object(); + private ArrayList vertexDescendants = null; final int MAX_TASKS; - AsyncDelegateRequestHandler(LinkedBlockingQueue clientRequestQueue, + AsyncDelegateRequestHandler(LinkedBlockingQueue clientRequestQueue, LocalContainerFactory localContainerFactory, - HashMap taskAllocations, + HashMap taskAllocations, TaskSchedulerContext taskSchedulerContext, Configuration conf) { this.clientRequestQueue = clientRequestQueue; @@ -334,10 +361,33 @@ public class LocalTaskSchedulerService extends TaskScheduler { this.taskRequestQueue = new PriorityBlockingQueue<>(); } + void dagComplete() { + synchronized (descendantsLock) { + vertexDescendants = null; + } + } + private void ensureVertexDescendants() { + synchronized (descendantsLock) { + if (vertexDescendants == null) { + DagInfo info = taskSchedulerContext.getCurrentDagInfo(); + if (info == null) { + throw new IllegalStateException("Scheduling tasks but no current DAG info?"); + } + int numVertices = info.getTotalVertices(); + ArrayList descendants = new ArrayList<>(numVertices); + for (int i = 0; i < numVertices; ++i) { + descendants.add(info.getVertexDescendants(i)); + } + vertexDescendants = descendants; + } + } + } + public void addAllocateTaskRequest(Object task, Resource capability, Priority priority, Object clientCookie) { try { - clientRequestQueue.put(new AllocateTaskRequest(task, capability, priority, clientCookie)); + int vertexIndex = taskSchedulerContext.getVertexIndexForTask(task); + clientRequestQueue.put(new AllocateTaskRequest(task, vertexIndex, capability, priority, clientCookie)); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } @@ -352,10 +402,22 @@ public class LocalTaskSchedulerService extends TaskScheduler { return true; } + public void addDeallocateContainerRequest(ContainerId containerId) { + try { + clientRequestQueue.put(new DeallocateContainerRequest(containerId)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + boolean shouldProcess() { return !taskRequestQueue.isEmpty() && taskAllocations.size() < MAX_TASKS; } + boolean shouldPreempt() { + return !taskRequestQueue.isEmpty() && taskAllocations.size() >= MAX_TASKS; + } + @Override public void run() { while (!Thread.currentThread().isInterrupted()) { @@ -368,13 +430,19 @@ public class LocalTaskSchedulerService extends TaskScheduler { void dispatchRequest() { try { - TaskRequest request = clientRequestQueue.take(); + SchedulerRequest request = clientRequestQueue.take(); if (request instanceof AllocateTaskRequest) { taskRequestQueue.put((AllocateTaskRequest)request); + if (shouldPreempt()) { + maybePreempt((AllocateTaskRequest) request); + } } else if (request instanceof DeallocateTaskRequest) { deallocateTask((DeallocateTaskRequest)request); } + else if (request instanceof DeallocateContainerRequest) { + preemptTask((DeallocateContainerRequest)request); + } else { LOG.error("Unknown task request message: " + request); } @@ -383,12 +451,29 @@ public class LocalTaskSchedulerService extends TaskScheduler { } } + void maybePreempt(AllocateTaskRequest request) { + Priority priority = request.priority; + for (Map.Entry entry : taskAllocations.entrySet()) { + AllocatedTask allocatedTask = entry.getValue(); + Container container = allocatedTask.container; + if (priority.compareTo(allocatedTask.container.getPriority()) > 0) { + Object task = entry.getKey(); + ensureVertexDescendants(); + if (vertexDescendants.get(request.vertexIndex).get(allocatedTask.request.vertexIndex)) { + LOG.info("Preempting task/container for task/priority:" + task + "/" + container + + " for " + request.task + "/" + priority); + taskSchedulerContext.preemptContainer(allocatedTask.container.getId()); + } + } + } + } + void allocateTask() { try { AllocateTaskRequest request = taskRequestQueue.take(); Container container = localContainerFactory.createContainer(request.capability, request.priority); - taskAllocations.put(request.task, container); + taskAllocations.put(request.task, new AllocatedTask(request, container)); taskSchedulerContext.taskAllocated(request.task, request.clientCookie, container); } catch (InterruptedException e) { Thread.currentThread().interrupt(); @@ -396,24 +481,34 @@ public class LocalTaskSchedulerService extends TaskScheduler { } void deallocateTask(DeallocateTaskRequest request) { - Container container = taskAllocations.remove(request.task); - if (container != null) { - taskSchedulerContext.containerBeingReleased(container.getId()); + AllocatedTask allocatedTask = taskAllocations.remove(request.task); + if (allocatedTask != null) { + taskSchedulerContext.containerBeingReleased(allocatedTask.container.getId()); } else { - boolean deallocationBeforeAllocation = false; Iterator iter = taskRequestQueue.iterator(); while (iter.hasNext()) { TaskRequest taskRequest = iter.next(); if (taskRequest.task.equals(request.task)) { iter.remove(); - deallocationBeforeAllocation = true; LOG.info("Deallocation request before allocation for task:" + request.task); break; } } - if (!deallocationBeforeAllocation) { - throw new TezUncheckedException("Unable to find and remove task " + request.task + " from task allocations"); + } + } + + void preemptTask(DeallocateContainerRequest request) { + LOG.info("Trying to preempt: " + request.containerId); + Iterator> entries = taskAllocations.entrySet().iterator(); + while (entries.hasNext()) { + Map.Entry entry = entries.next(); + Container container = entry.getValue().container; + if (container.getId().equals(request.containerId)) { + entries.remove(); + Object task = entry.getKey(); + LOG.info("Preempting task/container:" + task + "/" + container); + taskSchedulerContext.containerBeingReleased(container.getId()); } } } http://git-wip-us.apache.org/repos/asf/tez/blob/c34e46c7/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskScheduler.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskScheduler.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskScheduler.java index 36505c2..d7b516a 100644 --- a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskScheduler.java +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskScheduler.java @@ -29,13 +29,13 @@ import org.junit.Test; import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; import org.apache.hadoop.yarn.api.records.ApplicationId; -import org.apache.hadoop.yarn.api.records.Container; import org.apache.hadoop.yarn.api.records.Priority; import org.apache.tez.dag.api.TezConfiguration; +import org.apache.tez.dag.app.rm.LocalTaskSchedulerService.AllocatedTask; import org.apache.tez.dag.app.rm.LocalTaskSchedulerService.AsyncDelegateRequestHandler; import org.apache.tez.dag.app.rm.LocalTaskSchedulerService.LocalContainerFactory; -import org.apache.tez.dag.app.rm.LocalTaskSchedulerService.TaskRequest; +import org.apache.tez.dag.app.rm.LocalTaskSchedulerService.SchedulerRequest; public class TestLocalTaskScheduler { @@ -56,8 +56,8 @@ public class TestLocalTaskScheduler { LocalContainerFactory containerFactory = new LocalContainerFactory(appAttemptId, 1000); - HashMap taskAllocations = new LinkedHashMap(); - LinkedBlockingQueue clientRequestQueue = new LinkedBlockingQueue<>(); + HashMap taskAllocations = new LinkedHashMap<>(); + LinkedBlockingQueue clientRequestQueue = new LinkedBlockingQueue<>(); // Object under test AsyncDelegateRequestHandler requestHandler = http://git-wip-us.apache.org/repos/asf/tez/blob/c34e46c7/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskSchedulerService.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskSchedulerService.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskSchedulerService.java index c2daf84..70e31f3 100644 --- a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskSchedulerService.java +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestLocalTaskSchedulerService.java @@ -18,20 +18,25 @@ package org.apache.tez.dag.app.rm; +import java.util.BitSet; import java.util.HashMap; import java.util.concurrent.LinkedBlockingQueue; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; import org.apache.hadoop.yarn.api.records.ApplicationId; -import org.apache.hadoop.yarn.api.records.Container; +import org.apache.hadoop.yarn.api.records.ContainerId; import org.apache.hadoop.yarn.api.records.Priority; import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.tez.dag.api.TezConfiguration; import org.apache.tez.dag.app.dag.Task; import org.apache.tez.dag.app.rm.TestLocalTaskSchedulerService.MockLocalTaskSchedulerSerivce.MockAsyncDelegateRequestHandler; +import org.apache.tez.serviceplugins.api.DagInfo; import org.apache.tez.serviceplugins.api.TaskSchedulerContext; import org.junit.Assert; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -138,6 +143,82 @@ public class TestLocalTaskSchedulerService { taskSchedulerService.shutdown(); } + @Test + public void preemptDescendantsOnly() { + + final int MAX_TASKS = 2; + TezConfiguration tezConf = new TezConfiguration(); + tezConf.setInt(TezConfiguration.TEZ_AM_INLINE_TASK_EXECUTION_MAX_TASKS, MAX_TASKS); + + ApplicationId appId = ApplicationId.newInstance(2000, 1); + ApplicationAttemptId appAttemptId = ApplicationAttemptId.newInstance(appId, 1); + Long parentTask1 = new Long(1); + Long parentTask2 = new Long(2); + Long childTask1 = new Long(3); + Long grandchildTask1 = new Long(4); + + TaskSchedulerContext + mockContext = TestTaskSchedulerHelpers.setupMockTaskSchedulerContext("", 0, "", true, + appAttemptId, 1000l, null, tezConf); + when(mockContext.getVertexIndexForTask(parentTask1)).thenReturn(0); + when(mockContext.getVertexIndexForTask(parentTask2)).thenReturn(0); + when(mockContext.getVertexIndexForTask(childTask1)).thenReturn(1); + when(mockContext.getVertexIndexForTask(grandchildTask1)).thenReturn(2); + + DagInfo mockDagInfo = mock(DagInfo.class); + when(mockDagInfo.getTotalVertices()).thenReturn(3); + BitSet vertex1Descendants = new BitSet(); + vertex1Descendants.set(1); + vertex1Descendants.set(2); + BitSet vertex2Descendants = new BitSet(); + vertex2Descendants.set(2); + BitSet vertex3Descendants = new BitSet(); + when(mockDagInfo.getVertexDescendants(0)).thenReturn(vertex1Descendants); + when(mockDagInfo.getVertexDescendants(1)).thenReturn(vertex2Descendants); + when(mockDagInfo.getVertexDescendants(2)).thenReturn(vertex3Descendants); + when(mockContext.getCurrentDagInfo()).thenReturn(mockDagInfo); + + Priority priority1 = Priority.newInstance(1); + Priority priority2 = Priority.newInstance(2); + Priority priority3 = Priority.newInstance(3); + Priority priority4 = Priority.newInstance(4); + Resource resource = Resource.newInstance(1024, 1); + + MockLocalTaskSchedulerSerivce taskSchedulerService = new MockLocalTaskSchedulerSerivce(mockContext); + + // The mock context need to send a deallocate container request to the scheduler service + Answer answer = new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + ContainerId containerId = invocation.getArgumentAt(0, ContainerId.class); + taskSchedulerService.deallocateContainer(containerId); + return null; + } + }; + doAnswer(answer).when(mockContext).preemptContainer(any(ContainerId.class)); + + taskSchedulerService.initialize(); + taskSchedulerService.start(); + taskSchedulerService.startRequestHandlerThread(); + + MockAsyncDelegateRequestHandler requestHandler = taskSchedulerService.getRequestHandler(); + taskSchedulerService.allocateTask(parentTask1, resource, null, null, priority1, null, null); + taskSchedulerService.allocateTask(childTask1, resource, null, null, priority3, null, null); + taskSchedulerService.allocateTask(grandchildTask1, resource, null, null, priority4, null, null); + requestHandler.drainRequest(3); + + // We should not preempt if we have not reached max task allocations + Assert.assertEquals("Wrong number of allocate tasks", MAX_TASKS, requestHandler.allocateCount); + Assert.assertTrue("Another allocation should not fit", !requestHandler.shouldProcess()); + + // Next task allocation should preempt + taskSchedulerService.allocateTask(parentTask2, Resource.newInstance(1024, 1), null, null, priority2, null, null); + requestHandler.drainRequest(5); + + // All allocated tasks should have been removed + Assert.assertEquals("Wrong number of preempted tasks", 1, requestHandler.preemptCount); + } + static class MockLocalTaskSchedulerSerivce extends LocalTaskSchedulerService { private MockAsyncDelegateRequestHandler requestHandler; @@ -173,12 +254,13 @@ public class TestLocalTaskSchedulerService { public int allocateCount = 0; public int deallocateCount = 0; + public int preemptCount = 0; public int dispatchCount = 0; MockAsyncDelegateRequestHandler( - LinkedBlockingQueue taskRequestQueue, + LinkedBlockingQueue taskRequestQueue, LocalContainerFactory localContainerFactory, - HashMap taskAllocations, + HashMap taskAllocations, TaskSchedulerContext appClientDelegate, Configuration conf) { super(taskRequestQueue, localContainerFactory, taskAllocations, appClientDelegate, conf); @@ -211,6 +293,12 @@ public class TestLocalTaskSchedulerService { super.deallocateTask(request); deallocateCount++; } + + @Override + void preemptTask(DeallocateContainerRequest request) { + super.preemptTask(request); + preemptCount++; + } } } }