kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From damian...@apache.org
Subject kafka git commit: HOTFIX: state transition cherry picking
Date Tue, 15 Aug 2017 14:35:51 GMT
Repository: kafka
Updated Branches:
  refs/heads/0.11.0 fb47e213e -> 77b81c02b


HOTFIX: state transition cherry picking

Cherry picked from https://github.com/apache/kafka/pull/3432

Author: Eno Thereska <eno.thereska@gmail.com>

Reviewers: Damian Guy <damian.guy@gmail.com>, Guozhang Wang <wangguoz@gmail.com>

Closes #3622 from enothereska/KAFKA-5571-0.11


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/77b81c02
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/77b81c02
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/77b81c02

Branch: refs/heads/0.11.0
Commit: 77b81c02b98b979ca6ecbfcfe4f620f4ab7ee240
Parents: fb47e21
Author: Eno Thereska <eno.thereska@gmail.com>
Authored: Tue Aug 15 15:35:43 2017 +0100
Committer: Damian Guy <damian.guy@gmail.com>
Committed: Tue Aug 15 15:35:43 2017 +0100

----------------------------------------------------------------------
 .../org/apache/kafka/streams/KafkaStreams.java  | 373 +++++++++++++------
 .../processor/internals/GlobalStreamThread.java | 132 ++++++-
 .../processor/internals/StreamThread.java       |  97 +++--
 .../ThreadStateTransitionValidator.java         |  24 ++
 .../apache/kafka/streams/KafkaStreamsTest.java  | 166 ++++++++-
 .../internals/GlobalStreamThreadTest.java       |  46 ++-
 .../processor/internals/StreamThreadTest.java   | 127 ++++---
 7 files changed, 744 insertions(+), 221 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/77b81c02/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
index 0c7c598..6056fa6 100644
--- a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
+++ b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
@@ -46,6 +46,7 @@ import org.apache.kafka.streams.processor.internals.StateDirectory;
 import org.apache.kafka.streams.processor.internals.StreamThread;
 import org.apache.kafka.streams.processor.internals.StreamsKafkaClient;
 import org.apache.kafka.streams.processor.internals.StreamsMetadataState;
+import org.apache.kafka.streams.processor.internals.ThreadStateTransitionValidator;
 import org.apache.kafka.streams.state.HostInfo;
 import org.apache.kafka.streams.state.QueryableStoreType;
 import org.apache.kafka.streams.state.StreamsMetadata;
@@ -75,6 +76,10 @@ import java.util.concurrent.TimeUnit;
 
 import static org.apache.kafka.common.utils.Utils.getHost;
 import static org.apache.kafka.common.utils.Utils.getPort;
+import static org.apache.kafka.streams.KafkaStreams.State.ERROR;
+import static org.apache.kafka.streams.KafkaStreams.State.NOT_RUNNING;
+import static org.apache.kafka.streams.KafkaStreams.State.PENDING_SHUTDOWN;
+import static org.apache.kafka.streams.KafkaStreams.State.RUNNING;
 import static org.apache.kafka.streams.StreamsConfig.EXACTLY_ONCE;
 import static org.apache.kafka.streams.StreamsConfig.PROCESSING_GUARANTEE_CONFIG;
 
@@ -154,9 +159,9 @@ public class KafkaStreams {
      *         |       +-----+--------+
      *         |             |
      *         |             v
-     *         |       +-----+--------+
-     *         +<----- | Rebalancing  | <----+
-     *         |       +--------------+      |
+     *         |       +-----+--------+ <-+
+     *         +<----- | Rebalancing  | --+
+     *         |       +--------------+ <----+
      *         |                             |
      *         |                             |
      *         |       +--------------+      |
@@ -165,18 +170,31 @@ public class KafkaStreams {
      *         |             |
      *         |             v
      *         |       +-----+--------+
-     *         +-----> | Pending      |
-     *                 | Shutdown     |
-     *                 +-----+--------+
-     *                       |
-     *                       v
-     *                 +-----+--------+
-     *                 | Not Running  |
+     *         +-----> | Pending      |<----+
+     *         |       | Shutdown     |     |
+     *         |       +-----+--------+     |
+     *         |             |              |
+     *         |             v              |
+     *         |       +-----+--------+     |
+     *         |       | Not Running  |     |
+     *         |       +--------------+     |
+     *         |                            |
+     *         |       +--------------+     |
+     *         +-----> | Error        |-----+
      *                 +--------------+
+     *
+     *
      * </pre>
+     * Note the following:
+     * - Any state can go to PENDING_SHUTDOWN and subsequently NOT_RUNNING.
+     * - It is theoretically possible for a thread to always be in the PARTITION_REVOKED state
+     * (see {@code StreamThread} state diagram) and hence it is possible that this instance is always
+     * on a REBALANCING state.
+     * - Of special importance: If the global stream thread dies, or all stream threads die (or both) then
+     * the instance will be in the ERROR state. The user will need to close it.
      */
     public enum State {
-        CREATED(1, 2, 3), RUNNING(2, 3), REBALANCING(1, 2, 3), PENDING_SHUTDOWN(4), NOT_RUNNING;
+        CREATED(1, 2, 3, 5), REBALANCING(1, 2, 3, 5), RUNNING(1, 3, 5), PENDING_SHUTDOWN(4), NOT_RUNNING, ERROR(3);
 
         private final Set<Integer> validTransitions = new HashSet<>();
 
@@ -196,12 +214,9 @@ public class KafkaStreams {
     }
 
     private final Object stateLock = new Object();
-
     private volatile State state = State.CREATED;
-
     private KafkaStreams.StateListener stateListener = null;
 
-
     /**
      * Listen to {@link State} change events.
      */
@@ -224,28 +239,55 @@ public class KafkaStreams {
         stateListener = listener;
     }
 
-    private void setState(final State newState) {
+    /**
+     * Sets the state
+     * @param newState New state
+     * @return true if state is set, false otherwise
+     * @throws StreamsException when there is an unexpected transition.
+     */
+    private boolean setState(final State newState) {
+        State oldState;
         synchronized (stateLock) {
-            final State oldState = state;
+            // there are cases when we shouldn't check if a transition is valid, e.g.,
+            // when, for testing, Kafka Streams is closed multiple times. We could either
+            // check here and immediately return for those cases, or add them to the transition
+            // diagram (but then the diagram would be confusing and have transitions like
+            // NOT_RUNNING->NOT_RUNNING). These cases include:
+            // - calling close() multiple times. Would mean going from NOT_RUNNING -> PENDING_SHUTDOWN
+            // - calling start() after close(). Would mean going from PENDING_SHUTDOWN (or NOT_RUNNING) -> RUNNING
+
+            // note we could be going from PENDING_SHUTDOWN to NOT_RUNNING, and we obviously want to allow that
+            // transition, hence the check newState != NOT_RUNNING.
+            if (newState != NOT_RUNNING &&
+                    (state == State.NOT_RUNNING || state == PENDING_SHUTDOWN)) {
+                return false;
+            }
+
+            oldState = state;
             if (!state.isValidTransition(newState)) {
                 log.warn("{} Unexpected state transition from {} to {}.", logPrefix, oldState, newState);
+                throw new StreamsException(logPrefix + " Unexpected state transition from " + oldState + " to " + newState);
             } else {
                 log.info("{} State transition from {} to {}.", logPrefix, oldState, newState);
             }
             state = newState;
-            if (stateListener != null) {
-                stateListener.onChange(state, oldState);
-            }
         }
+        if (stateListener != null) {
+            stateListener.onChange(state, oldState);
+        }
+
+        return true;
     }
 
     /**
      * Return the current {@link State} of this {@code KafkaStreams} instance.
      *
-     * @return the currnt state of this Kafka Streams instance
+     * @return the current state of this Kafka Streams instance
      */
-    public synchronized State state() {
-        return state;
+    public State state() {
+        synchronized (stateLock) {
+            return state;
+        }
     }
 
     /**
@@ -257,33 +299,103 @@ public class KafkaStreams {
         return Collections.unmodifiableMap(metrics.metrics());
     }
 
-    private final class StreamStateListener implements StreamThread.StateListener {
-
+    /**
+     * Class that handles stream thread transitions
+     */
+    final class StreamStateListener implements StreamThread.StateListener {
         private final Map<Long, StreamThread.State> threadState;
+        private GlobalStreamThread.State globalThreadState;
 
-        StreamStateListener(Map<Long, StreamThread.State> threadState) {
+        StreamStateListener(final Map<Long, StreamThread.State> threadState,
+                            final GlobalStreamThread.State globalThreadState) {
             this.threadState = threadState;
+            this.globalThreadState = globalThreadState;
+        }
+
+        /**
+         * If all threads are dead set to ERROR
+         */
+        private void checkAllThreadsDeadAndSetError() {
+
+            synchronized (stateLock) {
+                // if we are pending a shutdown, it's ok for all threads to die, in fact
+                // it is expected. Otherwise, it is an error
+                if (state != PENDING_SHUTDOWN) {
+                    // one thread died, check if we have enough threads running
+                    for (final StreamThread.State state : threadState.values()) {
+                        if (state != StreamThread.State.DEAD) {
+                            return;
+                        }
+                    }
+                    log.warn("{} All stream threads have died. The Kafka Streams instance will be in an error state and should be closed.",
+                            logPrefix);
+                    setState(ERROR);
+                }
+            }
         }
 
+        /**
+         * If all global thread is DEAD
+         */
+        private void maybeSetErrorSinceGlobalStreamThreadIsDead() {
+
+            synchronized (stateLock) {
+                // if we are pending a shutdown, it's ok for all threads to die, in fact
+                // it is expected. Otherwise, it is an error
+                if (state != PENDING_SHUTDOWN) {
+                    log.warn("{} Global Stream thread has died. The Kafka Streams instance will be in an error state and should be closed.",
+                            logPrefix);
+                    setState(ERROR);
+                }
+            }
+        }
+
+        /**
+         * If all threads are up, including the global thread, set to RUNNING
+         */
+        private void maybeSetRunning() {
+            // one thread is running, check others, including global thread
+            for (final StreamThread.State state : threadState.values()) {
+                if (state != StreamThread.State.RUNNING) {
+                    return;
+                }
+            }
+            // the global state thread is relevant only if it is started. There are cases
+            // when we don't have a global state thread at all, e.g., when we don't have global KTables
+            if (globalThreadState != null && globalThreadState != GlobalStreamThread.State.RUNNING) {
+                return;
+            }
+
+            setState(State.RUNNING);
+        }
+
+
         @Override
-        public synchronized void onChange(final StreamThread thread,
-                                          final StreamThread.State newState,
-                                          final StreamThread.State oldState) {
-            if (newState != StreamThread.State.DEAD) {
+        public synchronized void onChange(final Thread thread,
+                                          final ThreadStateTransitionValidator abstractNewState,
+                                          final ThreadStateTransitionValidator abstractOldState) {
+            // StreamThreads first
+            if (thread instanceof StreamThread) {
+                StreamThread.State newState = (StreamThread.State) abstractNewState;
                 threadState.put(thread.getId(), newState);
-            } else {
-                threadState.remove(thread.getId());
-            }
-            if (newState == StreamThread.State.PARTITIONS_REVOKED ||
-                newState == StreamThread.State.ASSIGNING_PARTITIONS) {
-                setState(State.REBALANCING);
-            } else if (newState == StreamThread.State.RUNNING) {
-                for (final StreamThread.State state : threadState.values()) {
-                    if (state != StreamThread.State.RUNNING) {
-                        return;
-                    }
+
+                if (newState == StreamThread.State.PARTITIONS_REVOKED ||
+                        newState == StreamThread.State.ASSIGNING_PARTITIONS) {
+                    setState(State.REBALANCING);
+                } else if (newState == StreamThread.State.RUNNING && state() != State.RUNNING) {
+                    maybeSetRunning();
+                } else if (newState == StreamThread.State.DEAD) {
+                    checkAllThreadsDeadAndSetError();
+                }
+            } else if (thread instanceof GlobalStreamThread) {
+                // global stream thread has different invariants
+                GlobalStreamThread.State newState = (GlobalStreamThread.State) abstractNewState;
+                globalThreadState = newState;
+
+                // special case when global thread is dead
+                if (newState == GlobalStreamThread.State.DEAD) {
+                    maybeSetErrorSinceGlobalStreamThreadIsDead();
                 }
-                setState(State.RUNNING);
             }
         }
     }
@@ -350,6 +462,7 @@ public class KafkaStreams {
 
         threads = new StreamThread[config.getInt(StreamsConfig.NUM_STREAM_THREADS_CONFIG)];
         final Map<Long, StreamThread.State> threadState = new HashMap<>(threads.length);
+        GlobalStreamThread.State globalThreadState = null;
         final ArrayList<StateStoreProvider> storeProviders = new ArrayList<>();
         streamsMetadataState = new StreamsMetadataState(builder, parseHostInfo(config.getString(StreamsConfig.APPLICATION_SERVER_CONFIG)));
 
@@ -372,9 +485,9 @@ public class KafkaStreams {
                                                         metrics,
                                                         time,
                                                         globalThreadId);
+            globalThreadState = globalStreamThread.state();
         }
 
-        final StreamStateListener streamStateListener = new StreamStateListener(threadState);
         for (int i = 0; i < threads.length; i++) {
             threads[i] = new StreamThread(builder,
                                           config,
@@ -387,11 +500,16 @@ public class KafkaStreams {
                                           streamsMetadataState,
                                           cacheSizeBytes,
                                           stateDirectory);
-
-            threads[i].setStateListener(streamStateListener);
             threadState.put(threads[i].getId(), threads[i].state());
             storeProviders.add(new StreamThreadStateStoreProvider(threads[i]));
         }
+        final StreamStateListener streamStateListener = new StreamStateListener(threadState, globalThreadState);
+        if (globalTaskTopology != null) {
+            globalStreamThread.setStateListener(streamStateListener);
+        }
+        for (int i = 0; i < threads.length; i++) {
+            threads[i].setStateListener(streamStateListener);
+        }
         final GlobalStateStoreProvider globalStateStoreProvider = new GlobalStateStoreProvider(builder.globalStateStores());
         queryableStoreProvider = new QueryableStoreProvider(storeProviders, globalStateStoreProvider);
         final String cleanupThreadName = clientId + "-CleanupThread";
@@ -440,6 +558,17 @@ public class KafkaStreams {
 
     }
 
+    private void validateStartOnce() {
+        try {
+            if (setState(RUNNING)) {
+                return;
+            }
+        } catch (StreamsException e) {
+            // do nothing, will throw
+        }
+        throw new IllegalStateException("Cannot start again.");
+    }
+
     /**
      * Start the {@code KafkaStreams} instance by starting all its threads.
      * <p>
@@ -452,38 +581,32 @@ public class KafkaStreams {
      */
     public synchronized void start() throws IllegalStateException, StreamsException {
         log.debug("{} Starting Kafka Stream process.", logPrefix);
+        validateStartOnce();
+        checkBrokerVersionCompatibility();
 
-        if (state == State.CREATED) {
-            checkBrokerVersionCompatibility();
-            setState(State.RUNNING);
-
-            if (globalStreamThread != null) {
-                globalStreamThread.start();
-            }
-
-            for (final StreamThread thread : threads) {
-                thread.start();
-            }
+        if (globalStreamThread != null) {
+            globalStreamThread.start();
+        }
 
-            final Long cleanupDelay = config.getLong(StreamsConfig.STATE_CLEANUP_DELAY_MS_CONFIG);
-            stateDirCleaner.scheduleAtFixedRate(new Runnable() {
-                @Override
-                public void run() {
-                    synchronized (stateLock) {
-                        if (state == State.RUNNING) {
-                            stateDirectory.cleanRemovedTasks(cleanupDelay);
-                        }
+        for (final StreamThread thread : threads) {
+            thread.start();
+        }
+        final Long cleanupDelay = config.getLong(StreamsConfig.STATE_CLEANUP_DELAY_MS_CONFIG);
+        stateDirCleaner.scheduleAtFixedRate(new Runnable() {
+            @Override
+            public void run() {
+                synchronized (stateLock) {
+                    if (state == State.RUNNING) {
+                        stateDirectory.cleanRemovedTasks(cleanupDelay);
                     }
                 }
-            }, cleanupDelay, cleanupDelay, TimeUnit.MILLISECONDS);
-
-            log.info("{} Started Kafka Stream process", logPrefix);
-        } else {
-            throw new IllegalStateException("Cannot start again.");
-        }
+            }
+        }, cleanupDelay, cleanupDelay, TimeUnit.MILLISECONDS);
 
+        log.info("{} Started Kafka Stream process", logPrefix);
     }
 
+
     /**
      * Shutdown this {@code KafkaStreams} instance by signaling all the threads to stop, and then wait for them to join.
      * This will block until all threads have stopped.
@@ -492,6 +615,25 @@ public class KafkaStreams {
         close(DEFAULT_CLOSE_TIMEOUT, TimeUnit.SECONDS);
     }
 
+    private boolean checkFirstTimeClosing() {
+        return setState(PENDING_SHUTDOWN);
+    }
+
+    private void closeGlobalStreamThread() {
+        if (globalStreamThread != null) {
+            globalStreamThread.setStateListener(null);
+            globalStreamThread.close();
+            if (!globalStreamThread.stillRunning()) {
+                try {
+                    globalStreamThread.join();
+                } catch (final InterruptedException e) {
+                    Thread.interrupted();
+                }
+            }
+            globalStreamThread = null;
+        }
+    }
+
     /**
      * Shutdown this {@code KafkaStreams} by signaling all the threads to stop, and then wait up to the timeout for the
      * threads to join.
@@ -501,59 +643,52 @@ public class KafkaStreams {
      * @param timeUnit unit of time used for timeout
      * @return {@code true} if all threads were successfully stopped&mdash;{@code false} if the timeout was reached
      * before all threads stopped
+     * Note that this method must not be called in the {@code onChange} callback of {@link StateListener}.
      */
     public synchronized boolean close(final long timeout, final TimeUnit timeUnit) {
         log.debug("{} Stopping Kafka Stream process.", logPrefix);
-        if (state.isCreatedOrRunning()) {
-            setState(State.PENDING_SHUTDOWN);
-            stateDirCleaner.shutdownNow();
-            // save the current thread so that if it is a stream thread
-            // we don't attempt to join it and cause a deadlock
-            final Thread shutdown = new Thread(new Runnable() {
-                @Override
-                public void run() {
-                    // signal the threads to stop and wait
-                    for (final StreamThread thread : threads) {
-                        // avoid deadlocks by stopping any further state reports
-                        // from the thread since we're shutting down
-                        thread.setStateListener(null);
-                        thread.close();
-                    }
-                    if (globalStreamThread != null) {
-                        globalStreamThread.close();
-                        if (!globalStreamThread.stillRunning()) {
-                            try {
-                                globalStreamThread.join();
-                            } catch (final InterruptedException e) {
-                                Thread.interrupted();
-                            }
-                        }
-                    }
-                    for (final StreamThread thread : threads) {
-                        try {
-                            if (!thread.stillRunning()) {
-                                thread.join();
-                            }
-                        } catch (final InterruptedException ex) {
-                            Thread.interrupted();
+
+        // only clean up once
+        if (!checkFirstTimeClosing()) {
+            return true;
+        }
+        stateDirCleaner.shutdownNow();
+        // save the current thread so that if it is a stream thread
+        // we don't attempt to join it and cause a deadlock
+        final Thread shutdown = new Thread(new Runnable() {
+            @Override
+            public void run() {
+                // signal the threads to stop and wait
+                for (final StreamThread thread : threads) {
+                    // avoid deadlocks by stopping any further state reports
+                    // from the thread since we're shutting down
+                    thread.setStateListener(null);
+                    thread.close();
+                }
+                closeGlobalStreamThread();
+                for (final StreamThread thread : threads) {
+                    try {
+                        if (!thread.stillRunning()) {
+                            thread.join();
                         }
+                    } catch (final InterruptedException ex) {
+                        Thread.interrupted();
                     }
-
-                    metrics.close();
-                    log.info("{} Stopped Kafka Streams process.", logPrefix);
                 }
-            }, "kafka-streams-close-thread");
-            shutdown.setDaemon(true);
-            shutdown.start();
-            try {
-                shutdown.join(TimeUnit.MILLISECONDS.convert(timeout, timeUnit));
-            } catch (final InterruptedException e) {
-                Thread.interrupted();
+
+                metrics.close();
+                log.info("{} Stopped Kafka Streams process.", logPrefix);
             }
-            setState(State.NOT_RUNNING);
-            return !shutdown.isAlive();
+        }, "kafka-streams-close-thread");
+        shutdown.setDaemon(true);
+        shutdown.start();
+        try {
+            shutdown.join(TimeUnit.MILLISECONDS.convert(timeout, timeUnit));
+        } catch (final InterruptedException e) {
+            Thread.interrupted();
         }
-        return true;
+        setState(State.NOT_RUNNING);
+        return !shutdown.isAlive();
     }
 
     /**
@@ -590,6 +725,12 @@ public class KafkaStreams {
         return sb.toString();
     }
 
+    private boolean isRunning() {
+        synchronized (stateLock) {
+            return state.isRunning();
+        }
+    }
+
     /**
      * Do a clean up of the local {@link StateStore} directory ({@link StreamsConfig#STATE_DIR_CONFIG}) by deleting all
      * data with regard to the {@link StreamsConfig#APPLICATION_ID_CONFIG application ID}.
@@ -602,7 +743,7 @@ public class KafkaStreams {
      * @throws IllegalStateException if the instance is currently running
      */
     public void cleanUp() {
-        if (state.isRunning()) {
+        if (isRunning()) {
             throw new IllegalStateException("Cannot clean up while running.");
         }
         stateDirectory.cleanRemovedTasks(0);
@@ -747,7 +888,7 @@ public class KafkaStreams {
     }
 
     private void validateIsRunning() {
-        if (!state.isRunning()) {
+        if (!isRunning()) {
             throw new IllegalStateException("KafkaStreams is not running. State is " + state + ".");
         }
     }

http://git-wip-us.apache.org/repos/asf/kafka/blob/77b81c02/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
index 36a248e..0faf81e 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
@@ -33,6 +33,11 @@ import org.slf4j.LoggerFactory;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.Map;
+import java.util.Set;
+import java.util.HashSet;
+import java.util.Arrays;
+import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.DEAD;
+import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.PENDING_SHUTDOWN;
 
 /**
  * This is the thread responsible for keeping all Global State Stores updated.
@@ -49,9 +54,112 @@ public class GlobalStreamThread extends Thread {
     private final ThreadCache cache;
     private final StreamsMetrics streamsMetrics;
     private final ProcessorTopology topology;
-    private volatile boolean running = false;
     private volatile StreamsException startupException;
 
+    /**
+     * The states that the global stream thread can be in
+     *
+     * <pre>
+     *                +-------------+
+     *          +<--- | Created     |
+     *          |     +-----+-------+
+     *          |           |
+     *          |           v
+     *          |     +-----+-------+
+     *          +<--- | Running     |
+     *          |     +-----+-------+
+     *          |           |
+     *          |           v
+     *          |     +-----+-------+
+     *          +---> | Pending     |
+     *                | Shutdown    |
+     *                +-----+-------+
+     *                      |
+     *                      v
+     *                +-----+-------+
+     *                | Dead        |
+     *                +-------------+
+     * </pre>
+     *
+     * Note the following:
+     * - Any state can go to PENDING_SHUTDOWN and subsequently to DEAD
+     *
+     */
+    public enum State implements ThreadStateTransitionValidator {
+        CREATED(1, 2), RUNNING(2), PENDING_SHUTDOWN(3), DEAD;
+
+        private final Set<Integer> validTransitions = new HashSet<>();
+
+        State(final Integer... validTransitions) {
+            this.validTransitions.addAll(Arrays.asList(validTransitions));
+        }
+
+        public boolean isRunning() {
+            return !equals(PENDING_SHUTDOWN) && !equals(CREATED) && !equals(DEAD);
+        }
+
+        public boolean isValidTransition(final ThreadStateTransitionValidator newState) {
+            State tmpState = (State) newState;
+            return validTransitions.contains(tmpState.ordinal());
+        }
+    }
+
+    private volatile State state = State.CREATED;
+    private final Object stateLock = new Object();
+    private StreamThread.StateListener stateListener = null;
+    private final String logPrefix;
+
+
+    /**
+     * Set the {@link StreamThread.StateListener} to be notified when state changes. Note this API is internal to
+     * Kafka Streams and is not intended to be used by an external application.
+     */
+    public void setStateListener(final StreamThread.StateListener listener) {
+        stateListener = listener;
+    }
+
+    /**
+     * @return The state this instance is in
+     */
+    public State state() {
+        synchronized (stateLock) {
+            return state;
+        }
+    }
+
+    /**
+     * Sets the state
+     * @param newState New state
+     * @param ignoreWhenShuttingDownOrDead,       if true, then we'll first check if the state is
+     *                                            PENDING_SHUTDOWN or DEAD, and if it is,
+     *                                            we immediately return. Effectively this enables
+     *                                            a conditional set, under the stateLock lock.
+     */
+    void setState(final State newState, boolean ignoreWhenShuttingDownOrDead) {
+        State oldState;
+        synchronized (stateLock) {
+            oldState = state;
+
+            if (ignoreWhenShuttingDownOrDead) {
+                if (state == PENDING_SHUTDOWN || state == DEAD) {
+                    return;
+                }
+            }
+
+            if (!state.isValidTransition(newState)) {
+                log.warn("{} Unexpected state transition from {} to {}.", logPrefix, oldState, newState);
+                throw new StreamsException(logPrefix + " Unexpected state transition from " + oldState + " to " + newState);
+            } else {
+                log.info("{} State transition from {} to {}.", logPrefix, oldState, newState);
+            }
+
+            state = newState;
+        }
+        if (stateListener != null) {
+            stateListener.onChange(this, state, oldState);
+        }
+    }
+
     public GlobalStreamThread(final ProcessorTopology topology,
                               final StreamsConfig config,
                               final Consumer<byte[], byte[]> globalConsumer,
@@ -69,6 +177,7 @@ public class GlobalStreamThread extends Thread {
                 (config.getInt(StreamsConfig.NUM_STREAM_THREADS_CONFIG) + 1));
         this.streamsMetrics = new StreamsMetricsImpl(metrics, threadClientId, Collections.singletonMap("client-id", threadClientId));
         this.cache = new ThreadCache(threadClientId, cacheSizeBytes, streamsMetrics);
+        this.logPrefix = String.format("global-stream-thread [%s]", threadClientId);
     }
 
     static class StateConsumer {
@@ -134,14 +243,19 @@ public class GlobalStreamThread extends Thread {
             return;
         }
 
+        // one could kill the thread before it had a chance to actually start
+        setState(State.RUNNING, true);
+
         try {
-            while (running) {
+            while (stillRunning()) {
                 stateConsumer.pollAndUpdate();
             }
             log.debug("Shutting down GlobalStreamThread at user request");
         } finally {
             try {
+                setState(PENDING_SHUTDOWN, true);
                 stateConsumer.close();
+                setState(DEAD, false);
             } catch (IOException e) {
                 log.error("Failed to cleanly shutdown GlobalStreamThread", e);
             }
@@ -164,7 +278,6 @@ public class GlobalStreamThread extends Thread {
                                         config.getLong(StreamsConfig.POLL_MS_CONFIG),
                                         config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG));
             stateConsumer.initialize();
-            running = true;
             return stateConsumer;
         } catch (StreamsException e) {
             startupException = e;
@@ -177,7 +290,7 @@ public class GlobalStreamThread extends Thread {
     @Override
     public synchronized void start() {
         super.start();
-        while (!running) {
+        while (!stillRunning()) {
             Utils.sleep(1);
             if (startupException != null) {
                 throw startupException;
@@ -185,14 +298,15 @@ public class GlobalStreamThread extends Thread {
         }
     }
 
-
     public void close() {
-        running = false;
+        // one could call close() multiple times, so ignore subsequent calls
+        // if already shutting down or dead
+        setState(PENDING_SHUTDOWN, true);
     }
 
     public boolean stillRunning() {
-        return running;
+        synchronized (stateLock) {
+            return state.isRunning();
+        }
     }
-
-
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/77b81c02/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index c2da0cb..d25af64 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -82,23 +82,23 @@ public class StreamThread extends Thread {
      *
      * <pre>
      *                +-------------+
-     *                | Created     |
-     *                +-----+-------+
-     *                      |
-     *                      v
-     *                +-----+-------+
+     *          +<--- | Created     |
+     *          |     +-----+-------+
+     *          |           |
+     *          |           v
+     *          |     +-----+-------+
      *          +<--- | Running     | <----+
      *          |     +-----+-------+      |
      *          |           |              |
      *          |           v              |
      *          |     +-----+-------+      |
-     *          +<--- | Partitions  |      |
-     *          |     | Revoked     |      |
+     *          +<--- | Partitions  | <-+  |
+     *          |     | Revoked     | --+  |
      *          |     +-----+-------+      |
      *          |           |              |
      *          |           v              |
      *          |     +-----+-------+      |
-     *          |     | Assigning   |      |
+     *          +<--- | Assigning   |      |
      *          |     | Partitions  | ---->+
      *          |     +-----+-------+
      *          |           |
@@ -113,9 +113,15 @@ public class StreamThread extends Thread {
      *                | Dead        |
      *                +-------------+
      * </pre>
+     *
+     * Note the following:
+     * - Any state can go to PENDING_SHUTDOWN followed by a subsequent transition to DEAD.
+     * - A streams thread can stay in PARTITIONS_REVOKED indefinitely, in the corner case when
+     *   the coordinator repeatedly fails in-between revoking partitions and assigning new partitions.
+     *
      */
-    public enum State {
-        CREATED(1), RUNNING(1, 2, 4), PARTITIONS_REVOKED(3, 4), ASSIGNING_PARTITIONS(1, 4), PENDING_SHUTDOWN(5), DEAD;
+    public enum State implements ThreadStateTransitionValidator {
+        CREATED(1, 4), RUNNING(2, 4), PARTITIONS_REVOKED(2, 3, 4), ASSIGNING_PARTITIONS(1, 4), PENDING_SHUTDOWN(5), DEAD;
 
         private final Set<Integer> validTransitions = new HashSet<>();
 
@@ -127,8 +133,10 @@ public class StreamThread extends Thread {
             return !equals(PENDING_SHUTDOWN) && !equals(CREATED) && !equals(DEAD);
         }
 
-        public boolean isValidTransition(final State newState) {
-            return validTransitions.contains(newState.ordinal());
+        @Override
+        public boolean isValidTransition(final ThreadStateTransitionValidator newState) {
+            State tmpState = (State) newState;
+            return validTransitions.contains(tmpState.ordinal());
         }
     }
 
@@ -143,7 +151,7 @@ public class StreamThread extends Thread {
          * @param newState     current state
          * @param oldState     previous state
          */
-        void onChange(final StreamThread thread, final State newState, final State oldState);
+        void onChange(final Thread thread, final ThreadStateTransitionValidator newState, final ThreadStateTransitionValidator oldState);
     }
 
     private class RebalanceListener implements ConsumerRebalanceListener {
@@ -175,7 +183,7 @@ public class StreamThread extends Thread {
             final long start = time.milliseconds();
             try {
                 storeChangelogReader = new StoreChangelogReader(getName(), restoreConsumer, time, requestTimeOut);
-                setStateWhenNotInPendingShutdown(State.ASSIGNING_PARTITIONS);
+                setState(State.ASSIGNING_PARTITIONS);
                 // do this first as we may have suspended standby tasks that
                 // will become active or vice versa
                 closeNonAssignedSuspendedStandbyTasks();
@@ -185,7 +193,7 @@ public class StreamThread extends Thread {
                 addStandbyTasks(start);
                 streamsMetadataState.onChange(partitionAssignor.getPartitionsByHostState(), partitionAssignor.clusterMetadata());
                 lastCleanMs = time.milliseconds(); // start the cleaning cycle
-                setStateWhenNotInPendingShutdown(State.RUNNING);
+                setState(State.RUNNING);
             } catch (final Throwable t) {
                 rebalanceException = t;
                 throw t;
@@ -214,7 +222,7 @@ public class StreamThread extends Thread {
 
             final long start = time.milliseconds();
             try {
-                setStateWhenNotInPendingShutdown(State.PARTITIONS_REVOKED);
+                setState(State.PARTITIONS_REVOKED);
                 lastCleanMs = Long.MAX_VALUE; // stop the cleaning cycle until partitions are assigned
                 // suspend active tasks
                 suspendTasksAndState();
@@ -378,6 +386,7 @@ public class StreamThread extends Thread {
 
 
     private volatile State state = State.CREATED;
+    private final Object stateLock = new Object();
     private StreamThread.StateListener stateListener = null;
     final PartitionGrouper partitionGrouper;
     private final StreamsMetadataState streamsMetadataState;
@@ -501,7 +510,6 @@ public class StreamThread extends Thread {
         lastCleanMs = Long.MAX_VALUE; // the cleaning cycle won't start until partition assignment
         lastCommitMs = timerStartedMs;
         rebalanceListener = new RebalanceListener(time, config.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG));
-        setState(State.RUNNING);
     }
 
     /**
@@ -513,7 +521,7 @@ public class StreamThread extends Thread {
     @Override
     public void run() {
         log.info("{} Starting", logPrefix);
-
+        setState(State.RUNNING);
         boolean cleanRun = false;
         try {
             runLoop();
@@ -895,17 +903,19 @@ public class StreamThread extends Thread {
 
     /**
      * Shutdown this stream thread.
+     * Note that there is nothing to prevent this function from being called multiple times
+     * (e.g., in testing), hence the state is set only the first time
      */
     public synchronized void close() {
         log.info("{} Informed thread to shut down", logPrefix);
         setState(State.PENDING_SHUTDOWN);
     }
 
-    public synchronized boolean isInitialized() {
+    public boolean isInitialized() {
         return state == State.RUNNING;
     }
 
-    public synchronized boolean stillRunning() {
+    public boolean stillRunning() {
         return state.isRunning();
     }
 
@@ -961,26 +971,42 @@ public class StreamThread extends Thread {
     /**
      * @return The state this instance is in
      */
-    public synchronized State state() {
+    public State state() {
         return state;
     }
 
-    private synchronized void setStateWhenNotInPendingShutdown(final State newState) {
-        if (state == State.PENDING_SHUTDOWN) {
-            return;
-        }
-        setState(newState);
-    }
+    /**
+     * Sets the state
+     * @param newState New state
+     */
+    void setState(final State newState) {
+        State oldState;
+        synchronized (stateLock) {
+            oldState = state;
+
+            // there are cases when we shouldn't check if a transition is valid, e.g.,
+            // when, for testing, a thread is closed multiple times. We could either
+            // check here and immediately return for those cases, or add them to the transition
+            // diagram (but then the diagram would be confusing and have transitions like
+            // PENDING_SHUTDOWN->PENDING_SHUTDOWN). These cases include:
+            // - normal close() sequence. State is set to PENDING_SHUTDOWN in close() as well as in shutdown().
+            // - calling close() on the thread after an exception within the thread has already called shutdown().
+
+            // note we could be going from PENDING_SHUTDOWN to DEAD, and we obviously want to allow that
+            // transition, hence the check newState != DEAD.
+            if (newState != State.DEAD &&
+                    (state == State.PENDING_SHUTDOWN || state == State.DEAD)) {
+                return;
+            }
+            if (!state.isValidTransition(newState)) {
+                log.warn("{} Unexpected state transition from {} to {}.", logPrefix, oldState, newState);
+                throw new StreamsException(logPrefix + " Unexpected state transition from " + oldState + " to " + newState);
+            } else {
+                log.info("{} State transition from {} to {}.", logPrefix, oldState, newState);
+            }
 
-    private synchronized void setState(final State newState) {
-        final State oldState = state;
-        if (!state.isValidTransition(newState)) {
-            log.warn("{} Unexpected state transition from {} to {}.", logPrefix, oldState, newState);
-        } else {
-            log.info("{} State transition from {} to {}.", logPrefix, oldState, newState);
+            state = newState;
         }
-
-        state = newState;
         if (stateListener != null) {
             stateListener.onChange(this, state, oldState);
         }
@@ -1038,6 +1064,7 @@ public class StreamThread extends Thread {
 
     private void shutdown(final boolean cleanRun) {
         log.info("{} Shutting down", logPrefix);
+        setState(State.PENDING_SHUTDOWN);
         shutdownTasksAndState(cleanRun);
 
         // close all embedded clients

http://git-wip-us.apache.org/repos/asf/kafka/blob/77b81c02/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadStateTransitionValidator.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadStateTransitionValidator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadStateTransitionValidator.java
new file mode 100644
index 0000000..4197c71
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadStateTransitionValidator.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+/**
+ * Basic interface for keeping track of the state of a thread.
+ */
+public interface ThreadStateTransitionValidator {
+    boolean isValidTransition(final ThreadStateTransitionValidator newState);
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/kafka/blob/77b81c02/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
index a03b7cc..985dc93 100644
--- a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
@@ -28,6 +28,8 @@ import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
 import org.apache.kafka.streams.kstream.ForeachAction;
 import org.apache.kafka.streams.kstream.KStreamBuilder;
 import org.apache.kafka.streams.processor.StreamPartitioner;
+import org.apache.kafka.streams.processor.internals.GlobalStreamThread;
+import org.apache.kafka.streams.processor.internals.StreamThread;
 import org.apache.kafka.test.IntegrationTest;
 import org.apache.kafka.test.MockMetricsReporter;
 import org.apache.kafka.test.TestCondition;
@@ -55,6 +57,7 @@ import static org.junit.Assert.assertTrue;
 public class KafkaStreamsTest {
 
     private static final int NUM_BROKERS = 1;
+    private static final int NUM_THREADS = 2;
     // We need this to avoid the KafkaConsumer hanging on poll (this may occur if the test doesn't complete
     // quick enough)
     @ClassRule
@@ -70,17 +73,14 @@ public class KafkaStreamsTest {
         props.setProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
         props.setProperty(StreamsConfig.METRIC_REPORTER_CLASSES_CONFIG, MockMetricsReporter.class.getName());
         props.setProperty(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath());
+        props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, NUM_THREADS);
         streams = new KafkaStreams(builder, props);
     }
 
     @Test
-    public void testInitializesAndDestroysMetricsReporters() throws Exception {
-        final int oldInitCount = MockMetricsReporter.INIT_COUNT.get();
+    public void testStateChanges() throws Exception {
         final KStreamBuilder builder = new KStreamBuilder();
         final KafkaStreams streams = new KafkaStreams(builder, props);
-        final int newInitCount = MockMetricsReporter.INIT_COUNT.get();
-        final int initDiff = newInitCount - oldInitCount;
-        assertTrue("some reporters should be initialized by calling on construction", initDiff > 0);
 
         StateListenerStub stateListener = new StateListenerStub();
         streams.setStateListener(stateListener);
@@ -88,19 +88,137 @@ public class KafkaStreamsTest {
         Assert.assertEquals(stateListener.numChanges, 0);
 
         streams.start();
-        Assert.assertEquals(streams.state(), KafkaStreams.State.RUNNING);
-        Assert.assertEquals(stateListener.numChanges, 1);
-        Assert.assertEquals(stateListener.oldState, KafkaStreams.State.CREATED);
-        Assert.assertEquals(stateListener.newState, KafkaStreams.State.RUNNING);
-        Assert.assertEquals(stateListener.mapStates.get(KafkaStreams.State.RUNNING).longValue(), 1L);
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return streams.state() == KafkaStreams.State.RUNNING;
+            }
+        }, 10 * 1000, "Streams never started.");
+        streams.close();
+        Assert.assertEquals(streams.state(), KafkaStreams.State.NOT_RUNNING);
+    }
+
+    @Test
+    public void testStateCloseAfterCreate() throws Exception {
+        final KStreamBuilder builder = new KStreamBuilder();
+        final KafkaStreams streams = new KafkaStreams(builder, props);
+
+        StateListenerStub stateListener = new StateListenerStub();
+        streams.setStateListener(stateListener);
+        streams.close();
+        Assert.assertEquals(streams.state(), KafkaStreams.State.NOT_RUNNING);
+    }
+
+    private void testStateThreadCloseHelper(final int numThreads) throws Exception {
+        final java.lang.reflect.Field threadsField = streams.getClass().getDeclaredField("threads");
+        threadsField.setAccessible(true);
+        final StreamThread[] threads = (StreamThread[]) threadsField.get(streams);
+
+        assertEquals(numThreads, threads.length);
+        assertEquals(streams.state(), KafkaStreams.State.CREATED);
+
+        streams.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return streams.state() == KafkaStreams.State.RUNNING;
+            }
+        }, 10 * 1000, "Streams never started.");
+
+        for (int i = 0; i < numThreads; i++) {
+            final StreamThread tmpThread = threads[i];
+            tmpThread.close();
+            TestUtils.waitForCondition(new TestCondition() {
+                @Override
+                public boolean conditionMet() {
+                    return tmpThread.state() == StreamThread.State.DEAD;
+                }
+            }, 10 * 1000, "Thread never stopped.");
+            threads[i].join();
+        }
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return streams.state() == KafkaStreams.State.ERROR;
+            }
+        }, 10 * 1000, "Streams never stopped.");
+        streams.close();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return streams.state() == KafkaStreams.State.NOT_RUNNING;
+            }
+        }, 10 * 1000, "Streams never stopped.");
+
+        final java.lang.reflect.Field globalThreadField = streams.getClass().getDeclaredField("globalStreamThread");
+        globalThreadField.setAccessible(true);
+        GlobalStreamThread globalStreamThread = (GlobalStreamThread) globalThreadField.get(streams);
+        assertEquals(globalStreamThread, null);
+    }
+
+    @Test
+    public void testStateThreadClose() throws Exception {
+        final int numThreads = 2;
+        final KStreamBuilder builder = new KStreamBuilder();
+        // make sure we have the global state thread running too
+        builder.globalTable("anyTopic");
+        props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, numThreads);
+        final KafkaStreams streams = new KafkaStreams(builder, props);
+
+        testStateThreadCloseHelper(numThreads);
+    }
+    
+    @Test
+    public void testStateGlobalThreadClose() throws Exception {
+        final int numThreads = 2;
+        final KStreamBuilder builder = new KStreamBuilder();
+        // make sure we have the global state thread running too
+        builder.globalTable("anyTopic");
+        props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, numThreads);
+        final KafkaStreams streams = new KafkaStreams(builder, props);
+
+
+        streams.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return streams.state() == KafkaStreams.State.RUNNING;
+            }
+        }, 10 * 1000, "Streams never started.");
+        final java.lang.reflect.Field globalThreadField = streams.getClass().getDeclaredField("globalStreamThread");
+        globalThreadField.setAccessible(true);
+        final GlobalStreamThread globalStreamThread = (GlobalStreamThread) globalThreadField.get(streams);
+        globalStreamThread.close();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return globalStreamThread.state() == GlobalStreamThread.State.DEAD;
+            }
+        }, 10 * 1000, "Thread never stopped.");
+        globalStreamThread.join();
+        assertEquals(streams.state(), KafkaStreams.State.ERROR);
+
+        streams.close();
+        assertEquals(streams.state(), KafkaStreams.State.NOT_RUNNING);
+
+    }
+
+    @Test
+    public void testInitializesAndDestroysMetricsReporters() throws Exception {
+        final int oldInitCount = MockMetricsReporter.INIT_COUNT.get();
+        final KStreamBuilder builder = new KStreamBuilder();
+        final KafkaStreams streams = new KafkaStreams(builder, props);
+        final int newInitCount = MockMetricsReporter.INIT_COUNT.get();
+        final int initDiff = newInitCount - oldInitCount;
+        assertTrue("some reporters should be initialized by calling on construction", initDiff > 0);
+
+        streams.start();
         final int oldCloseCount = MockMetricsReporter.CLOSE_COUNT.get();
         streams.close();
         assertEquals(oldCloseCount + initDiff, MockMetricsReporter.CLOSE_COUNT.get());
-        Assert.assertEquals(streams.state(), KafkaStreams.State.NOT_RUNNING);
-        Assert.assertEquals(stateListener.mapStates.get(KafkaStreams.State.RUNNING).longValue(), 1L);
-        Assert.assertEquals(stateListener.mapStates.get(KafkaStreams.State.NOT_RUNNING).longValue(), 1L);
     }
 
+
     @Test
     public void testCloseIsIdempotent() throws Exception {
         streams.close();
@@ -362,6 +480,23 @@ public class KafkaStreamsTest {
         public KafkaStreams.State oldState;
         public KafkaStreams.State newState;
         public Map<KafkaStreams.State, Long> mapStates = new HashMap<>();
+        private final boolean closeOnChange;
+        private final KafkaStreams streams;
+
+        public StateListenerStub() {
+            this.closeOnChange = false;
+            this.streams = null;
+        }
+
+        /**
+         * For testing only, we might want to test closing streams on a transition change
+         * @param closeOnChange
+         * @param streams
+         */
+        public StateListenerStub(final boolean closeOnChange, final KafkaStreams streams) {
+            this.closeOnChange = closeOnChange;
+            this.streams = streams;
+        }
 
         @Override
         public void onChange(final KafkaStreams.State newState, final KafkaStreams.State oldState) {
@@ -370,6 +505,11 @@ public class KafkaStreamsTest {
             this.oldState = oldState;
             this.newState = newState;
             this.mapStates.put(newState, prevCount + 1);
+            if (this.closeOnChange &&
+                    (newState == KafkaStreams.State.NOT_RUNNING || newState == KafkaStreams.State.ERROR)) {
+                streams.close();
+            }
         }
+
     }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/77b81c02/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java
index 30582ed..5546474 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java
@@ -27,6 +27,7 @@ import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.kstream.KStreamBuilder;
 import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.test.TestCondition;
 import org.apache.kafka.test.TestUtils;
 import org.junit.Before;
 import org.junit.Test;
@@ -35,11 +36,13 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 
+import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.RUNNING;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.core.IsInstanceOf.instanceOf;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
 
 public class GlobalStreamThreadTest {
@@ -107,7 +110,7 @@ public class GlobalStreamThreadTest {
 
 
     @Test
-    public void shouldBeRunningAfterSuccesulStart() throws Exception {
+    public void shouldBeRunningAfterSuccessfulStart() throws Exception {
         initializeConsumer();
         globalStreamThread.start();
         assertTrue(globalStreamThread.stillRunning());
@@ -119,6 +122,7 @@ public class GlobalStreamThreadTest {
         globalStreamThread.start();
         globalStreamThread.close();
         globalStreamThread.join();
+        assertEquals(GlobalStreamThread.State.DEAD, globalStreamThread.state());
     }
 
     @Test
@@ -132,6 +136,46 @@ public class GlobalStreamThreadTest {
         assertFalse(globalStore.isOpen());
     }
 
+    @SuppressWarnings("unchecked")
+    @Test
+    public void shouldTransitionToDeadOnClose() throws InterruptedException {
+
+        initializeConsumer();
+        globalStreamThread.start();
+        globalStreamThread.close();
+        globalStreamThread.join();
+
+        assertEquals(GlobalStreamThread.State.DEAD, globalStreamThread.state());
+    }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void shouldStayDeadAfterTwoCloses() throws InterruptedException {
+
+        initializeConsumer();
+        globalStreamThread.start();
+        globalStreamThread.close();
+        globalStreamThread.join();
+        globalStreamThread.close();
+
+        assertEquals(GlobalStreamThread.State.DEAD, globalStreamThread.state());
+    }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void shouldTransitiontoRunningOnStart() throws InterruptedException {
+
+        initializeConsumer();
+        globalStreamThread.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return globalStreamThread.state() == RUNNING;
+            }
+        }, 10 * 1000, "Thread never started.");
+        globalStreamThread.close();
+    }
+
     private void initializeConsumer() {
         mockConsumer.updatePartitions("foo", Collections.singletonList(new PartitionInfo("foo",
                                                                                          0,

http://git-wip-us.apache.org/repos/asf/kafka/blob/77b81c02/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 3b280f1..ded1bfd 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -268,10 +268,10 @@ public class StreamThreadTest {
 
         final StateListenerStub stateListener = new StateListenerStub();
         thread.setStateListener(stateListener);
-        assertEquals(thread.state(), StreamThread.State.RUNNING);
+        assertEquals(thread.state(), StreamThread.State.CREATED);
 
         final ConsumerRebalanceListener rebalanceListener = thread.rebalanceListener;
-
+        thread.setState(StreamThread.State.RUNNING);
         assertTrue(thread.tasks().isEmpty());
 
         List<TopicPartition> revokedPartitions;
@@ -284,9 +284,6 @@ public class StreamThreadTest {
         rebalanceListener.onPartitionsRevoked(revokedPartitions);
 
         assertEquals(thread.state(), StreamThread.State.PARTITIONS_REVOKED);
-        Assert.assertEquals(stateListener.numChanges, 1);
-        Assert.assertEquals(stateListener.oldState, StreamThread.State.RUNNING);
-        Assert.assertEquals(stateListener.newState, StreamThread.State.PARTITIONS_REVOKED);
 
         // assign single partition
         assignedPartitions = Collections.singletonList(t1p1);
@@ -295,9 +292,8 @@ public class StreamThreadTest {
         rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
         assertEquals(thread.state(), StreamThread.State.RUNNING);
-        Assert.assertEquals(stateListener.numChanges, 3);
+        Assert.assertEquals(stateListener.numChanges, 4);
         Assert.assertEquals(stateListener.oldState, StreamThread.State.ASSIGNING_PARTITIONS);
-        Assert.assertEquals(stateListener.newState, StreamThread.State.RUNNING);
         assertTrue(thread.tasks().containsKey(task1));
         assertEquals(expectedGroup1, thread.tasks().get(task1).partitions());
         assertEquals(1, thread.tasks().size());
@@ -346,8 +342,7 @@ public class StreamThreadTest {
         assertTrue(thread.tasks().isEmpty());
 
         thread.close();
-        assertTrue((thread.state() == StreamThread.State.PENDING_SHUTDOWN) ||
-            (thread.state() == StreamThread.State.CREATED));
+        assertTrue(thread.state() == StreamThread.State.PENDING_SHUTDOWN);
     }
 
     @SuppressWarnings("unchecked")
@@ -370,10 +365,10 @@ public class StreamThreadTest {
 
         final StateListenerStub stateListener = new StateListenerStub();
         thread.setStateListener(stateListener);
-        assertEquals(thread.state(), StreamThread.State.RUNNING);
+        assertEquals(thread.state(), StreamThread.State.CREATED);
 
         final ConsumerRebalanceListener rebalanceListener = thread.rebalanceListener;
-
+        thread.setState(StreamThread.State.RUNNING);
         assertTrue(thread.tasks().isEmpty());
 
         List<TopicPartition> revokedPartitions;
@@ -386,9 +381,6 @@ public class StreamThreadTest {
         rebalanceListener.onPartitionsRevoked(revokedPartitions);
 
         assertEquals(thread.state(), StreamThread.State.PARTITIONS_REVOKED);
-        Assert.assertEquals(stateListener.numChanges, 1);
-        Assert.assertEquals(stateListener.oldState, StreamThread.State.RUNNING);
-        Assert.assertEquals(stateListener.newState, StreamThread.State.PARTITIONS_REVOKED);
 
         // assign four new partitions of second subtopology
         assignedPartitions = Arrays.asList(t2p1, t2p2, t3p1, t3p2);
@@ -444,10 +436,7 @@ public class StreamThreadTest {
         assertTrue(thread.tasks().isEmpty());
 
         thread.close();
-
-        assertTrue((thread.state() == StreamThread.State.PENDING_SHUTDOWN) ||
-            (thread.state() == StreamThread.State.CREATED));
-
+        assertEquals(thread.state(), StreamThread.State.PENDING_SHUTDOWN);
     }
 
     @SuppressWarnings("unchecked")
@@ -485,6 +474,7 @@ public class StreamThreadTest {
             }
         }, 10 * 1000, "Thread never shut down.");
         thread.close();
+        assertEquals(thread.state(), StreamThread.State.DEAD);
     }
 
     private final static String TOPIC = "topic";
@@ -493,18 +483,18 @@ public class StreamThreadTest {
 
     @SuppressWarnings("unchecked")
     @Test
-    public void testHandingOverTaskFromOneToAnotherThread() throws Exception {
+    public void testHandingOverTaskFromOneToAnotherThread() throws InterruptedException {
         builder.addStateStore(
-            Stores
-                .create("store")
-                .withByteArrayKeys()
-                .withByteArrayValues()
-                .persistent()
-                .build()
+                Stores
+                        .create("store")
+                        .withByteArrayKeys()
+                        .withByteArrayValues()
+                        .persistent()
+                        .build()
         );
         builder.addSource("source", TOPIC);
 
-        clientSupplier.consumer.assign(Arrays.asList(new TopicPartition(TOPIC, 0), new TopicPartition(TOPIC, 1)));
+        //clientSupplier.consumer.assign(Arrays.asList(new TopicPartition(TOPIC, 0), new TopicPartition(TOPIC, 1)));
 
         final StreamThread thread1 = new StreamThread(
             builder,
@@ -540,6 +530,23 @@ public class StreamThreadTest {
         thread1.setPartitionAssignor(new MockStreamsPartitionAssignor(thread1Assignment));
         thread2.setPartitionAssignor(new MockStreamsPartitionAssignor(thread2Assignment));
 
+
+        thread1.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return thread1.state() == StreamThread.State.RUNNING;
+            }
+        }, 10 * 1000, "Thread never started.");
+
+        thread2.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return thread2.state() == StreamThread.State.RUNNING;
+            }
+        }, 10 * 1000, "Thread never started.");
+
         // revoke (to get threads in correct state)
         thread1.rebalanceListener.onPartitionsRevoked(EMPTY_SET);
         thread2.rebalanceListener.onPartitionsRevoked(EMPTY_SET);
@@ -613,7 +620,7 @@ public class StreamThreadTest {
     }
 
     @Test
-    public void testMetrics() throws Exception {
+    public void testMetrics() {
         final StreamThread thread = new StreamThread(
             builder,
             config,
@@ -716,7 +723,8 @@ public class StreamThreadTest {
             //
             revokedPartitions = Collections.emptyList();
             assignedPartitions = Arrays.asList(t1p1, t1p2);
-
+            thread.setState(StreamThread.State.RUNNING);
+            thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
             rebalanceListener.onPartitionsRevoked(revokedPartitions);
             rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
@@ -778,6 +786,8 @@ public class StreamThreadTest {
         assignment.put(new TaskId(0, 1), Collections.singleton(new TopicPartition("someTopic", 1)));
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(assignment));
 
+        thread.setState(StreamThread.State.RUNNING);
+        thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(Collections.singleton(new TopicPartition("someTopic", 0)));
 
         assertEquals(1, clientSupplier.producers.size());
@@ -816,6 +826,8 @@ public class StreamThreadTest {
 
         final Set<TopicPartition> assignedPartitions = new HashSet<>();
         Collections.addAll(assignedPartitions, new TopicPartition("someTopic", 0), new TopicPartition("someTopic", 2));
+        thread.setState(StreamThread.State.RUNNING);
+        thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
         assertNull(thread.threadProducer);
@@ -850,7 +862,8 @@ public class StreamThreadTest {
         assignment.put(new TaskId(0, 0), Collections.singleton(new TopicPartition("someTopic", 0)));
         assignment.put(new TaskId(0, 1), Collections.singleton(new TopicPartition("someTopic", 1)));
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(assignment));
-
+        thread.setState(StreamThread.State.RUNNING);
+        thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(Collections.singleton(new TopicPartition("someTopic", 0)));
 
         thread.close();
@@ -883,6 +896,8 @@ public class StreamThreadTest {
         assignment.put(new TaskId(0, 1), Collections.singleton(new TopicPartition("someTopic", 1)));
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(assignment));
 
+        thread.setState(StreamThread.State.RUNNING);
+        thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(Collections.singleton(new TopicPartition("someTopic", 0)));
 
         thread.close();
@@ -915,6 +930,7 @@ public class StreamThreadTest {
             }
         });
 
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(Collections.<TopicPartition>emptyList());
     }
@@ -968,6 +984,8 @@ public class StreamThreadTest {
             }
         });
 
+        thread.setState(StreamThread.State.RUNNING);
+        thread.setState(StreamThread.State.PARTITIONS_REVOKED);
         thread.rebalanceListener.onPartitionsAssigned(activeTasks);
         thread.rebalanceListener.onPartitionsRevoked(activeTasks);
 
@@ -1026,6 +1044,7 @@ public class StreamThreadTest {
             }
         });
 
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(Collections.<TopicPartition>emptyList());
 
@@ -1100,6 +1119,7 @@ public class StreamThreadTest {
             }
         });
 
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(Utils.mkSet(t2));
 
@@ -1175,6 +1195,7 @@ public class StreamThreadTest {
         builder.updateSubscriptions(subscriptionUpdates, null);
 
         // should create task for id 0_0 with a single partition
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(task00Partitions);
 
@@ -1220,14 +1241,19 @@ public class StreamThreadTest {
         activeTasks.put(task1, task0Assignment);
 
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(activeTasks));
+        thread.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return thread.state() == StreamThread.State.RUNNING;
+            }
+        }, 10 * 1000, "Thread never started.");
 
         thread.rebalanceListener.onPartitionsRevoked(null);
         thread.rebalanceListener.onPartitionsAssigned(task0Assignment);
         assertThat(thread.tasks().size(), equalTo(1));
         final MockProducer producer = clientSupplier.producers.get(0);
 
-        thread.start();
-
         TestUtils.waitForCondition(
             new TestCondition() {
                 @Override
@@ -1309,7 +1335,7 @@ public class StreamThreadTest {
         activeTasks.put(task1, task0Assignment);
 
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(activeTasks));
-
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(null);
         thread.rebalanceListener.onPartitionsAssigned(task0Assignment);
         assertThat(thread.tasks().size(), equalTo(1));
@@ -1371,11 +1397,10 @@ public class StreamThreadTest {
         activeTasks.put(testStreamTask.id(), testStreamTask.partitions);
 
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(activeTasks));
-
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
 
-        thread.start();
         thread.close();
         thread.join();
         assertFalse("task shouldn't have been committed as there was an exception during shutdown", testStreamTask.committed);
@@ -1428,13 +1453,18 @@ public class StreamThreadTest {
 
 
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(activeTasks));
-
+        thread.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return thread.state() == StreamThread.State.RUNNING;
+            }
+        }, 10 * 1000, "Thread never started.");
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
         // store should have been opened
         assertTrue(stateStore.isOpen());
 
-        thread.start();
         thread.close();
         thread.join();
         assertFalse("task shouldn't have been committed as there was an exception during shutdown", testStreamTask.committed);
@@ -1487,7 +1517,7 @@ public class StreamThreadTest {
         activeTasks.put(testStreamTask.id(), testStreamTask.partitions);
 
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(activeTasks));
-
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
 
@@ -1546,7 +1576,7 @@ public class StreamThreadTest {
 
 
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(activeTasks));
-
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
         try {
@@ -1607,7 +1637,7 @@ public class StreamThreadTest {
 
 
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(activeTasks));
-
+        thread.setState(StreamThread.State.RUNNING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptyList());
         thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
         try {
@@ -1730,7 +1760,6 @@ public class StreamThreadTest {
     @Test
     public void shouldReleaseStateDirLockIfFailureOnTaskCloseForSuspendedTask() throws Exception {
         final TaskId taskId = new TaskId(0, 0);
-
         final StateDirectory stateDirMock = mockStateDirInteractions(taskId);
 
         final StreamThread thread = setupTest(taskId, stateDirMock);
@@ -1739,6 +1768,7 @@ public class StreamThreadTest {
         EasyMock.verify(stateDirMock);
     }
 
+
     private StreamThread setupTest(final TaskId taskId, final StateDirectory stateDirectory) throws InterruptedException {
         final TopologyBuilder builder = new TopologyBuilder();
         builder.setApplicationId(applicationId);
@@ -1763,7 +1793,6 @@ public class StreamThreadTest {
             }
         };
 
-
         final StreamThread thread = new StreamThread(
             builder,
             config,
@@ -1786,6 +1815,12 @@ public class StreamThreadTest {
         activeTasks.put(testStreamTask.id, testStreamTask.partitions);
         thread.setPartitionAssignor(new MockStreamsPartitionAssignor(activeTasks));
         thread.start();
+        TestUtils.waitForCondition(new TestCondition() {
+            @Override
+            public boolean conditionMet() {
+                return thread.state() == StreamThread.State.RUNNING;
+            }
+        }, "thread didn't transition to running");
         thread.rebalanceListener.onPartitionsRevoked(Collections.<TopicPartition>emptySet());
         thread.rebalanceListener.onPartitionsAssigned(testStreamTask.partitions);
 
@@ -1795,7 +1830,6 @@ public class StreamThreadTest {
     @Test
     public void shouldReleaseStateDirLockIfFailureOnStandbyTaskSuspend() throws Exception {
         final TaskId taskId = new TaskId(0, 0);
-
         final StateDirectory stateDirMock = mockStateDirInteractions(taskId);
         final StreamThread thread = setupStandbyTest(taskId, stateDirMock);
 
@@ -1808,7 +1842,6 @@ public class StreamThreadTest {
             // ok
         } finally {
             thread.close();
-
         }
         EasyMock.verify(stateDirMock);
     }
@@ -1828,7 +1861,6 @@ public class StreamThreadTest {
     @Test
     public void shouldReleaseStateDirLockIfFailureOnStandbyTaskCloseForUnassignedSuspendedStandbyTask() throws Exception {
         final TaskId taskId = new TaskId(0, 0);
-
         final StateDirectory stateDirMock = mockStateDirInteractions(taskId);
         final StreamThread thread = setupStandbyTest(taskId, stateDirMock);
         startThreadAndRebalance(thread);
@@ -1936,11 +1968,12 @@ public class StreamThreadTest {
 
     private static class StateListenerStub implements StreamThread.StateListener {
         int numChanges = 0;
-        StreamThread.State oldState = null;
-        StreamThread.State newState = null;
+        ThreadStateTransitionValidator oldState = null;
+        ThreadStateTransitionValidator newState = null;
 
         @Override
-        public void onChange(final StreamThread thread, final StreamThread.State newState, final StreamThread.State oldState) {
+        public void onChange(final Thread thread, final ThreadStateTransitionValidator newState,
+                             final ThreadStateTransitionValidator oldState) {
             ++numChanges;
             if (this.newState != null) {
                 if (this.newState != oldState) {


Mime
View raw message