kafka-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From guozh...@apache.org
Subject [2/3] kafka git commit: KAFKA-5152: perform state restoration in poll loop
Date Tue, 22 Aug 2017 18:13:12 GMT
http://git-wip-us.apache.org/repos/asf/kafka/blob/b78d7ba5/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index 4bdad9a..151ef35 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -16,26 +16,18 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
-import org.apache.kafka.clients.consumer.CommitFailedException;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.common.errors.ProducerFencedException;
-import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.processor.TaskId;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.TreeSet;
-import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicReference;
 
 import static java.util.Collections.singleton;
@@ -45,15 +37,9 @@ class TaskManager {
     // activeTasks needs to be concurrent as it can be accessed
     // by QueryableState
     private static final Logger log = LoggerFactory.getLogger(TaskManager.class);
-    private final Map<TaskId, Task> activeTasks = new ConcurrentHashMap<>();
-    private final Map<TaskId, Task> standbyTasks = new HashMap<>();
-    private final Map<TopicPartition, Task> activeTasksByPartition = new HashMap<>();
-    private final Map<TopicPartition, Task> standbyTasksByPartition = new HashMap<>();
-    private final Set<TaskId> prevActiveTasks = new TreeSet<>();
-    private final Map<TaskId, Task> suspendedTasks = new HashMap<>();
-    private final Map<TaskId, Task> suspendedStandbyTasks = new HashMap<>();
+    private final AssignedTasks active;
+    private final AssignedTasks standby;
     private final ChangelogReader changelogReader;
-    private final Time time;
     private final String logPrefix;
     private final Consumer<byte[], byte[]> restoreConsumer;
     private final StreamThread.AbstractTaskCreator taskCreator;
@@ -62,17 +48,19 @@ class TaskManager {
     private Consumer<byte[], byte[]> consumer;
 
     TaskManager(final ChangelogReader changelogReader,
-                final Time time,
                 final String logPrefix,
                 final Consumer<byte[], byte[]> restoreConsumer,
                 final StreamThread.AbstractTaskCreator taskCreator,
-                final StreamThread.AbstractTaskCreator standbyTaskCreator) {
+                final StreamThread.AbstractTaskCreator standbyTaskCreator,
+                final AssignedTasks active,
+                final AssignedTasks standby) {
         this.changelogReader = changelogReader;
-        this.time = time;
         this.logPrefix = logPrefix;
         this.restoreConsumer = restoreConsumer;
         this.taskCreator = taskCreator;
         this.standbyTaskCreator = standbyTaskCreator;
+        this.active = active;
+        this.standby = standby;
     }
 
     void createTasks(final Collection<TopicPartition> assignment) {
@@ -83,63 +71,29 @@ class TaskManager {
             throw new IllegalStateException(logPrefix + " consumer has not been initialized while adding stream tasks. This should not happen.");
         }
 
-        final long start = time.milliseconds();
-        changelogReader.clear();
+        changelogReader.reset();
         // do this first as we may have suspended standby tasks that
         // will become active or vice versa
-        closeNonAssignedSuspendedStandbyTasks();
+        standby.closeNonAssignedSuspendedTasks(threadMetadataProvider.standbyTasks());
         Map<TaskId, Set<TopicPartition>> assignedActiveTasks = threadMetadataProvider.activeTasks();
-        closeNonAssignedSuspendedTasks(assignedActiveTasks);
-        addStreamTasks(assignment, assignedActiveTasks, start);
-        changelogReader.restore();
-        addStandbyTasks(start);
+        active.closeNonAssignedSuspendedTasks(assignedActiveTasks);
+        addStreamTasks(assignment);
+        addStandbyTasks();
+        final Set<TopicPartition> partitions = active.uninitializedPartitions();
+        log.trace("{} pausing partitions: {}", logPrefix, partitions);
+        consumer.pause(partitions);
     }
 
     void setThreadMetadataProvider(final ThreadMetadataProvider threadMetadataProvider) {
         this.threadMetadataProvider = threadMetadataProvider;
     }
 
-    private void closeNonAssignedSuspendedStandbyTasks() {
-        final Set<TaskId> currentSuspendedTaskIds = threadMetadataProvider.standbyTasks().keySet();
-        final Iterator<Map.Entry<TaskId, Task>> standByTaskIterator = suspendedStandbyTasks.entrySet().iterator();
-        while (standByTaskIterator.hasNext()) {
-            final Map.Entry<TaskId, Task> suspendedTask = standByTaskIterator.next();
-            if (!currentSuspendedTaskIds.contains(suspendedTask.getKey())) {
-                final Task task = suspendedTask.getValue();
-                log.debug("{} Closing suspended and not re-assigned standby task {}", logPrefix, task.id());
-                try {
-                    task.close(true);
-                } catch (final Exception e) {
-                    log.error("{} Failed to remove suspended standby task {} due to the following error:", logPrefix, task.id(), e);
-                } finally {
-                    standByTaskIterator.remove();
-                }
-            }
-        }
-    }
-
-    private void closeNonAssignedSuspendedTasks(final Map<TaskId, Set<TopicPartition>> newTaskAssignment) {
-        final Iterator<Map.Entry<TaskId, Task>> suspendedTaskIterator = suspendedTasks.entrySet().iterator();
-        while (suspendedTaskIterator.hasNext()) {
-            final Map.Entry<TaskId, Task> next = suspendedTaskIterator.next();
-            final Task task = next.getValue();
-            final Set<TopicPartition> assignedPartitionsForTask = newTaskAssignment.get(next.getKey());
-            if (!task.partitions().equals(assignedPartitionsForTask)) {
-                log.debug("{} Closing suspended and not re-assigned task {}", logPrefix, task.id());
-                try {
-                    task.closeSuspended(true, null);
-                } catch (final Exception e) {
-                    log.error("{} Failed to close suspended task {} due to the following error:", logPrefix, next.getKey(), e);
-                } finally {
-                    suspendedTaskIterator.remove();
-                }
-            }
+    private void addStreamTasks(final Collection<TopicPartition> assignment) {
+        Map<TaskId, Set<TopicPartition>> assignedTasks = threadMetadataProvider.activeTasks();
+        if (assignedTasks.isEmpty()) {
+            return;
         }
-    }
-
-    private void addStreamTasks(final Collection<TopicPartition> assignment, final Map<TaskId, Set<TopicPartition>> assignedTasks, final long start) {
         final Map<TaskId, Set<TopicPartition>> newTasks = new HashMap<>();
-
         // collect newly assigned tasks and reopen re-assigned tasks
         log.debug("{} Adding assigned tasks as active: {}", logPrefix, assignedTasks);
         for (final Map.Entry<TaskId, Set<TopicPartition>> entry : assignedTasks.entrySet()) {
@@ -148,17 +102,7 @@ class TaskManager {
 
             if (assignment.containsAll(partitions)) {
                 try {
-                    final Task task = findMatchingSuspendedTask(taskId, partitions);
-                    if (task != null) {
-                        suspendedTasks.remove(taskId);
-                        task.resume();
-
-                        activeTasks.put(taskId, task);
-
-                        for (final TopicPartition partition : partitions) {
-                            activeTasksByPartition.put(partition, task);
-                        }
-                    } else {
+                    if (!active.maybeResumeSuspendedTask(taskId, partitions)) {
                         newTasks.put(taskId, partitions);
                     }
                 } catch (final StreamsException e) {
@@ -170,135 +114,59 @@ class TaskManager {
             }
         }
 
+        if (newTasks.isEmpty()) {
+            return;
+        }
+
         // create all newly assigned tasks (guard against race condition with other thread via backoff and retry)
         // -> other thread will call removeSuspendedTasks(); eventually
         log.trace("{} New active tasks to be created: {}", logPrefix, newTasks);
 
-        if (!newTasks.isEmpty()) {
-            final Map<Task, Set<TopicPartition>> createdTasks = taskCreator.retryWithBackoff(consumer, newTasks, start);
-            for (final Map.Entry<Task, Set<TopicPartition>> entry : createdTasks.entrySet()) {
-                final Task task = entry.getKey();
-                activeTasks.put(task.id(), task);
-                for (final TopicPartition partition : entry.getValue()) {
-                    activeTasksByPartition.put(partition, task);
-                }
-            }
+        for (final Task task : taskCreator.createTasks(consumer, newTasks)) {
+            active.addNewTask(task);
         }
     }
 
-    private void addStandbyTasks(final long start) {
-        final Map<TopicPartition, Long> checkpointedOffsets = new HashMap<>();
-
-        final Map<TaskId, Set<TopicPartition>> newStandbyTasks = new HashMap<>();
-
-        Map<TaskId, Set<TopicPartition>> assignedStandbyTasks = threadMetadataProvider.standbyTasks();
+    private void addStandbyTasks() {
+        final Map<TaskId, Set<TopicPartition>> assignedStandbyTasks = threadMetadataProvider.standbyTasks();
+        if (assignedStandbyTasks.isEmpty()) {
+            return;
+        }
         log.debug("{} Adding assigned standby tasks {}", logPrefix, assignedStandbyTasks);
+        final Map<TaskId, Set<TopicPartition>> newStandbyTasks = new HashMap<>();
         // collect newly assigned standby tasks and reopen re-assigned standby tasks
         for (final Map.Entry<TaskId, Set<TopicPartition>> entry : assignedStandbyTasks.entrySet()) {
             final TaskId taskId = entry.getKey();
             final Set<TopicPartition> partitions = entry.getValue();
-            final Task task = findMatchingSuspendedStandbyTask(taskId, partitions);
-
-            if (task != null) {
-                suspendedStandbyTasks.remove(taskId);
-                task.resume();
-            } else {
+            if (!standby.maybeResumeSuspendedTask(taskId, partitions)) {
                 newStandbyTasks.put(taskId, partitions);
             }
 
-            updateStandByTasks(checkpointedOffsets, taskId, partitions, task);
+        }
+
+        if (newStandbyTasks.isEmpty()) {
+            return;
         }
 
         // create all newly assigned standby tasks (guard against race condition with other thread via backoff and retry)
         // -> other thread will call removeSuspendedStandbyTasks(); eventually
         log.trace("{} New standby tasks to be created: {}", logPrefix, newStandbyTasks);
-        if (!newStandbyTasks.isEmpty()) {
-            final Map<Task, Set<TopicPartition>> createdStandbyTasks = standbyTaskCreator.retryWithBackoff(consumer, newStandbyTasks, start);
-            for (Map.Entry<Task, Set<TopicPartition>> entry : createdStandbyTasks.entrySet()) {
-                final Task task = entry.getKey();
-                updateStandByTasks(checkpointedOffsets, task.id(), entry.getValue(), task);
-            }
-        }
-
-        restoreConsumer.assign(checkpointedOffsets.keySet());
 
-        for (final Map.Entry<TopicPartition, Long> entry : checkpointedOffsets.entrySet()) {
-            final TopicPartition partition = entry.getKey();
-            final long offset = entry.getValue();
-            if (offset >= 0) {
-                restoreConsumer.seek(partition, offset);
-            } else {
-                restoreConsumer.seekToBeginning(singleton(partition));
-            }
-        }
-    }
-
-    private void updateStandByTasks(final Map<TopicPartition, Long> checkpointedOffsets,
-                                    final TaskId taskId,
-                                    final Set<TopicPartition> partitions,
-                                    final Task task) {
-        if (task != null) {
-            standbyTasks.put(taskId, task);
-            for (final TopicPartition partition : partitions) {
-                standbyTasksByPartition.put(partition, task);
-            }
-            // collect checked pointed offsets to position the restore consumer
-            // this include all partitions from which we restore states
-            for (final TopicPartition partition : task.checkpointedOffsets().keySet()) {
-                standbyTasksByPartition.put(partition, task);
-            }
-            checkpointedOffsets.putAll(task.checkpointedOffsets());
-        }
-    }
-
-    List<Task> allTasks() {
-        final List<Task> tasks = activeAndStandbytasks();
-        tasks.addAll(suspendedAndSuspendedStandbytasks());
-        return tasks;
-    }
-
-    private List<Task> activeAndStandbytasks() {
-        final List<Task> tasks = new ArrayList<>(activeTasks.values());
-        tasks.addAll(standbyTasks.values());
-        return tasks;
-    }
-
-    private List<Task> suspendedAndSuspendedStandbytasks() {
-        final List<Task> tasks = new ArrayList<>(suspendedTasks.values());
-        tasks.addAll(suspendedStandbyTasks.values());
-        return tasks;
-    }
-
-    private Task findMatchingSuspendedTask(final TaskId taskId, final Set<TopicPartition> partitions) {
-        if (suspendedTasks.containsKey(taskId)) {
-            final Task task = suspendedTasks.get(taskId);
-            if (task.partitions().equals(partitions)) {
-                return task;
-            }
+        for (final Task task : standbyTaskCreator.createTasks(consumer, newStandbyTasks)) {
+            standby.addNewTask(task);
         }
-        return null;
-    }
-
-    private Task findMatchingSuspendedStandbyTask(final TaskId taskId, final Set<TopicPartition> partitions) {
-        if (suspendedStandbyTasks.containsKey(taskId)) {
-            final Task task = suspendedStandbyTasks.get(taskId);
-            if (task.partitions().equals(partitions)) {
-                return task;
-            }
-        }
-        return null;
     }
 
     Set<TaskId> activeTaskIds() {
-        return Collections.unmodifiableSet(activeTasks.keySet());
+        return active.allAssignedTaskIds();
     }
 
     Set<TaskId> standbyTaskIds() {
-        return Collections.unmodifiableSet(standbyTasks.keySet());
+        return standby.allAssignedTaskIds();
     }
 
     Set<TaskId> prevActiveTaskIds() {
-        return Collections.unmodifiableSet(prevActiveTasks);
+        return active.previousTaskIds();
     }
 
     /**
@@ -307,58 +175,15 @@ class TaskManager {
      */
     void suspendTasksAndState()  {
         log.debug("{} Suspending all active tasks {} and standby tasks {}",
-                  logPrefix, activeTasks.keySet(), standbyTasks.keySet());
+                  logPrefix, active.runningTaskIds(), standby.runningTaskIds());
 
         final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null);
 
-        firstException.compareAndSet(null, performOnActiveTasks(new TaskAction() {
-            @Override
-            public String name() {
-                return "suspend";
-            }
-
-            @Override
-            public void apply(final Task task) {
-                try {
-                    task.suspend();
-                } catch (final CommitFailedException e) {
-                    // commit failed during suspension. Just log it.
-                    log.warn("{} Failed to commit task {} state when suspending due to CommitFailedException", logPrefix, task.id());
-                } catch (final Exception e) {
-                    log.error("{} Suspending task {} failed due to the following error:", logPrefix, task.id(), e);
-                    try {
-                        task.close(false);
-                    } catch (final Exception f) {
-                        log.error("{} After suspending failed, closing the same task {} failed again due to the following error:", logPrefix, task.id(), f);
-                    }
-                    throw e;
-                }
-            }
-        }));
-
-        for (final Task task : standbyTasks.values()) {
-            try {
-                try {
-                    task.suspend();
-                } catch (final Exception e) {
-                    log.error("{} Suspending standby task {} failed due to the following error:", logPrefix, task.id(), e);
-                    try {
-                        task.close(false);
-                    } catch (final Exception f) {
-                        log.error("{} After suspending failed, closing the same standby task {} failed again due to the following error:", logPrefix, task.id(), f);
-                    }
-                    throw e;
-                }
-            } catch (final RuntimeException e) {
-                firstException.compareAndSet(null, e);
-            }
-        }
-
+        firstException.compareAndSet(null, active.suspend());
+        firstException.compareAndSet(null, standby.suspend());
         // remove the changelog partitions from restore consumer
         firstException.compareAndSet(null, unAssignChangeLogPartitions());
 
-        updateSuspendedTasks();
-
         if (firstException.get() != null) {
             throw new StreamsException(logPrefix + " failed to suspend stream tasks", firstException.get());
         }
@@ -375,88 +200,13 @@ class TaskManager {
         return null;
     }
 
-    private void updateSuspendedTasks() {
-        suspendedTasks.clear();
-        suspendedTasks.putAll(activeTasks);
-        suspendedStandbyTasks.putAll(standbyTasks);
-    }
-
-    private void removeStreamTasks() {
-        log.debug("{} Removing all active tasks {}", logPrefix, activeTasks.keySet());
-
-        try {
-            prevActiveTasks.clear();
-            prevActiveTasks.addAll(activeTasks.keySet());
-
-            activeTasks.clear();
-            activeTasksByPartition.clear();
-        } catch (final Exception e) {
-            log.error("{} Failed to remove stream tasks due to the following error:", logPrefix, e);
-        }
-    }
-
-    void closeZombieTask(final Task task) {
-        log.warn("{} Producer of task {} fenced; closing zombie task", logPrefix, task.id());
-        try {
-            task.close(false);
-        } catch (final Exception e) {
-            log.warn("{} Failed to close zombie task due to {}, ignore and proceed", logPrefix, e);
-        }
-        activeTasks.remove(task.id());
-    }
-
-
-    RuntimeException performOnActiveTasks(final TaskAction action) {
-        return performOnTasks(action, activeTasks, "stream task");
-    }
-
-    RuntimeException performOnStandbyTasks(final TaskAction action) {
-        return performOnTasks(action, standbyTasks, "standby task");
-    }
-
-    private RuntimeException performOnTasks(final TaskAction action, final Map<TaskId, Task> tasks, final String taskType) {
-        RuntimeException firstException = null;
-        final Iterator<Map.Entry<TaskId, Task>> it = tasks.entrySet().iterator();
-        while (it.hasNext()) {
-            final Task task = it.next().getValue();
-            try {
-                action.apply(task);
-            } catch (final ProducerFencedException e) {
-                closeZombieTask(task);
-                it.remove();
-            } catch (final RuntimeException t) {
-                log.error("{} Failed to {} " + taskType + " {} due to the following error:",
-                          logPrefix,
-                          action.name(),
-                          task.id(),
-                          t);
-                if (firstException == null) {
-                    firstException = t;
-                }
-            }
-        }
-
-        return firstException;
-    }
-
-
-
     void shutdown(final boolean clean) {
         log.debug("{} Shutting down all active tasks {}, standby tasks {}, suspended tasks {}, and suspended standby tasks {}",
-                  logPrefix, activeTasks.keySet(), standbyTasks.keySet(),
-                  suspendedTasks.keySet(), suspendedStandbyTasks.keySet());
-
-        for (final Task task : allTasks()) {
-            try {
-                task.close(clean);
-            } catch (final RuntimeException e) {
-                log.error("{} Failed while closing {} {} due to the following error:",
-                          logPrefix,
-                          task.getClass().getSimpleName(),
-                          task.id(),
-                          e);
-            }
-        }
+                  logPrefix, active.runningTaskIds(), standby.runningTaskIds(),
+                  active.previousTaskIds(), standby.previousTaskIds());
+
+        active.close(clean);
+        standby.close(clean);
         try {
             threadMetadataProvider.close();
         } catch (final Throwable e) {
@@ -464,61 +214,104 @@ class TaskManager {
         }
         // remove the changelog partitions from restore consumer
         unAssignChangeLogPartitions();
+        taskCreator.close();
 
     }
 
     Set<TaskId> suspendedActiveTaskIds() {
-        return Collections.unmodifiableSet(suspendedTasks.keySet());
+        return active.previousTaskIds();
     }
 
     Set<TaskId> suspendedStandbyTaskIds() {
-        return Collections.unmodifiableSet(suspendedStandbyTasks.keySet());
+        return standby.previousTaskIds();
     }
 
-    void removeTasks() {
-        removeStreamTasks();
-        removeStandbyTasks();
+    Task activeTask(final TopicPartition partition) {
+        return active.runningTaskFor(partition);
     }
 
-    private void removeStandbyTasks() {
-        log.debug("{} Removing all standby tasks {}", logPrefix, standbyTasks.keySet());
-        standbyTasks.clear();
-        standbyTasksByPartition.clear();
+
+    Task standbyTask(final TopicPartition partition) {
+        return standby.runningTaskFor(partition);
     }
 
-    Task activeTask(final TopicPartition partition) {
-        return activeTasksByPartition.get(partition);
+    Map<TaskId, Task> activeTasks() {
+        return active.runningTaskMap();
     }
 
-    boolean hasStandbyTasks() {
-        return !standbyTasks.isEmpty();
+    void setConsumer(final Consumer<byte[], byte[]> consumer) {
+        this.consumer = consumer;
     }
 
-    Task standbyTask(final TopicPartition partition) {
-        return standbyTasksByPartition.get(partition);
+
+    boolean updateNewAndRestoringTasks() {
+        active.initializeNewTasks();
+        standby.initializeNewTasks();
+
+        final Collection<TopicPartition> restored = changelogReader.restore();
+        final Set<TopicPartition> resumed = active.updateRestored(restored);
+
+        if (!resumed.isEmpty()) {
+            log.trace("{} resuming partitions {}", logPrefix, resumed);
+            consumer.resume(resumed);
+        }
+        if (active.allTasksRunning()) {
+            assignStandbyPartitions();
+            return true;
+        }
+        return false;
     }
 
-    public Map<TaskId, Task> activeTasks() {
-        return activeTasks;
+    boolean hasActiveRunningTasks() {
+        return active.hasRunningTasks();
     }
 
-    boolean hasActiveTasks() {
-        return !activeTasks.isEmpty();
+    boolean hasStandbyRunningTasks() {
+        return standby.hasRunningTasks();
     }
 
-    void setConsumer(final Consumer<byte[], byte[]> consumer) {
-        this.consumer = consumer;
+    private void assignStandbyPartitions() {
+        final Collection<Task> running = standby.running();
+        final Map<TopicPartition, Long> checkpointedOffsets = new HashMap<>();
+        for (final Task standbyTask : running) {
+            checkpointedOffsets.putAll(standbyTask.checkpointedOffsets());
+        }
+
+        restoreConsumer.assign(checkpointedOffsets.keySet());
+        for (final Map.Entry<TopicPartition, Long> entry : checkpointedOffsets.entrySet()) {
+            final TopicPartition partition = entry.getKey();
+            final long offset = entry.getValue();
+            if (offset >= 0) {
+                restoreConsumer.seek(partition, offset);
+            } else {
+                restoreConsumer.seekToBeginning(singleton(partition));
+            }
+        }
     }
 
-    public void closeProducer() {
-        taskCreator.close();
+    int commitAll() {
+        int committed = active.commit();
+        return committed + standby.commit();
     }
 
+    int process() {
+        return active.process();
+    }
 
+    int punctuate() {
+        return active.punctuate();
+    }
 
+    int maybeCommitActiveTasks() {
+        return active.maybeCommit();
+    }
 
-    interface TaskAction {
-        String name();
-        void apply(final Task task);
+    public String toString(final String indent) {
+        final StringBuilder builder = new StringBuilder();
+        builder.append(indent).append("\tActive tasks:\n");
+        builder.append(active.toString(indent + "\t\t"));
+        builder.append(indent).append("\tStandby tasks:\n");
+        builder.append(standby.toString(indent + "\t\t"));
+        return builder.toString();
     }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/b78d7ba5/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
index c6accde..1a50d46 100644
--- a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
@@ -478,7 +478,7 @@ public class KafkaStreamsTest {
         CLUSTER.createTopic(topic);
         final StreamsBuilder builder = new StreamsBuilder();
 
-        builder.stream(Serdes.String(), Serdes.String(), topic);
+        builder.table(Serdes.String(), Serdes.String(), topic, topic);
 
         final KafkaStreams streams = new KafkaStreams(builder.build(), props);
         final CountDownLatch latch = new CountDownLatch(1);

http://git-wip-us.apache.org/repos/asf/kafka/blob/b78d7ba5/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationTest.java
index 1a7c3a0..99a524e 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationTest.java
@@ -133,7 +133,7 @@ public class ResetIntegrationTest {
     public void testReprocessingFromScratchAfterResetWithIntermediateUserTopic() throws Exception {
         CLUSTER.createTopic(INTERMEDIATE_USER_TOPIC);
 
-        final Properties streamsConfiguration = prepareTest();
+        final Properties streamsConfiguration = prepareTest(4);
         final Properties resultTopicConsumerConfig = TestUtils.consumerConfig(
             CLUSTER.bootstrapServers(),
             APP_ID + "-standard-consumer-" + OUTPUT_TOPIC,
@@ -199,7 +199,7 @@ public class ResetIntegrationTest {
 
     @Test
     public void testReprocessingFromScratchAfterResetWithoutIntermediateUserTopic() throws Exception {
-        final Properties streamsConfiguration = prepareTest();
+        final Properties streamsConfiguration = prepareTest(1);
         final Properties resultTopicConsumerConfig = TestUtils.consumerConfig(
                 CLUSTER.bootstrapServers(),
                 APP_ID + "-standard-consumer-" + OUTPUT_TOPIC,
@@ -242,14 +242,14 @@ public class ResetIntegrationTest {
         cleanGlobal(null);
     }
 
-    private Properties prepareTest() throws Exception {
+    private Properties prepareTest(final int threads) throws Exception {
         final Properties streamsConfiguration = new Properties();
         streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, APP_ID + testNo);
         streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
         streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath());
         streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Long().getClass());
         streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass());
-        streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 4);
+        streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, threads);
         streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0);
         streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100);
         streamsConfiguration.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 100);

http://git-wip-us.apache.org/repos/asf/kafka/blob/b78d7ba5/streams/src/test/java/org/apache/kafka/streams/integration/RestoreIntegrationTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/RestoreIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/RestoreIntegrationTest.java
new file mode 100644
index 0000000..69c42fe
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/RestoreIntegrationTest.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.clients.consumer.KafkaConsumer;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.clients.producer.KafkaProducer;
+import org.apache.kafka.clients.producer.ProducerConfig;
+import org.apache.kafka.clients.producer.ProducerRecord;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.serialization.IntegerDeserializer;
+import org.apache.kafka.common.serialization.IntegerSerializer;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.StreamsBuilder;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.kstream.ForeachAction;
+import org.apache.kafka.streams.processor.StateRestoreListener;
+import org.apache.kafka.test.IntegrationTest;
+import org.apache.kafka.test.TestUtils;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.core.IsEqual.equalTo;
+import static org.junit.Assert.assertTrue;
+
+@Category({IntegrationTest.class})
+public class RestoreIntegrationTest {
+    private static final int NUM_BROKERS = 1;
+
+    @ClassRule
+    public static final EmbeddedKafkaCluster CLUSTER =
+            new EmbeddedKafkaCluster(NUM_BROKERS);
+    private final String inputStream = "input-stream";
+    private final int numberOfKeys = 10000;
+    private KafkaStreams kafkaStreams;
+    private String applicationId = "restore-test";
+
+
+    private void createTopics() throws InterruptedException {
+        CLUSTER.createTopic(inputStream, 2, 1);
+    }
+
+    @Before
+    public void before() throws IOException, InterruptedException {
+        createTopics();
+    }
+
+    private Properties props() {
+        Properties streamsConfiguration = new Properties();
+        streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, applicationId);
+        streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
+        streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest");
+        streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory(applicationId).getPath());
+        streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass());
+        streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Integer().getClass());
+        streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000);
+        return streamsConfiguration;
+    }
+
+    @After
+    public void shutdown() throws IOException {
+        if (kafkaStreams != null) {
+            kafkaStreams.close(30, TimeUnit.SECONDS);
+        }
+    }
+
+
+    @Test
+    public void shouldRestoreState() throws ExecutionException, InterruptedException {
+        final AtomicInteger numReceived = new AtomicInteger(0);
+        final StreamsBuilder builder = new StreamsBuilder();
+
+        createStateForRestoration();
+
+        builder.table(Serdes.Integer(), Serdes.Integer(), inputStream, "store")
+                .toStream()
+                .foreach(new ForeachAction<Integer, Integer>() {
+                    @Override
+                    public void apply(final Integer key, final Integer value) {
+                        numReceived.incrementAndGet();
+                    }
+                });
+
+
+        final CountDownLatch startupLatch = new CountDownLatch(1);
+        kafkaStreams = new KafkaStreams(builder.build(), props());
+        kafkaStreams.setStateListener(new KafkaStreams.StateListener() {
+            @Override
+            public void onChange(final KafkaStreams.State newState, final KafkaStreams.State oldState) {
+                if (newState == KafkaStreams.State.RUNNING && oldState == KafkaStreams.State.REBALANCING) {
+                    startupLatch.countDown();
+                }
+            }
+        });
+
+        final AtomicLong restored = new AtomicLong(0);
+        kafkaStreams.setGlobalStateRestoreListener(new StateRestoreListener() {
+            @Override
+            public void onRestoreStart(final TopicPartition topicPartition, final String storeName, final long startingOffset, final long endingOffset) {
+
+            }
+
+            @Override
+            public void onBatchRestored(final TopicPartition topicPartition, final String storeName, final long batchEndOffset, final long numRestored) {
+
+            }
+
+            @Override
+            public void onRestoreEnd(final TopicPartition topicPartition, final String storeName, final long totalRestored) {
+                restored.addAndGet(totalRestored);
+            }
+        });
+        kafkaStreams.start();
+
+        assertTrue(startupLatch.await(30, TimeUnit.SECONDS));
+        assertThat(restored.get(), equalTo((long) numberOfKeys));
+        assertThat(numReceived.get(), equalTo(0));
+    }
+
+
+    private void createStateForRestoration()
+            throws ExecutionException, InterruptedException {
+        final Properties producerConfig = new Properties();
+        producerConfig.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
+
+        try (final KafkaProducer<Integer, Integer> producer =
+                     new KafkaProducer<>(producerConfig, new IntegerSerializer(), new IntegerSerializer())) {
+
+            for (int i = 0; i < numberOfKeys; i++) {
+                producer.send(new ProducerRecord<>(inputStream, i, i));
+            }
+        }
+
+        final Properties consumerConfig = new Properties();
+        consumerConfig.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
+        consumerConfig.put(ConsumerConfig.GROUP_ID_CONFIG, applicationId);
+        consumerConfig.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, IntegerDeserializer.class);
+        consumerConfig.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, IntegerDeserializer.class);
+
+        final Consumer consumer = new KafkaConsumer(consumerConfig);
+        final List<TopicPartition> partitions = Arrays.asList(new TopicPartition(inputStream, 0),
+                                                              new TopicPartition(inputStream, 1));
+
+        consumer.assign(partitions);
+        consumer.seekToEnd(partitions);
+
+        final Map<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>();
+        for (TopicPartition partition : partitions) {
+            final long position = consumer.position(partition);
+            offsets.put(partition, new OffsetAndMetadata(position + 1));
+        }
+
+        consumer.commitSync(offsets);
+        consumer.close();
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/b78d7ba5/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
index 353f740..d6709b8 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
@@ -25,63 +25,106 @@ import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.AuthorizationException;
 import org.apache.kafka.common.errors.WakeupException;
-import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.test.MockStateRestoreListener;
 import org.apache.kafka.test.TestUtils;
+import org.easymock.EasyMock;
+import org.junit.Before;
 import org.junit.Test;
 
+import java.io.IOException;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
 
+import static org.junit.Assert.fail;
+
 public class AbstractTaskTest {
 
+    private final TaskId id = new TaskId(0, 0);
+    private StateDirectory stateDirectory  = EasyMock.createMock(StateDirectory.class);
+
+    @Before
+    public void before() {
+        EasyMock.expect(stateDirectory.directoryForTask(id)).andReturn(TestUtils.tempDirectory());
+    }
+
     @Test(expected = ProcessorStateException.class)
     public void shouldThrowProcessorStateExceptionOnInitializeOffsetsWhenAuthorizationException() throws Exception {
         final Consumer consumer = mockConsumer(new AuthorizationException("blah"));
-        final AbstractTask task = createTask(consumer);
+        final AbstractTask task = createTask(consumer, Collections.<StateStore>emptyList());
         task.updateOffsetLimits();
     }
 
     @Test(expected = ProcessorStateException.class)
     public void shouldThrowProcessorStateExceptionOnInitializeOffsetsWhenKafkaException() throws Exception {
         final Consumer consumer = mockConsumer(new KafkaException("blah"));
-        final AbstractTask task = createTask(consumer);
+        final AbstractTask task = createTask(consumer, Collections.<StateStore>emptyList());
         task.updateOffsetLimits();
     }
 
     @Test(expected = WakeupException.class)
     public void shouldThrowWakeupExceptionOnInitializeOffsetsWhenWakeupException() throws Exception {
         final Consumer consumer = mockConsumer(new WakeupException());
-        final AbstractTask task = createTask(consumer);
+        final AbstractTask task = createTask(consumer, Collections.<StateStore>emptyList());
         task.updateOffsetLimits();
     }
 
-    private AbstractTask createTask(final Consumer consumer) {
-        final MockTime time = new MockTime();
+    @Test
+    public void shouldThrowLockExceptionIfFailedToLockStateDirectoryWhenTopologyHasStores() throws IOException {
+        final Consumer consumer = EasyMock.createNiceMock(Consumer.class);
+        final StateStore store = EasyMock.createNiceMock(StateStore.class);
+        EasyMock.expect(stateDirectory.lock(id, 5)).andReturn(false);
+        EasyMock.replay(stateDirectory);
+
+        final AbstractTask task = createTask(consumer, Collections.singletonList(store));
+
+        try {
+            task.initializeStateStores();
+            fail("Should have thrown LockException");
+        } catch (final LockException e) {
+            // ok
+        }
+
+    }
+
+    @Test
+    public void shouldNotAttemptToLockIfNoStores() throws IOException {
+        final Consumer consumer = EasyMock.createNiceMock(Consumer.class);
+        EasyMock.replay(stateDirectory);
+
+        final AbstractTask task = createTask(consumer, Collections.<StateStore>emptyList());
+
+        task.initializeStateStores();
+
+        // should fail if lock is called
+        EasyMock.verify(stateDirectory);
+    }
+
+    private AbstractTask createTask(final Consumer consumer, final List<StateStore> stateStores) {
         final Properties properties = new Properties();
         properties.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-id");
         properties.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummyhost:9092");
         final StreamsConfig config = new StreamsConfig(properties);
-        return new AbstractTask(new TaskId(0, 0),
+        return new AbstractTask(id,
                                 "app",
                                 Collections.singletonList(new TopicPartition("t", 0)),
                                 new ProcessorTopology(Collections.<ProcessorNode>emptyList(),
                                                       Collections.<String, SourceNode>emptyMap(),
                                                       Collections.<String, SinkNode>emptyMap(),
-                                                      Collections.<StateStore>emptyList(),
+                                                      stateStores,
                                                       Collections.<String, String>emptyMap(),
                                                       Collections.<StateStore>emptyList()),
                                 consumer,
                                 new StoreChangelogReader(consumer, Time.SYSTEM, 5000, new MockStateRestoreListener()),
                                 false,
-                                new StateDirectory("app", TestUtils.tempDirectory().getPath(), time),
+                                stateDirectory,
                                 config) {
             @Override
             public void resume() {}
@@ -111,6 +154,11 @@ public class AbstractTaskTest {
             }
 
             @Override
+            public boolean commitNeeded() {
+                return false;
+            }
+
+            @Override
             public boolean maybePunctuateStreamTime() {
                 return false;
             }
@@ -131,7 +179,7 @@ public class AbstractTaskTest {
             }
 
             @Override
-            public boolean commitNeeded() {
+            public boolean initialize() {
                 return false;
             }
         };

http://git-wip-us.apache.org/repos/asf/kafka/blob/b78d7ba5/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java
new file mode 100644
index 0000000..de2a489
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java
@@ -0,0 +1,412 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.consumer.CommitFailedException;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.ProducerFencedException;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.streams.processor.TaskId;
+import org.easymock.EasyMock;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Set;
+
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.CoreMatchers.nullValue;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.core.IsEqual.equalTo;
+import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+public class AssignedTasksTest {
+
+    private final Task t1 = EasyMock.createMock(Task.class);
+    private final Task t2 = EasyMock.createMock(Task.class);
+    private final TopicPartition tp1 = new TopicPartition("t1", 0);
+    private final TopicPartition tp2 = new TopicPartition("t2", 0);
+    private final TopicPartition changeLog1 = new TopicPartition("cl1", 0);
+    private final TopicPartition changeLog2 = new TopicPartition("cl2", 0);
+    private final TaskId taskId1 = new TaskId(0, 0);
+    private final TaskId taskId2 = new TaskId(1, 0);
+    private AssignedTasks assignedTasks;
+
+    @Before
+    public void before() {
+        assignedTasks = new AssignedTasks("log", "task");
+        EasyMock.expect(t1.id()).andReturn(taskId1).anyTimes();
+        EasyMock.expect(t2.id()).andReturn(taskId2).anyTimes();
+    }
+
+    @Test
+    public void shouldGetPartitionsFromNewTasksThatHaveStateStores() {
+        EasyMock.expect(t1.hasStateStores()).andReturn(true);
+        EasyMock.expect(t2.hasStateStores()).andReturn(true);
+        EasyMock.expect(t1.partitions()).andReturn(Collections.singleton(tp1));
+        EasyMock.expect(t2.partitions()).andReturn(Collections.singleton(tp2));
+        EasyMock.replay(t1, t2);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.addNewTask(t2);
+
+        final Set<TopicPartition> partitions = assignedTasks.uninitializedPartitions();
+        assertThat(partitions, equalTo(Utils.mkSet(tp1, tp2)));
+        EasyMock.verify(t1, t2);
+    }
+
+    @Test
+    public void shouldNotGetPartitionsFromNewTasksWithoutStateStores() {
+        EasyMock.expect(t1.hasStateStores()).andReturn(false);
+        EasyMock.expect(t2.hasStateStores()).andReturn(false);
+        EasyMock.replay(t1, t2);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.addNewTask(t2);
+
+        final Set<TopicPartition> partitions = assignedTasks.uninitializedPartitions();
+        assertTrue(partitions.isEmpty());
+        EasyMock.verify(t1, t2);
+    }
+
+    @Test
+    public void shouldInitializeNewTasks() {
+        EasyMock.expect(t1.initialize()).andReturn(false);
+        EasyMock.replay(t1);
+
+        addAndInitTask();
+
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldMoveInitializedTasksNeedingRestoreToRestoring() {
+        EasyMock.expect(t1.initialize()).andReturn(false);
+        EasyMock.expect(t2.initialize()).andReturn(true);
+        EasyMock.expect(t2.partitions()).andReturn(Collections.singleton(tp2));
+        EasyMock.expect(t2.changelogPartitions()).andReturn(Collections.<TopicPartition>emptyList());
+
+        EasyMock.replay(t1, t2);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.addNewTask(t2);
+
+        assignedTasks.initializeNewTasks();
+
+        Collection<Task> restoring = assignedTasks.restoringTasks();
+        assertThat(restoring.size(), equalTo(1));
+        assertSame(restoring.iterator().next(), t1);
+    }
+
+    @Test
+    public void shouldMoveInitializedTasksThatDontNeedRestoringToRunning() {
+        EasyMock.expect(t2.initialize()).andReturn(true);
+        EasyMock.expect(t2.partitions()).andReturn(Collections.singleton(tp2));
+        EasyMock.expect(t2.changelogPartitions()).andReturn(Collections.<TopicPartition>emptyList());
+
+        EasyMock.replay(t2);
+
+        assignedTasks.addNewTask(t2);
+        assignedTasks.initializeNewTasks();
+
+        assertThat(assignedTasks.runningTaskIds(), equalTo(Collections.singleton(taskId2)));
+    }
+
+    @Test
+    public void shouldTransitionFullyRestoredTasksToRunning() {
+        final Set<TopicPartition> task1Partitions = Utils.mkSet(tp1);
+        EasyMock.expect(t1.initialize()).andReturn(false);
+        EasyMock.expect(t1.partitions()).andReturn(task1Partitions).anyTimes();
+        EasyMock.expect(t1.changelogPartitions()).andReturn(Utils.mkSet(changeLog1, changeLog2)).anyTimes();
+        EasyMock.replay(t1);
+
+        addAndInitTask();
+
+        assertTrue(assignedTasks.updateRestored(Utils.mkSet(changeLog1)).isEmpty());
+        Set<TopicPartition> partitions = assignedTasks.updateRestored(Utils.mkSet(changeLog2));
+        assertThat(partitions, equalTo(task1Partitions));
+        assertThat(assignedTasks.runningTaskIds(), equalTo(Collections.singleton(taskId1)));
+    }
+
+    @Test
+    public void shouldSuspendRunningTasks() {
+        mockRunningTaskSuspension();
+        EasyMock.replay(t1);
+
+        suspendTask();
+
+        assertThat(assignedTasks.previousTaskIds(), equalTo(Collections.singleton(taskId1)));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldCloseRestoringTasks() {
+        EasyMock.expect(t1.initialize()).andReturn(false);
+        t1.close(false);
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+
+        suspendTask();
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldClosedUnInitializedTasksOnSuspend() {
+        t1.close(false);
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+
+        assignedTasks.addNewTask(t1);
+        assignedTasks.suspend();
+
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldNotSuspendSuspendedTasks() {
+        mockRunningTaskSuspension();
+        EasyMock.replay(t1);
+
+        suspendTask();
+        assignedTasks.suspend();
+        EasyMock.verify(t1);
+    }
+
+
+    @Test
+    public void shouldCloseTaskOnSuspendWhenRuntimeException() {
+        mockTaskInitialization();
+        t1.suspend();
+        EasyMock.expectLastCall().andThrow(new RuntimeException("KABOOM!"));
+        t1.close(false);
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+
+        assertThat(suspendTask(), not(nullValue()));
+        assertThat(assignedTasks.previousTaskIds(), equalTo(Collections.singleton(taskId1)));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldCloseTaskOnSuspendWhenProducerFencedException() {
+        mockTaskInitialization();
+        t1.suspend();
+        EasyMock.expectLastCall().andThrow(new ProducerFencedException("KABOOM!"));
+        t1.close(false);
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+
+
+        assertThat(suspendTask(), nullValue());
+        assertTrue(assignedTasks.previousTaskIds().isEmpty());
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldResumeMatchingSuspendedTasks() {
+        mockRunningTaskSuspension();
+        t1.resume();
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+
+        suspendTask();
+
+        assertTrue(assignedTasks.maybeResumeSuspendedTask(taskId1, Collections.singleton(tp1)));
+        assertThat(assignedTasks.runningTaskIds(), equalTo(Collections.singleton(taskId1)));
+        EasyMock.verify(t1);
+    }
+
+
+    private void mockTaskInitialization() {
+        EasyMock.expect(t1.initialize()).andReturn(true);
+        EasyMock.expect(t1.partitions()).andReturn(Collections.singleton(tp1));
+        EasyMock.expect(t1.changelogPartitions()).andReturn(Collections.<TopicPartition>emptyList());
+    }
+
+    @Test
+    public void shouldCommitRunningTasks() {
+        mockTaskInitialization();
+        t1.commit();
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+
+        addAndInitTask();
+
+        assignedTasks.commit();
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldCloseTaskOnCommitIfProduceFencedException() {
+        mockTaskInitialization();
+        t1.commit();
+        EasyMock.expectLastCall().andThrow(new ProducerFencedException(""));
+        t1.close(false);
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+        addAndInitTask();
+
+        assignedTasks.commit();
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldNotThrowCommitFailedExceptionOnCommit() {
+        mockTaskInitialization();
+        t1.commit();
+        EasyMock.expectLastCall().andThrow(new CommitFailedException());
+        EasyMock.replay(t1);
+        addAndInitTask();
+
+        assignedTasks.commit();
+        assertThat(assignedTasks.runningTaskIds(), equalTo(Collections.singleton(taskId1)));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldThrowExceptionOnCommitWhenNotCommitFailedOrProducerFenced() {
+        mockTaskInitialization();
+        t1.commit();
+        EasyMock.expectLastCall().andThrow(new RuntimeException(""));
+        EasyMock.replay(t1);
+        addAndInitTask();
+
+        try {
+            assignedTasks.commit();
+            fail("Should have thrown exception");
+        } catch (Exception e) {
+            // ok
+        }
+        assertThat(assignedTasks.runningTaskIds(), equalTo(Collections.singleton(taskId1)));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldCommitRunningTasksIfNeeded() {
+        mockTaskInitialization();
+        EasyMock.expect(t1.commitNeeded()).andReturn(true);
+        t1.commit();
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+
+        addAndInitTask();
+
+        assertThat(assignedTasks.maybeCommit(), equalTo(1));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldCloseTaskOnMaybeCommitIfProduceFencedException() {
+        mockTaskInitialization();
+        EasyMock.expect(t1.commitNeeded()).andReturn(true);
+        t1.commit();
+        EasyMock.expectLastCall().andThrow(new ProducerFencedException(""));
+        t1.close(false);
+        EasyMock.expectLastCall();
+        EasyMock.replay(t1);
+        addAndInitTask();
+
+        assignedTasks.maybeCommit();
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldNotThrowCommitFailedExceptionOnMaybeCommit() {
+        mockTaskInitialization();
+        EasyMock.expect(t1.commitNeeded()).andReturn(true);
+        t1.commit();
+        EasyMock.expectLastCall().andThrow(new CommitFailedException());
+        EasyMock.replay(t1);
+        addAndInitTask();
+
+        assignedTasks.maybeCommit();
+        assertThat(assignedTasks.runningTaskIds(), equalTo(Collections.singleton(taskId1)));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldThrowExceptionOnMaybeCommitWhenNotCommitFailedOrProducerFenced() {
+        mockTaskInitialization();
+        EasyMock.expect(t1.commitNeeded()).andReturn(true);
+        t1.commit();
+        EasyMock.expectLastCall().andThrow(new RuntimeException(""));
+        EasyMock.replay(t1);
+
+        addAndInitTask();
+
+        try {
+            assignedTasks.maybeCommit();
+            fail("Should have thrown exception");
+        } catch (Exception e) {
+            // ok
+        }
+        assertThat(assignedTasks.runningTaskIds(), equalTo(Collections.singleton(taskId1)));
+        EasyMock.verify(t1);
+    }
+
+
+
+    @Test
+    public void shouldPunctuateRunningTasks() {
+        mockTaskInitialization();
+        EasyMock.expect(t1.maybePunctuateStreamTime()).andReturn(true);
+        EasyMock.expect(t1.maybePunctuateSystemTime()).andReturn(true);
+        EasyMock.replay(t1);
+
+        addAndInitTask();
+
+        assertThat(assignedTasks.punctuate(), equalTo(2));
+        EasyMock.verify(t1);
+    }
+
+    @Test
+    public void shouldReturnNumberOfPunctuations() {
+        mockTaskInitialization();
+        EasyMock.expect(t1.maybePunctuateStreamTime()).andReturn(true);
+        EasyMock.expect(t1.maybePunctuateSystemTime()).andReturn(false);
+        EasyMock.replay(t1);
+
+        addAndInitTask();
+
+        assertThat(assignedTasks.punctuate(), equalTo(1));
+        EasyMock.verify(t1);
+    }
+
+    private void addAndInitTask() {
+        assignedTasks.addNewTask(t1);
+        assignedTasks.initializeNewTasks();
+    }
+
+    private RuntimeException suspendTask() {
+        addAndInitTask();
+        return assignedTasks.suspend();
+    }
+
+    private void mockRunningTaskSuspension() {
+        EasyMock.expect(t1.initialize()).andReturn(true);
+        EasyMock.expect(t1.partitions()).andReturn(Collections.singleton(tp1)).anyTimes();
+        EasyMock.expect(t1.changelogPartitions()).andReturn(Collections.<TopicPartition>emptyList()).anyTimes();
+        t1.suspend();
+        EasyMock.expectLastCall();
+    }
+
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/kafka/blob/b78d7ba5/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
index 369987e..8aedf36 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
@@ -22,7 +22,6 @@ import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.KeyValue;
-import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
@@ -36,10 +35,7 @@ import org.junit.Test;
 
 import java.io.File;
 import java.io.IOException;
-import java.nio.channels.FileChannel;
-import java.nio.channels.FileLock;
 import java.nio.charset.Charset;
-import java.nio.file.StandardOpenOption;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
@@ -433,33 +429,6 @@ public class ProcessorStateManagerTest {
         assertThat(read, equalTo(Collections.<TopicPartition, Long>emptyMap()));
     }
 
-    @Test
-    public void shouldThrowLockExceptionIfFailedToLockStateDirectory() throws Exception {
-        final File taskDirectory = stateDirectory.directoryForTask(taskId);
-        final FileChannel channel = FileChannel.open(new File(taskDirectory,
-                                                              StateDirectory.LOCK_FILE_NAME).toPath(),
-                                                     StandardOpenOption.CREATE,
-                                                     StandardOpenOption.WRITE);
-        // lock the task directory
-        final FileLock lock = channel.lock();
-
-        try {
-            new ProcessorStateManager(
-                taskId,
-                noPartitions,
-                false,
-                stateDirectory,
-                Collections.<String, String>emptyMap(),
-                changelogReader,
-                false);
-            fail("Should have thrown LockException");
-        } catch (final LockException e) {
-           // pass
-        } finally {
-            lock.release();
-            channel.close();
-        }
-    }
 
     @Test
     public void shouldThrowIllegalArgumentExceptionIfStoreNameIsSameAsCheckpointFileName() throws Exception {

http://git-wip-us.apache.org/repos/asf/kafka/blob/b78d7ba5/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
index e232316..f22e773 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
@@ -165,7 +165,7 @@ public class StandbyTaskTest {
     public void testStorePartitions() throws Exception {
         StreamsConfig config = createConfig(baseDir);
         StandbyTask task = new StandbyTask(taskId, applicationId, topicPartitions, topology, consumer, changelogReader, config, null, stateDirectory);
-
+        task.initialize();
         assertEquals(Utils.mkSet(partition2), new HashSet<>(task.checkpointedOffsets().keySet()));
 
     }
@@ -188,7 +188,7 @@ public class StandbyTaskTest {
     public void testUpdate() throws Exception {
         StreamsConfig config = createConfig(baseDir);
         StandbyTask task = new StandbyTask(taskId, applicationId, topicPartitions, topology, consumer, changelogReader, config, null, stateDirectory);
-
+        task.initialize();
         restoreStateConsumer.assign(new ArrayList<>(task.checkpointedOffsets().keySet()));
 
         for (ConsumerRecord<Integer, Integer> record : Arrays.asList(
@@ -245,7 +245,7 @@ public class StandbyTaskTest {
 
         StreamsConfig config = createConfig(baseDir);
         StandbyTask task = new StandbyTask(taskId, applicationId, ktablePartitions, ktableTopology, consumer, changelogReader, config, null, stateDirectory);
-
+        task.initialize();
         restoreStateConsumer.assign(new ArrayList<>(task.checkpointedOffsets().keySet()));
 
         for (ConsumerRecord<Integer, Integer> record : Arrays.asList(
@@ -367,6 +367,7 @@ public class StandbyTaskTest {
                                                  null,
                                                  stateDirectory
         );
+        task.initialize();
 
 
         restoreStateConsumer.assign(new ArrayList<>(task.checkpointedOffsets().keySet()));
@@ -419,7 +420,7 @@ public class StandbyTaskTest {
                 closedStateManager.set(true);
             }
         };
-
+        task.initialize();
         try {
             task.close(true);
             fail("should have thrown exception");

http://git-wip-us.apache.org/repos/asf/kafka/blob/b78d7ba5/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index 648a15d..a9d3cac 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -125,6 +125,7 @@ public class StreamTaskTest {
     private final MockTime time = new MockTime();
     private File baseDir = TestUtils.tempDirectory();
     private StateDirectory stateDirectory;
+    private final RecordCollectorImpl recordCollector = new RecordCollectorImpl(producer, "taskId");
     private StreamsConfig config;
     private StreamsConfig eosConfig;
     private StreamTask task;
@@ -164,6 +165,7 @@ public class StreamTaskTest {
         stateDirectory = new StateDirectory("applicationId", baseDir.getPath(), new MockTime());
         task = new StreamTask(taskId00, applicationId, partitions, topology, consumer,
                               changelogReader, config, streamsMetrics, stateDirectory, null, time, producer);
+        task.initialize();
     }
 
     @After
@@ -456,6 +458,7 @@ public class StreamTaskTest {
 
         task = new StreamTask(taskId00, applicationId, partitions, topology, consumer, changelogReader, config,
             streamsMetrics, stateDirectory, null, time, producer);
+        task.initialize();
         final int offset = 20;
         task.addRecords(partition1, Collections.singletonList(
                 new ConsumerRecord<>(partition1.topic(), partition1.partition(), offset, 0L, TimestampType.CREATE_TIME, 0L, 0, 0, recordKey, recordValue)));
@@ -613,6 +616,7 @@ public class StreamTaskTest {
                 };
             }
         };
+        streamTask.initialize();
 
         time.sleep(config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG));
 
@@ -739,6 +743,7 @@ public class StreamTaskTest {
     public void shouldThrowExceptionIfAnyExceptionsRaisedDuringCloseButStillCloseAllProcessorNodesTopology() throws Exception {
         task.close(true);
         task = createTaskThatThrowsExceptionOnClose();
+        task.initialize();
         try {
             task.close(true);
             fail("should have thrown runtime exception");
@@ -955,7 +960,7 @@ public class StreamTaskTest {
         final StreamTask task = new StreamTask(taskId00, applicationId, Utils.mkSet(partition1), topology, consumer,
                                                changelogReader, eosConfig, streamsMetrics, stateDirectory, null, time, producer);
 
-
+        task.initialize();
         try {
             task.suspend();
             fail("should have thrown an exception");
@@ -992,6 +997,16 @@ public class StreamTaskTest {
         assertTrue(stateManagerCloseCalled.get());
     }
 
+    @Test
+    public void shouldNotCloseTopologyProcessorNodesIfNotInitialized() {
+        final StreamTask task = createTaskThatThrowsExceptionOnClose();
+        try {
+            task.close(true);
+        } catch (Exception e) {
+            fail("should have not closed unitialized topology");
+        }
+    }
+
     @SuppressWarnings("unchecked")
     private StreamTask createTaskThatThrowsExceptionOnClose() {
         final MockSourceNode processorNode = new MockSourceNode(topic1, intDeserializer, intDeserializer) {


Mime
View raw message