activemq-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From chir...@apache.org
Subject svn commit: r1034238 - /activemq/activemq-apollo/trunk/apollo-stomp/src/main/scala/org/apache/activemq/apollo/stomp/StompProtocol.scala
Date Fri, 12 Nov 2010 01:27:53 GMT
Author: chirino
Date: Fri Nov 12 01:27:53 2010
New Revision: 1034238

URL: http://svn.apache.org/viewvc?rev=1034238&view=rev
Log:
Simplified stomp handler logic a bit.

Modified:
    activemq/activemq-apollo/trunk/apollo-stomp/src/main/scala/org/apache/activemq/apollo/stomp/StompProtocol.scala

Modified: activemq/activemq-apollo/trunk/apollo-stomp/src/main/scala/org/apache/activemq/apollo/stomp/StompProtocol.scala
URL: http://svn.apache.org/viewvc/activemq/activemq-apollo/trunk/apollo-stomp/src/main/scala/org/apache/activemq/apollo/stomp/StompProtocol.scala?rev=1034238&r1=1034237&r2=1034238&view=diff
==============================================================================
--- activemq/activemq-apollo/trunk/apollo-stomp/src/main/scala/org/apache/activemq/apollo/stomp/StompProtocol.scala
(original)
+++ activemq/activemq-apollo/trunk/apollo-stomp/src/main/scala/org/apache/activemq/apollo/stomp/StompProtocol.scala
Fri Nov 12 01:27:53 2010
@@ -141,7 +141,7 @@ class StompProtocolHandler extends Proto
     }
     
     def perform_ack(msgid: AsciiBuffer, uow:StoreUOW=null) = {
-      die("The subscription ack mode does not expect ACK frames")
+      async_die("The subscription ack mode does not expect ACK frames")
     }
   }
 
@@ -176,7 +176,7 @@ class StompProtocolHandler extends Proto
       }
 
       if( acked.isEmpty ) {
-        die("ACK failed, invalid message id: %s".format(msgid))
+        async_die("ACK failed, invalid message id: %s".format(msgid))
       } else {
         consumer_acks = not_acked
         acked.foreach{case (id, ack)=>
@@ -213,7 +213,7 @@ class StompProtocolHandler extends Proto
           if( ack!=null ) {
             ack(uow)
           }
-        case None => die("ACK failed, invalid message id: %s".format(msgid))
+        case None => async_die("ACK failed, invalid message id: %s".format(msgid))
       }
 
       if( protocol_version eq V1_0 ) {
@@ -321,6 +321,37 @@ class StompProtocolHandler extends Proto
     rc
   }
 
+  class ProtocolException(msg:String) extends RuntimeException(msg)
+  class Break extends RuntimeException
+
+  private def async_die(msg:String, e:Throwable=null) = try {
+    die(msg)
+  } catch {
+    case x:Break=>
+  }
+
+  private def die[T](msg:String, e:Throwable=null):T = {
+    if( e!=null) {
+      debug(e, "Shutting connection down due to: "+msg)
+    } else {
+      debug("Shutting connection down due to: "+msg)
+    }
+    die((MESSAGE_HEADER, ascii(msg))::Nil, "")
+  }
+
+  private def die[T](headers:HeaderMap, body:String):T = {
+    if( !connection.stopped ) {
+      suspendRead("shutdown")
+      connection.transport.offer(StompFrame(ERROR, headers, BufferContent(ascii(body))) )
+      // TODO: if there are too many open connections we should just close the connection
+      // without waiting for the error to get sent to the client.
+      queue.after(die_delay, TimeUnit.MILLISECONDS) {
+        connection.stop()
+      }
+    }
+    throw new Break()
+  }
+
   override def onTransportConnected() = {
 
     session_manager = new SinkMux[StompFrame]( MapSink(connection.transportSink){x=>
@@ -412,9 +443,9 @@ class StompProtocolHandler extends Proto
           die("Internal Server Error");
       }
     }  catch {
+      case e: Break =>
       case e:Exception =>
-        warn(e, "Internal Server Error")
-        die("Internal Server Error");
+        async_die("Internal Server Error", e);
     }
   }
 
@@ -438,10 +469,9 @@ class StompProtocolHandler extends Proto
       case Some(x) => x
       case None=>
         val supported_versions = SUPPORTED_PROTOCOL_VERSIONS.mkString(",")
-        _die((MESSAGE_HEADER, ascii("version not supported"))::
+        die((MESSAGE_HEADER, ascii("version not supported"))::
             (VERSION, ascii(supported_versions))::Nil,
             "Supported protocol versions are %s".format(supported_versions))
-        return
     }
 
     val heart_beat = get(headers, HEART_BEAT).getOrElse(DEFAULT_HEAT_BEAT)
@@ -458,7 +488,7 @@ class StompProtocolHandler extends Proto
             heart_beat_monitor.read_interval += heart_beat_monitor.read_interval.min(5000)
 
             heart_beat_monitor.on_dead = () => {
-              die("Stale connection.  Missed heartbeat.")
+              async_die("Stale connection.  Missed heartbeat.")
             }
           }
           if( outbound_heartbeat>=0 && please_send > 0 ) {
@@ -474,11 +504,9 @@ class StompProtocolHandler extends Proto
         } catch {
           case x:NumberFormatException=>
             die("Invalid heart-beat header: "+heart_beat)
-            return
         }
       case _ =>
         die("Invalid heart-beat header: "+heart_beat)
-        return
     }
 
     def noop = shift {  k: (Unit=>Unit) => k() }
@@ -493,7 +521,7 @@ class StompProtocolHandler extends Proto
       }
       resumeRead
       if(host==null) {
-        die("Invalid virtual host: "+host_header.get)
+        async_die("Invalid virtual host: "+host_header.get)
         noop // to make the cps compiler plugin happy.
       } else {
         this.host=host
@@ -508,8 +536,8 @@ class StompProtocolHandler extends Proto
           noop // to make the cps compiler plugin happy.
         }
 
-        if( authenticated ) {
-            die("Authentication failed.")
+        if( !authenticated ) {
+            async_die("Authentication failed.")
         } else {
           val outbound_heart_beat_header = ascii("%d,%d".format(outbound_heartbeat,inbound_heartbeat))
           session_id = ascii(this.host.config.id + ":"+this.host.session_counter.incrementAndGet)
@@ -560,8 +588,8 @@ class StompProtocolHandler extends Proto
           case None=>
             perform_send(frame)
           case Some(txid)=>
-            get_or_create_tx_queue(txid){ txqueue=>
-              txqueue.add(frame, (uow)=>{perform_send(frame, uow)} )
+            get_or_create_tx_queue(txid).add { uow=>
+              perform_send(frame, uow)
             }
         }
 
@@ -659,26 +687,19 @@ class StompProtocolHandler extends Proto
   }
 
   def on_stomp_subscribe(headers:HeaderMap):Unit = {
-    val dest = get(headers, DESTINATION) match {
-      case Some(dest)=> dest
-      case None=>
-        die("destination not set.")
-        return
-    }
+    val dest = get(headers, DESTINATION).getOrElse(die("destination not set."))
     val destination:Destination = dest
 
     val subscription_id = get(headers, ID)
-    var id:AsciiBuffer = subscription_id match {
-      case None =>
-        if( protocol_version eq V1_0 ) {
+    var id:AsciiBuffer = subscription_id.getOrElse {
+      if( protocol_version eq V1_0 ) {
           // in 1.0 it's ok if the client does not send us the
           // the id header
           dest
         } else {
           die("The id header is missing from the SUBSCRIBE frame");
-          return
         }
-      case Some(x:AsciiBuffer)=> x
+
     }
 
     val topic = destination.getDomain == Router.TOPIC_DOMAIN
@@ -694,7 +715,6 @@ class StompProtocolHandler extends Proto
         case ACK_MODE_MESSAGE=> new MessageAckHandler
         case ack:AsciiBuffer =>
           die("Unsuported ack mode: "+ack);
-          return;
       }
     }
 
@@ -706,13 +726,11 @@ class StompProtocolHandler extends Proto
         } catch {
           case e:FilterException =>
             die("Invalid selector expression: "+e.getMessage)
-            return;
         }
     }
 
     if ( consumers.contains(id) ) {
       die("A subscription with identified with '"+id+"' allready exists")
-      return;
     }
 
     val binding: BindingDTO = if( topic && !persistent ) {
@@ -775,18 +793,15 @@ class StompProtocolHandler extends Proto
           case Some(dest)=> dest
           case None=>
             die("destination not set.")
-            return
         }
       } else {
         die("The id header is missing from the UNSUBSCRIBE frame");
-        return
       }
     }
 
     consumers.get(id) match {
       case None=>
         die("The subscription '%s' not found.".format(id))
-        return;
       case Some(consumer)=>
         // consumer.close
         if( consumer.binding==null ) {
@@ -812,67 +827,29 @@ class StompProtocolHandler extends Proto
 
   def on_stomp_ack(frame:StompFrame):Unit = {
     val headers = frame.headers
-    get(headers, MESSAGE_ID) match {
-      case Some(messageId)=>
+    val messageId = get(headers, MESSAGE_ID).getOrElse(die("message id header not set"))
 
-        val subscription_id = get(headers, SUBSCRIPTION);
-        if( subscription_id == None && !(protocol_version eq V1_0) ) {
+    val subscription_id = get(headers, SUBSCRIPTION);
+    val handler = subscription_id match {
+      case None=>
+        if( !(protocol_version eq V1_0) ) {
           die("The subscription header is required")
-          return
-        }
-
-        val handler = subscription_id match {
-          case None=>
-
-            connection_ack_handlers.get(messageId) match {
-              case None =>
-                die("Not expecting ack for message id '%s'".format(messageId))
-                None
-              case Some(handler) =>
-                Some(handler)
-            }
-
-          case Some(id) =>
-            consumers.get(id) match {
-              case None=>
-                die("The subscription '%s' does not exist".format(id))
-                None
-              case Some(consumer)=>
-                Some(consumer.ack_handler)
-            }
-        }
-
-        handler.foreach{ handler=>
-
-          get(headers, TRANSACTION) match {
-            case None=>
-              handler.perform_ack(messageId, null)
-            case Some(txid)=>
-              get_or_create_tx_queue(txid){ _.add(frame, (uow)=>{ handler.perform_ack(messageId,
uow)} ) }
-          }
-
-          send_receipt(headers)
         }
-
-
-      case None=> die("message id header not set")
+        connection_ack_handlers.get(messageId).orElse(die("Not expecting ack for message
id '%s'".format(messageId)))
+      case Some(id) =>
+        consumers.get(id).map(_.ack_handler).orElse(die("The subscription '%s' does not exist".format(id)))
     }
-  }
 
-  private def die(msg:String, explained:String="") = {
-    debug("Shutting connection down due to: "+msg)
-    _die((MESSAGE_HEADER, ascii(msg))::Nil, explained)
-  }
-
-  private def _die(headers:HeaderMap, explained:String="") = {
-    if( !connection.stopped ) {
-      suspendRead("shutdown")
-      connection.transport.offer(StompFrame(ERROR, headers, BufferContent(ascii(explained)))
)
-      // TODO: if there are too many open connections we should just close the connection
-      // without waiting for the error to get sent to the client.
-      queue.after(die_delay, TimeUnit.MILLISECONDS) {
-        connection.stop()
+    handler.foreach{ handler=>
+      get(headers, TRANSACTION) match {
+        case None=>
+          handler.perform_ack(messageId, null)
+        case Some(txid)=>
+          get_or_create_tx_queue(txid).add{ uow=>
+            handler.perform_ack(messageId, uow)
+          }
       }
+      send_receipt(headers)
     }
   }
 
@@ -885,24 +862,24 @@ class StompProtocolHandler extends Proto
   }
 
 
-  def require_transaction_header[T](headers:HeaderMap)(proc:(AsciiBuffer)=>T):Option[T]
= {
-    get(headers, TRANSACTION) match {
-      case None=> die("transaction header not set")
-      None
-      case Some(txid)=> Some(proc(txid))
-    }
+  def require_transaction_header[T](headers:HeaderMap):AsciiBuffer = {
+    get(headers, TRANSACTION).getOrElse(die("transaction header not set"))
   }
 
   def on_stomp_begin(headers:HeaderMap) = {
-    require_transaction_header(headers){ txid=>create_tx_queue(txid){ _ => send_receipt(headers)
}  }
+    create_tx_queue(require_transaction_header(headers))
+    send_receipt(headers)
   }
 
   def on_stomp_commit(headers:HeaderMap) = {
-    require_transaction_header(headers){ txid=>remove_tx_queue(txid){ _.commit { send_receipt(headers)
} } }
+    remove_tx_queue(require_transaction_header(headers)).commit {
+      send_receipt(headers)
+    }
   }
 
   def on_stomp_abort(headers:HeaderMap) = {
-    require_transaction_header(headers){ txid=>remove_tx_queue(txid){ _.rollback { send_receipt(headers)
} } }
+    remove_tx_queue(require_transaction_header(headers)).rollback
+    send_receipt(headers)
   }
 
 
@@ -920,10 +897,10 @@ class StompProtocolHandler extends Proto
     // TODO: eventually we want to back this /w a broker Queue which
     // can provides persistence and memory swapping.
 
-    val queue = ListBuffer[(StompFrame, (StoreUOW)=>Unit)]()
+    val queue = ListBuffer[(StoreUOW)=>Unit]()
 
-    def add(frame:StompFrame, proc:(StoreUOW)=>Unit) = {
-      queue += ( frame->proc )
+    def add(proc:(StoreUOW)=>Unit):Unit = {
+      queue += proc
     }
 
     def commit(onComplete: => Unit) = {
@@ -934,16 +911,7 @@ class StompProtocolHandler extends Proto
         null
       }
 
-      queue.foreach { case (frame, proc) =>
-        proc(uow)
-//        frame.action match {
-//          case SEND =>
-//            perform_send(frame, uow)
-//          case ACK =>
-//            perform_ack(frame, uow)
-//          case _ => throw new java.lang.AssertionError("assertion failed: only send
or ack frames are transactional")
-//        }
-      }
+      queue.foreach{ _(uow) }
       if( uow!=null ) {
         uow.onComplete(^{
           onComplete
@@ -955,34 +923,30 @@ class StompProtocolHandler extends Proto
 
     }
 
-    def rollback(onComplete: => Unit) = {
+    def rollback = {
       queue.clear
-      onComplete
     }
 
   }
 
   val transactions = HashMap[AsciiBuffer, TransactionQueue]()
 
-  def create_tx_queue(txid:AsciiBuffer)(proc:(TransactionQueue)=>Unit) = {
+  def create_tx_queue(txid:AsciiBuffer):TransactionQueue = {
     if ( transactions.contains(txid) ) {
       die("transaction allready started")
     } else {
       val queue = new TransactionQueue
       transactions.put(txid, queue)
-      proc( queue )
+      queue
     }
   }
 
-  def get_or_create_tx_queue(txid:AsciiBuffer)(proc:(TransactionQueue)=>Unit) = {
-    proc(transactions.getOrElseUpdate(txid, new TransactionQueue))
+  def get_or_create_tx_queue(txid:AsciiBuffer):TransactionQueue = {
+    transactions.getOrElseUpdate(txid, new TransactionQueue)
   }
 
-  def remove_tx_queue(txid:AsciiBuffer)(proc:(TransactionQueue)=>Unit) = {
-    transactions.remove(txid) match {
-      case None=> die("transaction not active: %d".format(txid))
-      case Some(txqueue)=> proc(txqueue)
-    }
+  def remove_tx_queue(txid:AsciiBuffer):TransactionQueue = {
+    transactions.remove(txid).getOrElse(die("transaction not active: %d".format(txid)))
   }
 
 }



Mime
View raw message