Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 8CD26200D3B for ; Fri, 10 Nov 2017 09:26:48 +0100 (CET) Received: by cust-asf.ponee.io (Postfix) id 8B672160BF2; Fri, 10 Nov 2017 08:26:48 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 5A14E160BEE for ; Fri, 10 Nov 2017 09:26:47 +0100 (CET) Received: (qmail 7180 invoked by uid 500); 10 Nov 2017 08:26:46 -0000 Mailing-List: contact commits-help@flink.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@flink.apache.org Delivered-To: mailing list commits@flink.apache.org Received: (qmail 7170 invoked by uid 99); 10 Nov 2017 08:26:46 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 10 Nov 2017 08:26:46 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 5F1FDDFC2F; Fri, 10 Nov 2017 08:26:46 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: aljoscha@apache.org To: commits@flink.apache.org Message-Id: <8039462df0884d4bae8e916eedce8ead@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: flink git commit: [FLINK-8005] Set user-code class loader as context loader before snapshot Date: Fri, 10 Nov 2017 08:26:46 +0000 (UTC) archived-at: Fri, 10 Nov 2017 08:26:48 -0000 Repository: flink Updated Branches: refs/heads/release-1.4 005a87177 -> 2117eb77b [FLINK-8005] Set user-code class loader as context loader before snapshot During checkpointing, user code may dynamically load classes from the user code jar. This is a problem if the thread invoking the snapshot callbacks does not have the user code class loader set as its context class loader. This commit makes sure that the correct class loader is set. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/2117eb77 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/2117eb77 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/2117eb77 Branch: refs/heads/release-1.4 Commit: 2117eb77bb9d34da4288b5dd4455ef06c583ce7c Parents: 005a871 Author: gyao Authored: Wed Nov 8 11:46:45 2017 +0100 Committer: Aljoscha Krettek Committed: Fri Nov 10 09:26:37 2017 +0100 ---------------------------------------------------------------------- .../taskmanager/DispatcherThreadFactory.java | 24 ++- .../apache/flink/runtime/taskmanager/Task.java | 22 +- .../runtime/taskmanager/TaskAsyncCallTest.java | 206 +++++++++++++++---- .../flink/runtime/taskmanager/TaskStopTest.java | 157 -------------- 4 files changed, 204 insertions(+), 205 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/2117eb77/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/DispatcherThreadFactory.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/DispatcherThreadFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/DispatcherThreadFactory.java index 97060a8..543b159 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/DispatcherThreadFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/DispatcherThreadFactory.java @@ -18,6 +18,8 @@ package org.apache.flink.runtime.taskmanager; +import javax.annotation.Nullable; + import java.util.concurrent.ThreadFactory; /** @@ -29,21 +31,41 @@ public class DispatcherThreadFactory implements ThreadFactory { private final ThreadGroup group; private final String threadName; + + private final ClassLoader classLoader; /** * Creates a new thread factory. - * + * * @param group The group that the threads will be associated with. * @param threadName The name for the threads. */ public DispatcherThreadFactory(ThreadGroup group, String threadName) { + this(group, threadName, null); + } + + /** + * Creates a new thread factory. + * + * @param group The group that the threads will be associated with. + * @param threadName The name for the threads. + * @param classLoader The {@link ClassLoader} to be set as context class loader. + */ + public DispatcherThreadFactory( + ThreadGroup group, + String threadName, + @Nullable ClassLoader classLoader) { this.group = group; this.threadName = threadName; + this.classLoader = classLoader; } @Override public Thread newThread(Runnable r) { Thread t = new Thread(group, r, threadName); + if (classLoader != null) { + t.setContextClassLoader(classLoader); + } t.setDaemon(true); return t; } http://git-wip-us.apache.org/repos/asf/flink/blob/2117eb77/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index 58dd9e3..2cb356c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -99,6 +99,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; /** * The Task represents one execution of a parallel subtask on a TaskManager. @@ -265,6 +266,12 @@ public class Task implements Runnable, TaskActions { private long taskCancellationTimeout; /** + * This class loader should be set as the context class loader of the threads in + * {@link #asyncCallDispatcher} because user code may dynamically load classes in all callbacks. + */ + private ClassLoader userCodeClassLoader; + + /** *

IMPORTANT: This constructor may not start any work that would need to * be undone in the case of a failing task deployment.

*/ @@ -563,7 +570,6 @@ public class Task implements Runnable, TaskActions { Map> distributedCacheEntries = new HashMap>(); AbstractInvokable invokable = null; - ClassLoader userCodeClassLoader; try { // ---------------------------- // Task Bootstrap - We periodically @@ -580,7 +586,7 @@ public class Task implements Runnable, TaskActions { // this may involve downloading the job's JAR files and/or classes LOG.info("Loading JAR files for task {}.", this); - userCodeClassLoader = createUserCodeClassloader(libraryCache); + userCodeClassLoader = createUserCodeClassloader(); final ExecutionConfig executionConfig = serializedExecutionConfig.deserializeValue(userCodeClassLoader); if (executionConfig.getTaskCancellationInterval() >= 0) { @@ -865,7 +871,7 @@ public class Task implements Runnable, TaskActions { } } - private ClassLoader createUserCodeClassloader(LibraryCacheManager libraryCache) throws Exception { + private ClassLoader createUserCodeClassloader() throws Exception { long startDownloadTime = System.currentTimeMillis(); // triggers the download of all missing jar files from the job manager @@ -1342,15 +1348,19 @@ public class Task implements Runnable, TaskActions { if (executionState != ExecutionState.RUNNING) { return; } - + // get ourselves a reference on the stack that cannot be concurrently modified ExecutorService executor = this.asyncCallDispatcher; if (executor == null) { // first time use, initialize + checkState(userCodeClassLoader != null, "userCodeClassLoader must not be null"); executor = Executors.newSingleThreadExecutor( - new DispatcherThreadFactory(TASK_THREADS_GROUP, "Async calls on " + taskNameWithSubtask)); + new DispatcherThreadFactory( + TASK_THREADS_GROUP, + "Async calls on " + taskNameWithSubtask, + userCodeClassLoader)); this.asyncCallDispatcher = executor; - + // double-check for execution state, and make sure we clean up after ourselves // if we created the dispatcher while the task was concurrently canceled if (executionState != ExecutionState.RUNNING) { http://git-wip-us.apache.org/repos/asf/flink/blob/2117eb77/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java index d925e4d..5045606 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java @@ -48,6 +48,7 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; +import org.apache.flink.runtime.jobgraph.tasks.StoppableTask; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; @@ -58,10 +59,19 @@ import org.apache.flink.util.SerializedValue; import org.junit.Before; import org.junit.Test; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.concurrent.Executor; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.everyItem; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.isOneOf; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; @@ -69,61 +79,76 @@ import static org.mockito.Mockito.when; public class TaskAsyncCallTest { - private static final int NUM_CALLS = 1000; - + /** Number of expected checkpoints. */ + private static int numCalls; + + /** Triggered at the beginning of {@link CheckpointsInOrderInvokable#invoke()}. */ private static OneShotLatch awaitLatch; + + /** + * Triggered when {@link CheckpointsInOrderInvokable#triggerCheckpoint(CheckpointMetaData, CheckpointOptions)} + * was called {@link #numCalls} times. + */ private static OneShotLatch triggerLatch; + /** + * Triggered when {@link CheckpointsInOrderInvokable#notifyCheckpointComplete(long)} + * was called {@link #numCalls} times. + */ + private static OneShotLatch notifyCheckpointCompleteLatch; + + /** Triggered on {@link ContextClassLoaderInterceptingInvokable#stop()}}. */ + private static OneShotLatch stopLatch; + + private static final List classLoaders = Collections.synchronizedList(new ArrayList<>()); + @Before public void createQueuesAndActors() { + numCalls = 1000; + awaitLatch = new OneShotLatch(); triggerLatch = new OneShotLatch(); + notifyCheckpointCompleteLatch = new OneShotLatch(); + stopLatch = new OneShotLatch(); + + classLoaders.clear(); } // ------------------------------------------------------------------------ // Tests // ------------------------------------------------------------------------ - + @Test - public void testCheckpointCallsInOrder() { - try { - Task task = createTask(); + public void testCheckpointCallsInOrder() throws Exception { + Task task = createTask(CheckpointsInOrderInvokable.class); + try (TaskCleaner ignored = new TaskCleaner(task)) { task.startTaskThread(); - + awaitLatch.await(); - - for (int i = 1; i <= NUM_CALLS; i++) { + + for (int i = 1; i <= numCalls; i++) { task.triggerCheckpointBarrier(i, 156865867234L, CheckpointOptions.forCheckpoint()); } - + triggerLatch.await(); - + assertFalse(task.isCanceledOrFailed()); ExecutionState currentState = task.getExecutionState(); - if (currentState != ExecutionState.RUNNING && currentState != ExecutionState.FINISHED) { - fail("Task should be RUNNING or FINISHED, but is " + currentState); - } - - task.cancelExecution(); - task.getExecutingThread().join(); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); + assertThat(currentState, isOneOf(ExecutionState.RUNNING, ExecutionState.FINISHED)); } } @Test - public void testMixedAsyncCallsInOrder() { - try { - Task task = createTask(); + public void testMixedAsyncCallsInOrder() throws Exception { + Task task = createTask(CheckpointsInOrderInvokable.class); + try (TaskCleaner ignored = new TaskCleaner(task)) { task.startTaskThread(); awaitLatch.await(); - for (int i = 1; i <= NUM_CALLS; i++) { + for (int i = 1; i <= numCalls; i++) { task.triggerCheckpointBarrier(i, 156865867234L, CheckpointOptions.forCheckpoint()); task.notifyCheckpointComplete(i); } @@ -131,26 +156,62 @@ public class TaskAsyncCallTest { triggerLatch.await(); assertFalse(task.isCanceledOrFailed()); + ExecutionState currentState = task.getExecutionState(); - if (currentState != ExecutionState.RUNNING && currentState != ExecutionState.FINISHED) { - fail("Task should be RUNNING or FINISHED, but is " + currentState); - } + assertThat(currentState, isOneOf(ExecutionState.RUNNING, ExecutionState.FINISHED)); + } + } - task.cancelExecution(); - task.getExecutingThread().join(); + @Test + public void testThrowExceptionIfStopInvokedWithNotStoppableTask() throws Exception { + Task task = createTask(CheckpointsInOrderInvokable.class); + try (TaskCleaner ignored = new TaskCleaner(task)) { + task.startTaskThread(); + awaitLatch.await(); + + try { + task.stopExecution(); + fail("Expected exception not thrown"); + } catch (UnsupportedOperationException e) { + assertThat(e.getMessage(), containsString("Stopping not supported by task")); + } } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); + } + + /** + * Asserts that {@link StatefulTask#triggerCheckpoint(CheckpointMetaData, CheckpointOptions)}, + * {@link StatefulTask#notifyCheckpointComplete(long)}, and {@link StoppableTask#stop()} are + * invoked by a thread whose context class loader is set to the user code class loader. + */ + @Test + public void testSetsUserCodeClassLoader() throws Exception { + numCalls = 1; + + Task task = createTask(ContextClassLoaderInterceptingInvokable.class); + try (TaskCleaner ignored = new TaskCleaner(task)) { + task.startTaskThread(); + + awaitLatch.await(); + + task.triggerCheckpointBarrier(1, 1, CheckpointOptions.forCheckpoint()); + task.notifyCheckpointComplete(1); + task.stopExecution(); + + triggerLatch.await(); + notifyCheckpointCompleteLatch.await(); + stopLatch.await(); + + assertThat(classLoaders, hasSize(greaterThanOrEqualTo(3))); + assertThat(classLoaders, everyItem(instanceOf(TestUserCodeClassLoader.class))); } } - - private static Task createTask() throws Exception { + + private Task createTask(Class invokableClass) throws Exception { BlobCacheService blobService = new BlobCacheService(mock(PermanentBlobCache.class), mock(TransientBlobCache.class)); LibraryCacheManager libCache = mock(LibraryCacheManager.class); - when(libCache.getClassLoader(any(JobID.class))).thenReturn(ClassLoader.getSystemClassLoader()); + when(libCache.getClassLoader(any(JobID.class))).thenReturn(new TestUserCodeClassLoader()); ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); @@ -178,7 +239,7 @@ public class TaskAsyncCallTest { "Test Task", 1, 1, - CheckpointsInOrderInvokable.class.getName(), + invokableClass.getName(), new Configuration()); return new Task( @@ -221,13 +282,17 @@ public class TaskAsyncCallTest { // wait forever (until canceled) synchronized (this) { - while (error == null && lastCheckpointId < NUM_CALLS) { + while (error == null && lastCheckpointId < numCalls) { wait(); } } - - triggerLatch.trigger(); + if (error != null) { + // exit method prematurely due to error but make sure that the tests can finish + triggerLatch.trigger(); + notifyCheckpointCompleteLatch.trigger(); + stopLatch.trigger(); + throw error; } } @@ -239,7 +304,7 @@ public class TaskAsyncCallTest { public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) { lastCheckpointId++; if (checkpointMetaData.getCheckpointId() == lastCheckpointId) { - if (lastCheckpointId == NUM_CALLS) { + if (lastCheckpointId == numCalls) { triggerLatch.trigger(); } } @@ -269,7 +334,66 @@ public class TaskAsyncCallTest { synchronized (this) { notifyAll(); } + } else if (lastCheckpointId == numCalls) { + notifyCheckpointCompleteLatch.trigger(); } } } + + /** + * This is an {@link AbstractInvokable} that stores the context class loader of the invoking + * thread in a static field so that tests can assert on the class loader instances. + * + * @see #testSetsUserCodeClassLoader() + */ + public static class ContextClassLoaderInterceptingInvokable extends CheckpointsInOrderInvokable implements StoppableTask { + + @Override + public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) { + classLoaders.add(Thread.currentThread().getContextClassLoader()); + + return super.triggerCheckpoint(checkpointMetaData, checkpointOptions); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) { + classLoaders.add(Thread.currentThread().getContextClassLoader()); + + super.notifyCheckpointComplete(checkpointId); + } + + @Override + public void stop() { + classLoaders.add(Thread.currentThread().getContextClassLoader()); + stopLatch.trigger(); + } + + } + + /** + * A {@link ClassLoader} that delegates everything to {@link ClassLoader#getSystemClassLoader()}. + * + * @see #testSetsUserCodeClassLoader() + */ + private static class TestUserCodeClassLoader extends ClassLoader { + public TestUserCodeClassLoader() { + super(ClassLoader.getSystemClassLoader()); + } + } + + private static class TaskCleaner implements AutoCloseable { + + private final Task task; + + private TaskCleaner(Task task) { + this.task = task; + } + + @Override + public void close() throws Exception { + task.cancelExecution(); + task.getExecutingThread().join(5000); + } + } + } http://git-wip-us.apache.org/repos/asf/flink/blob/2117eb77/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java deleted file mode 100644 index d062def..0000000 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.runtime.taskmanager; - -import org.apache.flink.api.common.JobID; -import org.apache.flink.api.common.TaskInfo; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.runtime.blob.BlobCacheService; -import org.apache.flink.runtime.blob.PermanentBlobCache; -import org.apache.flink.runtime.blob.TransientBlobCache; -import org.apache.flink.runtime.broadcast.BroadcastVariableManager; -import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; -import org.apache.flink.runtime.clusterframework.types.AllocationID; -import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; -import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; -import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor; -import org.apache.flink.runtime.execution.ExecutionState; -import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager; -import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; -import org.apache.flink.runtime.executiongraph.JobInformation; -import org.apache.flink.runtime.executiongraph.TaskInformation; -import org.apache.flink.runtime.filecache.FileCache; -import org.apache.flink.runtime.io.disk.iomanager.IOManager; -import org.apache.flink.runtime.io.network.NetworkEnvironment; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; -import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; -import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; -import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; -import org.apache.flink.runtime.jobgraph.tasks.StoppableTask; -import org.apache.flink.runtime.memory.MemoryManager; -import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; -import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; - -import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; - -import java.lang.reflect.Field; -import java.util.Collections; -import java.util.concurrent.Executor; - -import scala.concurrent.duration.FiniteDuration; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -@RunWith(PowerMockRunner.class) -@PrepareForTest({ TaskDeploymentDescriptor.class, JobID.class, FiniteDuration.class }) -public class TaskStopTest { - private Task task; - - public void doMocking(AbstractInvokable taskMock) throws Exception { - - TaskInfo taskInfoMock = mock(TaskInfo.class); - when(taskInfoMock.getTaskNameWithSubtasks()).thenReturn("dummyName"); - - TaskManagerRuntimeInfo tmRuntimeInfo = mock(TaskManagerRuntimeInfo.class); - when(tmRuntimeInfo.getConfiguration()).thenReturn(new Configuration()); - - TaskMetricGroup taskMetricGroup = mock(TaskMetricGroup.class); - when(taskMetricGroup.getIOMetricGroup()).thenReturn(mock(TaskIOMetricGroup.class)); - - BlobCacheService blobService = - new BlobCacheService(mock(PermanentBlobCache.class), mock(TransientBlobCache.class)); - - task = new Task( - mock(JobInformation.class), - new TaskInformation( - new JobVertexID(), - "test task name", - 1, - 1, - "foobar", - new Configuration()), - mock(ExecutionAttemptID.class), - mock(AllocationID.class), - 0, - 0, - Collections.emptyList(), - Collections.emptyList(), - 0, - mock(TaskStateSnapshot.class), - mock(MemoryManager.class), - mock(IOManager.class), - mock(NetworkEnvironment.class), - mock(BroadcastVariableManager.class), - mock(TaskManagerActions.class), - mock(InputSplitProvider.class), - mock(CheckpointResponder.class), - blobService, - mock(LibraryCacheManager.class), - mock(FileCache.class), - tmRuntimeInfo, - taskMetricGroup, - mock(ResultPartitionConsumableNotifier.class), - mock(PartitionProducerStateChecker.class), - mock(Executor.class)); - Field f = task.getClass().getDeclaredField("invokable"); - f.setAccessible(true); - f.set(task, taskMock); - - Field f2 = task.getClass().getDeclaredField("executionState"); - f2.setAccessible(true); - f2.set(task, ExecutionState.RUNNING); - } - - @Test(timeout = 20000) - public void testStopExecution() throws Exception { - StoppableTestTask taskMock = new StoppableTestTask(); - doMocking(taskMock); - - task.stopExecution(); - - while (!taskMock.stopCalled) { - Thread.sleep(100); - } - } - - @Test(expected = RuntimeException.class) - public void testStopExecutionFail() throws Exception { - AbstractInvokable taskMock = mock(AbstractInvokable.class); - doMocking(taskMock); - - task.stopExecution(); - } - - private final static class StoppableTestTask extends AbstractInvokable implements StoppableTask { - public volatile boolean stopCalled = false; - - @Override - public void invoke() throws Exception { - } - - @Override - public void stop() { - this.stopCalled = true; - } - } - -}