cassandra-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sn...@apache.org
Subject cassandra git commit: Push notification when tracing completes for an operation
Date Tue, 31 Mar 2015 08:41:27 GMT
Repository: cassandra
Updated Branches:
  refs/heads/trunk 082dedf97 -> f6217ae19


Push notification when tracing completes for an operation

Patch by Robert Stupp; Reviewed by Stefania for CASSANDRA-7807


Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/f6217ae1
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/f6217ae1
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/f6217ae1

Branch: refs/heads/trunk
Commit: f6217ae198861d95225f6201dceae679b3304cc0
Parents: 082dedf
Author: Robert Stupp <snazy@snazy.de>
Authored: Tue Mar 31 10:40:34 2015 +0200
Committer: Robert Stupp <snazy@snazy.de>
Committed: Tue Mar 31 10:40:34 2015 +0200

----------------------------------------------------------------------
 CHANGES.txt                                     |   1 +
 doc/native_protocol_v4.spec                     |   4 +
 .../apache/cassandra/service/QueryState.java    |  14 +-
 .../apache/cassandra/tracing/TraceState.java    |  82 +++++++-
 .../org/apache/cassandra/tracing/Tracing.java   |  56 ++---
 .../org/apache/cassandra/transport/Client.java  |  25 ++-
 .../apache/cassandra/transport/Connection.java  |  14 +-
 .../org/apache/cassandra/transport/Event.java   |  76 ++++++-
 .../org/apache/cassandra/transport/Server.java  |   5 +-
 .../cassandra/transport/ServerConnection.java   |   2 +-
 .../cassandra/transport/SimpleClient.java       |  63 +++++-
 .../transport/messages/BatchMessage.java        |   2 +-
 .../transport/messages/ExecuteMessage.java      |   2 +-
 .../transport/messages/PrepareMessage.java      |   2 +-
 .../transport/messages/QueryMessage.java        |   2 +-
 .../transport/messages/RegisterMessage.java     |  10 +-
 .../org/apache/cassandra/cql3/CQLTester.java    |   7 +
 .../cassandra/tracing/TraceCompleteTest.java    | 204 +++++++++++++++++++
 .../cassandra/transport/MessagePayloadTest.java | 135 ++++++++++--
 19 files changed, 602 insertions(+), 104 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index b600312..c43b688 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 3.0
+ * Push notification when tracing completes for an operation (CASSANDRA-7807)
  * Delay "node up" and "node added" notifications until native protocol server is started (CASSANDRA-8236)
  * Compressed Commit Log (CASSANDRA-6809)
  * Optimise IntervalTree (CASSANDRA-8988)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/doc/native_protocol_v4.spec
----------------------------------------------------------------------
diff --git a/doc/native_protocol_v4.spec b/doc/native_protocol_v4.spec
index ed089ae..ac57749 100644
--- a/doc/native_protocol_v4.spec
+++ b/doc/native_protocol_v4.spec
@@ -742,6 +742,9 @@ Table of Contents
             - [string] keyspace containing the user defined function / aggregate
             - [string] the function/aggregate name
             - [string list] one string for each argument type (as CQL type)
+    - "TRACE_COMPLETE": notification that a trace session has completed at least
+      on the coordinator. After the event type, the rest of the message will
+      contain the trace session-ID [uuid] as the only argument.
 
   All EVENT messages have a streamId of -1 (Section 2.3).
 
@@ -1125,3 +1128,4 @@ Table of Contents
   * Read_failure error code was added.
   * Function_failure error code was added.
   * Add custom payload to frames for custom QueryHandler implementations (ignored by Cassandra's standard QueryHandler)
+  * Add "TRACE_COMPLETE" event (section 4.2.6).

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/src/java/org/apache/cassandra/service/QueryState.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/service/QueryState.java b/src/java/org/apache/cassandra/service/QueryState.java
index af31f47..5e89ac8 100644
--- a/src/java/org/apache/cassandra/service/QueryState.java
+++ b/src/java/org/apache/cassandra/service/QueryState.java
@@ -22,6 +22,7 @@ import java.util.UUID;
 import java.util.concurrent.ThreadLocalRandom;
 
 import org.apache.cassandra.tracing.Tracing;
+import org.apache.cassandra.transport.Connection;
 
 /**
  * Represents the state related to a given query.
@@ -76,15 +77,20 @@ public class QueryState
 
     public void createTracingSession()
     {
-        if (this.preparedTracingSession == null)
+        createTracingSession(null);
+    }
+
+    public void createTracingSession(Connection connection)
+    {
+        UUID session = this.preparedTracingSession;
+        if (session == null)
         {
-            Tracing.instance.newSession();
+            Tracing.instance.newSession(connection);
         }
         else
         {
-            UUID session = this.preparedTracingSession;
+            Tracing.instance.newSession(connection, session);
             this.preparedTracingSession = null;
-            Tracing.instance.newSession(session);
         }
     }
 

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/src/java/org/apache/cassandra/tracing/TraceState.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/tracing/TraceState.java b/src/java/org/apache/cassandra/tracing/TraceState.java
index 758dceb..c029ac7 100644
--- a/src/java/org/apache/cassandra/tracing/TraceState.java
+++ b/src/java/org/apache/cassandra/tracing/TraceState.java
@@ -19,9 +19,10 @@ package org.apache.cassandra.tracing;
 
 import java.net.InetAddress;
 import java.nio.ByteBuffer;
-import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.UUID;
+import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 
@@ -30,6 +31,12 @@ import org.slf4j.helpers.MessageFormatter;
 
 import org.apache.cassandra.concurrent.Stage;
 import org.apache.cassandra.concurrent.StageManager;
+import org.apache.cassandra.db.ConsistencyLevel;
+import org.apache.cassandra.db.Mutation;
+import org.apache.cassandra.exceptions.OverloadedException;
+import org.apache.cassandra.service.StorageProxy;
+import org.apache.cassandra.transport.Connection;
+import org.apache.cassandra.transport.Event;
 import org.apache.cassandra.utils.ByteBufferUtil;
 import org.apache.cassandra.utils.WrappedRunnable;
 import org.apache.cassandra.utils.progress.ProgressEvent;
@@ -50,9 +57,13 @@ public class TraceState implements ProgressEventNotifier
     public final int ttl;
 
     private boolean notify;
-    private List<ProgressListener> listeners = new ArrayList<>();
+    private final List<ProgressListener> listeners = new CopyOnWriteArrayList<>();
     private String tag;
 
+    private final boolean withFinishEvent;
+    private final AtomicInteger pendingMutations = new AtomicInteger();
+    private final Connection connection;
+
     public enum Status
     {
         IDLE,
@@ -60,29 +71,31 @@ public class TraceState implements ProgressEventNotifier
         STOPPED
     }
 
-    private Status status;
+    private volatile Status status;
 
     // Multiple requests can use the same TraceState at a time, so we need to reference count.
     // See CASSANDRA-7626 for more details.
     private final AtomicInteger references = new AtomicInteger(1);
 
-    public TraceState(InetAddress coordinator, UUID sessionId)
+    public TraceState(InetAddress coordinator, UUID sessionId, Tracing.TraceType traceType)
     {
-        this(coordinator, sessionId, Tracing.TraceType.QUERY);
+        this(coordinator, null, sessionId, traceType, false);
     }
 
-    public TraceState(InetAddress coordinator, UUID sessionId, Tracing.TraceType traceType)
+    public TraceState(InetAddress coordinator, Connection connection, UUID sessionId, Tracing.TraceType traceType, boolean withFinishEvent)
     {
         assert coordinator != null;
         assert sessionId != null;
 
         this.coordinator = coordinator;
+        this.connection = connection;
         this.sessionId = sessionId;
         sessionIdBytes = ByteBufferUtil.bytes(sessionId);
         this.traceType = traceType;
         this.ttl = traceType.getTTL();
         watch = Stopwatch.createStarted();
         this.status = Status.IDLE;
+        this.withFinishEvent = withFinishEvent;
     }
 
     /**
@@ -121,6 +134,19 @@ public class TraceState implements ProgressEventNotifier
     {
         status = Status.STOPPED;
         notifyAll();
+        pushEventIfStopped();
+    }
+
+    private void pushEventIfStopped()
+    {
+        if (status == Status.STOPPED && pendingMutations.get() == 0)
+        {
+            // poor-man's prevention of duplicate tracing-finished events
+            pendingMutations.set(Integer.MIN_VALUE);
+
+            if (connection != null && withFinishEvent)
+                connection.sendIfRegistered(new Event.TraceComplete(sessionId));
+        }
     }
 
     /*
@@ -177,7 +203,10 @@ public class TraceState implements ProgressEventNotifier
         if (notify)
             notifyActivity();
 
-        TraceState.mutateWithTracing(sessionIdBytes, message, elapsed(), ttl);
+        final String threadName = Thread.currentThread().getName();
+        final int elapsed = elapsed();
+
+        executeMutation(TraceKeyspace.makeEventMutation(sessionIdBytes, message, elapsed, threadName, ttl));
 
         for (ProgressListener listener : listeners)
         {
@@ -185,6 +214,31 @@ public class TraceState implements ProgressEventNotifier
         }
     }
 
+    void executeMutation(final Mutation mutation)
+    {
+        pendingMutations.incrementAndGet();
+
+        StageManager.getStage(Stage.TRACING).execute(new WrappedRunnable()
+        {
+            protected void runMayThrow() throws Exception
+            {
+                try
+                {
+                    mutateWithCatch(mutation);
+                }
+                finally
+                {
+                    if (pendingMutations.decrementAndGet() == 0)
+                        pushEventIfStopped();
+                }
+            }
+        });
+    }
+
+    /**
+     * Called from {@link org.apache.cassandra.net.OutboundTcpConnection} for non-local traces (traces
+     * that are not initiated by local node == coordinator).
+     */
     public static void mutateWithTracing(final ByteBuffer sessionId, final String message, final int elapsed, final int ttl)
     {
         final String threadName = Thread.currentThread().getName();
@@ -193,11 +247,23 @@ public class TraceState implements ProgressEventNotifier
         {
             public void runMayThrow()
             {
-                Tracing.mutateWithCatch(TraceKeyspace.makeEventMutation(sessionId, message, elapsed, threadName, ttl));
+                mutateWithCatch(TraceKeyspace.makeEventMutation(sessionId, message, elapsed, threadName, ttl));
             }
         });
     }
 
+    static void mutateWithCatch(Mutation mutation)
+    {
+        try
+        {
+            StorageProxy.mutate(Collections.singletonList(mutation), ConsistencyLevel.ANY);
+        }
+        catch (OverloadedException e)
+        {
+            Tracing.logger.warn("Too many nodes are overloaded to save trace events");
+        }
+    }
+
     public boolean acquireReference()
     {
         while (true)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/src/java/org/apache/cassandra/tracing/Tracing.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/tracing/Tracing.java b/src/java/org/apache/cassandra/tracing/Tracing.java
index dc7067e..0e49cd0 100644
--- a/src/java/org/apache/cassandra/tracing/Tracing.java
+++ b/src/java/org/apache/cassandra/tracing/Tracing.java
@@ -21,7 +21,6 @@ package org.apache.cassandra.tracing;
 
 import java.net.InetAddress;
 import java.nio.ByteBuffer;
-import java.util.Arrays;
 import java.util.Map;
 import java.util.UUID;
 import java.util.concurrent.ConcurrentHashMap;
@@ -30,15 +29,11 @@ import java.util.concurrent.ConcurrentMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.cassandra.concurrent.Stage;
-import org.apache.cassandra.concurrent.StageManager;
 import org.apache.cassandra.config.DatabaseDescriptor;
-import org.apache.cassandra.db.*;
 import org.apache.cassandra.db.marshal.TimeUUIDType;
-import org.apache.cassandra.exceptions.OverloadedException;
 import org.apache.cassandra.net.MessageIn;
 import org.apache.cassandra.net.MessagingService;
-import org.apache.cassandra.service.StorageProxy;
+import org.apache.cassandra.transport.Connection;
 import org.apache.cassandra.utils.FBUtilities;
 import org.apache.cassandra.utils.UUIDGen;
 
@@ -82,7 +77,7 @@ public class Tracing
         }
     }
 
-    private static final Logger logger = LoggerFactory.getLogger(Tracing.class);
+    static final Logger logger = LoggerFactory.getLogger(Tracing.class);
 
     private final InetAddress localAddress = FBUtilities.getLocalAddress();
 
@@ -118,26 +113,31 @@ public class Tracing
         return instance.state.get() != null;
     }
 
-    public UUID newSession()
+    public UUID newSession(Connection connection)
     {
-        return newSession(TraceType.QUERY);
+        return newSession(connection, TraceType.QUERY);
     }
 
     public UUID newSession(TraceType traceType)
     {
-        return newSession(TimeUUIDType.instance.compose(ByteBuffer.wrap(UUIDGen.getTimeUUIDBytes())), traceType);
+        return newSession(null, traceType);
     }
 
-    public UUID newSession(UUID sessionId)
+    public UUID newSession(Connection connection, TraceType traceType)
     {
-        return newSession(sessionId, TraceType.QUERY);
+        return newSession(connection, TimeUUIDType.instance.compose(ByteBuffer.wrap(UUIDGen.getTimeUUIDBytes())), traceType, false);
     }
 
-    public UUID newSession(UUID sessionId, TraceType traceType)
+    public UUID newSession(Connection connection, UUID sessionId)
+    {
+        return newSession(connection, sessionId, TraceType.QUERY, true);
+    }
+
+    private UUID newSession(Connection connection, UUID sessionId, TraceType traceType, boolean withFinishEvent)
     {
         assert state.get() == null;
 
-        TraceState ts = new TraceState(localAddress, sessionId, traceType);
+        TraceState ts = new TraceState(localAddress, connection, sessionId, traceType, withFinishEvent);
         state.set(ts);
         sessions.put(sessionId, ts);
 
@@ -166,13 +166,7 @@ public class Tracing
             final ByteBuffer sessionId = state.sessionIdBytes;
             final int ttl = state.ttl;
 
-            StageManager.getStage(Stage.TRACING).execute(new Runnable()
-            {
-                public void run()
-                {
-                    mutateWithCatch(TraceKeyspace.makeStopSessionMutation(sessionId, elapsed, ttl));
-                }
-            });
+            state.executeMutation(TraceKeyspace.makeStopSessionMutation(sessionId, elapsed, ttl));
 
             state.stop();
             sessions.remove(state.sessionId);
@@ -210,13 +204,7 @@ public class Tracing
         final String command = state.traceType.toString();
         final int ttl = state.ttl;
 
-        StageManager.getStage(Stage.TRACING).execute(new Runnable()
-        {
-            public void run()
-            {
-                mutateWithCatch(TraceKeyspace.makeStartSessionMutation(sessionId, client, parameters, request, startedAt, command, ttl));
-            }
-        });
+        state.executeMutation(TraceKeyspace.makeStartSessionMutation(sessionId, client, parameters, request, startedAt, command, ttl));
 
         return state;
     }
@@ -304,16 +292,4 @@ public class Tracing
 
         state.trace(format, args);
     }
-
-    static void mutateWithCatch(Mutation mutation)
-    {
-        try
-        {
-            StorageProxy.mutate(Arrays.asList(mutation), ConsistencyLevel.ANY);
-        }
-        catch (OverloadedException e)
-        {
-            logger.warn("Too many nodes are overloaded to save trace events");
-        }
-    }
 }

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/src/java/org/apache/cassandra/transport/Client.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/Client.java b/src/java/org/apache/cassandra/transport/Client.java
index 571a7ce..92466d2 100644
--- a/src/java/org/apache/cassandra/transport/Client.java
+++ b/src/java/org/apache/cassandra/transport/Client.java
@@ -27,6 +27,7 @@ import java.util.*;
 import com.google.common.base.Splitter;
 
 import org.apache.cassandra.auth.PasswordAuthenticator;
+import org.apache.cassandra.config.Config;
 import org.apache.cassandra.cql3.QueryOptions;
 import org.apache.cassandra.db.ConsistencyLevel;
 import org.apache.cassandra.db.marshal.Int32Type;
@@ -40,9 +41,12 @@ import static org.apache.cassandra.config.EncryptionOptions.ClientEncryptionOpti
 
 public class Client extends SimpleClient
 {
-    public Client(String host, int port, ClientEncryptionOptions encryptionOptions)
+    private final SimpleEventHandler eventHandler = new SimpleEventHandler();
+
+    public Client(String host, int port, int version, ClientEncryptionOptions encryptionOptions)
     {
-        super(host, port, encryptionOptions);
+        super(host, port, version, encryptionOptions);
+        setEventHandler(eventHandler);
     }
 
     public void run() throws IOException
@@ -56,6 +60,12 @@ public class Client extends SimpleClient
         BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
         for (;;)
         {
+            Event event;
+            while ((event = eventHandler.queue.poll()) != null)
+            {
+                System.out.println("<< " + event);
+            }
+
             System.out.print(">> ");
             System.out.flush();
             String line = in.readLine();
@@ -228,21 +238,24 @@ public class Client extends SimpleClient
 
     public static void main(String[] args) throws Exception
     {
+        Config.setClientMode(true);
+
         // Print usage if no argument is specified.
-        if (args.length != 2)
+        if (args.length < 2 || args.length > 3)
         {
-            System.err.println("Usage: " + Client.class.getSimpleName() + " <host> <port>");
+            System.err.println("Usage: " + Client.class.getSimpleName() + " <host> <port> [<version>]");
             return;
         }
 
         // Parse options.
         String host = args[0];
         int port = Integer.parseInt(args[1]);
+        int version = args.length == 3 ? Integer.parseInt(args[2]) : Server.CURRENT_VERSION;
 
         ClientEncryptionOptions encryptionOptions = new ClientEncryptionOptions();
-        System.out.println("CQL binary protocol console " + host + "@" + port);
+        System.out.println("CQL binary protocol console " + host + "@" + port + " using native protocol version " + version);
 
-        new Client(host, port, encryptionOptions).run();
+        new Client(host, port, version, encryptionOptions).run();
         System.exit(0);
     }
 }

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/src/java/org/apache/cassandra/transport/Connection.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/Connection.java b/src/java/org/apache/cassandra/transport/Connection.java
index aa571a7..e2811e9 100644
--- a/src/java/org/apache/cassandra/transport/Connection.java
+++ b/src/java/org/apache/cassandra/transport/Connection.java
@@ -19,6 +19,7 @@ package org.apache.cassandra.transport;
 
 import io.netty.channel.Channel;
 import io.netty.util.AttributeKey;
+import org.apache.cassandra.transport.messages.EventMessage;
 
 public class Connection
 {
@@ -64,14 +65,21 @@ public class Connection
         return channel;
     }
 
+    public void sendIfRegistered(Event event)
+    {
+        if (getTracker().isRegistered(event.type, channel))
+            channel.writeAndFlush(new EventMessage(event));
+    }
+
     public interface Factory
     {
-        public Connection newConnection(Channel channel, int version);
+        Connection newConnection(Channel channel, int version);
     }
 
     public interface Tracker
     {
-        public void addConnection(Channel ch, Connection connection);
-        public void closeAll();
+        void addConnection(Channel ch, Connection connection);
+
+        boolean isRegistered(Event.Type type, Channel ch);
     }
 }

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/src/java/org/apache/cassandra/transport/Event.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/Event.java b/src/java/org/apache/cassandra/transport/Event.java
index 5e9c6b7..070c27b 100644
--- a/src/java/org/apache/cassandra/transport/Event.java
+++ b/src/java/org/apache/cassandra/transport/Event.java
@@ -21,13 +21,26 @@ import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.util.Iterator;
 import java.util.List;
+import java.util.UUID;
 
 import com.google.common.base.Objects;
 import io.netty.buffer.ByteBuf;
 
 public abstract class Event
 {
-    public enum Type { TOPOLOGY_CHANGE, STATUS_CHANGE, SCHEMA_CHANGE }
+    public enum Type {
+        TOPOLOGY_CHANGE(Server.VERSION_2),
+        STATUS_CHANGE(Server.VERSION_2),
+        SCHEMA_CHANGE(Server.VERSION_2),
+        TRACE_COMPLETE(Server.VERSION_4);
+
+        public final int minimumVersion;
+
+        Type(int minimumVersion)
+        {
+            this.minimumVersion = minimumVersion;
+        }
+    }
 
     public final Type type;
 
@@ -38,7 +51,10 @@ public abstract class Event
 
     public static Event deserialize(ByteBuf cb, int version)
     {
-        switch (CBUtil.readEnumValue(Type.class, cb))
+        Type eventType = CBUtil.readEnumValue(Type.class, cb);
+        if (eventType.minimumVersion > version)
+            throw new ProtocolException("Event " + eventType.name() + " not valid for protocol version " + version);
+        switch (eventType)
         {
             case TOPOLOGY_CHANGE:
                 return TopologyChange.deserializeEvent(cb, version);
@@ -46,12 +62,16 @@ public abstract class Event
                 return StatusChange.deserializeEvent(cb, version);
             case SCHEMA_CHANGE:
                 return SchemaChange.deserializeEvent(cb, version);
+            case TRACE_COMPLETE:
+                return TraceComplete.deserializeEvent(cb, version);
         }
         throw new AssertionError();
     }
 
     public void serialize(ByteBuf dest, int version)
     {
+        if (type.minimumVersion > version)
+            throw new ProtocolException("Event " + type.name() + " not valid for protocol version " + version);
         CBUtil.writeEnumValue(type, dest);
         serializeEvent(dest, version);
     }
@@ -397,4 +417,56 @@ public abstract class Event
                 && Objects.equal(argTypes, scc.argTypes);
         }
     }
+
+    /**
+     * @since native protocol v4
+     */
+    public static class TraceComplete extends Event
+    {
+        public final UUID traceSessionId;
+
+        public TraceComplete(UUID traceSessionId)
+        {
+            super(Type.TRACE_COMPLETE);
+            this.traceSessionId = traceSessionId;
+        }
+
+        public static Event deserializeEvent(ByteBuf cb, int version)
+        {
+            UUID traceSessionId = CBUtil.readUUID(cb);
+            return new TraceComplete(traceSessionId);
+        }
+
+        protected void serializeEvent(ByteBuf dest, int version)
+        {
+            CBUtil.writeUUID(traceSessionId, dest);
+        }
+
+        protected int eventSerializedSize(int version)
+        {
+            return CBUtil.sizeOfUUID(traceSessionId);
+        }
+
+        @Override
+        public String toString()
+        {
+            return traceSessionId.toString();
+        }
+
+        @Override
+        public int hashCode()
+        {
+            return Objects.hashCode(traceSessionId);
+        }
+
+        @Override
+        public boolean equals(Object other)
+        {
+            if (!(other instanceof TraceComplete))
+                return false;
+
+            TraceComplete tf = (TraceComplete)other;
+            return Objects.equal(traceSessionId, tf.traceSessionId);
+        }
+    }
 }

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/src/java/org/apache/cassandra/transport/Server.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/Server.java b/src/java/org/apache/cassandra/transport/Server.java
index c7c1bdb..40a3371 100644
--- a/src/java/org/apache/cassandra/transport/Server.java
+++ b/src/java/org/apache/cassandra/transport/Server.java
@@ -233,10 +233,9 @@ public class Server implements CassandraDaemon.Server
             groups.get(type).add(ch);
         }
 
-        public void unregister(Channel ch)
+        public boolean isRegistered(Event.Type type, Channel ch)
         {
-            for (ChannelGroup group : groups.values())
-                group.remove(ch);
+            return groups.get(type).contains(ch);
         }
 
         public void send(Event event)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/src/java/org/apache/cassandra/transport/ServerConnection.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/ServerConnection.java b/src/java/org/apache/cassandra/transport/ServerConnection.java
index 24eb643..dbaf123 100644
--- a/src/java/org/apache/cassandra/transport/ServerConnection.java
+++ b/src/java/org/apache/cassandra/transport/ServerConnection.java
@@ -34,7 +34,7 @@ public class ServerConnection extends Connection
     private final ClientState clientState;
     private volatile State state;
 
-    private final ConcurrentMap<Integer, QueryState> queryStates = new NonBlockingHashMap<Integer, QueryState>();
+    private final ConcurrentMap<Integer, QueryState> queryStates = new NonBlockingHashMap<>();
 
     public ServerConnection(Channel channel, int version, Connection.Tracker tracker)
     {

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/src/java/org/apache/cassandra/transport/SimpleClient.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/SimpleClient.java b/src/java/org/apache/cassandra/transport/SimpleClient.java
index 3e4631b..b39f166 100644
--- a/src/java/org/apache/cassandra/transport/SimpleClient.java
+++ b/src/java/org/apache/cassandra/transport/SimpleClient.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.SynchronousQueue;
 import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLEngine;
@@ -45,6 +46,7 @@ import org.apache.cassandra.db.ConsistencyLevel;
 import org.apache.cassandra.security.SSLFactory;
 import org.apache.cassandra.transport.messages.CredentialsMessage;
 import org.apache.cassandra.transport.messages.ErrorMessage;
+import org.apache.cassandra.transport.messages.EventMessage;
 import org.apache.cassandra.transport.messages.ExecuteMessage;
 import org.apache.cassandra.transport.messages.PrepareMessage;
 import org.apache.cassandra.transport.messages.QueryMessage;
@@ -72,8 +74,9 @@ public class SimpleClient
 
     protected final ResponseHandler responseHandler = new ResponseHandler();
     protected final Connection.Tracker tracker = new ConnectionTracker();
+    protected final int version;
     // We don't track connection really, so we don't need one Connection per channel
-    protected final Connection connection = new Connection(null, Server.CURRENT_VERSION, tracker);
+    protected Connection connection;
     protected Bootstrap bootstrap;
     protected Channel channel;
     protected ChannelFuture lastWriteFuture;
@@ -82,18 +85,28 @@ public class SimpleClient
     {
         public Connection newConnection(Channel channel, int version)
         {
-            assert version == Server.CURRENT_VERSION;
             return connection;
         }
     };
 
-    public SimpleClient(String host, int port, ClientEncryptionOptions encryptionOptions)
+    public SimpleClient(String host, int port, int version, ClientEncryptionOptions encryptionOptions)
     {
         this.host = host;
         this.port = port;
+        this.version = version;
         this.encryptionOptions = encryptionOptions;
     }
 
+    public SimpleClient(String host, int port, ClientEncryptionOptions encryptionOptions)
+    {
+        this(host, port, Server.CURRENT_VERSION, encryptionOptions);
+    }
+
+    public SimpleClient(String host, int port, int version)
+    {
+        this(host, port, version, new ClientEncryptionOptions());
+    }
+
     public SimpleClient(String host, int port)
     {
         this(host, port, new ClientEncryptionOptions());
@@ -103,7 +116,7 @@ public class SimpleClient
     {
         establishConnection();
 
-        Map<String, String> options = new HashMap<String, String>();
+        Map<String, String> options = new HashMap<>();
         options.put(StartupMessage.CQL_VERSION, "3.0.0");
         if (useCompression)
         {
@@ -113,6 +126,11 @@ public class SimpleClient
         execute(new StartupMessage(options));
     }
 
+    public void setEventHandler(EventHandler eventHandler)
+    {
+        responseHandler.eventHandler = eventHandler;
+    }
+
     protected void establishConnection() throws IOException
     {
         // Configure the client.
@@ -188,7 +206,7 @@ public class SimpleClient
         bootstrap.group().shutdownGracefully();
     }
 
-    protected Message.Response execute(Message.Request request)
+    public Message.Response execute(Message.Request request)
     {
         try
         {
@@ -205,6 +223,21 @@ public class SimpleClient
         }
     }
 
+    public interface EventHandler
+    {
+        void onEvent(Event event);
+    }
+
+    public static class SimpleEventHandler implements EventHandler
+    {
+        public final LinkedBlockingQueue<Event> queue = new LinkedBlockingQueue<>();
+
+        public void onEvent(Event event)
+        {
+            queue.add(event);
+        }
+    }
+
     // Stateless handlers
     private static final Message.ProtocolDecoder messageDecoder = new Message.ProtocolDecoder();
     private static final Message.ProtocolEncoder messageEncoder = new Message.ProtocolEncoder();
@@ -215,13 +248,20 @@ public class SimpleClient
     private static class ConnectionTracker implements Connection.Tracker
     {
         public void addConnection(Channel ch, Connection connection) {}
-        public void closeAll() {}
+
+        public boolean isRegistered(Event.Type type, Channel ch)
+        {
+            return false;
+        }
     }
 
     private class Initializer extends ChannelInitializer<Channel>
     {
         protected void initChannel(Channel channel) throws Exception
         {
+            connection = new Connection(channel, version, tracker);
+            channel.attr(Connection.attributeKey).set(connection);
+
             ChannelPipeline pipeline = channel.pipeline();
             pipeline.addLast("frameDecoder", new Frame.Decoder(connectionFactory));
             pipeline.addLast("frameEncoder", frameEncoder);
@@ -259,14 +299,21 @@ public class SimpleClient
     @ChannelHandler.Sharable
     private static class ResponseHandler extends SimpleChannelInboundHandler<Message.Response>
     {
-        public final BlockingQueue<Message.Response> responses = new SynchronousQueue<Message.Response>(true);
+        public final BlockingQueue<Message.Response> responses = new SynchronousQueue<>(true);
+        public EventHandler eventHandler;
 
         @Override
         public void channelRead0(ChannelHandlerContext ctx, Message.Response r)
         {
             try
             {
-                responses.put(r);
+                if (r instanceof EventMessage)
+                {
+                    if (eventHandler != null)
+                        eventHandler.onEvent(((EventMessage) r).event);
+                }
+                else
+                    responses.put(r);
             }
             catch (InterruptedException ie)
             {

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/src/java/org/apache/cassandra/transport/messages/BatchMessage.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/messages/BatchMessage.java b/src/java/org/apache/cassandra/transport/messages/BatchMessage.java
index 3acdbdd..88394e7 100644
--- a/src/java/org/apache/cassandra/transport/messages/BatchMessage.java
+++ b/src/java/org/apache/cassandra/transport/messages/BatchMessage.java
@@ -165,7 +165,7 @@ public class BatchMessage extends Message.Request
 
             if (state.traceNextQuery())
             {
-                state.createTracingSession();
+                state.createTracingSession(connection);
                 // TODO we don't have [typed] access to CQL bind variables here.  CASSANDRA-4560 is open to add support.
                 Tracing.instance.begin("Execute batch of CQL3 queries", state.getClientAddress(), Collections.<String, String>emptyMap());
             }

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/src/java/org/apache/cassandra/transport/messages/ExecuteMessage.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/messages/ExecuteMessage.java b/src/java/org/apache/cassandra/transport/messages/ExecuteMessage.java
index 50f6619..5f3e368 100644
--- a/src/java/org/apache/cassandra/transport/messages/ExecuteMessage.java
+++ b/src/java/org/apache/cassandra/transport/messages/ExecuteMessage.java
@@ -122,7 +122,7 @@ public class ExecuteMessage extends Message.Request
 
             if (state.traceNextQuery())
             {
-                state.createTracingSession();
+                state.createTracingSession(connection);
 
                 ImmutableMap.Builder<String, String> builder = ImmutableMap.builder();
                 if (options.getPageSize() > 0)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/src/java/org/apache/cassandra/transport/messages/PrepareMessage.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/messages/PrepareMessage.java b/src/java/org/apache/cassandra/transport/messages/PrepareMessage.java
index f54d1d9..db9e304 100644
--- a/src/java/org/apache/cassandra/transport/messages/PrepareMessage.java
+++ b/src/java/org/apache/cassandra/transport/messages/PrepareMessage.java
@@ -71,7 +71,7 @@ public class PrepareMessage extends Message.Request
 
             if (state.traceNextQuery())
             {
-                state.createTracingSession();
+                state.createTracingSession(connection);
                 Tracing.instance.begin("Preparing CQL3 query", state.getClientAddress(), ImmutableMap.of("query", query));
             }
 

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/src/java/org/apache/cassandra/transport/messages/QueryMessage.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/messages/QueryMessage.java b/src/java/org/apache/cassandra/transport/messages/QueryMessage.java
index 4e21678..fe86a89 100644
--- a/src/java/org/apache/cassandra/transport/messages/QueryMessage.java
+++ b/src/java/org/apache/cassandra/transport/messages/QueryMessage.java
@@ -106,7 +106,7 @@ public class QueryMessage extends Message.Request
 
             if (state.traceNextQuery())
             {
-                state.createTracingSession();
+                state.createTracingSession(connection);
 
                 ImmutableMap.Builder<String, String> builder = ImmutableMap.builder();
                 builder.put("query", query);

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/src/java/org/apache/cassandra/transport/messages/RegisterMessage.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/transport/messages/RegisterMessage.java b/src/java/org/apache/cassandra/transport/messages/RegisterMessage.java
index ee410bb..928e676 100644
--- a/src/java/org/apache/cassandra/transport/messages/RegisterMessage.java
+++ b/src/java/org/apache/cassandra/transport/messages/RegisterMessage.java
@@ -32,7 +32,7 @@ public class RegisterMessage extends Message.Request
         public RegisterMessage decode(ByteBuf body, int version)
         {
             int length = body.readUnsignedShort();
-            List<Event.Type> eventTypes = new ArrayList<Event.Type>(length);
+            List<Event.Type> eventTypes = new ArrayList<>(length);
             for (int i = 0; i < length; ++i)
                 eventTypes.add(CBUtil.readEnumValue(Event.Type.class, body));
             return new RegisterMessage(eventTypes);
@@ -65,10 +65,14 @@ public class RegisterMessage extends Message.Request
     public Response execute(QueryState state)
     {
         assert connection instanceof ServerConnection;
-        Connection.Tracker tracker = ((ServerConnection)connection).getTracker();
+        Connection.Tracker tracker = connection.getTracker();
         assert tracker instanceof Server.ConnectionTracker;
         for (Event.Type type : eventTypes)
-            ((Server.ConnectionTracker)tracker).register(type, connection().channel());
+        {
+            if (type.minimumVersion > connection.getVersion())
+                throw new ProtocolException("Event " + type.name() + " not valid for protocol version " + connection.getVersion());
+            ((Server.ConnectionTracker) tracker).register(type, connection().channel());
+        }
         return new ReadyMessage();
     }
 

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/test/unit/org/apache/cassandra/cql3/CQLTester.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/cql3/CQLTester.java b/test/unit/org/apache/cassandra/cql3/CQLTester.java
index e49250b..0fe323e 100644
--- a/test/unit/org/apache/cassandra/cql3/CQLTester.java
+++ b/test/unit/org/apache/cassandra/cql3/CQLTester.java
@@ -475,6 +475,13 @@ public abstract class CQLTester
         return session[protocolVersion-1].execute(formatQuery(query), values);
     }
 
+    protected Session sessionNet(int protocolVersion)
+    {
+        requireNetwork();
+
+        return session[protocolVersion-1];
+    }
+
     private String formatQuery(String query)
     {
         String currentTable = currentTable();

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/test/unit/org/apache/cassandra/tracing/TraceCompleteTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/tracing/TraceCompleteTest.java b/test/unit/org/apache/cassandra/tracing/TraceCompleteTest.java
new file mode 100644
index 0000000..8ef7e52
--- /dev/null
+++ b/test/unit/org/apache/cassandra/tracing/TraceCompleteTest.java
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.cassandra.tracing;
+
+import java.util.Collections;
+import java.util.concurrent.TimeUnit;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.cassandra.cql3.CQLTester;
+import org.apache.cassandra.cql3.QueryOptions;
+import org.apache.cassandra.service.StorageService;
+import org.apache.cassandra.transport.Event;
+import org.apache.cassandra.transport.Message;
+import org.apache.cassandra.transport.ProtocolException;
+import org.apache.cassandra.transport.Server;
+import org.apache.cassandra.transport.SimpleClient;
+import org.apache.cassandra.transport.messages.QueryMessage;
+import org.apache.cassandra.transport.messages.RegisterMessage;
+
+public class TraceCompleteTest extends CQLTester
+{
+    @Test
+    public void testTraceComplete() throws Throwable
+    {
+        sessionNet(3);
+
+        SimpleClient clientA = new SimpleClient(nativeAddr.getHostAddress(), nativePort);
+        clientA.connect(false);
+        try
+        {
+            SimpleClient.SimpleEventHandler eventHandlerA = new SimpleClient.SimpleEventHandler();
+            clientA.setEventHandler(eventHandlerA);
+
+            SimpleClient clientB = new SimpleClient(nativeAddr.getHostAddress(), nativePort);
+            clientB.connect(false);
+            try
+            {
+                SimpleClient.SimpleEventHandler eventHandlerB = new SimpleClient.SimpleEventHandler();
+                clientB.setEventHandler(eventHandlerB);
+
+                Message.Response resp = clientA.execute(new RegisterMessage(Collections.singletonList(Event.Type.TRACE_COMPLETE)));
+                Assert.assertSame(Message.Type.READY, resp.type);
+
+                createTable("CREATE TABLE %s (pk int PRIMARY KEY, v text)");
+
+                QueryMessage query = new QueryMessage("SELECT * FROM " + KEYSPACE + '.' + currentTable(), QueryOptions.DEFAULT);
+                query.setTracingRequested();
+                resp = clientA.execute(query);
+
+                Event event = eventHandlerA.queue.poll(250, TimeUnit.MILLISECONDS);
+                Assert.assertNotNull(event);
+
+                // assert that only the connection that started the trace receives the trace-complete event
+                Assert.assertNull(eventHandlerB.queue.poll(100, TimeUnit.MILLISECONDS));
+
+                Assert.assertSame(Event.Type.TRACE_COMPLETE, event.type);
+                Assert.assertEquals(resp.getTracingId(), ((Event.TraceComplete) event).traceSessionId);
+            }
+            finally
+            {
+                clientB.close();
+            }
+        }
+        finally
+        {
+            clientA.close();
+        }
+    }
+
+    @Test
+    public void testTraceCompleteVersion3() throws Throwable
+    {
+        sessionNet(3);
+
+        SimpleClient clientA = new SimpleClient(nativeAddr.getHostAddress(), nativePort, Server.VERSION_3);
+        clientA.connect(false);
+        try
+        {
+            SimpleClient.SimpleEventHandler eventHandlerA = new SimpleClient.SimpleEventHandler();
+            clientA.setEventHandler(eventHandlerA);
+
+            try
+            {
+                clientA.execute(new RegisterMessage(Collections.singletonList(Event.Type.TRACE_COMPLETE)));
+                Assert.fail();
+            }
+            catch (RuntimeException e)
+            {
+                Assert.assertTrue(e.getCause() instanceof ProtocolException); // that's what we want
+            }
+
+            createTable("CREATE TABLE %s (pk int PRIMARY KEY, v text)");
+
+            QueryMessage query = new QueryMessage("SELECT * FROM " + KEYSPACE + '.' + currentTable(), QueryOptions.DEFAULT);
+            query.setTracingRequested();
+            clientA.execute(query);
+
+            Event event = eventHandlerA.queue.poll(250, TimeUnit.MILLISECONDS);
+            Assert.assertNull(event);
+        }
+        finally
+        {
+            clientA.close();
+        }
+    }
+
+    @Test
+    public void testTraceCompleteNotRegistered() throws Throwable
+    {
+        sessionNet(3);
+
+        SimpleClient clientA = new SimpleClient(nativeAddr.getHostAddress(), nativePort);
+        clientA.connect(false);
+        try
+        {
+            SimpleClient.SimpleEventHandler eventHandlerA = new SimpleClient.SimpleEventHandler();
+            clientA.setEventHandler(eventHandlerA);
+
+            createTable("CREATE TABLE %s (pk int PRIMARY KEY, v text)");
+
+            // check that we do NOT receive a trace-complete event, since we didn't register for that
+
+            // with setTracingRequested()
+            QueryMessage query = new QueryMessage("SELECT * FROM " + KEYSPACE + '.' + currentTable(), QueryOptions.DEFAULT);
+            query.setTracingRequested();
+            clientA.execute(query);
+            // and without setTracingRequested()
+            query = new QueryMessage("SELECT * FROM " + KEYSPACE + '.' + currentTable(), QueryOptions.DEFAULT);
+            clientA.execute(query);
+
+            Event event = eventHandlerA.queue.poll(250, TimeUnit.MILLISECONDS);
+            Assert.assertNull(event);
+        }
+        finally
+        {
+            clientA.close();
+        }
+    }
+
+    @Test
+    public void testTraceCompleteWithProbability() throws Throwable
+    {
+        sessionNet(3);
+
+        double traceProbability = StorageService.instance.getTraceProbability();
+        // check for trace-probability in QueryState.traceNextQuery() is x<y, not x<=y
+        StorageService.instance.setTraceProbability(1.1d);
+
+        SimpleClient clientA = new SimpleClient(nativeAddr.getHostAddress(), nativePort);
+        clientA.connect(false);
+        try
+        {
+            SimpleClient.SimpleEventHandler eventHandlerA = new SimpleClient.SimpleEventHandler();
+            clientA.setEventHandler(eventHandlerA);
+
+            SimpleClient clientB = new SimpleClient(nativeAddr.getHostAddress(), nativePort);
+            clientB.connect(false);
+            try
+            {
+                SimpleClient.SimpleEventHandler eventHandlerB = new SimpleClient.SimpleEventHandler();
+                clientB.setEventHandler(eventHandlerB);
+
+                Message.Response resp = clientA.execute(new RegisterMessage(Collections.singletonList(Event.Type.TRACE_COMPLETE)));
+                Assert.assertSame(Message.Type.READY, resp.type);
+
+                createTable("CREATE TABLE %s (pk int PRIMARY KEY, v text)");
+
+                QueryMessage query = new QueryMessage("SELECT * FROM " + KEYSPACE + '.' + currentTable(), QueryOptions.DEFAULT);
+                clientA.execute(query);
+
+                Event event = eventHandlerA.queue.poll(2000, TimeUnit.MILLISECONDS);
+                Assert.assertNull(event);
+
+                Assert.assertNull(eventHandlerB.queue.poll(100, TimeUnit.MILLISECONDS));
+            }
+            finally
+            {
+                clientB.close();
+            }
+        }
+        finally
+        {
+            StorageService.instance.setTraceProbability(traceProbability);
+            clientA.close();
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/cassandra/blob/f6217ae1/test/unit/org/apache/cassandra/transport/MessagePayloadTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/transport/MessagePayloadTest.java b/test/unit/org/apache/cassandra/transport/MessagePayloadTest.java
index 6df889f..afb738a 100644
--- a/test/unit/org/apache/cassandra/transport/MessagePayloadTest.java
+++ b/test/unit/org/apache/cassandra/transport/MessagePayloadTest.java
@@ -23,6 +23,7 @@ import java.nio.ByteBuffer;
 import java.util.Collections;
 import java.util.Map;
 
+import org.junit.After;
 import org.junit.AfterClass;
 import org.junit.Assert;
 import org.junit.BeforeClass;
@@ -81,11 +82,12 @@ public class MessagePayloadTest extends CQLTester
             return;
         try
         {
-            cqlQueryHandlerField.setAccessible(false);
-
             Field modifiersField = Field.class.getDeclaredField("modifiers");
             modifiersField.setAccessible(true);
             modifiersField.setInt(cqlQueryHandlerField, cqlQueryHandlerField.getModifiers() | Modifier.FINAL);
+
+            cqlQueryHandlerField.setAccessible(false);
+
             modifiersField.setAccessible(modifiersAccessible);
         }
         catch (IllegalAccessException | NoSuchFieldException e)
@@ -94,6 +96,19 @@ public class MessagePayloadTest extends CQLTester
         }
     }
 
+    @After
+    public void dropCreatedTable()
+    {
+        try
+        {
+            QueryProcessor.executeOnceInternal("DROP TABLE " + KEYSPACE + ".atable");
+        }
+        catch (Throwable t)
+        {
+            // ignore
+        }
+    }
+
     @Test
     public void testMessagePayload() throws Throwable
     {
@@ -163,6 +178,102 @@ public class MessagePayloadTest extends CQLTester
         }
     }
 
+    @Test
+    public void testMessagePayloadVersion3() throws Throwable
+    {
+        QueryHandler queryHandler = (QueryHandler) cqlQueryHandlerField.get(null);
+        cqlQueryHandlerField.set(null, new TestQueryHandler());
+        try
+        {
+            requireNetwork();
+
+            Assert.assertSame(TestQueryHandler.class, ClientState.getCQLQueryHandler().getClass());
+
+            SimpleClient client = new SimpleClient(nativeAddr.getHostAddress(), nativePort, Server.VERSION_3);
+            try
+            {
+                client.connect(false);
+
+                Map<String, byte[]> reqMap;
+
+                QueryMessage queryMessage = new QueryMessage(
+                                                            "CREATE TABLE " + KEYSPACE + ".atable (pk int PRIMARY KEY, v text)",
+                                                            QueryOptions.DEFAULT
+                );
+                PrepareMessage prepareMessage = new PrepareMessage("SELECT * FROM " + KEYSPACE + ".atable");
+
+                reqMap = Collections.singletonMap("foo", "42".getBytes());
+                responsePayload = Collections.singletonMap("bar", "42".getBytes());
+                queryMessage.setCustomPayload(reqMap);
+                try
+                {
+                    client.execute(queryMessage);
+                    Assert.fail();
+                }
+                catch (ProtocolException e)
+                {
+                    // that's what we want
+                }
+                queryMessage.setCustomPayload(null);
+                client.execute(queryMessage);
+
+                reqMap = Collections.singletonMap("foo", "43".getBytes());
+                responsePayload = Collections.singletonMap("bar", "43".getBytes());
+                prepareMessage.setCustomPayload(reqMap);
+                try
+                {
+                    client.execute(prepareMessage);
+                    Assert.fail();
+                }
+                catch (ProtocolException e)
+                {
+                    // that's what we want
+                }
+                prepareMessage.setCustomPayload(null);
+                ResultMessage.Prepared prepareResponse = (ResultMessage.Prepared) client.execute(prepareMessage);
+
+                ExecuteMessage executeMessage = new ExecuteMessage(prepareResponse.statementId, QueryOptions.DEFAULT);
+                reqMap = Collections.singletonMap("foo", "44".getBytes());
+                responsePayload = Collections.singletonMap("bar", "44".getBytes());
+                executeMessage.setCustomPayload(reqMap);
+                try
+                {
+                    client.execute(executeMessage);
+                    Assert.fail();
+                }
+                catch (ProtocolException e)
+                {
+                    // that's what we want
+                }
+
+                BatchMessage batchMessage = new BatchMessage(BatchStatement.Type.UNLOGGED,
+                                                             Collections.<Object>singletonList("INSERT INTO " + KEYSPACE + ".atable (pk,v) VALUES (1, 'foo')"),
+                                                             Collections.singletonList(Collections.<ByteBuffer>emptyList()),
+                                                             QueryOptions.DEFAULT);
+                reqMap = Collections.singletonMap("foo", "45".getBytes());
+                responsePayload = Collections.singletonMap("bar", "45".getBytes());
+                batchMessage.setCustomPayload(reqMap);
+                try
+                {
+                    client.execute(batchMessage);
+                    Assert.fail();
+                }
+                catch (ProtocolException e)
+                {
+                    // that's what we want
+                }
+            }
+            finally
+            {
+                client.close();
+            }
+        }
+        finally
+        {
+            cqlQueryHandlerField.set(null, queryHandler);
+        }
+    }
+
     private static void payloadEquals(Map<String, byte[]> map1, Map<String, byte[]> map2)
     {
         Assert.assertNotNull(map1);
@@ -184,26 +295,6 @@ public class MessagePayloadTest extends CQLTester
             return QueryProcessor.instance.getPreparedForThrift(id);
         }
 
-        public ResultMessage processPrepared(CQLStatement statement, QueryState state, QueryOptions options) throws RequestExecutionException, RequestValidationException
-        {
-            return processPrepared(statement, state, options, null);
-        }
-
-        public ResultMessage processBatch(BatchStatement statement, QueryState state, BatchQueryOptions options) throws RequestExecutionException, RequestValidationException
-        {
-            return processBatch(statement, state, options, null);
-        }
-
-        public ResultMessage process(String query, QueryState state, QueryOptions options) throws RequestExecutionException, RequestValidationException
-        {
-            return process(query, state, options, null);
-        }
-
-        public ResultMessage.Prepared prepare(String query, QueryState state) throws RequestValidationException
-        {
-            return prepare(query, state, null);
-        }
-
         public ResultMessage.Prepared prepare(String query, QueryState state, Map<String, byte[]> customPayload) throws RequestValidationException
         {
             if (customPayload != null)


Mime
View raw message