From commits-return-7423-archive-asf-public=cust-asf.ponee.io@zookeeper.apache.org Tue Nov 27 17:56:54 2018 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx-eu-01.ponee.io (Postfix) with SMTP id 05BE9180677 for ; Tue, 27 Nov 2018 17:56:52 +0100 (CET) Received: (qmail 74738 invoked by uid 500); 27 Nov 2018 16:56:52 -0000 Mailing-List: contact commits-help@zookeeper.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@zookeeper.apache.org Delivered-To: mailing list commits@zookeeper.apache.org Received: (qmail 74710 invoked by uid 99); 27 Nov 2018 16:56:52 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Tue, 27 Nov 2018 16:56:52 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id E86D0E12E9; Tue, 27 Nov 2018 16:56:51 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: andor@apache.org To: commits@zookeeper.apache.org Date: Tue, 27 Nov 2018 16:56:51 -0000 Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: [1/2] zookeeper git commit: ZOOKEEPER-3172: Quorum TLS - fix port unification to allow rolling upgrades Repository: zookeeper Updated Branches: refs/heads/branch-3.5 019e841e7 -> f3f1146a7 http://git-wip-us.apache.org/repos/asf/zookeeper/blob/f3f1146a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketTest.java ---------------------------------------------------------------------- diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketTest.java index 09a5d41..5e4e619 100644 --- a/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketTest.java +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketTest.java @@ -17,156 +17,584 @@ */ package org.apache.zookeeper.server.quorum; +import java.io.BufferedInputStream; +import java.io.IOException; +import java.net.ConnectException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import javax.net.ssl.HandshakeCompletedEvent; +import javax.net.ssl.HandshakeCompletedListener; +import javax.net.ssl.SSLSocket; + import org.apache.zookeeper.PortAssignment; -import org.apache.zookeeper.client.ZKClientConfig; +import org.apache.zookeeper.common.BaseX509ParameterizedTestCase; import org.apache.zookeeper.common.ClientX509Util; -import org.apache.zookeeper.common.Time; +import org.apache.zookeeper.common.KeyStoreFileType; +import org.apache.zookeeper.common.X509Exception; +import org.apache.zookeeper.common.X509KeyType; +import org.apache.zookeeper.common.X509TestContext; import org.apache.zookeeper.common.X509Util; -import org.apache.zookeeper.server.ServerCnxnFactory; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; -import javax.net.ssl.HandshakeCompletedEvent; -import javax.net.ssl.HandshakeCompletedListener; -import javax.net.ssl.SSLSocket; -import java.io.IOException; -import java.net.ConnectException; -import java.net.InetSocketAddress; -import java.net.Socket; +@RunWith(Parameterized.class) +public class UnifiedServerSocketTest extends BaseX509ParameterizedTestCase { -import static org.hamcrest.CoreMatchers.equalTo; -import static org.junit.Assert.assertThat; - -public class UnifiedServerSocketTest { + @Parameterized.Parameters + public static Collection params() { + ArrayList result = new ArrayList<>(); + int paramIndex = 0; + for (X509KeyType caKeyType : X509KeyType.values()) { + for (X509KeyType certKeyType : X509KeyType.values()) { + for (Boolean hostnameVerification : new Boolean[] { true, false }) { + result.add(new Object[]{ + caKeyType, + certKeyType, + hostnameVerification, + paramIndex++ + }); + } + } + } + return result; + } private static final int MAX_RETRIES = 5; private static final int TIMEOUT = 1000; + private static final byte[] DATA_TO_CLIENT = "hello client".getBytes(); + private static final byte[] DATA_FROM_CLIENT = "hello server".getBytes(); private X509Util x509Util; - private int port; - private volatile boolean handshakeCompleted; + private InetSocketAddress localServerAddress; + private final Object handshakeCompletedLock = new Object(); + // access only inside synchronized(handshakeCompletedLock) { ... } blocks + private boolean handshakeCompleted = false; + + public UnifiedServerSocketTest( + final X509KeyType caKeyType, + final X509KeyType certKeyType, + final Boolean hostnameVerification, + final Integer paramIndex) { + super(paramIndex, () -> { + try { + return X509TestContext.newBuilder() + .setTempDir(tempDir) + .setKeyStoreKeyType(certKeyType) + .setTrustStoreKeyType(caKeyType) + .setHostnameVerification(hostnameVerification) + .build(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } @Before public void setUp() throws Exception { - handshakeCompleted = false; - - port = PortAssignment.unique(); + localServerAddress = new InetSocketAddress(InetAddress.getLoopbackAddress(), PortAssignment.unique()); + x509Util = new ClientX509Util(); + x509TestContext.setSystemProperties(x509Util, KeyStoreFileType.JKS, KeyStoreFileType.JKS); + } - String testDataPath = System.getProperty("test.data.dir", "build/test/data"); - System.setProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY, "org.apache.zookeeper.server.NettyServerCnxnFactory"); - System.setProperty(ZKClientConfig.ZOOKEEPER_CLIENT_CNXN_SOCKET, "org.apache.zookeeper.ClientCnxnSocketNetty"); - System.setProperty(ZKClientConfig.SECURE_CLIENT, "true"); + @After + public void tearDown() throws Exception { + x509TestContext.clearSystemProperties(x509Util); + } - x509Util = new ClientX509Util(); + private static void forceClose(Socket s) { + if (s == null || s.isClosed()) { + return; + } + try { + s.close(); + } catch (IOException e) { + } + } - System.setProperty(x509Util.getSslKeystoreLocationProperty(), testDataPath + "/ssl/testKeyStore.jks"); - System.setProperty(x509Util.getSslKeystorePasswdProperty(), "testpass"); - System.setProperty(x509Util.getSslTruststoreLocationProperty(), testDataPath + "/ssl/testTrustStore.jks"); - System.setProperty(x509Util.getSslTruststorePasswdProperty(), "testpass"); - System.setProperty(x509Util.getSslHostnameVerificationEnabledProperty(), "false"); + private static void forceClose(ServerSocket s) { + if (s == null || s.isClosed()) { + return; + } + try { + s.close(); + } catch (IOException e) { + } } - @Test - public void testConnectWithSSL() throws Exception { - class ServerThread extends Thread { - public void run() { - try { - Socket unifiedSocket = new UnifiedServerSocket(x509Util, port).accept(); - ((SSLSocket)unifiedSocket).getSession(); // block until handshake completes - } catch (IOException e) { - e.printStackTrace(); + private static final class UnifiedServerThread extends Thread { + private final byte[] dataToClient; + private List dataFromClients; + private ExecutorService workerPool; + private UnifiedServerSocket serverSocket; + + UnifiedServerThread(X509Util x509Util, + InetSocketAddress bindAddress, + boolean allowInsecureConnection, + byte[] dataToClient) throws IOException { + this.dataToClient = dataToClient; + dataFromClients = new ArrayList<>(); + workerPool = Executors.newCachedThreadPool(); + serverSocket = new UnifiedServerSocket(x509Util, allowInsecureConnection); + serverSocket.bind(bindAddress); + } + + @Override + public void run() { + try { + Random rnd = new Random(); + while (true) { + final Socket unifiedSocket = serverSocket.accept(); + final boolean tcpNoDelay = rnd.nextBoolean(); + unifiedSocket.setTcpNoDelay(tcpNoDelay); + unifiedSocket.setSoTimeout(TIMEOUT); + final boolean keepAlive = rnd.nextBoolean(); + unifiedSocket.setKeepAlive(keepAlive); + // Note: getting the input stream should not block the thread or trigger mode detection. + BufferedInputStream bis = new BufferedInputStream(unifiedSocket.getInputStream()); + workerPool.submit(new Runnable() { + @Override + public void run() { + try { + byte[] buf = new byte[1024]; + int bytesRead = unifiedSocket.getInputStream().read(buf, 0, 1024); + // Make sure the settings applied above before the socket was potentially upgraded to + // TLS still apply. + Assert.assertEquals(tcpNoDelay, unifiedSocket.getTcpNoDelay()); + Assert.assertEquals(TIMEOUT, unifiedSocket.getSoTimeout()); + Assert.assertEquals(keepAlive, unifiedSocket.getKeepAlive()); + if (bytesRead > 0) { + byte[] dataFromClient = new byte[bytesRead]; + System.arraycopy(buf, 0, dataFromClient, 0, bytesRead); + synchronized (dataFromClients) { + dataFromClients.add(dataFromClient); + } + } + unifiedSocket.getOutputStream().write(dataToClient); + unifiedSocket.getOutputStream().flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + forceClose(unifiedSocket); + } + } + }); } + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + forceClose(serverSocket); + workerPool.shutdown(); } } - ServerThread serverThread = new ServerThread(); - serverThread.start(); + public void shutdown(long millis) throws InterruptedException { + forceClose(serverSocket); // this should break the run() loop + workerPool.awaitTermination(millis, TimeUnit.MILLISECONDS); + this.join(millis); + } + + synchronized byte[] getDataFromClient(int index) { + return dataFromClients.get(index); + } + } + + private SSLSocket connectWithSSL() throws IOException, X509Exception, InterruptedException { SSLSocket sslSocket = null; int retries = 0; while (retries < MAX_RETRIES) { try { sslSocket = x509Util.createSSLSocket(); + sslSocket.addHandshakeCompletedListener(new HandshakeCompletedListener() { + @Override + public void handshakeCompleted(HandshakeCompletedEvent handshakeCompletedEvent) { + synchronized (handshakeCompletedLock) { + handshakeCompleted = true; + handshakeCompletedLock.notifyAll(); + } + } + }); sslSocket.setSoTimeout(TIMEOUT); - sslSocket.connect(new InetSocketAddress(port), TIMEOUT); + sslSocket.connect(localServerAddress, TIMEOUT); break; } catch (ConnectException connectException) { connectException.printStackTrace(); + forceClose(sslSocket); + sslSocket = null; Thread.sleep(TIMEOUT); } retries++; } - sslSocket.addHandshakeCompletedListener(new HandshakeCompletedListener() { - @Override - public void handshakeCompleted(HandshakeCompletedEvent handshakeCompletedEvent) { - completeHandshake(); + Assert.assertNotNull("Failed to connect to server with SSL", sslSocket); + return sslSocket; + } + + private Socket connectWithoutSSL() throws IOException, InterruptedException { + Socket socket = null; + int retries = 0; + while (retries < MAX_RETRIES) { + try { + socket = new Socket(); + socket.setSoTimeout(TIMEOUT); + socket.connect(localServerAddress, TIMEOUT); + break; + } catch (ConnectException connectException) { + connectException.printStackTrace(); + forceClose(socket); + socket = null; + Thread.sleep(TIMEOUT); } - }); - sslSocket.startHandshake(); + retries++; + } + Assert.assertNotNull("Failed to connect to server without SSL", socket); + return socket; + } + + // In the tests below, a "Strict" server means a UnifiedServerSocket that + // does not allow plaintext connections (in other words, it's SSL-only). + // A "Non Strict" server means a UnifiedServerSocket that allows both + // plaintext and SSL incoming connections. + + /** + * Attempting to connect to a SSL-or-plaintext server with SSL should work. + */ + @Test + public void testConnectWithSSLToNonStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, true, DATA_TO_CLIENT); + serverThread.start(); + + Socket sslSocket = connectWithSSL(); + try { + sslSocket.getOutputStream().write(DATA_FROM_CLIENT); + sslSocket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = sslSocket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + + synchronized (handshakeCompletedLock) { + if (!handshakeCompleted) { + handshakeCompletedLock.wait(TIMEOUT); + } + Assert.assertTrue(handshakeCompleted); + } + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + } finally { + forceClose(sslSocket); + serverThread.shutdown(TIMEOUT); + } + } - serverThread.join(TIMEOUT); + /** + * Attempting to connect to a SSL-only server with SSL should work. + */ + @Test + public void testConnectWithSSLToStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, false, DATA_TO_CLIENT); + serverThread.start(); + + Socket sslSocket = connectWithSSL(); + try { + sslSocket.getOutputStream().write(DATA_FROM_CLIENT); + sslSocket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = sslSocket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); - long start = Time.currentElapsedTime(); - while (Time.currentElapsedTime() < start + TIMEOUT) { - if (handshakeCompleted) { - return; + synchronized (handshakeCompletedLock) { + if (!handshakeCompleted) { + handshakeCompletedLock.wait(TIMEOUT); + } + Assert.assertTrue(handshakeCompleted); } + + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + } finally { + forceClose(sslSocket); + serverThread.shutdown(TIMEOUT); } + } - Assert.fail("failed to complete handshake"); + /** + * Attempting to connect to a SSL-or-plaintext server without SSL should work. + */ + @Test + public void testConnectWithoutSSLToNonStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, true, DATA_TO_CLIENT); + serverThread.start(); + + Socket socket = connectWithoutSSL(); + try { + socket.getOutputStream().write(DATA_FROM_CLIENT); + socket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = socket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + } finally { + forceClose(socket); + serverThread.shutdown(TIMEOUT); + } } - private void completeHandshake() { - handshakeCompleted = true; + /** + * Attempting to connect to a SSL-or-plaintext server without SSL with a + * small initial data write should work. This makes sure that sending + * less than 5 bytes does not break the logic in the server's initial 5 + * byte read. + */ + @Test + public void testConnectWithoutSSLToNonStrictServerPartialWrite() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, true, DATA_TO_CLIENT); + serverThread.start(); + + Socket socket = connectWithoutSSL(); + try { + // Write only 2 bytes of the message, wait a bit, then write the rest. + // This makes sure that writes smaller than 5 bytes don't break the plaintext mode on the server + // once it decides that the input doesn't look like a TLS handshake. + socket.getOutputStream().write(DATA_FROM_CLIENT, 0, 2); + socket.getOutputStream().flush(); + Thread.sleep(TIMEOUT / 2); + socket.getOutputStream().write(DATA_FROM_CLIENT, 2, DATA_FROM_CLIENT.length - 2); + socket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = socket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + } finally { + forceClose(socket); + serverThread.shutdown(TIMEOUT); + } } + /** + * Attempting to connect to a SSL-only server without SSL should fail. + */ @Test - public void testConnectWithoutSSL() throws Exception { - final byte[] testData = "hello there".getBytes(); - final String[] dataReadFromClient = {null}; - - class ServerThread extends Thread { - public void run() { - try { - Socket unifiedSocket = new UnifiedServerSocket(x509Util, port).accept(); - unifiedSocket.getOutputStream().write(testData); - unifiedSocket.getOutputStream().flush(); - byte[] inputbuff = new byte[5]; - unifiedSocket.getInputStream().read(inputbuff, 0, 5); - dataReadFromClient[0] = new String(inputbuff); - } catch (IOException e) { - e.printStackTrace(); + public void testConnectWithoutSSLToStrictServer() throws Exception { + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, false, DATA_TO_CLIENT); + serverThread.start(); + + Socket socket = connectWithoutSSL(); + socket.getOutputStream().write(DATA_FROM_CLIENT); + socket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + try { + socket.getInputStream().read(buf, 0, buf.length); + } catch (SocketException e) { + // We expect the other end to hang up the connection + return; + } finally { + forceClose(socket); + serverThread.shutdown(TIMEOUT); + } + Assert.fail("Expected server to hang up the connection. Read from server succeeded unexpectedly."); + } + + /** + * This test makes sure that UnifiedServerSocket used properly (a single + * thread accept()-ing connections and handing the resulting sockets to + * other threads for processing) is not vulnerable to blocking the + * accept() thread while doing mode detection if a misbehaving client + * connects. A misbehaving client is one that either disconnects + * immediately, or connects but does not send any data. + * + * This version of the test uses a non-strict server socket (i.e. it + * accepts both TLS and plaintext connections). + */ + @Test + public void testTLSDetectionNonBlockingNonStrictServerIdleClient() throws Exception { + Socket badClientSocket = null; + Socket clientSocket = null; + Socket secureClientSocket = null; + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, true, DATA_TO_CLIENT); + serverThread.start(); + + try { + badClientSocket = connectWithoutSSL(); // Leave the bad client socket idle + + clientSocket = connectWithoutSSL(); + clientSocket.getOutputStream().write(DATA_FROM_CLIENT); + clientSocket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = clientSocket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + + synchronized (handshakeCompletedLock) { + Assert.assertFalse(handshakeCompleted); + } + + secureClientSocket = connectWithSSL(); + secureClientSocket.getOutputStream().write(DATA_FROM_CLIENT); + secureClientSocket.getOutputStream().flush(); + buf = new byte[DATA_TO_CLIENT.length]; + bytesRead = secureClientSocket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(1)); + + synchronized (handshakeCompletedLock) { + if (!handshakeCompleted) { + handshakeCompletedLock.wait(TIMEOUT); } + Assert.assertTrue(handshakeCompleted); } + } finally { + forceClose(badClientSocket); + forceClose(clientSocket); + forceClose(secureClientSocket); + serverThread.shutdown(TIMEOUT); } - ServerThread serverThread = new ServerThread(); + } + + /** + * Like the above test, but with a strict server socket (closes non-TLS + * connections after seeing that there is no handshake). + */ + @Test + public void testTLSDetectionNonBlockingStrictServerIdleClient() throws Exception { + Socket badClientSocket = null; + Socket secureClientSocket = null; + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, false, DATA_TO_CLIENT); serverThread.start(); - Socket socket = null; - int retries = 0; - while (retries < MAX_RETRIES) { - try { - socket = new Socket(); - socket.setSoTimeout(TIMEOUT); - socket.connect(new InetSocketAddress(port), TIMEOUT); - break; - } catch (ConnectException connectException) { - connectException.printStackTrace(); - Thread.sleep(TIMEOUT); + try { + badClientSocket = connectWithoutSSL(); // Leave the bad client socket idle + + secureClientSocket = connectWithSSL(); + secureClientSocket.getOutputStream().write(DATA_FROM_CLIENT); + secureClientSocket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = secureClientSocket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + + synchronized (handshakeCompletedLock) { + if (!handshakeCompleted) { + handshakeCompletedLock.wait(TIMEOUT); + } + Assert.assertTrue(handshakeCompleted); } - retries++; + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + } finally { + forceClose(badClientSocket); + forceClose(secureClientSocket); + serverThread.shutdown(TIMEOUT); } + } - socket.getOutputStream().write("hellobello".getBytes()); - socket.getOutputStream().flush(); + /** + * Similar to the tests above, but the bad client disconnects immediately + * without sending any data. + */ + @Test + public void testTLSDetectionNonBlockingNonStrictServerDisconnectedClient() throws Exception { + Socket clientSocket = null; + Socket secureClientSocket = null; + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, true, DATA_TO_CLIENT); + serverThread.start(); + + try { + Socket badClientSocket = connectWithoutSSL(); + forceClose(badClientSocket); // close the bad client socket immediately + + clientSocket = connectWithoutSSL(); + clientSocket.getOutputStream().write(DATA_FROM_CLIENT); + clientSocket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = clientSocket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + + synchronized (handshakeCompletedLock) { + Assert.assertFalse(handshakeCompleted); + } - byte[] readBytes = new byte[testData.length]; - socket.getInputStream().read(readBytes, 0, testData.length); + secureClientSocket = connectWithSSL(); + secureClientSocket.getOutputStream().write(DATA_FROM_CLIENT); + secureClientSocket.getOutputStream().flush(); + buf = new byte[DATA_TO_CLIENT.length]; + bytesRead = secureClientSocket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(1)); - serverThread.join(TIMEOUT); + synchronized (handshakeCompletedLock) { + if (!handshakeCompleted) { + handshakeCompletedLock.wait(TIMEOUT); + } + Assert.assertTrue(handshakeCompleted); + } + } finally { + forceClose(clientSocket); + forceClose(secureClientSocket); + serverThread.shutdown(TIMEOUT); + } + } - Assert.assertArrayEquals(testData, readBytes); - assertThat("Data sent by the client is invalid", dataReadFromClient[0], equalTo("hello")); + /** + * Like the above test, but with a strict server socket (closes non-TLS + * connections after seeing that there is no handshake). + */ + @Test + public void testTLSDetectionNonBlockingStrictServerDisconnectedClient() throws Exception { + Socket secureClientSocket = null; + UnifiedServerThread serverThread = new UnifiedServerThread( + x509Util, localServerAddress, false, DATA_TO_CLIENT); + serverThread.start(); + + try { + Socket badClientSocket = connectWithoutSSL(); + forceClose(badClientSocket); // close the bad client socket immediately + + secureClientSocket = connectWithSSL(); + secureClientSocket.getOutputStream().write(DATA_FROM_CLIENT); + secureClientSocket.getOutputStream().flush(); + byte[] buf = new byte[DATA_TO_CLIENT.length]; + int bytesRead = secureClientSocket.getInputStream().read(buf, 0, buf.length); + Assert.assertEquals(buf.length, bytesRead); + Assert.assertArrayEquals(DATA_TO_CLIENT, buf); + + synchronized (handshakeCompletedLock) { + if (!handshakeCompleted) { + handshakeCompletedLock.wait(TIMEOUT); + } + Assert.assertTrue(handshakeCompleted); + } + Assert.assertArrayEquals(DATA_FROM_CLIENT, serverThread.getDataFromClient(0)); + } finally { + forceClose(secureClientSocket); + serverThread.shutdown(TIMEOUT); + } } }