hive-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From vgumas...@apache.org
Subject svn commit: r1645733 - in /hive/branches/branch-0.14: common/src/java/org/apache/hadoop/hive/conf/ itests/hive-unit-hadoop2/src/test/java/org/apache/hadoop/hive/thrift/ metastore/src/java/org/apache/hadoop/hive/metastore/ service/src/java/org/apache/hi...
Date Mon, 15 Dec 2014 19:40:32 GMT
Author: vgumashta
Date: Mon Dec 15 19:40:31 2014
New Revision: 1645733

URL: http://svn.apache.org/r1645733
Log:
HIVE-6468: HS2 & Metastore using SASL out of memory error when curl sends a get request
(Navis Ryu, Vaibhav Gumashta reviewed by Thejas Nair, Ravi Prakash)

Modified:
    hive/branches/branch-0.14/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
    hive/branches/branch-0.14/itests/hive-unit-hadoop2/src/test/java/org/apache/hadoop/hive/thrift/TestHadoop20SAuthBridge.java
    hive/branches/branch-0.14/metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java
    hive/branches/branch-0.14/service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java
    hive/branches/branch-0.14/service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java
    hive/branches/branch-0.14/shims/common-secure/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge20S.java
    hive/branches/branch-0.14/shims/common/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge.java

Modified: hive/branches/branch-0.14/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
URL: http://svn.apache.org/viewvc/hive/branches/branch-0.14/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java?rev=1645733&r1=1645732&r2=1645733&view=diff
==============================================================================
--- hive/branches/branch-0.14/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java (original)
+++ hive/branches/branch-0.14/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java Mon
Dec 15 19:40:31 2014
@@ -1592,6 +1592,11 @@ public class HiveConf extends Configurat
     HIVE_SSL_PROTOCOL_BLACKLIST("hive.ssl.protocol.blacklist", "SSLv2,SSLv2Hello,SSLv3",
         "SSL Versions to disable for all Hive Servers"),
 
+     // General Thrift configs (Thrift configs common to Metastore and HiveServer2)
+     HIVE_THRIFT_SASL_MESSAGE_LIMIT("hive.thrift.sasl.message.limit", 104857600,
+        "If the length of incoming sasl message is greater than this, regard it as invalid
and close the transport. " +
+        "Zero or less value disables this. Default is 100MB."),
+
      // HiveServer2 specific configs
     HIVE_SERVER2_MAX_START_ATTEMPTS("hive.server2.max.start.attempts", 30L, new RangeValidator(0L,
null),
         "Number of times HiveServer2 will attempt to start before exiting, sleeping 60 seconds
" +

Modified: hive/branches/branch-0.14/itests/hive-unit-hadoop2/src/test/java/org/apache/hadoop/hive/thrift/TestHadoop20SAuthBridge.java
URL: http://svn.apache.org/viewvc/hive/branches/branch-0.14/itests/hive-unit-hadoop2/src/test/java/org/apache/hadoop/hive/thrift/TestHadoop20SAuthBridge.java?rev=1645733&r1=1645732&r2=1645733&view=diff
==============================================================================
--- hive/branches/branch-0.14/itests/hive-unit-hadoop2/src/test/java/org/apache/hadoop/hive/thrift/TestHadoop20SAuthBridge.java
(original)
+++ hive/branches/branch-0.14/itests/hive-unit-hadoop2/src/test/java/org/apache/hadoop/hive/thrift/TestHadoop20SAuthBridge.java
Mon Dec 15 19:40:31 2014
@@ -80,7 +80,7 @@ public class TestHadoop20SAuthBridge ext
         super();
       }
       @Override
-      public TTransportFactory createTransportFactory(Map<String, String> saslProps)
+      public TTransportFactory createTransportFactory(Map<String, String> saslProps,
int saslMessageLimit)
       throws TTransportException {
         TSaslServerTransport.Factory transFactory =
           new TSaslServerTransport.Factory();

Modified: hive/branches/branch-0.14/metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java
URL: http://svn.apache.org/viewvc/hive/branches/branch-0.14/metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java?rev=1645733&r1=1645732&r2=1645733&view=diff
==============================================================================
--- hive/branches/branch-0.14/metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java
(original)
+++ hive/branches/branch-0.14/metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java
Mon Dec 15 19:40:31 2014
@@ -5773,8 +5773,10 @@ public class HiveMetaStore extends Thrif
             conf.getVar(HiveConf.ConfVars.METASTORE_KERBEROS_PRINCIPAL));
         // start delegation token manager
         saslServer.startDelegationTokenSecretManager(conf, baseHandler.getMS(), ServerMode.METASTORE);
-        transFactory = saslServer.createTransportFactory(
-                MetaStoreUtils.getMetaStoreSaslProperties(conf));
+        int saslMessageLimit = conf.getIntVar(ConfVars.HIVE_THRIFT_SASL_MESSAGE_LIMIT);
+        transFactory =
+            saslServer.createTransportFactory(MetaStoreUtils.getMetaStoreSaslProperties(conf),
+                saslMessageLimit);
         processor = saslServer.wrapProcessor(
           new ThriftHiveMetastore.Processor<IHMSHandler>(handler));
         LOG.info("Starting DB backed MetaStore Server in Secure Mode");

Modified: hive/branches/branch-0.14/service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java
URL: http://svn.apache.org/viewvc/hive/branches/branch-0.14/service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java?rev=1645733&r1=1645732&r2=1645733&view=diff
==============================================================================
--- hive/branches/branch-0.14/service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java
(original)
+++ hive/branches/branch-0.14/service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java
Mon Dec 15 19:40:31 2014
@@ -18,7 +18,6 @@
 package org.apache.hive.service.auth;
 
 import java.io.IOException;
-import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
@@ -28,7 +27,6 @@ import java.util.List;
 import java.util.Map;
 
 import javax.net.ssl.SSLServerSocket;
-import javax.security.auth.login.LoginException;
 import javax.security.sasl.Sasl;
 
 import org.apache.hadoop.hive.conf.HiveConf;
@@ -57,30 +55,28 @@ import org.slf4j.LoggerFactory;
 public class HiveAuthFactory {
   private static final Logger LOG = LoggerFactory.getLogger(HiveAuthFactory.class);
 
-
   public enum AuthTypes {
-    NOSASL("NOSASL"),
-    NONE("NONE"),
-    LDAP("LDAP"),
-    KERBEROS("KERBEROS"),
-    CUSTOM("CUSTOM"),
-    PAM("PAM");
-
-    private final String authType;
-
-    AuthTypes(String authType) {
-      this.authType = authType;
-    }
-
-    public String getAuthName() {
-      return authType;
-    }
+    NOSASL, NONE, LDAP, KERBEROS, CUSTOM, PAM
+  }
 
+  public static enum TransTypes {
+    HTTP {
+      AuthTypes getDefaultAuthType() {
+        return AuthTypes.NOSASL;
+      }
+    },
+    BINARY {
+      AuthTypes getDefaultAuthType() {
+        return AuthTypes.NONE;
+      }
+    };
+    abstract AuthTypes getDefaultAuthType();
   }
 
-  private HadoopThriftAuthBridge.Server saslServer;
-  private String authTypeStr;
-  private final String transportMode;
+  private final HadoopThriftAuthBridge.Server saslServer;
+  private final AuthTypes authType;
+  private final TransTypes transportType;
+  private final int saslMessageLimit;
   private final HiveConf conf;
 
   public static final String HS2_PROXY_USER = "hive.server2.proxy.user";
@@ -88,30 +84,28 @@ public class HiveAuthFactory {
 
   public HiveAuthFactory(HiveConf conf) throws TTransportException {
     this.conf = conf;
-    transportMode = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE);
-    authTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION);
-
-    // In http mode we use NOSASL as the default auth type
-    if ("http".equalsIgnoreCase(transportMode)) {
-      if (authTypeStr == null) {
-        authTypeStr = AuthTypes.NOSASL.getAuthName();
+    saslMessageLimit = conf.getIntVar(ConfVars.HIVE_THRIFT_SASL_MESSAGE_LIMIT);
+    String transTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE);
+    String authTypeStr = conf.getVar(ConfVars.HIVE_SERVER2_AUTHENTICATION);
+    transportType = TransTypes.valueOf(transTypeStr.toUpperCase());
+    authType =
+        authTypeStr == null ? transportType.getDefaultAuthType() : AuthTypes.valueOf(authTypeStr
+            .toUpperCase());
+    if (transportType == TransTypes.BINARY
+        && authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.name())
+        && ShimLoader.getHadoopShims().isSecureShimImpl()) {
+      saslServer =
+          ShimLoader.getHadoopThriftAuthBridge().createServer(
+              conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB),
+              conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL));
+      // start delegation token manager
+      try {
+        saslServer.startDelegationTokenSecretManager(conf, null, ServerMode.HIVESERVER2);
+      } catch (Exception e) {
+        throw new TTransportException("Failed to start token manager", e);
       }
     } else {
-      if (authTypeStr == null) {
-        authTypeStr = AuthTypes.NONE.getAuthName();
-      }
-      if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())
-          && ShimLoader.getHadoopShims().isSecureShimImpl()) {
-        saslServer = ShimLoader.getHadoopThriftAuthBridge()
-          .createServer(conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB),
-                        conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL));
-        // start delegation token manager
-        try {
-          saslServer.startDelegationTokenSecretManager(conf, null, ServerMode.HIVESERVER2);
-        } catch (IOException e) {
-          throw new TTransportException("Failed to start token manager", e);
-        }
-      }
+      saslServer = null;
     }
   }
 
@@ -123,42 +117,28 @@ public class HiveAuthFactory {
     return saslProps;
   }
 
-  public TTransportFactory getAuthTransFactory() throws LoginException {
-    TTransportFactory transportFactory;
-    if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) {
-      try {
-        transportFactory = saslServer.createTransportFactory(getSaslProperties());
-      } catch (TTransportException e) {
-        throw new LoginException(e.getMessage());
-      }
-    } else if (authTypeStr.equalsIgnoreCase(AuthTypes.NONE.getAuthName())) {
-      transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr);
-    } else if (authTypeStr.equalsIgnoreCase(AuthTypes.LDAP.getAuthName())) {
-      transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr);
-    } else if (authTypeStr.equalsIgnoreCase(AuthTypes.PAM.getAuthName())) {
-      transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr);
-    } else if (authTypeStr.equalsIgnoreCase(AuthTypes.NOSASL.getAuthName())) {
-      transportFactory = new TTransportFactory();
-    } else if (authTypeStr.equalsIgnoreCase(AuthTypes.CUSTOM.getAuthName())) {
-      transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr);
-    } else {
-      throw new LoginException("Unsupported authentication type " + authTypeStr);
+  public TTransportFactory getAuthTransFactory() throws Exception {
+    if (authType == AuthTypes.KERBEROS) {
+      return saslServer.createTransportFactory(getSaslProperties(), saslMessageLimit);
     }
-    return transportFactory;
+    if (authType == AuthTypes.NOSASL) {
+      return new TTransportFactory();
+    }
+    return PlainSaslHelper.getPlainTransportFactory(authType.name(), saslMessageLimit);
   }
 
   /**
    * Returns the thrift processor factory for HiveServer2 running in binary mode
+   *
    * @param service
    * @return
    * @throws LoginException
    */
-  public TProcessorFactory getAuthProcFactory(ThriftCLIService service) throws LoginException
{
-    if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) {
+  public TProcessorFactory getAuthProcFactory(ThriftCLIService service) {
+    if (authType == AuthTypes.KERBEROS) {
       return KerberosSaslHelper.getKerberosProcessorFactory(saslServer, service);
-    } else {
-      return PlainSaslHelper.getPlainProcessorFactory(service);
     }
+    return PlainSaslHelper.getPlainProcessorFactory(service);
   }
 
   public String getRemoteUser() {

Modified: hive/branches/branch-0.14/service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java
URL: http://svn.apache.org/viewvc/hive/branches/branch-0.14/service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java?rev=1645733&r1=1645732&r2=1645733&view=diff
==============================================================================
--- hive/branches/branch-0.14/service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java
(original)
+++ hive/branches/branch-0.14/service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java
Mon Dec 15 19:40:31 2014
@@ -30,6 +30,7 @@ import javax.security.sasl.Authenticatio
 import javax.security.sasl.AuthorizeCallback;
 import javax.security.sasl.SaslException;
 
+import org.apache.hadoop.hive.thrift.HadoopThriftAuthBridge;
 import org.apache.hive.service.auth.AuthenticationProviderFactory.AuthMethods;
 import org.apache.hive.service.auth.PlainSaslServer.SaslPlainProvider;
 import org.apache.hive.service.cli.thrift.TCLIService.Iface;
@@ -42,7 +43,6 @@ import org.apache.thrift.transport.TTran
 import org.apache.thrift.transport.TTransportFactory;
 
 public final class PlainSaslHelper {
-
   public static TProcessorFactory getPlainProcessorFactory(ThriftCLIService service) {
     return new SQLPlainProcessorFactory(service);
   }
@@ -52,16 +52,18 @@ public final class PlainSaslHelper {
     Security.addProvider(new SaslPlainProvider());
   }
 
-  public static TTransportFactory getPlainTransportFactory(String authTypeStr)
-    throws LoginException {
-    TSaslServerTransport.Factory saslFactory = new TSaslServerTransport.Factory();
-    try {
-      saslFactory.addServerDefinition("PLAIN", authTypeStr, null, new HashMap<String,
String>(),
-        new PlainServerCallbackHandler(authTypeStr));
-    } catch (AuthenticationException e) {
-      throw new LoginException("Error setting callback handler" + e);
+  public static TTransportFactory getPlainTransportFactory(String authTypeStr, int saslMessageLimit)
+      throws LoginException, AuthenticationException {
+    TSaslServerTransport.Factory saslTransportFactory;
+    if (saslMessageLimit > 0) {
+      saslTransportFactory =
+          new HadoopThriftAuthBridge.HiveSaslServerTransportFactory(saslMessageLimit);
+    } else {
+      saslTransportFactory = new TSaslServerTransport.Factory();
     }
-    return saslFactory;
+    saslTransportFactory.addServerDefinition("PLAIN", authTypeStr, null,
+        new HashMap<String, String>(), new PlainServerCallbackHandler(authTypeStr));
+    return saslTransportFactory;
   }
 
   public static TTransport getPlainTransport(String username, String password,

Modified: hive/branches/branch-0.14/shims/common-secure/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge20S.java
URL: http://svn.apache.org/viewvc/hive/branches/branch-0.14/shims/common-secure/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge20S.java?rev=1645733&r1=1645732&r2=1645733&view=diff
==============================================================================
--- hive/branches/branch-0.14/shims/common-secure/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge20S.java
(original)
+++ hive/branches/branch-0.14/shims/common-secure/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge20S.java
Mon Dec 15 19:40:31 2014
@@ -100,7 +100,8 @@ public class HadoopThriftAuthBridge20S e
   }
 
   @Override
-  public Server createServer(String keytabFile, String principalConf) throws TTransportException
{
+  public Server createServer(String keytabFile, String principalConf)
+      throws TTransportException {
     return new Server(keytabFile, principalConf);
   }
 
@@ -328,6 +329,7 @@ public class HadoopThriftAuthBridge20S e
         throw new TTransportException(ioe);
       }
     }
+
     /**
      * Create a server with a kerberos keytab/principal.
      */
@@ -339,7 +341,6 @@ public class HadoopThriftAuthBridge20S e
       if (principalConf == null || principalConf.isEmpty()) {
         throw new TTransportException("No principal specified");
       }
-
       // Login from the keytab
       String kerberosName;
       try {
@@ -355,34 +356,34 @@ public class HadoopThriftAuthBridge20S e
     }
 
     /**
-     * Create a TTransportFactory that, upon connection of a client socket,
-     * negotiates a Kerberized SASL transport. The resulting TTransportFactory
-     * can be passed as both the input and output transport factory when
-     * instantiating a TThreadPoolServer, for example.
+     * Create a TTransportFactory that, upon connection of a client socket, negotiates a
Kerberized
+     * SASL transport. The resulting TTransportFactory can be passed as both the input and
output
+     * transport factory when instantiating a TThreadPoolServer, for example.
      *
      * @param saslProps Map of SASL properties
      */
     @Override
-    public TTransportFactory createTransportFactory(Map<String, String> saslProps)
-        throws TTransportException {
+    public TTransportFactory createTransportFactory(Map<String, String> saslProps,
+        int saslMessageLimit) throws TTransportException {
       // Parse out the kerberos principal, host, realm.
       String kerberosName = realUgi.getUserName();
       final String names[] = SaslRpcServer.splitKerberosName(kerberosName);
       if (names.length != 3) {
         throw new TTransportException("Kerberos principal should have 3 parts: " + kerberosName);
       }
-
-      TSaslServerTransport.Factory transFactory = new TSaslServerTransport.Factory();
-      transFactory.addServerDefinition(
-          AuthMethod.KERBEROS.getMechanismName(),
-          names[0], names[1],  // two parts of kerberos principal
-          saslProps,
-          new SaslRpcServer.SaslGssCallbackHandler());
-      transFactory.addServerDefinition(AuthMethod.DIGEST.getMechanismName(),
-          null, SaslRpcServer.SASL_DEFAULT_REALM,
-          saslProps, new SaslDigestCallbackHandler(secretManager));
-
-      return new TUGIAssumingTransportFactory(transFactory, realUgi);
+      TSaslServerTransport.Factory saslTransportFactory;
+      if (saslMessageLimit > 0) {
+        saslTransportFactory = new HadoopThriftAuthBridge.HiveSaslServerTransportFactory(saslMessageLimit);
+      } else {
+        saslTransportFactory = new TSaslServerTransport.Factory();
+      }
+      saslTransportFactory.addServerDefinition(AuthMethod.KERBEROS.getMechanismName(), names[0],
names[1],
+          saslProps, new SaslRpcServer.SaslGssCallbackHandler());
+      saslTransportFactory
+          .addServerDefinition(AuthMethod.DIGEST.getMechanismName(), null,
+              SaslRpcServer.SASL_DEFAULT_REALM, saslProps, new SaslDigestCallbackHandler(
+                  secretManager));
+      return new TUGIAssumingTransportFactory(saslTransportFactory, realUgi);
     }
 
     /**

Modified: hive/branches/branch-0.14/shims/common/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge.java
URL: http://svn.apache.org/viewvc/hive/branches/branch-0.14/shims/common/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge.java?rev=1645733&r1=1645732&r2=1645733&view=diff
==============================================================================
--- hive/branches/branch-0.14/shims/common/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge.java
(original)
+++ hive/branches/branch-0.14/shims/common/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge.java
Mon Dec 15 19:40:31 2014
@@ -19,15 +19,26 @@
 package org.apache.hadoop.hive.thrift;
 
 import java.io.IOException;
+import java.io.UnsupportedEncodingException;
+import java.lang.ref.WeakReference;
 import java.net.InetAddress;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.Map;
+import java.util.WeakHashMap;
+
+import javax.security.auth.callback.CallbackHandler;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.thrift.EncodingUtils;
 import org.apache.thrift.TProcessor;
+import org.apache.thrift.transport.TSaslServerTransport;
 import org.apache.thrift.transport.TTransport;
 import org.apache.thrift.transport.TTransportException;
 import org.apache.thrift.transport.TTransportFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * This class is only overridden by the secure hadoop shim. It allows
@@ -58,8 +69,7 @@ public class HadoopThriftAuthBridge {
         "The current version of Hadoop does not support Authentication");
   }
 
-  public Server createServer(String keytabFile, String principalConf)
-      throws TTransportException {
+  public Server createServer(String keytabFile, String principalConf) throws TTransportException
{
     throw new UnsupportedOperationException(
         "The current version of Hadoop does not support Authentication");
   }
@@ -102,7 +112,9 @@ public class HadoopThriftAuthBridge {
     public enum ServerMode {
       HIVESERVER2, METASTORE
     };
-    public abstract TTransportFactory createTransportFactory(Map<String, String> saslProps)
throws TTransportException;
+
+    public abstract TTransportFactory createTransportFactory(Map<String, String> saslProps,
+        int saslMessageLimit) throws TTransportException;
     public abstract TProcessor wrapProcessor(TProcessor processor);
     public abstract TProcessor wrapNonAssumingProcessor(TProcessor processor);
     public abstract InetAddress getRemoteAddress();
@@ -117,5 +129,105 @@ public class HadoopThriftAuthBridge {
     public abstract void cancelDelegationToken(String tokenStrForm) throws IOException;
     public abstract String getUserFromToken(String tokenStr) throws IOException;
   }
+
+  public static class HiveSaslServerTransportFactory extends TSaslServerTransport.Factory
{
+    private static final Logger LOGGER = LoggerFactory.getLogger(TSaslServerTransport.class);
+    private final int saslMessageLimit;
+
+    public HiveSaslServerTransportFactory(int saslMessageLimit) {
+      this.saslMessageLimit = saslMessageLimit;
+    }
+
+    private static class TSaslServerDefinition {
+      public String mechanism;
+      public String protocol;
+      public String serverName;
+      public Map<String, String> props;
+      public CallbackHandler cbh;
+
+      public TSaslServerDefinition(String mechanism, String protocol, String serverName,
+          Map<String, String> props, CallbackHandler cbh) {
+        this.mechanism = mechanism;
+        this.protocol = protocol;
+        this.serverName = serverName;
+        this.props = props;
+        this.cbh = cbh;
+      }
+    }
+
+    private static Map<TTransport, WeakReference<TSaslServerTransport>> transportMap
= Collections
+        .synchronizedMap(new WeakHashMap<TTransport, WeakReference<TSaslServerTransport>>());
+    private Map<String, TSaslServerDefinition> serverDefinitionMap =
+        new HashMap<String, TSaslServerDefinition>();
+
+    public void addServerDefinition(String mechanism, String protocol, String serverName,
+        Map<String, String> props, CallbackHandler cbh) {
+      serverDefinitionMap.put(mechanism, new TSaslServerDefinition(mechanism, protocol, serverName,
+          props, cbh));
+    }
+
+    @Override
+    public TTransport getTransport(TTransport base) {
+      WeakReference<TSaslServerTransport> ret = transportMap.get(base);
+      TSaslServerTransport transport = ret == null ? null : ret.get();
+      if (transport == null) {
+        LOGGER.debug("transport map does not contain key {}", base);
+        transport = newSaslTransport(base);
+        try {
+          transport.open();
+        } catch (TTransportException e) {
+          LOGGER.debug("failed to open server transport", e);
+          throw new RuntimeException(e);
+        }
+        transportMap.put(base, new WeakReference<TSaslServerTransport>(transport));
+      } else {
+        LOGGER.debug("transport map does contain key {}", base);
+      }
+      return transport;
+    }
+
+    private TSaslServerTransport newSaslTransport(final TTransport base) {
+      // Anonymous subclass of TSaslServerTransport. TSaslServerTransport#recieveSaslMessage
+      // is replaced with one that has additional check for the message size.
+      TSaslServerTransport transport = new TSaslServerTransport(base) {
+        private final byte[] messageHeader = new byte[STATUS_BYTES + PAYLOAD_LENGTH_BYTES];
+
+        @Override
+        protected SaslResponse receiveSaslMessage() throws TTransportException {
+          underlyingTransport.readAll(messageHeader, 0, messageHeader.length);
+          byte statusByte = messageHeader[0];
+          int length = EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES);
+          if (length > saslMessageLimit) {
+            base.close();
+            throw new TTransportException("Sasl message is too big (" + length + " bytes).
"
+                + "The peer connection is possibly using a protocol other than thrift.");
+          }
+          byte[] payload = new byte[length];
+          underlyingTransport.readAll(payload, 0, payload.length);
+          NegotiationStatus status = NegotiationStatus.byValue(statusByte);
+          if (status == null) {
+            sendAndThrowMessage(NegotiationStatus.ERROR, "Invalid status " + statusByte);
+          } else if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR)
{
+            try {
+              String remoteMessage = new String(payload, "UTF-8");
+              throw new TTransportException("Peer indicated failure: " + remoteMessage);
+            } catch (UnsupportedEncodingException e) {
+              throw new TTransportException(e);
+            }
+          }
+          if (LOGGER.isDebugEnabled())
+            LOGGER.debug(getRole() + ": Received message with status {} and payload length
{}",
+                status, payload.length);
+          return new SaslResponse(status, payload);
+        }
+      };
+      for (Map.Entry<String, TSaslServerDefinition> entry : serverDefinitionMap.entrySet())
{
+        TSaslServerDefinition definition = entry.getValue();
+        transport.addServerDefinition(entry.getKey(), definition.protocol, definition.serverName,
+            definition.props, definition.cbh);
+      }
+      return transport;
+    }
+  }
 }
 



Mime
View raw message