hadoop-common-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From omal...@apache.org
Subject svn commit: r1077673 - in /hadoop/common/branches/branch-0.20-security-patches/src: core/org/apache/hadoop/ipc/ core/org/apache/hadoop/security/token/delegation/ test/org/apache/hadoop/ipc/
Date Fri, 04 Mar 2011 04:42:45 GMT
Author: omalley
Date: Fri Mar  4 04:42:45 2011
New Revision: 1077673

URL: http://svn.apache.org/viewvc?rev=1077673&view=rev
Log:
commit 3d53bc3547d9414f2dcaa0bda82c230b1cf65b48
Author: Devaraj Das <ddas@yahoo-inc.com>
Date:   Mon Sep 13 15:43:48 2010 -0700

    HADOOP:6907 from https://issues.apache.org/jira/secure/attachment/12453844/c6907-Y20S.1xx.05.patch
    
    +++ b/YAHOO-CHANGES.txt
    +    HADOOP-6907. Make RPC client to use per-proxy configuration.
    +    (Kan Zhang via ddas)
    +

Modified:
    hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/ipc/Client.java
    hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/ipc/RPC.java
    hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/security/token/delegation/AbstractDelegationTokenSecretManager.java
    hadoop/common/branches/branch-0.20-security-patches/src/test/org/apache/hadoop/ipc/TestIPC.java
    hadoop/common/branches/branch-0.20-security-patches/src/test/org/apache/hadoop/ipc/TestSaslRPC.java

Modified: hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/ipc/Client.java
URL: http://svn.apache.org/viewvc/hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/ipc/Client.java?rev=1077673&r1=1077672&r2=1077673&view=diff
==============================================================================
--- hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/ipc/Client.java
(original)
+++ hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/ipc/Client.java
Fri Mar  4 04:42:45 2011
@@ -37,6 +37,7 @@ import java.security.PrivilegedException
 import java.util.Hashtable;
 import java.util.Iterator;
 import java.util.Random;
+import java.util.Set;
 import java.util.Map.Entry;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicLong;
@@ -80,11 +81,6 @@ public class Client {
   private int counter;                            // counter for call ids
   private AtomicBoolean running = new AtomicBoolean(true); // if client runs
   final private Configuration conf;
-  final private int maxIdleTime; //connections will be culled if it was idle for 
-                           //maxIdleTime msecs
-  final private int maxRetries; //the max. no. of retries for socket connections
-  private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
-  private int pingInterval; // how often sends ping to the server in msecs
 
   private SocketFactory socketFactory;           // how to create sockets
   private int refCount = 1;
@@ -198,6 +194,11 @@ public class Client {
     private Socket socket = null;                 // connected socket
     private DataInputStream in;
     private DataOutputStream out;
+    private int maxIdleTime; //connections will be culled if it was idle for
+         //maxIdleTime msecs
+    private int maxRetries; //the max. no. of retries for socket connections
+    private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
+    private int pingInterval; // how often sends ping to the server in msecs
     
     // currently active calls
     private Hashtable<Integer, Call> calls = new Hashtable<Integer, Call>();
@@ -212,6 +213,13 @@ public class Client {
         throw new UnknownHostException("unknown host: " + 
                                        remoteId.getAddress().getHostName());
       }
+      this.maxIdleTime = remoteId.getMaxIdleTime();
+      this.maxRetries = remoteId.getMaxRetries();
+      this.tcpNoDelay = remoteId.getTcpNoDelay();
+      this.pingInterval = remoteId.getPingInterval();
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("The ping interval is" + this.pingInterval + "ms.");
+      }
       
       UserGroupInformation ticket = remoteId.getTicket();
       Class<?> protocol = remoteId.getProtocol();
@@ -234,15 +242,9 @@ public class Client {
         }
         KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
         if (krbInfo != null) {
-          String serverKey = krbInfo.serverPrincipal();
-          if (serverKey == null) {
-            throw new IOException(
-                "Can't obtain server Kerberos config key from KerberosInfo");
-          }
-          serverPrincipal = SecurityUtil.getServerPrincipal(
-              conf.get(serverKey), server.getAddress().getCanonicalHostName());
+          serverPrincipal = remoteId.getServerPrincipal();
           if (LOG.isDebugEnabled()) {
-            LOG.debug("RPC Server Kerberos principal name for protocol="
+            LOG.debug("RPC Server's Kerberos principal name for protocol="
                 + protocol.getCanonicalName() + " is " + serverPrincipal);
           }
         }
@@ -838,14 +840,6 @@ public class Client {
   public Client(Class<? extends Writable> valueClass, Configuration conf, 
       SocketFactory factory) {
     this.valueClass = valueClass;
-    this.maxIdleTime = 
-      conf.getInt("ipc.client.connection.maxidletime", 10000); //10s
-    this.maxRetries = conf.getInt("ipc.client.connect.max.retries", 10);
-    this.tcpNoDelay = conf.getBoolean("ipc.client.tcpnodelay", false);
-    this.pingInterval = getPingInterval(conf);
-    if (LOG.isDebugEnabled()) {
-      LOG.debug("The ping interval is" + this.pingInterval + "ms.");
-    }
     this.conf = conf;
     this.socketFactory = factory;
   }
@@ -897,7 +891,7 @@ public class Client {
   /** Make a call, passing <code>param</code>, to the IPC server running at
    * <code>address</code>, returning the value.  Throws exceptions if there are
    * network problems or if the remote code threw an exception.
-   * @deprecated Use {@link #call(Writable, InetSocketAddress, Class, UserGroupInformation)}
instead 
+   * @deprecated Use {@link #call(Writable, ConnectionId)} instead 
    */
   @Deprecated
   public Writable call(Writable param, InetSocketAddress address)
@@ -910,25 +904,56 @@ public class Client {
    * the value.  
    * Throws exceptions if there are network problems or if the remote code 
    * threw an exception.
-   * @deprecated Use {@link #call(Writable, InetSocketAddress, Class, UserGroupInformation)}
instead 
+   * @deprecated Use {@link #call(Writable, ConnectionId)} instead 
    */
   @Deprecated
   public Writable call(Writable param, InetSocketAddress addr, 
       UserGroupInformation ticket)  
       throws InterruptedException, IOException {
-    return call(param, addr, null, ticket);
+    ConnectionId remoteId = ConnectionId.getConnectionId(addr, null, ticket,
+        conf);
+    return call(param, remoteId);
   }
   
   /** Make a call, passing <code>param</code>, to the IPC server running at
    * <code>address</code> which is servicing the <code>protocol</code>
protocol, 
    * with the <code>ticket</code> credentials, returning the value.  
    * Throws exceptions if there are network problems or if the remote code 
-   * threw an exception. */
+   * threw an exception. 
+   * @deprecated Use {@link #call(Writable, ConnectionId)} instead 
+   */
+  @Deprecated
   public Writable call(Writable param, InetSocketAddress addr, 
                        Class<?> protocol, UserGroupInformation ticket)  
                        throws InterruptedException, IOException {
+    ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol,
+        ticket, conf);
+    return call(param, remoteId);
+  }
+  
+  /** Make a call, passing <code>param</code>, to the IPC server running at
+   * <code>address</code> which is servicing the <code>protocol</code>
protocol, 
+   * with the <code>ticket</code> credentials and <code>conf</code>
as 
+   * configuration for this connection, returning the value.  
+   * Throws exceptions if there are network problems or if the remote code 
+   * threw an exception. */
+  public Writable call(Writable param, InetSocketAddress addr, 
+                       Class<?> protocol, UserGroupInformation ticket,
+                       Configuration conf)  
+                       throws InterruptedException, IOException {
+    ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol,
+        ticket, conf);
+    return call(param, remoteId);
+  }
+  
+  /** Make a call, passing <code>param</code>, to the IPC server defined by
+   * <code>remoteId</code>, returning the value.  
+   * Throws exceptions if there are network problems or if the remote code 
+   * threw an exception. */
+  public Writable call(Writable param, ConnectionId remoteId)  
+                       throws InterruptedException, IOException {
     Call call = new Call(param);
-    Connection connection = getConnection(addr, protocol, ticket, call);
+    Connection connection = getConnection(remoteId, call);
     connection.sendParam(call);                 // send the parameter
     boolean interrupted = false;
     synchronized (call) {
@@ -951,7 +976,7 @@ public class Client {
           call.error.fillInStackTrace();
           throw call.error;
         } else { // local exception
-          throw wrapException(addr, call.error);
+          throw wrapException(remoteId.getAddress(), call.error);
         }
       } else {
         return call.value;
@@ -991,25 +1016,34 @@ public class Client {
   }
 
   /** 
-   * Makes a set of calls in parallel.  Each parameter is sent to the
-   * corresponding address.  When all values are available, or have timed out
-   * or errored, the collected results are returned in an array.  The array
-   * contains nulls for calls that timed out or errored.
-   * @deprecated Use {@link #call(Writable[], InetSocketAddress[], Class, UserGroupInformation)}
instead 
+   * @deprecated Use {@link #call(Writable[], InetSocketAddress[], 
+   * Class, UserGroupInformation, Configuration)} instead 
    */
   @Deprecated
   public Writable[] call(Writable[] params, InetSocketAddress[] addresses)
     throws IOException, InterruptedException {
-    return call(params, addresses, null, null);
+    return call(params, addresses, null, null, conf);
+  }
+  
+  /**  
+   * @deprecated Use {@link #call(Writable[], InetSocketAddress[], 
+   * Class, UserGroupInformation, Configuration)} instead 
+   */
+  @Deprecated
+  public Writable[] call(Writable[] params, InetSocketAddress[] addresses, 
+                         Class<?> protocol, UserGroupInformation ticket)
+    throws IOException, InterruptedException {
+    return call(params, addresses, protocol, ticket, conf);
   }
   
+
   /** Makes a set of calls in parallel.  Each parameter is sent to the
    * corresponding address.  When all values are available, or have timed out
    * or errored, the collected results are returned in an array.  The array
    * contains nulls for calls that timed out or errored.  */
-  public Writable[] call(Writable[] params, InetSocketAddress[] addresses, 
-                         Class<?> protocol, UserGroupInformation ticket)
-    throws IOException, InterruptedException {
+  public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
+      Class<?> protocol, UserGroupInformation ticket, Configuration conf)
+      throws IOException, InterruptedException {
     if (addresses.length == 0) return new Writable[0];
 
     ParallelResults results = new ParallelResults(params.length);
@@ -1017,8 +1051,9 @@ public class Client {
       for (int i = 0; i < params.length; i++) {
         ParallelCall call = new ParallelCall(params[i], results, i);
         try {
-          Connection connection = 
-            getConnection(addresses[i], protocol, ticket, call);
+          ConnectionId remoteId = ConnectionId.getConnectionId(addresses[i],
+              protocol, ticket, conf);
+          Connection connection = getConnection(remoteId, call);
           connection.sendParam(call);             // send each parameter
         } catch (IOException e) {
           // log errors
@@ -1037,11 +1072,16 @@ public class Client {
     }
   }
 
+  //for unit testing only
+  Set<ConnectionId> getConnectionIds() {
+    synchronized (connections) {
+      return connections.keySet();
+    }
+  }
+
   /** Get a connection from the pool, or create a new one and add it to the
-   * pool.  Connections to a given host/port are reused. */
-  private Connection getConnection(InetSocketAddress addr,
-                                   Class<?> protocol,
-                                   UserGroupInformation ticket,
+   * pool.  Connections to a given ConnectionId are reused. */
+   private Connection getConnection(ConnectionId remoteId,
                                    Call call)
                                    throws IOException, InterruptedException {
     if (!running.get()) {
@@ -1053,7 +1093,6 @@ public class Client {
      * connectionsId object and with set() method. We need to manage the
      * refs for keys in HashMap properly. For now its ok.
      */
-    ConnectionId remoteId = new ConnectionId(addr, protocol, ticket);
     do {
       synchronized (connections) {
         connection = connections.get(remoteId);
@@ -1072,51 +1111,136 @@ public class Client {
     return connection;
   }
 
-  /**
-   * This class holds the address and the user ticket. The client connections
-   * to servers are uniquely identified by <remoteAddress, protocol, ticket>
-   */
-  private static class ConnectionId {
-    InetSocketAddress address;
-    UserGroupInformation ticket;
-    Class<?> protocol;
-    private static final int PRIME = 16777619;
-    
-    ConnectionId(InetSocketAddress address, Class<?> protocol, 
-                 UserGroupInformation ticket) {
-      this.protocol = protocol;
-      this.address = address;
-      this.ticket = ticket;
-    }
-    
-    InetSocketAddress getAddress() {
-      return address;
-    }
-    
-    Class<?> getProtocol() {
-      return protocol;
-    }
-    
-    UserGroupInformation getTicket() {
-      return ticket;
-    }
-    
-    
-    @Override
-    public boolean equals(Object obj) {
-     if (obj instanceof ConnectionId) {
-       ConnectionId id = (ConnectionId) obj;
-       return address.equals(id.address) && protocol == id.protocol && 
-              ((ticket != null && ticket.equals(id.ticket)) ||
-               (ticket == id.ticket));
+   /**
+    * This class holds the address and the user ticket. The client connections
+    * to servers are uniquely identified by <remoteAddress, protocol, ticket>
+    */
+   static class ConnectionId {
+     InetSocketAddress address;
+     UserGroupInformation ticket;
+     Class<?> protocol;
+     private static final int PRIME = 16777619;
+     private String serverPrincipal;
+     private int maxIdleTime; //connections will be culled if it was idle for 
+     //maxIdleTime msecs
+     private int maxRetries; //the max. no. of retries for socket connections
+     private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
+     private int pingInterval; // how often sends ping to the server in msecs
+     
+     ConnectionId(InetSocketAddress address, Class<?> protocol, 
+                  UserGroupInformation ticket,
+                  String serverPrincipal, int maxIdleTime, 
+                  int maxRetries, boolean tcpNoDelay,
+                  int pingInterval) {
+       this.protocol = protocol;
+       this.address = address;
+       this.ticket = ticket;
+       this.serverPrincipal = serverPrincipal;
+       this.maxIdleTime = maxIdleTime;
+       this.maxRetries = maxRetries;
+       this.tcpNoDelay = tcpNoDelay;
+       this.pingInterval = pingInterval;
      }
-     return false;
-    }
-    
-    @Override
-    public int hashCode() {
-      return (address.hashCode() + PRIME * System.identityHashCode(protocol)) ^ 
-             (ticket == null ? 0 : ticket.hashCode());
-    }
-  }  
+     
+     InetSocketAddress getAddress() {
+       return address;
+     }
+     
+     Class<?> getProtocol() {
+       return protocol;
+     }
+     
+     UserGroupInformation getTicket() {
+       return ticket;
+     }
+     
+     String getServerPrincipal() {
+       return serverPrincipal;
+     }
+     
+     int getMaxIdleTime() {
+       return maxIdleTime;
+     }
+     
+     int getMaxRetries() {
+       return maxRetries;
+     }
+     
+     boolean getTcpNoDelay() {
+       return tcpNoDelay;
+     }
+     
+     int getPingInterval() {
+       return pingInterval;
+     }
+     
+     static ConnectionId getConnectionId(InetSocketAddress addr,
+         Class<?> protocol, UserGroupInformation ticket,
+         Configuration conf) throws IOException {
+       String remotePrincipal = getRemotePrincipal(conf, addr, protocol);
+       return new ConnectionId(addr, protocol, ticket,
+           remotePrincipal,
+           conf.getInt("ipc.client.connection.maxidletime", 10000), // 10s
+           conf.getInt("ipc.client.connect.max.retries", 10),
+           conf.getBoolean("ipc.client.tcpnodelay", false),
+           Client.getPingInterval(conf));
+     }
+     
+     private static String getRemotePrincipal(Configuration conf,
+         InetSocketAddress address, Class<?> protocol) throws IOException {
+       if (!UserGroupInformation.isSecurityEnabled() || protocol == null) {
+         return null;
+       }
+       KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
+       if (krbInfo != null) {
+         String serverKey = krbInfo.serverPrincipal();
+         if (serverKey == null) {
+           throw new IOException(
+               "Can't obtain server Kerberos config key from protocol="
+                   + protocol.getCanonicalName());
+         }
+         return SecurityUtil.getServerPrincipal(conf.get(serverKey), address
+             .getAddress().getCanonicalHostName());
+       }
+       return null;
+     }
+     
+     static boolean isEqual(Object a, Object b) {
+       return a == null ? b == null : a.equals(b);
+     }
+
+     @Override
+     public boolean equals(Object obj) {
+       if (obj == this) {
+         return true;
+       }
+       if (obj instanceof ConnectionId) {
+         ConnectionId that = (ConnectionId) obj;
+         return isEqual(this.address, that.address)
+             && this.maxIdleTime == that.maxIdleTime
+             && this.maxRetries == that.maxRetries
+             && this.pingInterval == that.pingInterval
+             && isEqual(this.protocol, that.protocol)
+             && isEqual(this.serverPrincipal, that.serverPrincipal)
+             && this.tcpNoDelay == that.tcpNoDelay
+             && isEqual(this.ticket, that.ticket);
+       }
+       return false;
+     }
+     
+     @Override
+     public int hashCode() {
+       int result = 1;
+       result = PRIME * result + ((address == null) ? 0 : address.hashCode());
+       result = PRIME * result + maxIdleTime;
+       result = PRIME * result + maxRetries;
+       result = PRIME * result + pingInterval;
+       result = PRIME * result + ((protocol == null) ? 0 : protocol.hashCode());
+       result = PRIME * result
+           + ((serverPrincipal == null) ? 0 : serverPrincipal.hashCode());
+       result = PRIME * result + (tcpNoDelay ? 1231 : 1237);
+       result = PRIME * result + ((ticket == null) ? 0 : ticket.hashCode());
+       return result;
+     }
+   }  
 }

Modified: hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/ipc/RPC.java
URL: http://svn.apache.org/viewvc/hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/ipc/RPC.java?rev=1077673&r1=1077672&r2=1077673&view=diff
==============================================================================
--- hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/ipc/RPC.java
(original)
+++ hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/ipc/RPC.java
Fri Mar  4 04:42:45 2011
@@ -195,19 +195,21 @@ public class RPC {
 
   private static ClientCache CLIENTS=new ClientCache();
   
+  //for unit testing only
+  static Client getClient(Configuration conf) {
+    return CLIENTS.getClient(conf);
+  }
+  
   private static class Invoker implements InvocationHandler {
-    private Class<? extends VersionedProtocol> protocol;
-    private InetSocketAddress address;
-    private UserGroupInformation ticket;
+    private Client.ConnectionId remoteId;
     private Client client;
     private boolean isClosed = false;
 
     public Invoker(Class<? extends VersionedProtocol> protocol,
         InetSocketAddress address, UserGroupInformation ticket,
-        Configuration conf, SocketFactory factory) {
-      this.protocol = protocol;
-      this.address = address;
-      this.ticket = ticket;
+        Configuration conf, SocketFactory factory) throws IOException {
+      this.remoteId = Client.ConnectionId.getConnectionId(address, protocol,
+          ticket, conf);
       this.client = CLIENTS.getClient(conf, factory);
     }
 
@@ -220,8 +222,7 @@ public class RPC {
       }
 
       ObjectWritable value = (ObjectWritable)
-        client.call(new Invocation(method, args), address, 
-                    protocol, ticket);
+        client.call(new Invocation(method, args), remoteId);
       if (logDebug) {
         long callTime = System.currentTimeMillis() - startTime;
         LOG.debug("Call: " + method.getName() + " " + callTime);
@@ -421,7 +422,7 @@ public class RPC {
     Client client = CLIENTS.getClient(conf);
     try {
     Writable[] wrappedValues = 
-      client.call(invocations, addrs, method.getDeclaringClass(), ticket);
+      client.call(invocations, addrs, method.getDeclaringClass(), ticket, conf);
     
     if (method.getReturnType() == Void.TYPE) {
       return null;

Modified: hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/security/token/delegation/AbstractDelegationTokenSecretManager.java
URL: http://svn.apache.org/viewvc/hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/security/token/delegation/AbstractDelegationTokenSecretManager.java?rev=1077673&r1=1077672&r2=1077673&view=diff
==============================================================================
--- hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/security/token/delegation/AbstractDelegationTokenSecretManager.java
(original)
+++ hadoop/common/branches/branch-0.20-security-patches/src/core/org/apache/hadoop/security/token/delegation/AbstractDelegationTokenSecretManager.java
Fri Mar  4 04:42:45 2011
@@ -206,11 +206,12 @@ extends AbstractDelegationTokenIdentifie
       throws InvalidToken {
     DelegationTokenInformation info = currentTokens.get(identifier);
     if (info == null) {
-      throw new InvalidToken("token is expired or doesn't exist");
+      throw new InvalidToken("token (" + identifier.toString()
+          + ") can't be found in cache");
     }
     long now = System.currentTimeMillis();
     if (info.getRenewDate() < now) {
-      throw new InvalidToken("token is expired");
+      throw new InvalidToken("token (" + identifier.toString() + ") is expired");
     }
     return info.getPassword();
   }

Modified: hadoop/common/branches/branch-0.20-security-patches/src/test/org/apache/hadoop/ipc/TestIPC.java
URL: http://svn.apache.org/viewvc/hadoop/common/branches/branch-0.20-security-patches/src/test/org/apache/hadoop/ipc/TestIPC.java?rev=1077673&r1=1077672&r2=1077673&view=diff
==============================================================================
--- hadoop/common/branches/branch-0.20-security-patches/src/test/org/apache/hadoop/ipc/TestIPC.java
(original)
+++ hadoop/common/branches/branch-0.20-security-patches/src/test/org/apache/hadoop/ipc/TestIPC.java
Fri Mar  4 04:42:45 2011
@@ -89,7 +89,7 @@ public class TestIPC extends TestCase {
         try {
           LongWritable param = new LongWritable(RANDOM.nextLong());
           LongWritable value =
-            (LongWritable)client.call(param, server, null, null);
+            (LongWritable)client.call(param, server, null, null, conf);
           if (!param.equals(value)) {
             LOG.fatal("Call failed!");
             failed = true;
@@ -122,7 +122,7 @@ public class TestIPC extends TestCase {
           Writable[] params = new Writable[addresses.length];
           for (int j = 0; j < addresses.length; j++)
             params[j] = new LongWritable(RANDOM.nextLong());
-          Writable[] values = client.call(params, addresses, null, null);
+          Writable[] values = client.call(params, addresses, null, null, conf);
           for (int j = 0; j < addresses.length; j++) {
             if (!params[j].equals(values[j])) {
               LOG.fatal("Call failed!");
@@ -217,7 +217,7 @@ public class TestIPC extends TestCase {
     InetSocketAddress address = new InetSocketAddress("127.0.0.1", 10);
     try {
       client.call(new LongWritable(RANDOM.nextLong()),
-              address, null, null);
+              address, null, null, conf);
       fail("Expected an exception to have been thrown");
     } catch (IOException e) {
       String message = e.getMessage();
@@ -258,7 +258,7 @@ public class TestIPC extends TestCase {
     Client client = new Client(LongErrorWritable.class, conf);
     try {
       client.call(new LongErrorWritable(RANDOM.nextLong()),
-          addr, null, null);
+          addr, null, null, conf);
       fail("Expected an exception to have been thrown");
     } catch (IOException e) {
       // check error

Modified: hadoop/common/branches/branch-0.20-security-patches/src/test/org/apache/hadoop/ipc/TestSaslRPC.java
URL: http://svn.apache.org/viewvc/hadoop/common/branches/branch-0.20-security-patches/src/test/org/apache/hadoop/ipc/TestSaslRPC.java?rev=1077673&r1=1077672&r2=1077673&view=diff
==============================================================================
--- hadoop/common/branches/branch-0.20-security-patches/src/test/org/apache/hadoop/ipc/TestSaslRPC.java
(original)
+++ hadoop/common/branches/branch-0.20-security-patches/src/test/org/apache/hadoop/ipc/TestSaslRPC.java
Fri Mar  4 04:42:45 2011
@@ -27,6 +27,7 @@ import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.security.PrivilegedExceptionAction;
 import java.util.Collection;
+import java.util.Set;
 
 import junit.framework.Assert;
 
@@ -35,6 +36,7 @@ import org.apache.commons.logging.impl.L
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.Text;
+import org.apache.hadoop.ipc.Client.ConnectionId;
 import org.apache.hadoop.net.NetUtils;
 import org.apache.hadoop.security.KerberosInfo;
 import org.apache.hadoop.security.token.SecretManager;
@@ -64,6 +66,9 @@ public class TestSaslRPC {
   static final String ERROR_MESSAGE = "Token is invalid";
   static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal";
   static final String SERVER_KEYTAB_KEY = "test.ipc.server.keytab";
+  static final String SERVER_PRINCIPAL_1 = "p1/foo@BAR";
+  static final String SERVER_PRINCIPAL_2 = "p2/foo@BAR";
+  
   private static Configuration conf;
   static {
     conf = new Configuration();
@@ -245,6 +250,85 @@ public class TestSaslRPC {
     }
   }
   
+  @Test
+  public void testGetRemotePrincipal() throws Exception {
+    try {
+      Configuration newConf = new Configuration(conf);
+      newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1);
+      ConnectionId remoteId = ConnectionId.getConnectionId(
+          new InetSocketAddress(0), TestSaslProtocol.class, null, newConf);
+      assertEquals(SERVER_PRINCIPAL_1, remoteId.getServerPrincipal());
+      // this following test needs security to be off
+      newConf.set(HADOOP_SECURITY_AUTHENTICATION, "simple");
+      UserGroupInformation.setConfiguration(newConf);
+      remoteId = ConnectionId.getConnectionId(new InetSocketAddress(0),
+          TestSaslProtocol.class, null, newConf);
+      assertEquals(
+          "serverPrincipal should be null when security is turned off", null,
+          remoteId.getServerPrincipal());
+    } finally {
+      // revert back to security is on
+      UserGroupInformation.setConfiguration(conf);
+    }
+  }
+  
+  @Test
+  public void testPerConnectionConf() throws Exception {
+    TestTokenSecretManager sm = new TestTokenSecretManager();
+    final Server server = RPC.getServer(
+        new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm);
+    server.start();
+    final UserGroupInformation current = UserGroupInformation.getCurrentUser();
+    final InetSocketAddress addr = NetUtils.getConnectAddress(server);
+    TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current
+        .getUserName()));
+    Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId,
+        sm);
+    Text host = new Text(addr.getAddress().getHostAddress() + ":"
+        + addr.getPort());
+    token.setService(host);
+    LOG.info("Service IP address for token is " + host);
+    current.addToken(token);
+
+    Configuration newConf = new Configuration(conf);
+    newConf.set("hadoop.rpc.socket.factory.class.default", "");
+    newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1);
+
+    TestSaslProtocol proxy1 = null;
+    TestSaslProtocol proxy2 = null;
+    TestSaslProtocol proxy3 = null;
+    try {
+      proxy1 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
+          TestSaslProtocol.versionID, addr, newConf);
+      Client client = RPC.getClient(conf);
+      Set<ConnectionId> conns = client.getConnectionIds();
+      assertEquals("number of connections in cache is wrong", 1, conns.size());
+      // same conf, connection should be re-used
+      proxy2 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
+          TestSaslProtocol.versionID, addr, newConf);
+      assertEquals("number of connections in cache is wrong", 1, conns.size());
+      // different conf, new connection should be set up
+      newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2);
+      proxy3 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
+          TestSaslProtocol.versionID, addr, newConf);
+      ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]);
+      assertEquals("number of connections in cache is wrong", 2,
+          connsArray.length);
+      String p1 = connsArray[0].getServerPrincipal();
+      String p2 = connsArray[1].getServerPrincipal();
+      assertFalse("should have different principals", p1.equals(p2));
+      assertTrue("principal not as expected", p1.equals(SERVER_PRINCIPAL_1)
+          || p1.equals(SERVER_PRINCIPAL_2));
+      assertTrue("principal not as expected", p2.equals(SERVER_PRINCIPAL_1)
+          || p2.equals(SERVER_PRINCIPAL_2));
+    } finally {
+      server.stop();
+      RPC.stopProxy(proxy1);
+      RPC.stopProxy(proxy2);
+      RPC.stopProxy(proxy3);
+    }
+  }
+  
   static void testKerberosRpc(String principal, String keytab) throws Exception {
     final Configuration newConf = new Configuration(conf);
     newConf.set(SERVER_PRINCIPAL_KEY, principal);



Mime
View raw message