hadoop-common-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From da...@apache.org
Subject svn commit: r1508090 - in /hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common: ./ src/main/java/org/apache/hadoop/ipc/ src/main/java/org/apache/hadoop/security/ src/test/java/org/apache/hadoop/ipc/
Date Mon, 29 Jul 2013 14:52:10 GMT
Author: daryn
Date: Mon Jul 29 14:52:10 2013
New Revision: 1508090

URL: http://svn.apache.org/r1508090
Log:
merge -c 1508086 FIXES: HADOOP-9698. [RPC v9] Client must honor server's SASL negotiate response (daryn)

Modified:
    hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/CHANGES.txt
    hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
    hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java
    hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java
    hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestSaslRPC.java

Modified: hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/CHANGES.txt
URL: http://svn.apache.org/viewvc/hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/CHANGES.txt?rev=1508090&r1=1508089&r2=1508090&view=diff
==============================================================================
--- hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/CHANGES.txt (original)
+++ hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/CHANGES.txt Mon Jul 29 14:52:10 2013
@@ -43,6 +43,8 @@ Release 2.1.0-beta - 2013-07-02
 
     HADOOP-9683. [RPC v9] Wrap IpcConnectionContext in RPC headers (daryn)
 
+    HADOOP-9698. [RPC v9] Client must honor server's SASL negotiate response (daryn)
+
   NEW FEATURES
 
     HADOOP-9283. Add support for running the Hadoop client on AIX. (atm)

Modified: hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
URL: http://svn.apache.org/viewvc/hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java?rev=1508090&r1=1508089&r2=1508090&view=diff
==============================================================================
--- hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java (original)
+++ hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java Mon Jul 29 14:52:10 2013
@@ -80,11 +80,6 @@ import org.apache.hadoop.security.SaslRp
 import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
 import org.apache.hadoop.security.SecurityUtil;
 import org.apache.hadoop.security.UserGroupInformation;
-import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
-import org.apache.hadoop.security.token.Token;
-import org.apache.hadoop.security.token.TokenIdentifier;
-import org.apache.hadoop.security.token.TokenInfo;
-import org.apache.hadoop.security.token.TokenSelector;
 import org.apache.hadoop.util.ProtoUtil;
 import org.apache.hadoop.util.ReflectionUtils;
 import org.apache.hadoop.util.StringUtils;
@@ -313,10 +308,9 @@ public class Client {
    * socket: responses may be delivered out of order. */
   private class Connection extends Thread {
     private InetSocketAddress server;             // server ip:port
-    private String serverPrincipal;  // server's krb5 principal name
     private final ConnectionId remoteId;                // connection id
     private AuthMethod authMethod; // authentication method
-    private Token<? extends TokenIdentifier> token;
+    private AuthProtocol authProtocol;
     private int serviceClass;
     private SaslRpcClient saslRpcClient;
     
@@ -363,45 +357,11 @@ public class Client {
       }
 
       UserGroupInformation ticket = remoteId.getTicket();
-      Class<?> protocol = remoteId.getProtocol();
-      if (protocol != null) {
-        TokenInfo tokenInfo = SecurityUtil.getTokenInfo(protocol, conf);
-        if (tokenInfo != null) {
-          TokenSelector<? extends TokenIdentifier> tokenSelector = null;
-          try {
-            tokenSelector = tokenInfo.value().newInstance();
-          } catch (InstantiationException e) {
-            throw new IOException(e.toString());
-          } catch (IllegalAccessException e) {
-            throw new IOException(e.toString());
-          }
-          token = tokenSelector.selectToken(
-              SecurityUtil.buildTokenService(server),
-              ticket.getTokens());
-        }
-        KerberosInfo krbInfo = SecurityUtil.getKerberosInfo(protocol, conf);
-        if (krbInfo != null) {
-          serverPrincipal = remoteId.getServerPrincipal();
-          if (LOG.isDebugEnabled()) {
-            LOG.debug("RPC Server's Kerberos principal name for protocol="
-                + protocol.getCanonicalName() + " is " + serverPrincipal);
-          }
-        }
-      }
-      
-      AuthenticationMethod authentication;
-      if (token != null) {
-        authentication = AuthenticationMethod.TOKEN;
-      } else if (ticket != null) {
-        authentication = ticket.getRealAuthenticationMethod();
-      } else { // this only happens in lazy tests
-        authentication = AuthenticationMethod.SIMPLE;
-      }
-      authMethod = authentication.getAuthMethod();
-      
-      if (LOG.isDebugEnabled())
-        LOG.debug("Use " + authMethod + " authentication for protocol "
-            + (protocol == null? null: protocol.getSimpleName()));
+      // try SASL if security is enabled or if the ugi contains tokens.
+      // this causes a SIMPLE client with tokens to attempt SASL
+      boolean trySasl = UserGroupInformation.isSecurityEnabled() ||
+                        (ticket != null && !ticket.getTokens().isEmpty());
+      this.authProtocol = trySasl ? AuthProtocol.SASL : AuthProtocol.NONE;
       
       this.setName("IPC Client (" + socketFactory.hashCode() +") connection to " +
           server.toString() +
@@ -512,11 +472,10 @@ public class Client {
       return false;
     }
     
-    private synchronized boolean setupSaslConnection(final InputStream in2, 
-        final OutputStream out2) 
-        throws IOException {
-      saslRpcClient = new SaslRpcClient(authMethod, token, serverPrincipal,
-          fallbackAllowed);
+    private synchronized AuthMethod setupSaslConnection(final InputStream in2, 
+        final OutputStream out2) throws IOException, InterruptedException {
+      saslRpcClient = new SaslRpcClient(remoteId.getTicket(),
+          remoteId.getProtocol(), remoteId.getAddress(), conf);
       return saslRpcClient.saslConnect(in2, out2);
     }
 
@@ -554,7 +513,8 @@ public class Client {
            * client, to ensure Server matching address of the client connection
            * to host name in principal passed.
            */
-          if (UserGroupInformation.isSecurityEnabled()) {
+          UserGroupInformation ticket = remoteId.getTicket();
+          if (ticket != null && ticket.hasKerberosCredentials()) {
             KerberosInfo krbInfo = 
               remoteId.getProtocol().getAnnotation(KerberosInfo.class);
             if (krbInfo != null && krbInfo.clientPrincipal() != null) {
@@ -632,7 +592,7 @@ public class Client {
             } else {
               String msg = "Couldn't setup connection for "
                   + UserGroupInformation.getLoginUser().getUserName() + " to "
-                  + serverPrincipal;
+                  + remoteId;
               LOG.warn(msg);
               throw (IOException) new IOException(msg).initCause(ex);
             }
@@ -668,19 +628,19 @@ public class Client {
           InputStream inStream = NetUtils.getInputStream(socket);
           OutputStream outStream = NetUtils.getOutputStream(socket);
           writeConnectionHeader(outStream);
-          if (authMethod != AuthMethod.SIMPLE) {
+          if (authProtocol == AuthProtocol.SASL) {
             final InputStream in2 = inStream;
             final OutputStream out2 = outStream;
             UserGroupInformation ticket = remoteId.getTicket();
             if (ticket.getRealUser() != null) {
               ticket = ticket.getRealUser();
             }
-            boolean continueSasl = false;
             try {
-              continueSasl = ticket
-                  .doAs(new PrivilegedExceptionAction<Boolean>() {
+              authMethod = ticket
+                  .doAs(new PrivilegedExceptionAction<AuthMethod>() {
                     @Override
-                    public Boolean run() throws IOException {
+                    public AuthMethod run()
+                        throws IOException, InterruptedException {
                       return setupSaslConnection(in2, out2);
                     }
                   });
@@ -692,13 +652,15 @@ public class Client {
                   ticket);
               continue;
             }
-            if (continueSasl) {
+            if (authMethod != AuthMethod.SIMPLE) {
               // Sasl connect is successful. Let's set up Sasl i/o streams.
               inStream = saslRpcClient.getInputStream(inStream);
               outStream = saslRpcClient.getOutputStream(outStream);
-            } else {
-              // fall back to simple auth because server told us so.
-              authMethod = AuthMethod.SIMPLE;
+            } else if (UserGroupInformation.isSecurityEnabled() &&
+                       !fallbackAllowed) {
+              throw new IOException("Server asks us to fall back to SIMPLE " +
+                  "auth, but this client is configured to only allow secure " +
+                  "connections.");
             }
           }
         
@@ -818,14 +780,6 @@ public class Client {
       out.write(RpcConstants.HEADER.array());
       out.write(RpcConstants.CURRENT_VERSION);
       out.write(serviceClass);
-      final AuthProtocol authProtocol;
-      switch (authMethod) {
-        case SIMPLE:
-          authProtocol = AuthProtocol.NONE;
-          break;
-        default:
-          authProtocol = AuthProtocol.SASL;
-      }
       out.write(authProtocol.callId);
       out.flush();
     }
@@ -1434,7 +1388,6 @@ public class Client {
     final Class<?> protocol;
     private static final int PRIME = 16777619;
     private final int rpcTimeout;
-    private final String serverPrincipal;
     private final int maxIdleTime; //connections will be culled if it was idle for 
     //maxIdleTime msecs
     private final RetryPolicy connectionRetryPolicy;
@@ -1445,15 +1398,13 @@ public class Client {
     private final int pingInterval; // how often sends ping to the server in msecs
     
     ConnectionId(InetSocketAddress address, Class<?> protocol, 
-                 UserGroupInformation ticket, int rpcTimeout,
-                 String serverPrincipal, int maxIdleTime, 
+                 UserGroupInformation ticket, int rpcTimeout, int maxIdleTime, 
                  RetryPolicy connectionRetryPolicy, int maxRetriesOnSocketTimeouts,
                  boolean tcpNoDelay, boolean doPing, int pingInterval) {
       this.protocol = protocol;
       this.address = address;
       this.ticket = ticket;
       this.rpcTimeout = rpcTimeout;
-      this.serverPrincipal = serverPrincipal;
       this.maxIdleTime = maxIdleTime;
       this.connectionRetryPolicy = connectionRetryPolicy;
       this.maxRetriesOnSocketTimeouts = maxRetriesOnSocketTimeouts;
@@ -1478,10 +1429,6 @@ public class Client {
       return rpcTimeout;
     }
     
-    String getServerPrincipal() {
-      return serverPrincipal;
-    }
-    
     int getMaxIdleTime() {
       return maxIdleTime;
     }
@@ -1531,11 +1478,9 @@ public class Client {
             max, 1, TimeUnit.SECONDS);
       }
 
-      String remotePrincipal = getRemotePrincipal(conf, addr, protocol);
       boolean doPing =
         conf.getBoolean(CommonConfigurationKeys.IPC_CLIENT_PING_KEY, true);
-      return new ConnectionId(addr, protocol, ticket,
-          rpcTimeout, remotePrincipal,
+      return new ConnectionId(addr, protocol, ticket, rpcTimeout,
           conf.getInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_KEY,
               CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_DEFAULT),
           connectionRetryPolicy,
@@ -1548,25 +1493,6 @@ public class Client {
           (doPing ? Client.getPingInterval(conf) : 0));
     }
     
-    private static String getRemotePrincipal(Configuration conf,
-        InetSocketAddress address, Class<?> protocol) throws IOException {
-      if (!UserGroupInformation.isSecurityEnabled() || protocol == null) {
-        return null;
-      }
-      KerberosInfo krbInfo = SecurityUtil.getKerberosInfo(protocol, conf);
-      if (krbInfo != null) {
-        String serverKey = krbInfo.serverPrincipal();
-        if (serverKey == null) {
-          throw new IOException(
-              "Can't obtain server Kerberos config key from protocol="
-                  + protocol.getCanonicalName());
-        }
-        return SecurityUtil.getServerPrincipal(conf.get(serverKey), address
-            .getAddress());
-      }
-      return null;
-    }
-    
     static boolean isEqual(Object a, Object b) {
       return a == null ? b == null : a.equals(b);
     }
@@ -1585,7 +1511,6 @@ public class Client {
             && this.pingInterval == that.pingInterval
             && isEqual(this.protocol, that.protocol)
             && this.rpcTimeout == that.rpcTimeout
-            && isEqual(this.serverPrincipal, that.serverPrincipal)
             && this.tcpNoDelay == that.tcpNoDelay
             && isEqual(this.ticket, that.ticket);
       }
@@ -1601,8 +1526,6 @@ public class Client {
       result = PRIME * result + pingInterval;
       result = PRIME * result + ((protocol == null) ? 0 : protocol.hashCode());
       result = PRIME * result + rpcTimeout;
-      result = PRIME * result
-          + ((serverPrincipal == null) ? 0 : serverPrincipal.hashCode());
       result = PRIME * result + (tcpNoDelay ? 1231 : 1237);
       result = PRIME * result + ((ticket == null) ? 0 : ticket.hashCode());
       return result;
@@ -1610,7 +1533,7 @@ public class Client {
     
     @Override
     public String toString() {
-      return serverPrincipal + "@" + address;
+      return address.toString();
     }
   }  
 

Modified: hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java
URL: http://svn.apache.org/viewvc/hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java?rev=1508090&r1=1508089&r2=1508090&view=diff
==============================================================================
--- hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java (original)
+++ hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java Mon Jul 29 14:52:10 2013
@@ -85,6 +85,7 @@ import org.apache.hadoop.ipc.protobuf.Rp
 import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcErrorCodeProto;
 import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto.RpcStatusProto;
 import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcSaslProto;
+import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcSaslProto.SaslAuth;
 import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcSaslProto.SaslState;
 import org.apache.hadoop.net.NetUtils;
 import org.apache.hadoop.security.AccessControlException;
@@ -792,7 +793,10 @@ public abstract class Server {
         LOG.info(getName() + ": readAndProcess caught InterruptedException", ieo);
         throw ieo;
       } catch (Exception e) {
-        // log stack trace for "interesting" exceptions not sent to client
+        // a WrappedRpcServerException is an exception that has been sent
+        // to the client, so the stacktrace is unnecessary; any other
+        // exceptions are unexpected internal server errors and thus the
+        // stacktrace should be logged
         LOG.info(getName() + ": readAndProcess from client " +
             c.getHostAddress() + " threw exception [" + e + "]",
             (e instanceof WrappedRpcServerException) ? null : e);
@@ -1161,7 +1165,6 @@ public abstract class Server {
     private AuthMethod authMethod;
     private AuthProtocol authProtocol;
     private boolean saslContextEstablished;
-    private boolean skipInitialSaslHandshake;
     private ByteBuffer connectionHeaderBuf = null;
     private ByteBuffer unwrappedData;
     private ByteBuffer unwrappedDataLengthBuffer;
@@ -1337,23 +1340,39 @@ public abstract class Server {
                 "Client already attempted negotiation");
           }
           saslResponse = buildSaslNegotiateResponse();
+          // simple-only server negotiate response is success which client
+          // interprets as switch to simple
+          if (saslResponse.getState() == SaslState.SUCCESS) {
+            switchToSimple();
+          }
           break;
         }
         case INITIATE: {
           if (saslMessage.getAuthsCount() != 1) {
             throw new SaslException("Client mechanism is malformed");
           }
-          String authMethodName = saslMessage.getAuths(0).getMethod();
-          authMethod = createSaslServer(authMethodName);
-          if (authMethod == null) { // the auth method is not supported
+          // verify the client requested an advertised authType
+          SaslAuth clientSaslAuth = saslMessage.getAuths(0);
+          if (!negotiateResponse.getAuthsList().contains(clientSaslAuth)) {
             if (sentNegotiate) {
               throw new AccessControlException(
-                  authMethodName + " authentication is not enabled."
+                  clientSaslAuth.getMethod() + " authentication is not enabled."
                       + "  Available:" + enabledAuthMethods);
             }
             saslResponse = buildSaslNegotiateResponse();
             break;
           }
+          authMethod = AuthMethod.valueOf(clientSaslAuth.getMethod());
+          // abort SASL for SIMPLE auth, server has already ensured that
+          // SIMPLE is a legit option above.  we will send no response
+          if (authMethod == AuthMethod.SIMPLE) {
+            switchToSimple();
+            break;
+          }
+          // sasl server for tokens may already be instantiated
+          if (saslServer == null || authMethod != AuthMethod.TOKEN) {
+            saslServer = createSaslServer(authMethod);
+          }
           // fallthru to process sasl token
         }
         case RESPONSE: {
@@ -1376,6 +1395,12 @@ public abstract class Server {
       }
       return saslResponse;
     }
+
+    private void switchToSimple() {
+      // disable SASL and blank out any SASL server
+      authProtocol = AuthProtocol.NONE;
+      saslServer = null;
+    }
     
     private RpcSaslProto buildSaslResponse(SaslState state, byte[] replyToken) {
       if (LOG.isDebugEnabled()) {
@@ -1432,7 +1457,8 @@ public abstract class Server {
       }
     }
 
-    public int readAndProcess() throws IOException, InterruptedException {
+    public int readAndProcess()
+        throws WrappedRpcServerException, IOException, InterruptedException {
       while (true) {
         /* Read at most one RPC. If the header is not read completely yet
          * then iterate until we read first RPC or until there is no data left.
@@ -1535,15 +1561,7 @@ public abstract class Server {
           }
           break;
         }
-        case SASL: {
-          // switch to simple hack, but don't switch if other auths are
-          // supported, ex. tokens
-          if (isSimpleEnabled && enabledAuthMethods.size() == 1) {
-            authProtocol = AuthProtocol.NONE;
-            skipInitialSaslHandshake = true;
-            doSaslReply(buildSaslResponse(SaslState.SUCCESS, null));
-          }
-          // else wait for a negotiate or initiate
+        default: {
           break;
         }
       }
@@ -1568,25 +1586,6 @@ public abstract class Server {
       return negotiateMessage;
     }
     
-    private AuthMethod createSaslServer(String authMethodName)
-        throws IOException, InterruptedException {
-      AuthMethod authMethod;
-      try {
-        authMethod = AuthMethod.valueOf(authMethodName);
-        if (!enabledAuthMethods.contains(authMethod)) {
-          authMethod = null;
-        }
-      } catch (IllegalArgumentException iae) {
-        authMethod = null;
-      }
-      if (authMethod != null &&
-          // sasl server for tokens may already be instantiated
-          (saslServer == null || authMethod != AuthMethod.TOKEN)) {
-        saslServer = createSaslServer(authMethod);
-      }
-      return authMethod;
-    }
-
     private SaslServer createSaslServer(AuthMethod authMethod)
         throws IOException, InterruptedException {
       return new SaslRpcServer(authMethod).create(this, secretManager);
@@ -1701,8 +1700,8 @@ public abstract class Server {
      *   or the request could not be decoded into a Call
      * @throws InterruptedException
      */    
-    private void processRpcRequestPacket(byte[] buf) throws IOException,
-        InterruptedException {
+    private void processRpcRequestPacket(byte[] buf)
+        throws WrappedRpcServerException, IOException, InterruptedException {
       if (saslContextEstablished && useWrap) {
         if (LOG.isDebugEnabled())
           LOG.debug("Have read input token of size " + buf.length
@@ -1715,8 +1714,8 @@ public abstract class Server {
       }
     }
     
-    private void unwrapPacketAndProcessRpcs(byte[] inBuf) throws IOException,
-        InterruptedException {
+    private void unwrapPacketAndProcessRpcs(byte[] inBuf)
+        throws WrappedRpcServerException, IOException, InterruptedException {
       ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(
           inBuf));
       // Read all RPCs contained in the inBuf, even partial ones
@@ -1901,13 +1900,9 @@ public abstract class Server {
       } else if (callId == AuthProtocol.SASL.callId) {
         // if client was switched to simple, ignore first SASL message
         if (authProtocol != AuthProtocol.SASL) {
-          if (!skipInitialSaslHandshake) {
-            throw new WrappedRpcServerException(
-                RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
-                "SASL protocol not requested by client");
-          }
-          skipInitialSaslHandshake = false;
-          return;
+          throw new WrappedRpcServerException(
+              RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER,
+              "SASL protocol not requested by client");
         }
         RpcSaslProto response = saslReadAndProcess(dis);
         // send back response if any, may throw IOException
@@ -2218,17 +2213,23 @@ public abstract class Server {
   private RpcSaslProto buildNegotiateResponse(List<AuthMethod> authMethods)
       throws IOException {
     RpcSaslProto.Builder negotiateBuilder = RpcSaslProto.newBuilder();
-    negotiateBuilder.setState(SaslState.NEGOTIATE);
-    for (AuthMethod authMethod : authMethods) {
-      if (authMethod == AuthMethod.SIMPLE) { // not a SASL method
-        continue;
-      }
-      SaslRpcServer saslRpcServer = new SaslRpcServer(authMethod);      
-      negotiateBuilder.addAuthsBuilder()
-          .setMethod(authMethod.toString())
-          .setMechanism(saslRpcServer.mechanism)
-          .setProtocol(saslRpcServer.protocol)
-          .setServerId(saslRpcServer.serverId);
+    if (authMethods.contains(AuthMethod.SIMPLE) && authMethods.size() == 1) {
+      // SIMPLE-only servers return success in response to negotiate
+      negotiateBuilder.setState(SaslState.SUCCESS);
+    } else {
+      negotiateBuilder.setState(SaslState.NEGOTIATE);
+      for (AuthMethod authMethod : authMethods) {
+        SaslRpcServer saslRpcServer = new SaslRpcServer(authMethod);      
+        SaslAuth.Builder builder = negotiateBuilder.addAuthsBuilder()
+            .setMethod(authMethod.toString())
+            .setMechanism(saslRpcServer.mechanism);
+        if (saslRpcServer.protocol != null) {
+          builder.setProtocol(saslRpcServer.protocol);
+        }
+        if (saslRpcServer.serverId != null) {
+          builder.setServerId(saslRpcServer.serverId);
+        }
+      }
     }
     return negotiateBuilder.build();
   }

Modified: hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java
URL: http://svn.apache.org/viewvc/hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java?rev=1508090&r1=1508089&r2=1508090&view=diff
==============================================================================
--- hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java (original)
+++ hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java Mon Jul 29 14:52:10 2013
@@ -25,6 +25,9 @@ import java.io.DataOutputStream;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Map;
 
 import javax.security.auth.callback.Callback;
@@ -32,6 +35,7 @@ import javax.security.auth.callback.Call
 import javax.security.auth.callback.NameCallback;
 import javax.security.auth.callback.PasswordCallback;
 import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.kerberos.KerberosPrincipal;
 import javax.security.sasl.RealmCallback;
 import javax.security.sasl.RealmChoiceCallback;
 import javax.security.sasl.Sasl;
@@ -42,6 +46,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.classification.InterfaceStability;
+import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcRequestMessageWrapper;
 import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseMessageWrapper;
 import org.apache.hadoop.ipc.RPC.RpcKind;
@@ -58,6 +63,8 @@ import org.apache.hadoop.security.SaslRp
 import org.apache.hadoop.security.authentication.util.KerberosName;
 import org.apache.hadoop.security.token.Token;
 import org.apache.hadoop.security.token.TokenIdentifier;
+import org.apache.hadoop.security.token.TokenInfo;
+import org.apache.hadoop.security.token.TokenSelector;
 import org.apache.hadoop.util.ProtoUtil;
 
 import com.google.protobuf.ByteString;
@@ -69,9 +76,13 @@ import com.google.protobuf.ByteString;
 public class SaslRpcClient {
   public static final Log LOG = LogFactory.getLog(SaslRpcClient.class);
 
-  private final AuthMethod authMethod;
-  private final SaslClient saslClient;
-  private final boolean fallbackAllowed;
+  private final UserGroupInformation ugi;
+  private final Class<?> protocol;
+  private final InetSocketAddress serverAddr;  
+  private final Configuration conf;
+
+  private SaslClient saslClient;
+  
   private static final RpcRequestHeaderProto saslHeader = ProtoUtil
       .makeRpcRequestHeader(RpcKind.RPC_PROTOCOL_BUFFER,
           OperationProto.RPC_FINAL_PACKET, AuthProtocol.SASL.callId,
@@ -80,44 +91,121 @@ public class SaslRpcClient {
       RpcSaslProto.newBuilder().setState(SaslState.NEGOTIATE).build();
   
   /**
-   * Create a SaslRpcClient for an authentication method
+   * Create a SaslRpcClient that can be used by a RPC client to negotiate
+   * SASL authentication with a RPC server
+   * @param ugi - connecting user
+   * @param protocol - RPC protocol
+   * @param serverAddr - InetSocketAddress of remote server
+   * @param conf - Configuration
+   */
+  public SaslRpcClient(UserGroupInformation ugi, Class<?> protocol,
+      InetSocketAddress serverAddr, Configuration conf) {
+    this.ugi = ugi;
+    this.protocol = protocol;
+    this.serverAddr = serverAddr;
+    this.conf = conf;
+  }
+  
+  /**
+   * Instantiate a sasl client for the first supported auth type in the
+   * given list.  The auth type must be defined, enabled, and the user
+   * must possess the required credentials, else the next auth is tried.
    * 
-   * @param method
-   *          the requested authentication method
-   * @param token
-   *          token to use if needed by the authentication method
+   * @param authTypes to attempt in the given order
+   * @return SaslAuth of instantiated client
+   * @throws AccessControlException - client doesn't support any of the auths
+   * @throws IOException - misc errors
    */
-  public SaslRpcClient(AuthMethod method,
-      Token<? extends TokenIdentifier> token, String serverPrincipal,
-      boolean fallbackAllowed)
-      throws IOException {
-    this.authMethod = method;
-    this.fallbackAllowed = fallbackAllowed;
+  private SaslAuth selectSaslClient(List<SaslAuth> authTypes)
+      throws SaslException, AccessControlException, IOException {
+    SaslAuth selectedAuthType = null;
+    boolean switchToSimple = false;
+    for (SaslAuth authType : authTypes) {
+      if (!isValidAuthType(authType)) {
+        continue; // don't know what it is, try next
+      }
+      AuthMethod authMethod = AuthMethod.valueOf(authType.getMethod());
+      if (authMethod == AuthMethod.SIMPLE) {
+        switchToSimple = true;
+      } else {
+        saslClient = createSaslClient(authType);
+        if (saslClient == null) { // client lacks credentials, try next
+          continue;
+        }
+      }
+      selectedAuthType = authType;
+      break;
+    }
+    if (saslClient == null && !switchToSimple) {
+      List<String> serverAuthMethods = new ArrayList<String>();
+      for (SaslAuth authType : authTypes) {
+        serverAuthMethods.add(authType.getMethod());
+      }
+      throw new AccessControlException(
+          "Client cannot authenticate via:" + serverAuthMethods);
+    }
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Use " + selectedAuthType.getMethod() +
+          " authentication for protocol " + protocol.getSimpleName());
+    }
+    return selectedAuthType;
+  }
+  
+
+  private boolean isValidAuthType(SaslAuth authType) {
+    AuthMethod authMethod;
+    try {
+      authMethod = AuthMethod.valueOf(authType.getMethod());
+    } catch (IllegalArgumentException iae) { // unknown auth
+      authMethod = null;
+    }
+    // do we know what it is?  is it using our mechanism?
+    return authMethod != null &&
+           authMethod.getMechanismName().equals(authType.getMechanism());
+  }  
+  
+  /**
+   * Try to create a SaslClient for an authentication type.  May return
+   * null if the type isn't supported or the client lacks the required
+   * credentials.
+   * 
+   * @param authType - the requested authentication method
+   * @return SaslClient for the authType or null
+   * @throws SaslException - error instantiating client
+   * @throws IOException - misc errors
+   */
+  private SaslClient createSaslClient(SaslAuth authType)
+      throws SaslException, IOException {
     String saslUser = null;
-    String saslProtocol = null;
-    String saslServerName = null;
+    // SASL requires the client and server to use the same proto and serverId
+    // if necessary, auth types below will verify they are valid
+    final String saslProtocol = authType.getProtocol();
+    final String saslServerName = authType.getServerId();
     Map<String, String> saslProperties = SaslRpcServer.SASL_PROPS;
     CallbackHandler saslCallback = null;
     
+    final AuthMethod method = AuthMethod.valueOf(authType.getMethod());
     switch (method) {
       case TOKEN: {
-        saslProtocol = "";
-        saslServerName = SaslRpcServer.SASL_DEFAULT_REALM;
+        Token<?> token = getServerToken(authType);
+        if (token == null) {
+          return null; // tokens aren't supported or user doesn't have one
+        }
         saslCallback = new SaslClientCallbackHandler(token);
         break;
       }
       case KERBEROS: {
-        if (serverPrincipal == null || serverPrincipal.isEmpty()) {
-          throw new IOException(
-              "Failed to specify server's Kerberos principal name");
+        if (ugi.getRealAuthenticationMethod().getAuthMethod() !=
+            AuthMethod.KERBEROS) {
+          return null; // client isn't using kerberos
+        }
+        String serverPrincipal = getServerPrincipal(authType);
+        if (serverPrincipal == null) {
+          return null; // protocol doesn't use kerberos
         }
-        KerberosName name = new KerberosName(serverPrincipal);
-        saslProtocol = name.getServiceName();
-        saslServerName = name.getHostName();
-        if (saslServerName == null) {
-          throw new IOException(
-              "Kerberos principal name does NOT have the expected hostname part: "
-                  + serverPrincipal);
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("RPC Server's Kerberos principal name for protocol="
+              + protocol.getCanonicalName() + " is " + serverPrincipal);
         }
         break;
       }
@@ -127,16 +215,85 @@ public class SaslRpcClient {
     
     String mechanism = method.getMechanismName();
     if (LOG.isDebugEnabled()) {
-      LOG.debug("Creating SASL " + mechanism + "(" + authMethod + ") "
+      LOG.debug("Creating SASL " + mechanism + "(" + method + ") "
           + " client to authenticate to service at " + saslServerName);
     }
-    saslClient = Sasl.createSaslClient(
+    return Sasl.createSaslClient(
         new String[] { mechanism }, saslUser, saslProtocol, saslServerName,
         saslProperties, saslCallback);
-    if (saslClient == null) {
-      throw new IOException("Unable to find SASL client implementation");
+  }
+  
+  /**
+   * Try to locate the required token for the server.
+   * 
+   * @param authType of the SASL client
+   * @return Token<?> for server, or null if no token available
+   * @throws IOException - token selector cannot be instantiated
+   */
+  private Token<?> getServerToken(SaslAuth authType) throws IOException {
+    TokenInfo tokenInfo = SecurityUtil.getTokenInfo(protocol, conf);
+    LOG.debug("Get token info proto:"+protocol+" info:"+tokenInfo);
+    if (tokenInfo == null) { // protocol has no support for tokens
+      return null;
+    }
+    TokenSelector<?> tokenSelector = null;
+    try {
+      tokenSelector = tokenInfo.value().newInstance();
+    } catch (InstantiationException e) {
+      throw new IOException(e.toString());
+    } catch (IllegalAccessException e) {
+      throw new IOException(e.toString());
     }
+    return tokenSelector.selectToken(
+        SecurityUtil.buildTokenService(serverAddr), ugi.getTokens());
   }
+  
+  /**
+   * Get the remote server's principal.  The value will be obtained from
+   * the config and cross-checked against the server's advertised principal.
+   * 
+   * @param authType of the SASL client
+   * @return String of the server's principal
+   * @throws IOException - error determining configured principal
+   */
+
+  // try to get the configured principal for the remote server
+  private String getServerPrincipal(SaslAuth authType) throws IOException {
+    KerberosInfo krbInfo = SecurityUtil.getKerberosInfo(protocol, conf);
+    LOG.debug("Get kerberos info proto:"+protocol+" info:"+krbInfo);
+    if (krbInfo == null) { // protocol has no support for kerberos
+      return null;
+    }
+    String serverKey = krbInfo.serverPrincipal();
+    if (serverKey == null) {
+      throw new IllegalArgumentException(
+          "Can't obtain server Kerberos config key from protocol="
+              + protocol.getCanonicalName());
+    }
+    // construct the expected principal from the config
+    String confPrincipal = SecurityUtil.getServerPrincipal(
+        conf.get(serverKey), serverAddr.getAddress());
+    if (confPrincipal == null || confPrincipal.isEmpty()) {
+      throw new IllegalArgumentException(
+          "Failed to specify server's Kerberos principal name");
+    }
+    // ensure it looks like a host-based service principal
+    KerberosName name = new KerberosName(confPrincipal);
+    if (name.getHostName() == null) {
+      throw new IllegalArgumentException(
+          "Kerberos principal name does NOT have the expected hostname part: "
+              + confPrincipal);
+    }
+    // check that the server advertised principal matches our conf
+    KerberosPrincipal serverPrincipal = new KerberosPrincipal(
+        authType.getProtocol() + "/" + authType.getServerId());
+    if (!serverPrincipal.getName().equals(confPrincipal)) {
+      throw new IllegalArgumentException(
+          "Server has invalid Kerberos principal: " + serverPrincipal);
+    }
+    return confPrincipal;
+  }
+  
 
   /**
    * Do client side SASL authentication with server via the given InputStream
@@ -146,18 +303,18 @@ public class SaslRpcClient {
    *          InputStream to use
    * @param outS
    *          OutputStream to use
-   * @return true if connection is set up, or false if needs to switch 
-   *             to simple Auth.
+   * @return AuthMethod used to negotiate the connection
    * @throws IOException
    */
-  public boolean saslConnect(InputStream inS, OutputStream outS)
+  public AuthMethod saslConnect(InputStream inS, OutputStream outS)
       throws IOException {
     DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS));
     DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream(
         outS));
     
-    // track if SASL ever started, or server switched us to simple
-    boolean inSasl = false;
+    // redefined if/when a SASL negotiation completes
+    AuthMethod authMethod = AuthMethod.SIMPLE;
+    
     sendSaslMessage(outStream, negotiateRequest);
     
     // loop until sasl is complete or a rpc error occurs
@@ -191,50 +348,48 @@ public class SaslRpcClient {
       RpcSaslProto.Builder response = null;
       switch (saslMessage.getState()) {
         case NEGOTIATE: {
-          inSasl = true;
-          // TODO: should instantiate sasl client based on advertisement
-          // but just blindly use the pre-instantiated sasl client for now
-          String clientAuthMethod = authMethod.toString();
-          SaslAuth saslAuthType = null;
-          for (SaslAuth authType : saslMessage.getAuthsList()) {
-            if (clientAuthMethod.equals(authType.getMethod())) {
-              saslAuthType = authType;
-              break;
-            }
-          }
-          if (saslAuthType == null) {
-            saslAuthType = SaslAuth.newBuilder()
-                .setMethod(clientAuthMethod)
-                .setMechanism(saslClient.getMechanismName())
-                .build();
-          }
+          // create a compatible SASL client, throws if no supported auths
+          SaslAuth saslAuthType = selectSaslClient(saslMessage.getAuthsList());
+          authMethod = AuthMethod.valueOf(saslAuthType.getMethod());
           
-          byte[] challengeToken = null;
-          if (saslAuthType != null && saslAuthType.hasChallenge()) {
-            // server provided the first challenge
-            challengeToken = saslAuthType.getChallenge().toByteArray();
-            saslAuthType =
-              SaslAuth.newBuilder(saslAuthType).clearChallenge().build();
-          } else if (saslClient.hasInitialResponse()) {
-            challengeToken = new byte[0];
+          byte[] responseToken = null;
+          if (authMethod == AuthMethod.SIMPLE) { // switching to SIMPLE
+            done = true; // not going to wait for success ack
+          } else {
+            byte[] challengeToken = null;
+            if (saslAuthType.hasChallenge()) {
+              // server provided the first challenge
+              challengeToken = saslAuthType.getChallenge().toByteArray();
+              saslAuthType =
+                  SaslAuth.newBuilder(saslAuthType).clearChallenge().build();
+            } else if (saslClient.hasInitialResponse()) {
+              challengeToken = new byte[0];
+            }
+            responseToken = (challengeToken != null)
+                ? saslClient.evaluateChallenge(challengeToken)
+                    : new byte[0];
           }
-          byte[] responseToken = (challengeToken != null)
-              ? saslClient.evaluateChallenge(challengeToken)
-              : new byte[0];
-          
           response = createSaslReply(SaslState.INITIATE, responseToken);
           response.addAuths(saslAuthType);
           break;
         }
         case CHALLENGE: {
-          inSasl = true;
+          if (saslClient == null) {
+            // should probably instantiate a client to allow a server to
+            // demand a specific negotiation
+            throw new SaslException("Server sent unsolicited challenge");
+          }
           byte[] responseToken = saslEvaluateToken(saslMessage, false);
           response = createSaslReply(SaslState.RESPONSE, responseToken);
           break;
         }
         case SUCCESS: {
-          if (inSasl && saslEvaluateToken(saslMessage, true) != null) {
-            throw new SaslException("SASL client generated spurious token");
+          // simple server sends immediate success to a SASL client for
+          // switch to simple
+          if (saslClient == null) {
+            authMethod = AuthMethod.SIMPLE;
+          } else {
+            saslEvaluateToken(saslMessage, true);
           }
           done = true;
           break;
@@ -248,12 +403,7 @@ public class SaslRpcClient {
         sendSaslMessage(outStream, response.build());
       }
     } while (!done);
-    if (!inSasl && !fallbackAllowed) {
-      throw new IOException("Server asks us to fall back to SIMPLE " +
-          "auth, but this client is configured to only allow secure " +
-          "connections.");
-    }
-    return inSasl;
+    return authMethod;
   }
   
   private void sendSaslMessage(DataOutputStream out, RpcSaslProto message)
@@ -268,17 +418,37 @@ public class SaslRpcClient {
     out.flush();    
   }
   
+  /**
+   * Evaluate the server provided challenge.  The server must send a token
+   * if it's not done.  If the server is done, the challenge token is
+   * optional because not all mechanisms send a final token for the client to
+   * update its internal state.  The client must also be done after
+   * evaluating the optional token to ensure a malicious server doesn't
+   * prematurely end the negotiation with a phony success.
+   *  
+   * @param saslResponse - client response to challenge
+   * @param serverIsDone - server negotiation state
+   * @throws SaslException - any problems with negotiation
+   */
   private byte[] saslEvaluateToken(RpcSaslProto saslResponse,
-      boolean done) throws SaslException {
+      boolean serverIsDone) throws SaslException {
     byte[] saslToken = null;
     if (saslResponse.hasToken()) {
       saslToken = saslResponse.getToken().toByteArray();
       saslToken = saslClient.evaluateChallenge(saslToken);
-    } else if (!done) {
-      throw new SaslException("Challenge contains no token");
-    }
-    if (done && !saslClient.isComplete()) {
-      throw new SaslException("Client is out of sync with server");
+    } else if (!serverIsDone) {
+      // the server may only omit a token when it's done
+      throw new SaslException("Server challenge contains no token");
+    }
+    if (serverIsDone) {
+      // server tried to report success before our client completed
+      if (!saslClient.isComplete()) {
+        throw new SaslException("Client is out of sync with server");
+      }
+      // a client cannot generate a response to a success message
+      if (saslToken != null) {
+        throw new SaslException("Client generated spurious response");        
+      }
     }
     return saslToken;
   }
@@ -327,7 +497,10 @@ public class SaslRpcClient {
 
   /** Release resources used by wrapped saslClient */
   public void dispose() throws SaslException {
-    saslClient.dispose();
+    if (saslClient != null) {
+      saslClient.dispose();
+      saslClient = null;
+    }
   }
 
   private static class SaslClientCallbackHandler implements CallbackHandler {
@@ -377,4 +550,4 @@ public class SaslRpcClient {
       }
     }
   }
-}
+}
\ No newline at end of file

Modified: hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestSaslRPC.java
URL: http://svn.apache.org/viewvc/hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestSaslRPC.java?rev=1508090&r1=1508089&r2=1508090&view=diff
==============================================================================
--- hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestSaslRPC.java (original)
+++ hadoop/common/branches/branch-2.1-beta/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestSaslRPC.java Mon Jul 29 14:52:10 2013
@@ -18,14 +18,9 @@
 
 package org.apache.hadoop.ipc;
 
-import static org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.KERBEROS;
-import static org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.SIMPLE;
-import static org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.TOKEN;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
+import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION;
+import static org.apache.hadoop.security.SaslRpcServer.AuthMethod.*;
+import static org.junit.Assert.*;
 
 import java.io.DataInput;
 import java.io.DataOutput;
@@ -51,6 +46,7 @@ import javax.security.sasl.SaslServer;
 
 import junit.framework.Assert;
 
+import org.apache.commons.lang.StringUtils;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.commons.logging.impl.Log4JLogger;
@@ -103,6 +99,13 @@ public class TestSaslRPC {
   static Boolean forceSecretManager = null;
   static Boolean clientFallBackToSimpleAllowed = true;
   
+  static enum UseToken {
+    NONE(),
+    VALID(),
+    INVALID(),
+    OTHER();
+  }
+  
   @BeforeClass
   public static void setupKerb() {
     System.setProperty("java.security.krb5.kdc", "");
@@ -113,9 +116,11 @@ public class TestSaslRPC {
   @Before
   public void setup() {
     conf = new Configuration();
-    SecurityUtil.setAuthenticationMethod(KERBEROS, conf);
+    conf.set(HADOOP_SECURITY_AUTHENTICATION, KERBEROS.toString());
     UserGroupInformation.setConfiguration(conf);
     enableSecretManager = null;
+    forceSecretManager = null;
+    clientFallBackToSimpleAllowed = true;
   }
 
   static {
@@ -368,28 +373,6 @@ public class TestSaslRPC {
   }
   
   @Test
-  public void testGetRemotePrincipal() throws Exception {
-    try {
-      Configuration newConf = new Configuration(conf);
-      newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1);
-      ConnectionId remoteId = ConnectionId.getConnectionId(
-          new InetSocketAddress(0), TestSaslProtocol.class, null, 0, newConf);
-      assertEquals(SERVER_PRINCIPAL_1, remoteId.getServerPrincipal());
-      // this following test needs security to be off
-      SecurityUtil.setAuthenticationMethod(SIMPLE, newConf);
-      UserGroupInformation.setConfiguration(newConf);
-      remoteId = ConnectionId.getConnectionId(new InetSocketAddress(0),
-          TestSaslProtocol.class, null, 0, newConf);
-      assertEquals(
-          "serverPrincipal should be null when security is turned off", null,
-          remoteId.getServerPrincipal());
-    } finally {
-      // revert back to security is on
-      UserGroupInformation.setConfiguration(conf);
-    }
-  }
-  
-  @Test
   public void testPerConnectionConf() throws Exception {
     TestTokenSecretManager sm = new TestTokenSecretManager();
     final Server server = new RPC.Builder(conf)
@@ -409,12 +392,13 @@ public class TestSaslRPC {
     Configuration newConf = new Configuration(conf);
     newConf.set(CommonConfigurationKeysPublic.
         HADOOP_RPC_SOCKET_FACTORY_CLASS_DEFAULT_KEY, "");
-    newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1);
 
     TestSaslProtocol proxy1 = null;
     TestSaslProtocol proxy2 = null;
     TestSaslProtocol proxy3 = null;
+    int timeouts[] = {111222, 3333333};
     try {
+      newConf.setInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_KEY, timeouts[0]);
       proxy1 = RPC.getProxy(TestSaslProtocol.class,
           TestSaslProtocol.versionID, addr, newConf);
       proxy1.getAuthMethod();
@@ -427,20 +411,21 @@ public class TestSaslRPC {
       proxy2.getAuthMethod();
       assertEquals("number of connections in cache is wrong", 1, conns.size());
       // different conf, new connection should be set up
-      newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2);
+      newConf.setInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_KEY, timeouts[1]);
       proxy3 = RPC.getProxy(TestSaslProtocol.class,
           TestSaslProtocol.versionID, addr, newConf);
       proxy3.getAuthMethod();
-      ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]);
-      assertEquals("number of connections in cache is wrong", 2,
-          connsArray.length);
-      String p1 = connsArray[0].getServerPrincipal();
-      String p2 = connsArray[1].getServerPrincipal();
-      assertFalse("should have different principals", p1.equals(p2));
-      assertTrue("principal not as expected", p1.equals(SERVER_PRINCIPAL_1)
-          || p1.equals(SERVER_PRINCIPAL_2));
-      assertTrue("principal not as expected", p2.equals(SERVER_PRINCIPAL_1)
-          || p2.equals(SERVER_PRINCIPAL_2));
+      assertEquals("number of connections in cache is wrong", 2, conns.size());
+      // now verify the proxies have the correct connection ids and timeouts
+      ConnectionId[] connsArray = {
+          RPC.getConnectionIdForProxy(proxy1),
+          RPC.getConnectionIdForProxy(proxy2),
+          RPC.getConnectionIdForProxy(proxy3)
+      };
+      assertEquals(connsArray[0], connsArray[1]);
+      assertEquals(connsArray[0].getMaxIdleTime(), timeouts[0]);
+      assertFalse(connsArray[0].equals(connsArray[2]));
+      assertNotSame(connsArray[2].getMaxIdleTime(), timeouts[1]);
     } finally {
       server.stop();
       RPC.stopProxy(proxy1);
@@ -599,75 +584,118 @@ public class TestSaslRPC {
   private static Pattern KrbFailed =
       Pattern.compile(".*Failed on local exception:.* " +
                       "Failed to specify server's Kerberos principal name.*");
-  private static Pattern Denied(AuthenticationMethod method) {
+  private static Pattern Denied(AuthMethod method) {
       return Pattern.compile(".*RemoteException.*AccessControlException.*: "
-          +method.getAuthMethod() + " authentication is not enabled.*");
+          + method + " authentication is not enabled.*");
+  }
+  private static Pattern No(AuthMethod ... method) {
+    String methods = StringUtils.join(method, ",\\s*");
+    return Pattern.compile(".*Failed on local exception:.* " +
+        "Client cannot authenticate via:\\[" + methods + "\\].*");
   }
   private static Pattern NoTokenAuth =
       Pattern.compile(".*IllegalArgumentException: " +
                       "TOKEN authentication requires a secret manager");
-  
+  private static Pattern NoFallback = 
+      Pattern.compile(".*Failed on local exception:.* " +
+          "Server asks us to fall back to SIMPLE auth, " +
+          "but this client is configured to only allow secure connections.*");
+
   /*
    *  simple server
    */
   @Test
   public void testSimpleServer() throws Exception {
     assertAuthEquals(SIMPLE,    getAuthMethod(SIMPLE,   SIMPLE));
-    // SASL methods are reverted to SIMPLE, but test setup fails
-    assertAuthEquals(KrbFailed, getAuthMethod(KERBEROS, SIMPLE));
+    assertAuthEquals(SIMPLE,    getAuthMethod(SIMPLE,   SIMPLE, UseToken.OTHER));
+    // SASL methods are normally reverted to SIMPLE
+    assertAuthEquals(SIMPLE,    getAuthMethod(KERBEROS, SIMPLE));
+    assertAuthEquals(SIMPLE,    getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER));
   }
 
   @Test
-  public void testSimpleServerWithTokensWithNoClientFallbackToSimple()
+  public void testNoClientFallbackToSimple()
       throws Exception {
-
     clientFallBackToSimpleAllowed = false;
-
-    try{
-      // Client has a token even though its configs says simple auth. Server
-      // is configured for simple auth, but as client sends the token, and
-      // server asks to switch to simple, this should fail.
-      getAuthMethod(SIMPLE,   SIMPLE, true);
-    } catch (IOException ioe) {
-      Assert
-        .assertTrue(ioe.getMessage().contains("Failed on local exception: " +
-        		"java.io.IOException: java.io.IOException: " +
-        		"Server asks us to fall back to SIMPLE auth, " +
-        		"but this client is configured to only allow secure connections"
-          ));
-    }
+    // tokens are irrelevant w/o secret manager enabled
+    assertAuthEquals(SIMPLE,     getAuthMethod(SIMPLE, SIMPLE));
+    assertAuthEquals(SIMPLE,     getAuthMethod(SIMPLE, SIMPLE, UseToken.OTHER));
+    assertAuthEquals(SIMPLE,     getAuthMethod(SIMPLE, SIMPLE, UseToken.VALID));
+    assertAuthEquals(SIMPLE,     getAuthMethod(SIMPLE, SIMPLE, UseToken.INVALID));
+
+    // A secure client must not fallback
+    assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE));
+    assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER));
+    assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID));
+    assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID));
 
     // Now set server to simple and also force the secret-manager. Now server
     // should have both simple and token enabled.
     forceSecretManager = true;
-    assertAuthEquals(TOKEN, getAuthMethod(SIMPLE,   SIMPLE, true));
-    forceSecretManager = false;
-    clientFallBackToSimpleAllowed = true;
+    assertAuthEquals(SIMPLE,     getAuthMethod(SIMPLE, SIMPLE));
+    assertAuthEquals(SIMPLE,     getAuthMethod(SIMPLE, SIMPLE, UseToken.OTHER));
+    assertAuthEquals(TOKEN,      getAuthMethod(SIMPLE, SIMPLE, UseToken.VALID));
+    assertAuthEquals(BadToken,   getAuthMethod(SIMPLE, SIMPLE, UseToken.INVALID));
+
+    // A secure client must not fallback
+    assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE));
+    assertAuthEquals(NoFallback, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER));
+    assertAuthEquals(TOKEN,      getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID));
+    assertAuthEquals(BadToken,   getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID));
+    
+    // doesn't try SASL
+    assertAuthEquals(Denied(SIMPLE), getAuthMethod(SIMPLE, TOKEN));
+    // does try SASL
+    assertAuthEquals(No(TOKEN),      getAuthMethod(SIMPLE, TOKEN, UseToken.OTHER));
+    assertAuthEquals(TOKEN,          getAuthMethod(SIMPLE, TOKEN, UseToken.VALID));
+    assertAuthEquals(BadToken,       getAuthMethod(SIMPLE, TOKEN, UseToken.INVALID));
+    
+    assertAuthEquals(No(TOKEN),      getAuthMethod(KERBEROS, TOKEN));
+    assertAuthEquals(No(TOKEN),      getAuthMethod(KERBEROS, TOKEN, UseToken.OTHER));
+    assertAuthEquals(TOKEN,          getAuthMethod(KERBEROS, TOKEN, UseToken.VALID));
+    assertAuthEquals(BadToken,       getAuthMethod(KERBEROS, TOKEN, UseToken.INVALID));
   }
 
   @Test
   public void testSimpleServerWithTokens() throws Exception {
     // Client not using tokens
     assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE));
-    // SASL methods are reverted to SIMPLE, but test setup fails
-    assertAuthEquals(KrbFailed, getAuthMethod(KERBEROS, SIMPLE));
+    // SASL methods are reverted to SIMPLE
+    assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE));
 
     // Use tokens. But tokens are ignored because client is reverted to simple
-    assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, true));
+    // due to server not using tokens
+    assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID));
+    assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER));
 
+    // server isn't really advertising tokens
     enableSecretManager = true;
-    assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE, true));
-    assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, true));
+    assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE, UseToken.VALID));
+    assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE, UseToken.OTHER));
+    
+    assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID));
+    assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER));
+    
+    // now the simple server takes tokens
+    forceSecretManager = true;
+    assertAuthEquals(TOKEN,  getAuthMethod(SIMPLE,   SIMPLE, UseToken.VALID));
+    assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE, UseToken.OTHER));
+    
+    assertAuthEquals(TOKEN,  getAuthMethod(KERBEROS, SIMPLE, UseToken.VALID));
+    assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.OTHER));
   }
 
   @Test
   public void testSimpleServerWithInvalidTokens() throws Exception {
     // Tokens are ignored because client is reverted to simple
-    assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE, false));
-    assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, false));
+    assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE, UseToken.INVALID));
+    assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID));
     enableSecretManager = true;
-    assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE, false));
-    assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, false));
+    assertAuthEquals(SIMPLE, getAuthMethod(SIMPLE,   SIMPLE, UseToken.INVALID));
+    assertAuthEquals(SIMPLE, getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID));
+    forceSecretManager = true;
+    assertAuthEquals(BadToken, getAuthMethod(SIMPLE,   SIMPLE, UseToken.INVALID));
+    assertAuthEquals(BadToken, getAuthMethod(KERBEROS, SIMPLE, UseToken.INVALID));
   }
   
   /*
@@ -675,26 +703,29 @@ public class TestSaslRPC {
    */
   @Test
   public void testTokenOnlyServer() throws Exception {
+    // simple client w/o tokens won't try SASL, so server denies
     assertAuthEquals(Denied(SIMPLE), getAuthMethod(SIMPLE,   TOKEN));
-    assertAuthEquals(KrbFailed,      getAuthMethod(KERBEROS, TOKEN));
+    assertAuthEquals(No(TOKEN),      getAuthMethod(SIMPLE,   TOKEN, UseToken.OTHER));
+    assertAuthEquals(No(TOKEN),      getAuthMethod(KERBEROS, TOKEN));
+    assertAuthEquals(No(TOKEN),      getAuthMethod(KERBEROS, TOKEN, UseToken.OTHER));
   }
 
   @Test
   public void testTokenOnlyServerWithTokens() throws Exception {
-    assertAuthEquals(TOKEN, getAuthMethod(SIMPLE,   TOKEN, true));
-    assertAuthEquals(TOKEN, getAuthMethod(KERBEROS, TOKEN, true));
+    assertAuthEquals(TOKEN,       getAuthMethod(SIMPLE,   TOKEN, UseToken.VALID));
+    assertAuthEquals(TOKEN,       getAuthMethod(KERBEROS, TOKEN, UseToken.VALID));
     enableSecretManager = false;
-    assertAuthEquals(NoTokenAuth, getAuthMethod(SIMPLE,   TOKEN, true));
-    assertAuthEquals(NoTokenAuth, getAuthMethod(KERBEROS, TOKEN, true));
+    assertAuthEquals(NoTokenAuth, getAuthMethod(SIMPLE,   TOKEN, UseToken.VALID));
+    assertAuthEquals(NoTokenAuth, getAuthMethod(KERBEROS, TOKEN, UseToken.VALID));
   }
 
   @Test
   public void testTokenOnlyServerWithInvalidTokens() throws Exception {
-    assertAuthEquals(BadToken, getAuthMethod(SIMPLE,   TOKEN, false));
-    assertAuthEquals(BadToken, getAuthMethod(KERBEROS, TOKEN, false));
+    assertAuthEquals(BadToken,    getAuthMethod(SIMPLE,   TOKEN, UseToken.INVALID));
+    assertAuthEquals(BadToken,    getAuthMethod(KERBEROS, TOKEN, UseToken.INVALID));
     enableSecretManager = false;
-    assertAuthEquals(NoTokenAuth, getAuthMethod(SIMPLE,   TOKEN, false));
-    assertAuthEquals(NoTokenAuth, getAuthMethod(KERBEROS, TOKEN, false));
+    assertAuthEquals(NoTokenAuth, getAuthMethod(SIMPLE,   TOKEN, UseToken.INVALID));
+    assertAuthEquals(NoTokenAuth, getAuthMethod(KERBEROS, TOKEN, UseToken.INVALID));
   }
 
   /*
@@ -702,38 +733,43 @@ public class TestSaslRPC {
    */
   @Test
   public void testKerberosServer() throws Exception {
-    assertAuthEquals(Denied(SIMPLE), getAuthMethod(SIMPLE,   KERBEROS));
-    assertAuthEquals(KrbFailed,      getAuthMethod(KERBEROS, KERBEROS));    
+    // doesn't try SASL
+    assertAuthEquals(Denied(SIMPLE),     getAuthMethod(SIMPLE,   KERBEROS));
+    // does try SASL
+    assertAuthEquals(No(TOKEN,KERBEROS), getAuthMethod(SIMPLE,   KERBEROS, UseToken.OTHER));
+    // no tgt
+    assertAuthEquals(KrbFailed,          getAuthMethod(KERBEROS, KERBEROS));
+    assertAuthEquals(KrbFailed,          getAuthMethod(KERBEROS, KERBEROS, UseToken.OTHER));
   }
 
   @Test
   public void testKerberosServerWithTokens() throws Exception {
     // can use tokens regardless of auth
-    assertAuthEquals(TOKEN, getAuthMethod(SIMPLE,   KERBEROS, true));
-    assertAuthEquals(TOKEN, getAuthMethod(KERBEROS, KERBEROS, true));
-    // can't fallback to simple when using kerberos w/o tokens
+    assertAuthEquals(TOKEN,        getAuthMethod(SIMPLE,   KERBEROS, UseToken.VALID));
+    assertAuthEquals(TOKEN,        getAuthMethod(KERBEROS, KERBEROS, UseToken.VALID));
     enableSecretManager = false;
-    assertAuthEquals(Denied(TOKEN), getAuthMethod(SIMPLE,   KERBEROS, true));
-    assertAuthEquals(Denied(TOKEN), getAuthMethod(KERBEROS, KERBEROS, true));
+    // shouldn't even try token because server didn't tell us to
+    assertAuthEquals(No(KERBEROS), getAuthMethod(SIMPLE,   KERBEROS, UseToken.VALID));
+    assertAuthEquals(KrbFailed,    getAuthMethod(KERBEROS, KERBEROS, UseToken.VALID));
   }
 
   @Test
   public void testKerberosServerWithInvalidTokens() throws Exception {
-    assertAuthEquals(BadToken, getAuthMethod(SIMPLE,   KERBEROS, false));
-    assertAuthEquals(BadToken, getAuthMethod(KERBEROS, KERBEROS, false));
+    assertAuthEquals(BadToken,     getAuthMethod(SIMPLE,   KERBEROS, UseToken.INVALID));
+    assertAuthEquals(BadToken,     getAuthMethod(KERBEROS, KERBEROS, UseToken.INVALID));
     enableSecretManager = false;
-    assertAuthEquals(Denied(TOKEN), getAuthMethod(SIMPLE,   KERBEROS, false));
-    assertAuthEquals(Denied(TOKEN), getAuthMethod(KERBEROS, KERBEROS, false));
+    assertAuthEquals(No(KERBEROS), getAuthMethod(SIMPLE,   KERBEROS, UseToken.INVALID));
+    assertAuthEquals(KrbFailed,    getAuthMethod(KERBEROS, KERBEROS, UseToken.INVALID));
   }
 
 
   // test helpers
 
   private String getAuthMethod(
-      final AuthenticationMethod clientAuth,
-      final AuthenticationMethod serverAuth) throws Exception {
+      final AuthMethod clientAuth,
+      final AuthMethod serverAuth) throws Exception {
     try {
-      return internalGetAuthMethod(clientAuth, serverAuth, false, false);
+      return internalGetAuthMethod(clientAuth, serverAuth, UseToken.NONE);
     } catch (Exception e) {
       LOG.warn("Auth method failure", e);
       return e.toString();
@@ -741,11 +777,11 @@ public class TestSaslRPC {
   }
 
   private String getAuthMethod(
-      final AuthenticationMethod clientAuth,
-      final AuthenticationMethod serverAuth,
-      final boolean useValidToken) throws Exception {
+      final AuthMethod clientAuth,
+      final AuthMethod serverAuth,
+      final UseToken tokenType) throws Exception {
     try {
-      return internalGetAuthMethod(clientAuth, serverAuth, true, useValidToken);
+      return internalGetAuthMethod(clientAuth, serverAuth, tokenType);
     } catch (Exception e) {
       LOG.warn("Auth method failure", e);
       return e.toString();
@@ -753,15 +789,14 @@ public class TestSaslRPC {
   }
   
   private String internalGetAuthMethod(
-      final AuthenticationMethod clientAuth,
-      final AuthenticationMethod serverAuth,
-      final boolean useToken,
-      final boolean useValidToken) throws Exception {
+      final AuthMethod clientAuth,
+      final AuthMethod serverAuth,
+      final UseToken tokenType) throws Exception {
     
     String currentUser = UserGroupInformation.getCurrentUser().getUserName();
     
     final Configuration serverConf = new Configuration(conf);
-    SecurityUtil.setAuthenticationMethod(serverAuth, serverConf);
+    serverConf.set(HADOOP_SECURITY_AUTHENTICATION, serverAuth.toString());
     UserGroupInformation.setConfiguration(serverConf);
     
     final UserGroupInformation serverUgi =
@@ -793,7 +828,7 @@ public class TestSaslRPC {
     });
 
     final Configuration clientConf = new Configuration(conf);
-    SecurityUtil.setAuthenticationMethod(clientAuth, clientConf);
+    clientConf.set(HADOOP_SECURITY_AUTHENTICATION, clientAuth.toString());
     clientConf.setBoolean(
         CommonConfigurationKeys.IPC_CLIENT_FALLBACK_TO_SIMPLE_AUTH_ALLOWED_KEY,
         clientFallBackToSimpleAllowed);
@@ -804,16 +839,26 @@ public class TestSaslRPC {
     clientUgi.setAuthenticationMethod(clientAuth);    
 
     final InetSocketAddress addr = NetUtils.getConnectAddress(server);
-    if (useToken) {
+    if (tokenType != UseToken.NONE) {
       TestTokenIdentifier tokenId = new TestTokenIdentifier(
           new Text(clientUgi.getUserName()));
-      Token<TestTokenIdentifier> token = useValidToken
-          ? new Token<TestTokenIdentifier>(tokenId, sm)
-          : new Token<TestTokenIdentifier>(
+      Token<TestTokenIdentifier> token = null;
+      switch (tokenType) {
+        case VALID:
+          token = new Token<TestTokenIdentifier>(tokenId, sm);
+          SecurityUtil.setTokenService(token, addr);
+          break;
+        case INVALID:
+          token = new Token<TestTokenIdentifier>(
               tokenId.getBytes(), "bad-password!".getBytes(),
               tokenId.getKind(), null);
-      
-      SecurityUtil.setTokenService(token, addr);
+          SecurityUtil.setTokenService(token, addr);
+          break;
+        case OTHER:
+          token = new Token<TestTokenIdentifier>();
+          break;
+        case NONE: // won't get here
+      }
       clientUgi.addToken(token);
     }
 
@@ -848,7 +893,7 @@ public class TestSaslRPC {
     }
   }
 
-  private static void assertAuthEquals(AuthenticationMethod expect,
+  private static void assertAuthEquals(AuthMethod expect,
       String actual) {
     assertEquals(expect.toString(), actual);
   }



Mime
View raw message