nifi-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From thena...@apache.org
Subject [nifi] branch main updated: NIFI-7468 Updated SSLSocketChannel to support TLS 1.3
Date Wed, 23 Jun 2021 02:31:31 GMT
This is an automated email from the ASF dual-hosted git repository.

thenatog pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/nifi.git


The following commit(s) were added to refs/heads/main by this push:
     new 6a83115  NIFI-7468 Updated SSLSocketChannel to support TLS 1.3
6a83115 is described below

commit 6a83115d6aea68a857c7cd409f33b6ba1d519a1e
Author: exceptionfactory <exceptionfactory@apache.org>
AuthorDate: Thu Jun 10 14:40:53 2021 -0500

    NIFI-7468 Updated SSLSocketChannel to support TLS 1.3
    
    - Handling additional FINISHED Handshake Status for TLS 1.3 Post-Handshake Messages per RFC 8446 Section 4.6
    - Removed clearing buffers after handshake to avoid losing packets
    - Updated read() method to check Handshake Status after SSLEngine.unwrap()
    - Changed SSLSocketChannelSender to close SSLSocketChannel before other resources
    - Added ChannelStatus enum and convenience logging methods for tracing status
    - Added unit tests for TLS 1.2 and 1.3 using Netty server and client handlers
    
    NIFI-8704 Updated netty-handler to 4.1.65.Final
    
    NIFI-7468 Corrected SSLSocketChannel.read() to return byte read
    
    NIFI-7468 Adjusted comment formatting
    
    Signed-off-by: Nathan Gough <thenatog@gmail.com>
    
    This closes #5152.
---
 nifi-commons/nifi-security-socket-ssl/pom.xml      |  12 +
 .../remote/io/socket/ssl/SSLSocketChannel.java     | 942 ++++++++++++---------
 .../remote/io/socket/ssl/SSLSocketChannelTest.java | 315 +++++++
 .../util/put/sender/SSLSocketChannelSender.java    |   5 +-
 4 files changed, 862 insertions(+), 412 deletions(-)

diff --git a/nifi-commons/nifi-security-socket-ssl/pom.xml b/nifi-commons/nifi-security-socket-ssl/pom.xml
index c4edb99..b4592a5 100644
--- a/nifi-commons/nifi-security-socket-ssl/pom.xml
+++ b/nifi-commons/nifi-security-socket-ssl/pom.xml
@@ -31,5 +31,17 @@
             <groupId>org.slf4j</groupId>
             <artifactId>slf4j-api</artifactId>
         </dependency>
+        <dependency>
+            <groupId>org.apache.nifi</groupId>
+            <artifactId>nifi-security-utils</artifactId>
+            <version>1.14.0-SNAPSHOT</version>
+            <scope>test</scope>
+        </dependency>
+        <dependency>
+            <groupId>io.netty</groupId>
+            <artifactId>netty-handler</artifactId>
+            <version>4.1.65.Final</version>
+            <scope>test</scope>
+        </dependency>
     </dependencies>
 </project>
diff --git a/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java b/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java
index 59902d0..9a5cdd8 100644
--- a/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java
+++ b/nifi-commons/nifi-security-socket-ssl/src/main/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannel.java
@@ -26,8 +26,9 @@ import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLEngine;
 import javax.net.ssl.SSLEngineResult;
 import javax.net.ssl.SSLEngineResult.Status;
+import javax.net.ssl.SSLException;
 import javax.net.ssl.SSLHandshakeException;
-import javax.net.ssl.SSLPeerUnverifiedException;
+import javax.net.ssl.SSLSession;
 import java.io.Closeable;
 import java.io.IOException;
 import java.net.InetAddress;
@@ -38,89 +39,87 @@ import java.net.SocketTimeoutException;
 import java.nio.ByteBuffer;
 import java.nio.channels.ClosedByInterruptException;
 import java.nio.channels.SocketChannel;
-import java.security.cert.Certificate;
-import java.security.cert.CertificateException;
-import java.security.cert.X509Certificate;
 import java.util.concurrent.TimeUnit;
 
+/**
+ * SSLSocketChannel supports reading and writing bytes using TLS and NIO SocketChannels with configurable timeouts
+ */
 public class SSLSocketChannel implements Closeable {
+    private static final Logger LOGGER = LoggerFactory.getLogger(SSLSocketChannel.class);
 
-    public static final int MAX_WRITE_SIZE = 65536;
-
-    private static final Logger logger = LoggerFactory.getLogger(SSLSocketChannel.class);
+    private static final int DISCARD_BUFFER_LENGTH = 8192;
+    private static final int END_OF_STREAM = -1;
+    private static final byte[] EMPTY_MESSAGE = new byte[0];
     private static final long BUFFER_FULL_EMPTY_WAIT_NANOS = TimeUnit.NANOSECONDS.convert(1, TimeUnit.MILLISECONDS);
+    private static final long FINISH_CONNECT_SLEEP = 50;
+    private static final long INITIAL_INCREMENTAL_SLEEP = 1;
+    private static final boolean CLIENT_AUTHENTICATION_REQUIRED = true;
 
     private final String remoteAddress;
     private final int port;
     private final SSLEngine engine;
     private final SocketAddress socketAddress;
-
-    private BufferStateManager streamInManager;
-    private BufferStateManager streamOutManager;
-    private BufferStateManager appDataManager;
-
-    private SocketChannel channel;
-
-    private final byte[] oneByteBuffer = new byte[1];
-
+    private final BufferStateManager streamInManager;
+    private final BufferStateManager streamOutManager;
+    private final BufferStateManager appDataManager;
+    private final SocketChannel channel;
     private int timeoutMillis = 30000;
-    private volatile boolean connected = false;
-    private boolean handshaking = false;
-    private boolean closed = false;
-    private volatile boolean interrupted = false;
 
-    public SSLSocketChannel(final SSLContext sslContext, final String hostname, final int port, final InetAddress localAddress, final boolean client) throws IOException {
-        this.socketAddress = new InetSocketAddress(hostname, port);
-        this.channel = SocketChannel.open();
-        if (localAddress != null) {
-            final SocketAddress localSocketAddress = new InetSocketAddress(localAddress, 0);
-            this.channel.bind(localSocketAddress);
-        }
-        this.remoteAddress = hostname;
+    private volatile boolean interrupted = false;
+    private volatile ChannelStatus channelStatus = ChannelStatus.DISCONNECTED;
+
+    /**
+     * SSLSocketChannel constructor with SSLContext and remote address parameters
+     *
+     * @param sslContext    SSLContext used to create SSLEngine with specified client mode
+     * @param remoteAddress Remote Address used for connection
+     * @param port          Remote Port used for connection
+     * @param bindAddress   Local address used for binding server channel when provided
+     * @param useClientMode Use Client Mode
+     * @throws IOException Thrown on failures creating Socket Channel
+     */
+    public SSLSocketChannel(final SSLContext sslContext, final String remoteAddress, final int port, final InetAddress bindAddress, final boolean useClientMode) throws IOException {
+        this.engine = createEngine(sslContext, useClientMode);
+        this.channel = createSocketChannel(bindAddress);
+        this.socketAddress = new InetSocketAddress(remoteAddress, port);
+        this.remoteAddress = remoteAddress;
         this.port = port;
-        this.engine = sslContext.createSSLEngine();
-        this.engine.setUseClientMode(client);
-        engine.setNeedClientAuth(true);
 
         streamInManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
         streamOutManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
         appDataManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()));
     }
 
-    public SSLSocketChannel(final SSLContext sslContext, final SocketChannel socketChannel, final boolean client) throws IOException {
-        if (!socketChannel.isConnected()) {
-            throw new IllegalArgumentException("Cannot pass an un-connected SocketChannel");
-        }
-
-        this.channel = socketChannel;
-
-        this.socketAddress = socketChannel.getRemoteAddress();
-        final Socket socket = socketChannel.socket();
-        this.remoteAddress = socket.getInetAddress().toString();
-        this.port = socket.getPort();
-
-        this.engine = sslContext.createSSLEngine();
-        this.engine.setUseClientMode(client);
-        this.engine.setNeedClientAuth(true);
-
-        streamInManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
-        streamOutManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
-        appDataManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()));
+    /**
+     * SSLSocketChannel constructor with SSLContext and connected SocketChannel
+     *
+     * @param sslContext    SSLContext used to create SSLEngine with specified client mode
+     * @param socketChannel Connected SocketChannel
+     * @param useClientMode Use Client Mode
+     * @throws IOException Thrown on SocketChannel.getRemoteAddress()
+     */
+    public SSLSocketChannel(final SSLContext sslContext, final SocketChannel socketChannel, final boolean useClientMode) throws IOException {
+        this(createEngine(sslContext, useClientMode), socketChannel);
     }
 
+    /**
+     * SSLSocketChannel constructor with configured SSLEngine and connected SocketChannel
+     *
+     * @param sslEngine     SSLEngine configured with mode and client authentication
+     * @param socketChannel Connected SocketChannel
+     * @throws IOException Thrown on SocketChannel.getRemoteAddress()
+     */
     public SSLSocketChannel(final SSLEngine sslEngine, final SocketChannel socketChannel) throws IOException {
         if (!socketChannel.isConnected()) {
-            throw new IllegalArgumentException("Cannot pass an un-connected SocketChannel");
+            throw new IllegalArgumentException("Connected SocketChannel required");
         }
 
+        socketChannel.configureBlocking(false);
         this.channel = socketChannel;
-
         this.socketAddress = socketChannel.getRemoteAddress();
         final Socket socket = socketChannel.socket();
         this.remoteAddress = socket.getInetAddress().toString();
         this.port = socket.getPort();
-
-        // don't set useClientMode or needClientAuth, use the engine as is and let the caller configure it
         this.engine = sslEngine;
 
         streamInManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
@@ -128,166 +127,64 @@ public class SSLSocketChannel implements Closeable {
         appDataManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()));
     }
 
-    public void setTimeout(final int millis) {
-        this.timeoutMillis = millis;
+    public void setTimeout(final int timeoutMillis) {
+        this.timeoutMillis = timeoutMillis;
     }
 
     public int getTimeout() {
         return timeoutMillis;
     }
 
+    /**
+     * Connect Channel when not connected and perform handshake process
+     *
+     * @throws IOException Thrown on connection failures
+     */
     public void connect() throws IOException {
+        channelStatus = ChannelStatus.CONNECTING;
+
         try {
-            channel.configureBlocking(false);
             if (!channel.isConnected()) {
-                final long startTime = System.currentTimeMillis();
+                logOperation("Connection Started");
+                final long started = System.currentTimeMillis();
 
                 if (!channel.connect(socketAddress)) {
                     while (!channel.finishConnect()) {
-                        if (interrupted) {
-                            throw new TransmissionDisabledException();
-                        }
-                        if (System.currentTimeMillis() > startTime + timeoutMillis) {
-                            throw new SocketTimeoutException("Timed out connecting to " + remoteAddress + ":" + port);
-                        }
+                        checkInterrupted();
+                        checkTimeoutExceeded(started);
 
                         try {
-                            Thread.sleep(50L);
+                            TimeUnit.MILLISECONDS.sleep(FINISH_CONNECT_SLEEP);
                         } catch (final InterruptedException e) {
+                            logOperation("Connection Interrupted");
                         }
                     }
                 }
             }
-            engine.beginHandshake();
-
-            performHandshake();
-            logger.debug("{} Successfully completed SSL handshake", this);
-
-            streamInManager.clear();
-            streamOutManager.clear();
-            appDataManager.clear();
-
-            connected = true;
+            channelStatus = ChannelStatus.CONNECTED;
         } catch (final Exception e) {
-            logger.error("{} failed to connect", this, e);
-            closeQuietly(channel);
-            engine.closeInbound();
-            engine.closeOutbound();
-            throw e;
-        }
-    }
-
-    public String getDn() throws CertificateException, SSLPeerUnverifiedException {
-        final Certificate[] certs = engine.getSession().getPeerCertificates();
-        if (certs == null || certs.length == 0) {
-            throw new SSLPeerUnverifiedException("No certificates found");
+            close();
+            throw new SSLException(String.format("[%s:%d] Connection Failed", remoteAddress, port), e);
         }
 
-        final Certificate certificate = certs[0];
-        if (certificate instanceof X509Certificate) {
-            final X509Certificate peerCertificate = (X509Certificate) certificate;
-            peerCertificate.checkValidity();
-            return peerCertificate.getSubjectDN().getName().trim();
-        } else {
-            throw new CertificateException(String.format("X.509 Certificate class not found [%s]", certificate.getClass()));
-        }
-    }
-
-    private void performHandshake() throws IOException {
-        // Generate handshake message
-        final byte[] emptyMessage = new byte[0];
-        handshaking = true;
-        logger.debug("{} Performing Handshake", this);
-
         try {
-            while (true) {
-                switch (engine.getHandshakeStatus()) {
-                    case FINISHED:
-                        return;
-                    case NEED_WRAP: {
-                        final ByteBuffer appDataOut = ByteBuffer.wrap(emptyMessage);
-
-                        final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
-
-                        final SSLEngineResult wrapHelloResult = engine.wrap(appDataOut, outboundBuffer);
-                        if (wrapHelloResult.getStatus() == Status.BUFFER_OVERFLOW) {
-                            streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
-                            continue;
-                        }
-
-                        if (wrapHelloResult.getStatus() != Status.OK) {
-                            throw new SSLHandshakeException("Could not generate SSL Handshake information: SSLEngineResult: "
-                                    + wrapHelloResult.toString());
-                        }
-
-                        logger.trace("{} Handshake response after wrapping: {}", this, wrapHelloResult);
-
-                        final ByteBuffer readableStreamOut = streamOutManager.prepareForRead(1);
-                        final int bytesToSend = readableStreamOut.remaining();
-                        writeFully(readableStreamOut);
-                        logger.trace("{} Sent {} bytes of wrapped data for handshake", this, bytesToSend);
-
-                        streamOutManager.clear();
-                    }
-                    continue;
-                    case NEED_UNWRAP: {
-                        final ByteBuffer readableDataIn = streamInManager.prepareForRead(0);
-                        final ByteBuffer appData = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
-
-                        // Read handshake response from other side
-                        logger.trace("{} Unwrapping: {} to {}", this, readableDataIn, appData);
-                        SSLEngineResult handshakeResponseResult = engine.unwrap(readableDataIn, appData);
-                        logger.trace("{} Handshake response after unwrapping: {}", this, handshakeResponseResult);
-
-                        if (handshakeResponseResult.getStatus() == Status.BUFFER_UNDERFLOW) {
-                            final ByteBuffer writableDataIn = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
-                            final int bytesRead = readData(writableDataIn);
-                            if (bytesRead > 0) {
-                                logger.trace("{} Read {} bytes for handshake", this, bytesRead);
-                            }
-
-                            if (bytesRead < 0) {
-                                throw new SSLHandshakeException("Reached End-of-File marker while performing handshake");
-                            }
-                        } else if (handshakeResponseResult.getStatus() == Status.CLOSED) {
-                            throw new IOException("Channel was closed by peer during handshake");
-                        } else {
-                            streamInManager.compact();
-                            appDataManager.clear();
-                        }
-                    }
-                    break;
-                    case NEED_TASK:
-                        performTasks();
-                        continue;
-                    case NOT_HANDSHAKING:
-                        return;
-                }
-            }
-        } finally {
-            handshaking = false;
-        }
-    }
-
-    private void performTasks() {
-        Runnable runnable;
-        while ((runnable = engine.getDelegatedTask()) != null) {
-            runnable.run();
-        }
-    }
-
-    private void closeQuietly(final Closeable closeable) {
-        try {
-            closeable.close();
-        } catch (final Exception e) {
+            performHandshake();
+        } catch (final IOException e) {
+            close();
+            throw new SSLException(String.format("[%s:%d] Handshake Failed", remoteAddress, port), e);
         }
     }
 
+    /**
+     * Shutdown Socket Channel input and read available bytes
+     *
+     * @throws IOException Thrown on Socket Channel failures
+     */
     public void consume() throws IOException {
         channel.shutdownInput();
 
-        final byte[] b = new byte[4096];
-        final ByteBuffer buffer = ByteBuffer.wrap(b);
+        final byte[] byteBuffer = new byte[DISCARD_BUFFER_LENGTH];
+        final ByteBuffer buffer = ByteBuffer.wrap(byteBuffer);
         int readCount;
         do {
             readCount = channel.read(buffer);
@@ -295,209 +192,104 @@ public class SSLSocketChannel implements Closeable {
         } while (readCount > 0);
     }
 
-    private int readData(final ByteBuffer dest) throws IOException {
-        final long startTime = System.currentTimeMillis();
-
-        while (true) {
-            if (interrupted) {
-                throw new TransmissionDisabledException();
-            }
-
-            if (dest.remaining() == 0) {
-                return 0;
-            }
-
-            final int readCount = channel.read(dest);
-
-            long sleepNanos = 1L;
-            if (readCount == 0) {
-                if (System.currentTimeMillis() > startTime + timeoutMillis) {
-                    throw new SocketTimeoutException("Timed out reading from socket connected to " + remoteAddress + ":" + port);
-                }
-                try {
-                    TimeUnit.NANOSECONDS.sleep(sleepNanos);
-                } catch (InterruptedException e) {
-                    close();
-                    Thread.currentThread().interrupt(); // set the interrupt status
-                    throw new ClosedByInterruptException();
-                }
-
-                sleepNanos = Math.min(sleepNanos * 2, BUFFER_FULL_EMPTY_WAIT_NANOS);
-
-                continue;
-            }
-
-            logger.trace("{} Read {} bytes", this, readCount);
-            return readCount;
-        }
-    }
-
-    private Status encryptAndWriteFully(final BufferStateManager src) throws IOException {
-        SSLEngineResult result = null;
-
-        final ByteBuffer buff = src.prepareForRead(0);
-        final ByteBuffer outBuff = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
-
-        logger.trace("{} Encrypting {} bytes", this, buff.remaining());
-        while (buff.remaining() > 0) {
-            result = engine.wrap(buff, outBuff);
-            if (result.getStatus() == Status.OK) {
-                final ByteBuffer readableOutBuff = streamOutManager.prepareForRead(0);
-                writeFully(readableOutBuff);
-                streamOutManager.clear();
-            } else {
-                return result.getStatus();
-            }
-        }
-
-        return result.getStatus();
-    }
-
-    private void writeFully(final ByteBuffer src) throws IOException {
-        long lastByteWrittenTime = System.currentTimeMillis();
-
-        int bytesWritten = 0;
-        while (src.hasRemaining()) {
-            if (interrupted) {
-                throw new TransmissionDisabledException();
-            }
-
-            final int written = channel.write(src);
-            bytesWritten += written;
-            final long now = System.currentTimeMillis();
-            long sleepNanos = 1L;
-
-            if (written > 0) {
-                lastByteWrittenTime = now;
-            } else {
-                if (now > lastByteWrittenTime + timeoutMillis) {
-                    throw new SocketTimeoutException("Timed out writing to socket connected to " + remoteAddress + ":" + port);
-                }
-                try {
-                    TimeUnit.NANOSECONDS.sleep(sleepNanos);
-                } catch (final InterruptedException e) {
-                    close();
-                    Thread.currentThread().interrupt(); // set the interrupt status
-                    throw new ClosedByInterruptException();
-                }
-
-                sleepNanos = Math.min(sleepNanos * 2, BUFFER_FULL_EMPTY_WAIT_NANOS);
-            }
-        }
-
-        logger.trace("{} Wrote {} bytes", this, bytesWritten);
-    }
-
+    /**
+     * Is Channel Closed
+     *
+     * @return Channel Closed Status
+     */
     public boolean isClosed() {
-        if (closed) {
+        if (ChannelStatus.CLOSED == channelStatus) {
             return true;
         }
-        // need to detect if peer has sent closure handshake...if so the answer is true
-        final ByteBuffer writableInBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
-        int readCount = 0;
+
+        // Read Channel to determine closed status
+        final ByteBuffer inputBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
+        int bytesRead;
         try {
-            readCount = channel.read(writableInBuffer);
-        } catch (IOException e) {
-            logger.error("{} failed to read data", this, e);
-            readCount = -1; // treat the condition same as if End of Stream
+            bytesRead = channel.read(inputBuffer);
+        } catch (final IOException e) {
+            LOGGER.warn("[{}:{}] Closed Status Read Failed", remoteAddress, port, e);
+            bytesRead = END_OF_STREAM;
         }
-        if (readCount == 0) {
-            return false;
-        }
-        if (readCount > 0) {
-            logger.trace("{} Read {} bytes", this, readCount);
+        logOperationBytes("Closed Status Read", bytesRead);
 
-            final ByteBuffer streamInBuffer = streamInManager.prepareForRead(1);
-            final ByteBuffer appDataBuffer = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
+        if (bytesRead == 0) {
+            return false;
+        } else if (bytesRead > 0) {
             try {
-                SSLEngineResult unwrapResponse = engine.unwrap(streamInBuffer, appDataBuffer);
-                logger.trace("{} When checking if closed, (handshake={}) Unwrap response: {}", this, handshaking, unwrapResponse);
-                if (unwrapResponse.getStatus().equals(Status.CLOSED)) {
-                    // Drain the incoming TCP buffer
-                    final ByteBuffer discardBuffer = ByteBuffer.allocate(8192);
-                    int bytesDiscarded = channel.read(discardBuffer);
-                    while (bytesDiscarded > 0) {
-                        discardBuffer.clear();
-                        bytesDiscarded = channel.read(discardBuffer);
-                    }
+                final SSLEngineResult unwrapResult = unwrap();
+                if (Status.CLOSED == unwrapResult.getStatus()) {
+                    readChannelDiscard();
                     engine.closeInbound();
                 } else {
                     streamInManager.compact();
                     return false;
                 }
-            } catch (IOException e) {
-                logger.error("{} failed to check if closed. Closing channel.", this, e);
+            } catch (final IOException e) {
+                LOGGER.warn("[{}:{}] Closed Status Unwrap Failed", remoteAddress, port, e);
             }
         }
-        // either readCount is -1, indicating an end of stream, or the peer sent a closure handshake
-        // so go ahead and close down the channel
-        closeQuietly(channel.socket());
-        closeQuietly(channel);
-        closed = true;
+
+        // Close Channel when encountering end of stream or closed status
+        try {
+            close();
+        } catch (final IOException e) {
+            LOGGER.warn("[{}:{}] Close Failed", remoteAddress, port, e);
+        }
         return true;
     }
 
+    /**
+     * Close Channel and process notifications
+     *
+     * @throws IOException Thrown on SSLEngine.wrap() failures
+     */
     @Override
     public void close() throws IOException {
-        logger.debug("{} Closing Connection", this);
-        if (channel == null) {
-            return;
-        }
-
-        if (closed) {
+        logOperation("Close Requested");
+        if (channelStatus == ChannelStatus.CLOSED) {
             return;
         }
 
         try {
             engine.closeOutbound();
 
-            final byte[] emptyMessage = new byte[0];
-
-            final ByteBuffer appDataOut = ByteBuffer.wrap(emptyMessage);
-            final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
-            final SSLEngineResult handshakeResult = engine.wrap(appDataOut, outboundBuffer);
-
-            if (handshakeResult.getStatus() != Status.CLOSED) {
-                throw new IOException("Invalid close state - will not send network data");
+            streamOutManager.clear();
+            final ByteBuffer inputBuffer = ByteBuffer.wrap(EMPTY_MESSAGE);
+            final ByteBuffer outputBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
+            SSLEngineResult wrapResult = wrap(inputBuffer, outputBuffer);
+            Status status = wrapResult.getStatus();
+            if (Status.OK == status) {
+                logOperation("Clearing Outbound Buffer");
+                outputBuffer.clear();
+                wrapResult = wrap(inputBuffer, outputBuffer);
+                status = wrapResult.getStatus();
             }
-
-            final ByteBuffer readableStreamOut = streamOutManager.prepareForRead(1);
-            writeFully(readableStreamOut);
-        } finally {
-            // Drain the incoming TCP buffer
-            final ByteBuffer discardBuffer = ByteBuffer.allocate(8192);
-            try {
-                int bytesDiscarded = channel.read(discardBuffer);
-                while (bytesDiscarded > 0) {
-                    discardBuffer.clear();
-                    bytesDiscarded = channel.read(discardBuffer);
+            if (Status.CLOSED == status) {
+                final ByteBuffer streamOutputBuffer = streamOutManager.prepareForRead(1);
+                try {
+                    writeChannel(streamOutputBuffer);
+                } catch (final IOException e) {
+                    logOperation(String.format("Write Close Notification Failed: %s", e.getMessage()));
                 }
-            } catch (Exception e) {
+            } else {
+                throw new SSLException(String.format("[%s:%d] Invalid Wrap Result Status [%s]", remoteAddress, port, status));
             }
-
+        } finally {
+            channelStatus = ChannelStatus.CLOSED;
+            readChannelDiscard();
             closeQuietly(channel.socket());
             closeQuietly(channel);
-            closed = true;
-        }
-    }
-
-    private int copyFromAppDataBuffer(final byte[] buffer, final int offset, final int len) {
-        // If any data already exists in the application data buffer, copy it to the buffer.
-        final ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
-
-        final int appDataRemaining = appDataBuffer.remaining();
-        if (appDataRemaining > 0) {
-            final int bytesToCopy = Math.min(len, appDataBuffer.remaining());
-            appDataBuffer.get(buffer, offset, bytesToCopy);
-
-            final int bytesCopied = appDataRemaining - appDataBuffer.remaining();
-            logger.trace("{} Copied {} ({}) bytes from unencrypted application buffer to user space",
-                    this, bytesToCopy, bytesCopied);
-            return bytesCopied;
+            logOperation("Close Completed");
         }
-        return 0;
     }
 
+    /**
+     * Get application bytes available for reading
+     *
+     * @return Number of application bytes available for reading
+     * @throws IOException Thrown on failures checking for available bytes
+     */
     public int available() throws IOException {
         ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
         ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
@@ -506,8 +298,7 @@ public class SSLSocketChannel implements Closeable {
             return buffered;
         }
 
-        final boolean wasAbleToRead = isDataAvailable();
-        if (!wasAbleToRead) {
+        if (!isDataAvailable()) {
             return 0;
         }
 
@@ -516,6 +307,12 @@ public class SSLSocketChannel implements Closeable {
         return appDataBuffer.remaining() + streamDataBuffer.remaining();
     }
 
+    /**
+     * Is data available for reading
+     *
+     * @return Data available status
+     * @throws IOException Thrown on SocketChannel.read() failures
+     */
     public boolean isDataAvailable() throws IOException {
         final ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
         final ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
@@ -529,101 +326,139 @@ public class SSLSocketChannel implements Closeable {
         return (bytesRead > 0);
     }
 
+    /**
+     * Read and return one byte
+     *
+     * @return Byte read or -1 when end of stream reached
+     * @throws IOException Thrown on read failures
+     */
     public int read() throws IOException {
-        final int bytesRead = read(oneByteBuffer);
-        if (bytesRead == -1) {
-            return -1;
+        final byte[] buffer = new byte[1];
+
+        final int bytesRead = read(buffer);
+        if (bytesRead == END_OF_STREAM) {
+            return END_OF_STREAM;
         }
-        return oneByteBuffer[0] & 0xFF;
+
+        return Byte.toUnsignedInt(buffer[0]);
     }
 
+    /**
+     * Read available bytes into buffer
+     *
+     * @param buffer Byte array buffer
+     * @return Number of bytes read
+     * @throws IOException Thrown on read failures
+     */
     public int read(final byte[] buffer) throws IOException {
         return read(buffer, 0, buffer.length);
     }
 
+    /**
+     * Read available bytes into buffer based on offset and length requested
+     *
+     * @param buffer Byte array buffer
+     * @param offset Buffer offset
+     * @param len    Length of bytes to read
+     * @return Number of bytes read
+     * @throws IOException Thrown on read failures
+     */
     public int read(final byte[] buffer, final int offset, final int len) throws IOException {
-        logger.debug("{} Reading up to {} bytes of data", this, len);
+        logOperationBytes("Read Requested", len);
+        checkChannelStatus();
 
-        if (!connected) {
-            connect();
-        }
-
-        int copied = copyFromAppDataBuffer(buffer, offset, len);
-        if (copied > 0) {
-            return copied;
+        int applicationBytesRead = readApplicationBuffer(buffer, offset, len);
+        if (applicationBytesRead > 0) {
+            return applicationBytesRead;
         }
-
         appDataManager.clear();
 
         while (true) {
-            // prepare buffers and call unwrap
-            final ByteBuffer streamInBuffer = streamInManager.prepareForRead(1);
-            SSLEngineResult unwrapResponse = null;
-            final ByteBuffer appDataBuffer = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
-            unwrapResponse = engine.unwrap(streamInBuffer, appDataBuffer);
-            logger.trace("{} When reading data, (handshake={}) Unwrap response: {}", this, handshaking, unwrapResponse);
-
-            switch (unwrapResponse.getStatus()) {
+            final SSLEngineResult unwrapResult = unwrap();
+
+            if (SSLEngineResult.HandshakeStatus.FINISHED == unwrapResult.getHandshakeStatus()) {
+                // RFC 8446 Section 4.6 describes Post-Handshake Messages for TLS 1.3
+                logOperation("Processing Post-Handshake Messages");
+                continue;
+            }
+
+            final Status status = unwrapResult.getStatus();
+            switch (status) {
                 case BUFFER_OVERFLOW:
-                    throw new SSLHandshakeException("Buffer Overflow, which is not allowed to happen from an unwrap");
-                case BUFFER_UNDERFLOW: {
-//                appDataManager.prepareForRead(engine.getSession().getApplicationBufferSize());
-
-                    final ByteBuffer writableInBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
-                    final int bytesRead = readData(writableInBuffer);
-                    if (bytesRead < 0) {
-                        return -1;
+                    throw new IllegalStateException(String.format("SSLEngineResult Status [%s] not allowed from unwrap", status));
+                case BUFFER_UNDERFLOW:
+                    final ByteBuffer streamBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
+                    final int channelBytesRead = readChannel(streamBuffer);
+                    logOperationBytes("Channel Read Completed", channelBytesRead);
+                    if (channelBytesRead == END_OF_STREAM) {
+                        return END_OF_STREAM;
                     }
-
-                    continue;
-                }
+                    break;
                 case CLOSED:
-                    copied = copyFromAppDataBuffer(buffer, offset, len);
-                    if (copied == 0) {
-                        return -1;
+                    applicationBytesRead = readApplicationBuffer(buffer, offset, len);
+                    if (applicationBytesRead == 0) {
+                        return END_OF_STREAM;
                     }
                     streamInManager.compact();
-                    return copied;
-                case OK: {
-                    copied = copyFromAppDataBuffer(buffer, offset, len);
-                    if (copied == 0) {
-                        throw new IOException("Failed to decrypt data");
+                    return applicationBytesRead;
+                case OK:
+                    applicationBytesRead = readApplicationBuffer(buffer, offset, len);
+                    if (applicationBytesRead == 0) {
+                        throw new IOException("Read Application Buffer Failed");
                     }
                     streamInManager.compact();
-                    return copied;
-                }
+                    return applicationBytesRead;
             }
         }
     }
 
+    /**
+     * Write one byte to channel
+     *
+     * @param data Byte to be written
+     * @throws IOException Thrown on write failures
+     */
     public void write(final int data) throws IOException {
         write(new byte[]{(byte) data}, 0, 1);
     }
 
+    /**
+     * Write bytes to channel
+     *
+     * @param data Byte array to be written
+     * @throws IOException Thrown on write failures
+     */
     public void write(final byte[] data) throws IOException {
         write(data, 0, data.length);
     }
 
+    /**
+     * Write data to channel performs multiple iterations based on data length
+     *
+     * @param data   Byte array to be written
+     * @param offset Byte array offset
+     * @param len    Length of bytes for writing
+     * @throws IOException Thrown on write failures
+     */
     public void write(final byte[] data, final int offset, final int len) throws IOException {
-        logger.debug("{} Writing {} bytes of data", this, len);
+        logOperationBytes("Write Started", len);
+        checkChannelStatus();
 
-        if (!connected) {
-            connect();
-        }
-
-        int iterations = len / MAX_WRITE_SIZE;
-        if (len % MAX_WRITE_SIZE > 0) {
+        final int applicationBufferSize = engine.getSession().getApplicationBufferSize();
+        logOperationBytes("Write Application Buffer Size", applicationBufferSize);
+        int iterations = len / applicationBufferSize;
+        if (len % applicationBufferSize > 0) {
             iterations++;
         }
 
         for (int i = 0; i < iterations; i++) {
             streamOutManager.clear();
-            final int itrOffset = offset + i * MAX_WRITE_SIZE;
-            final int itrLen = Math.min(len - itrOffset, MAX_WRITE_SIZE);
+            final int itrOffset = offset + i * applicationBufferSize;
+            final int itrLen = Math.min(len - itrOffset, applicationBufferSize);
             final ByteBuffer byteBuffer = ByteBuffer.wrap(data, itrOffset, itrLen);
 
-            final BufferStateManager buffMan = new BufferStateManager(byteBuffer, Direction.READ);
-            final Status status = encryptAndWriteFully(buffMan);
+            final BufferStateManager bufferStateManager = new BufferStateManager(byteBuffer, Direction.READ);
+            final Status status = wrapWriteChannel(bufferStateManager);
             switch (status) {
                 case BUFFER_OVERFLOW:
                     streamOutManager.ensureSize(engine.getSession().getPacketBufferSize());
@@ -639,7 +474,294 @@ public class SSLSocketChannel implements Closeable {
         }
     }
 
+    /**
+     * Interrupt processing and disable transmission
+     */
     public void interrupt() {
         this.interrupted = true;
     }
+
+    private void performHandshake() throws IOException {
+        logOperation("Handshake Started");
+        channelStatus = ChannelStatus.HANDSHAKING;
+        engine.beginHandshake();
+
+        SSLEngineResult.HandshakeStatus handshakeStatus = engine.getHandshakeStatus();
+        while (true) {
+            logHandshakeStatus(handshakeStatus);
+
+            switch (handshakeStatus) {
+                case FINISHED:
+                case NOT_HANDSHAKING:
+                    channelStatus = ChannelStatus.ESTABLISHED;
+                    final SSLSession session = engine.getSession();
+                    LOGGER.debug("[{}:{}] [{}] Negotiated Protocol [{}] Cipher Suite [{}]",
+                            remoteAddress,
+                            port,
+                            channelStatus,
+                            session.getProtocol(),
+                            session.getCipherSuite()
+                    );
+                    return;
+                case NEED_TASK:
+                    runDelegatedTasks();
+                    handshakeStatus = engine.getHandshakeStatus();
+                    break;
+                case NEED_UNWRAP:
+                    final SSLEngineResult unwrapResult = unwrap();
+                    handshakeStatus = unwrapResult.getHandshakeStatus();
+                    Status unwrapResultStatus = unwrapResult.getStatus();
+
+                    if (unwrapResultStatus == Status.BUFFER_UNDERFLOW) {
+                        final ByteBuffer writableDataIn = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
+                        final int bytesRead = readChannel(writableDataIn);
+                        logOperationBytes("Handshake Channel Read", bytesRead);
+
+                        if (bytesRead == END_OF_STREAM) {
+                            throw getHandshakeException(handshakeStatus, "End of Stream Found");
+                        }
+                    } else if (unwrapResultStatus == Status.CLOSED) {
+                        throw getHandshakeException(handshakeStatus, "Channel Closed");
+                    } else {
+                        streamInManager.compact();
+                        appDataManager.clear();
+                    }
+                    break;
+                case NEED_WRAP:
+                    final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
+                    final SSLEngineResult wrapResult = wrap(ByteBuffer.wrap(EMPTY_MESSAGE), outboundBuffer);
+                    handshakeStatus = wrapResult.getHandshakeStatus();
+                    final Status wrapResultStatus = wrapResult.getStatus();
+
+                    if (wrapResultStatus == Status.BUFFER_OVERFLOW) {
+                        streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
+                    } else if (wrapResultStatus == Status.OK) {
+                        final ByteBuffer streamBuffer = streamOutManager.prepareForRead(1);
+                        final int bytesRemaining = streamBuffer.remaining();
+                        writeChannel(streamBuffer);
+                        logOperationBytes("Handshake Channel Write Completed", bytesRemaining);
+                        streamOutManager.clear();
+                    } else {
+                        throw getHandshakeException(handshakeStatus, String.format("Wrap Failed [%s]", wrapResult.getStatus()));
+                    }
+                    break;
+            }
+        }
+    }
+
+    private int readChannel(final ByteBuffer outputBuffer) throws IOException {
+        logOperation("Channel Read Started");
+
+        final long started = System.currentTimeMillis();
+        long sleepNanoseconds = INITIAL_INCREMENTAL_SLEEP;
+        while (true) {
+            checkInterrupted();
+
+            if (outputBuffer.remaining() == 0) {
+                return 0;
+            }
+
+            final int channelBytesRead = channel.read(outputBuffer);
+            if (channelBytesRead == 0) {
+                checkTimeoutExceeded(started);
+                sleepNanoseconds = incrementalSleep(sleepNanoseconds);
+                continue;
+            }
+
+            return channelBytesRead;
+        }
+    }
+
+    private void writeChannel(final ByteBuffer inputBuffer) throws IOException {
+        long lastWriteCompleted = System.currentTimeMillis();
+
+        int totalBytes = 0;
+        long sleepNanoseconds = INITIAL_INCREMENTAL_SLEEP;
+        while (inputBuffer.hasRemaining()) {
+            checkInterrupted();
+
+            final int written = channel.write(inputBuffer);
+            totalBytes += written;
+
+            if (written > 0) {
+                lastWriteCompleted = System.currentTimeMillis();
+            } else {
+                checkTimeoutExceeded(lastWriteCompleted);
+                sleepNanoseconds = incrementalSleep(sleepNanoseconds);
+            }
+        }
+
+        logOperationBytes("Channel Write Completed", totalBytes);
+    }
+
+    private long incrementalSleep(final long nanoseconds) throws IOException {
+        try {
+            TimeUnit.NANOSECONDS.sleep(nanoseconds);
+        } catch (final InterruptedException e) {
+            close();
+            Thread.currentThread().interrupt();
+            throw new ClosedByInterruptException();
+        }
+        return Math.min(nanoseconds * 2, BUFFER_FULL_EMPTY_WAIT_NANOS);
+    }
+
+    private void readChannelDiscard() {
+        try {
+            final ByteBuffer readBuffer = ByteBuffer.allocate(DISCARD_BUFFER_LENGTH);
+            int bytesRead = channel.read(readBuffer);
+            while (bytesRead > 0) {
+                readBuffer.clear();
+                bytesRead = channel.read(readBuffer);
+            }
+        } catch (final IOException e) {
+            LOGGER.debug("[{}:{}] Read Channel Discard Failed", remoteAddress, port, e);
+        }
+    }
+
+    private int readApplicationBuffer(final byte[] buffer, final int offset, final int len) {
+        logOperationBytes("Application Buffer Read Requested", len);
+        final ByteBuffer appDataBuffer = appDataManager.prepareForRead(len);
+
+        final int appDataRemaining = appDataBuffer.remaining();
+        logOperationBytes("Application Buffer Remaining", appDataRemaining);
+        if (appDataRemaining > 0) {
+            final int bytesToCopy = Math.min(len, appDataBuffer.remaining());
+            appDataBuffer.get(buffer, offset, bytesToCopy);
+
+            final int bytesCopied = appDataRemaining - appDataBuffer.remaining();
+            logOperationBytes("Application Buffer Copied", bytesCopied);
+            return bytesCopied;
+        }
+        return 0;
+    }
+
+    private Status wrapWriteChannel(final BufferStateManager inputManager) throws IOException {
+        final ByteBuffer inputBuffer = inputManager.prepareForRead(0);
+        final ByteBuffer outputBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
+
+        logOperationBytes("Wrap Started", inputBuffer.remaining());
+        Status status = Status.OK;
+        while (inputBuffer.remaining() > 0) {
+            final SSLEngineResult result = wrap(inputBuffer, outputBuffer);
+            status = result.getStatus();
+            if (status == Status.OK) {
+                final ByteBuffer readableOutBuff = streamOutManager.prepareForRead(0);
+                writeChannel(readableOutBuff);
+                streamOutManager.clear();
+            } else {
+                break;
+            }
+        }
+
+        return status;
+    }
+
+    private SSLEngineResult wrap(final ByteBuffer inputBuffer, final ByteBuffer outputBuffer) throws SSLException {
+        final SSLEngineResult result = engine.wrap(inputBuffer, outputBuffer);
+        logEngineResult(result, "WRAP Completed");
+        return result;
+    }
+
+    private SSLEngineResult unwrap() throws IOException {
+        final ByteBuffer streamBuffer = streamInManager.prepareForRead(engine.getSession().getPacketBufferSize());
+        final ByteBuffer applicationBuffer = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
+        final SSLEngineResult result = engine.unwrap(streamBuffer, applicationBuffer);
+        logEngineResult(result, "UNWRAP Completed");
+        return result;
+    }
+
+    private void runDelegatedTasks() {
+        Runnable delegatedTask;
+        while ((delegatedTask = engine.getDelegatedTask()) != null) {
+            logOperation("Running Delegated Task");
+            delegatedTask.run();
+        }
+    }
+
+    private void closeQuietly(final Closeable closeable) {
+        try {
+            closeable.close();
+        } catch (final Exception e) {
+            logOperation(String.format("Close failed: %s", e.getMessage()));
+        }
+    }
+
+    private SSLHandshakeException getHandshakeException(final SSLEngineResult.HandshakeStatus handshakeStatus, final String message) {
+        final String formatted = String.format("[%s:%d] Handshake Status [%s] %s", remoteAddress, port, handshakeStatus, message);
+        return new SSLHandshakeException(formatted);
+    }
+
+    private void checkChannelStatus() throws IOException {
+        if (ChannelStatus.ESTABLISHED != channelStatus) {
+            connect();
+        }
+    }
+
+    private void checkInterrupted() {
+        if (interrupted) {
+            throw new TransmissionDisabledException();
+        }
+    }
+
+    private void checkTimeoutExceeded(final long started) throws SocketTimeoutException {
+        if (System.currentTimeMillis() > started + timeoutMillis) {
+            throw new SocketTimeoutException(String.format("Timeout Exceeded [%d ms] for [%s:%d]", timeoutMillis, remoteAddress, port));
+        }
+    }
+
+    private void logOperation(final String operation) {
+        LOGGER.trace("[{}:{}] [{}] {}", remoteAddress, port, channelStatus, operation);
+    }
+
+    private void logOperationBytes(final String operation, final int bytes) {
+        LOGGER.trace("[{}:{}] [{}] {} Bytes [{}]", remoteAddress, port, channelStatus, operation, bytes);
+    }
+
+    private void logHandshakeStatus(final SSLEngineResult.HandshakeStatus handshakeStatus) {
+        LOGGER.trace("[{}:{}] [{}] Handshake Status [{}]", remoteAddress, port, channelStatus, handshakeStatus);
+    }
+
+    private void logEngineResult(final SSLEngineResult result, final String method) {
+        LOGGER.trace("[{}:{}] [{}] {} Status [{}] Handshake Status [{}] Produced [{}] Consumed [{}]",
+                remoteAddress,
+                port,
+                channelStatus,
+                method,
+                result.getStatus(),
+                result.getHandshakeStatus(),
+                result.bytesProduced(),
+                result.bytesConsumed()
+        );
+    }
+
+    private static SocketChannel createSocketChannel(final InetAddress bindAddress) throws IOException {
+        final SocketChannel socketChannel = SocketChannel.open();
+        if (bindAddress != null) {
+            final SocketAddress socketAddress = new InetSocketAddress(bindAddress, 0);
+            socketChannel.bind(socketAddress);
+        }
+        socketChannel.configureBlocking(false);
+        return socketChannel;
+    }
+
+    private static SSLEngine createEngine(final SSLContext sslContext, final boolean useClientMode) {
+        final SSLEngine sslEngine = sslContext.createSSLEngine();
+        sslEngine.setUseClientMode(useClientMode);
+        sslEngine.setNeedClientAuth(CLIENT_AUTHENTICATION_REQUIRED);
+        return sslEngine;
+    }
+
+    private enum ChannelStatus {
+        DISCONNECTED,
+
+        CONNECTING,
+
+        CONNECTED,
+
+        HANDSHAKING,
+
+        ESTABLISHED,
+
+        CLOSED
+    }
 }
diff --git a/nifi-commons/nifi-security-socket-ssl/src/test/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelTest.java b/nifi-commons/nifi-security-socket-ssl/src/test/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelTest.java
new file mode 100644
index 0000000..aa9dde5
--- /dev/null
+++ b/nifi-commons/nifi-security-socket-ssl/src/test/java/org/apache/nifi/remote/io/socket/ssl/SSLSocketChannelTest.java
@@ -0,0 +1,315 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.remote.io.socket.ssl;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelPipeline;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.SimpleChannelInboundHandler;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.handler.codec.DelimiterBasedFrameDecoder;
+import io.netty.handler.codec.Delimiters;
+import io.netty.handler.codec.string.StringDecoder;
+import io.netty.handler.codec.string.StringEncoder;
+import io.netty.handler.ssl.SslHandler;
+import org.apache.nifi.remote.io.socket.NetworkUtils;
+import org.apache.nifi.security.util.KeyStoreUtils;
+import org.apache.nifi.security.util.SslContextFactory;
+import org.apache.nifi.security.util.TlsConfiguration;
+import org.apache.nifi.security.util.TlsPlatform;
+import org.junit.Assume;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLException;
+import java.io.File;
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.nio.channels.ServerSocketChannel;
+import java.nio.channels.SocketChannel;
+import java.nio.charset.Charset;
+import java.nio.charset.StandardCharsets;
+import java.security.GeneralSecurityException;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+public class SSLSocketChannelTest {
+    private static final String LOCALHOST = "localhost";
+
+    private static final int GROUP_THREADS = 1;
+
+    private static final boolean CLIENT_CHANNEL = true;
+
+    private static final boolean SERVER_CHANNEL = false;
+
+    private static final int CHANNEL_TIMEOUT = 15000;
+
+    private static final int CHANNEL_FAILURE_TIMEOUT = 100;
+
+    private static final int CHANNEL_POLL_TIMEOUT = 5000;
+
+    private static final long CHANNEL_SLEEP_BEFORE_READ = 100;
+
+    private static final int MAX_MESSAGE_LENGTH = 1024;
+
+    private static final String TLS_1_3 = "TLSv1.3";
+
+    private static final String TLS_1_2 = "TLSv1.2";
+
+    private static final String MESSAGE = "PING\n";
+
+    private static final Charset MESSAGE_CHARSET = StandardCharsets.UTF_8;
+
+    private static final byte[] MESSAGE_BYTES = MESSAGE.getBytes(StandardCharsets.UTF_8);
+
+    private static final int FIRST_BYTE_OFFSET = 1;
+
+    private static SSLContext sslContext;
+
+    @BeforeClass
+    public static void setConfiguration() throws GeneralSecurityException, IOException {
+        final TlsConfiguration tlsConfiguration = KeyStoreUtils.createTlsConfigAndNewKeystoreTruststore();
+        new File(tlsConfiguration.getKeystorePath()).deleteOnExit();
+        new File(tlsConfiguration.getTruststorePath()).deleteOnExit();
+        sslContext = SslContextFactory.createSslContext(tlsConfiguration);
+    }
+
+    @Test
+    public void testClientConnectFailed() throws IOException {
+        final int port = NetworkUtils.getAvailableTcpPort();
+        final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, LOCALHOST, port, null, CLIENT_CHANNEL);
+        sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT);
+        assertThrows(Exception.class, sslSocketChannel::connect);
+    }
+
+    @Test
+    public void testClientConnectHandshakeFailed() throws IOException {
+        assumeProtocolSupported(TLS_1_2);
+        final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
+
+        try (final SocketChannel socketChannel = SocketChannel.open()) {
+            final int port = NetworkUtils.getAvailableTcpPort();
+            startServer(group, port, TLS_1_2);
+
+            socketChannel.connect(new InetSocketAddress(LOCALHOST, port));
+            final SSLEngine sslEngine = createSslEngine(TLS_1_2, CLIENT_CHANNEL);
+
+            final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslEngine, socketChannel);
+            sslSocketChannel.setTimeout(CHANNEL_FAILURE_TIMEOUT);
+
+            group.shutdownGracefully().syncUninterruptibly();
+            assertThrows(SSLException.class, sslSocketChannel::connect);
+        } finally {
+            group.shutdownGracefully().syncUninterruptibly();
+        }
+    }
+
+    @Test
+    public void testClientConnectWriteReadTls12() throws Exception {
+        assumeProtocolSupported(TLS_1_2);
+        assertChannelConnectedWriteReadClosed(TLS_1_2);
+    }
+
+    @Test
+    public void testClientConnectWriteReadTls13() throws Exception {
+        assumeProtocolSupported(TLS_1_3);
+        assertChannelConnectedWriteReadClosed(TLS_1_3);
+    }
+
+    @Test(timeout = CHANNEL_TIMEOUT)
+    public void testServerReadWriteTls12() throws Exception {
+        assumeProtocolSupported(TLS_1_2);
+        assertServerChannelConnectedReadClosed(TLS_1_2);
+    }
+
+    @Test(timeout = CHANNEL_TIMEOUT)
+    public void testServerReadWriteTls13() throws Exception {
+        assumeProtocolSupported(TLS_1_3);
+        assertServerChannelConnectedReadClosed(TLS_1_3);
+    }
+
+    private void assumeProtocolSupported(final String protocol) {
+        Assume.assumeTrue(String.format("Protocol [%s] not supported", protocol), TlsPlatform.getSupportedProtocols().contains(protocol));
+    }
+
+    private void assertServerChannelConnectedReadClosed(final String enabledProtocol) throws IOException, InterruptedException {
+        final int port = NetworkUtils.getAvailableTcpPort();
+        final ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
+        final SocketAddress socketAddress = new InetSocketAddress(LOCALHOST, port);
+        serverSocketChannel.bind(socketAddress);
+
+        final Executor executor = Executors.newSingleThreadExecutor();
+        final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
+        try {
+            final Channel channel = startClient(group, port, enabledProtocol);
+
+            try {
+                final SocketChannel socketChannel = serverSocketChannel.accept();
+                final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, socketChannel, SERVER_CHANNEL);
+
+                final BlockingQueue<String> queue = new LinkedBlockingQueue<>();
+                final Runnable readCommand = () -> {
+                    final byte[] messageBytes = new byte[MESSAGE_BYTES.length];
+                    try {
+                        final int messageBytesRead = sslSocketChannel.read(messageBytes);
+                        if (messageBytesRead == MESSAGE_BYTES.length) {
+                            queue.add(new String(messageBytes, MESSAGE_CHARSET));
+                        }
+                    } catch (IOException e) {
+                        throw new UncheckedIOException(e);
+                    }
+                };
+                executor.execute(readCommand);
+                channel.writeAndFlush(MESSAGE).syncUninterruptibly();
+
+                final String messageRead = queue.poll(CHANNEL_POLL_TIMEOUT, TimeUnit.MILLISECONDS);
+                assertEquals("Message not matched", MESSAGE, messageRead);
+            } finally {
+                channel.close();
+            }
+        } finally {
+            group.shutdownGracefully().syncUninterruptibly();
+            serverSocketChannel.close();
+        }
+    }
+
+    private void assertChannelConnectedWriteReadClosed(final String enabledProtocol) throws IOException {
+        processClientSslSocketChannel(enabledProtocol, (sslSocketChannel -> {
+            try {
+                sslSocketChannel.connect();
+                assertFalse("Channel closed", sslSocketChannel.isClosed());
+
+                assertChannelWriteRead(sslSocketChannel);
+
+                sslSocketChannel.close();
+                assertTrue("Channel not closed", sslSocketChannel.isClosed());
+            } catch (final IOException e) {
+                throw new UncheckedIOException(String.format("Channel Failed for %s", enabledProtocol), e);
+            }
+        }));
+    }
+
+    private void assertChannelWriteRead(final SSLSocketChannel sslSocketChannel) throws IOException {
+        sslSocketChannel.write(MESSAGE_BYTES);
+
+        while (sslSocketChannel.available() == 0) {
+            try {
+                TimeUnit.MILLISECONDS.sleep(CHANNEL_SLEEP_BEFORE_READ);
+            } catch (final InterruptedException e) {
+                throw new RuntimeException(e);
+            }
+        }
+
+        final byte firstByteRead = (byte) sslSocketChannel.read();
+        assertEquals("Channel Message first byte not matched", MESSAGE_BYTES[0], firstByteRead);
+
+        final byte[] messageBytes = new byte[MESSAGE_BYTES.length];
+        messageBytes[0] = firstByteRead;
+
+        final int messageBytesRead = sslSocketChannel.read(messageBytes, FIRST_BYTE_OFFSET, messageBytes.length);
+        assertEquals("Channel Message Bytes Read not matched", messageBytes.length - FIRST_BYTE_OFFSET, messageBytesRead);
+
+        final String message  = new String(messageBytes, MESSAGE_CHARSET);
+        assertEquals("Channel Message not matched", MESSAGE, message);
+    }
+
+    private void processClientSslSocketChannel(final String enabledProtocol, final Consumer<SSLSocketChannel> channelConsumer) throws IOException {
+        final EventLoopGroup group = new NioEventLoopGroup(GROUP_THREADS);
+
+        try {
+            final int port = NetworkUtils.getAvailableTcpPort();
+            startServer(group, port, enabledProtocol);
+            final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, LOCALHOST, port, null, CLIENT_CHANNEL);
+            sslSocketChannel.setTimeout(CHANNEL_TIMEOUT);
+            channelConsumer.accept(sslSocketChannel);
+        } finally {
+            group.shutdownGracefully().syncUninterruptibly();
+        }
+    }
+
+    private Channel startClient(final EventLoopGroup group, final int port, final String enabledProtocol) {
+        final Bootstrap bootstrap = new Bootstrap();
+        bootstrap.group(group);
+        bootstrap.channel(NioSocketChannel.class);
+        bootstrap.handler(new ChannelInitializer<Channel>() {
+            @Override
+            protected void initChannel(final Channel channel) {
+                final ChannelPipeline pipeline = channel.pipeline();
+                final SSLEngine sslEngine = createSslEngine(enabledProtocol, CLIENT_CHANNEL);
+                setPipelineHandlers(pipeline, sslEngine);
+            }
+        });
+        return bootstrap.connect(LOCALHOST, port).syncUninterruptibly().channel();
+    }
+
+    private void startServer(final EventLoopGroup group, final int port, final String enabledProtocol) {
+        final ServerBootstrap bootstrap = new ServerBootstrap();
+        bootstrap.group(group);
+        bootstrap.channel(NioServerSocketChannel.class);
+        bootstrap.childHandler(new ChannelInitializer<Channel>() {
+            @Override
+            protected void initChannel(final Channel channel) {
+                final ChannelPipeline pipeline = channel.pipeline();
+                final SSLEngine sslEngine = createSslEngine(enabledProtocol, SERVER_CHANNEL);
+                setPipelineHandlers(pipeline, sslEngine);
+                pipeline.addLast(new SimpleChannelInboundHandler<String>() {
+                    @Override
+                    protected void channelRead0(ChannelHandlerContext channelHandlerContext, String s) throws Exception {
+                        channelHandlerContext.channel().writeAndFlush(MESSAGE).sync();
+                    }
+                });
+            }
+        });
+
+        final ChannelFuture bindFuture = bootstrap.bind(LOCALHOST, port);
+        bindFuture.syncUninterruptibly();
+    }
+
+    private SSLEngine createSslEngine(final String enabledProtocol, final boolean useClientMode) {
+        final SSLEngine sslEngine = sslContext.createSSLEngine();
+        sslEngine.setUseClientMode(useClientMode);
+        sslEngine.setEnabledProtocols(new String[]{enabledProtocol});
+        return sslEngine;
+    }
+
+    private void setPipelineHandlers(final ChannelPipeline pipeline, final SSLEngine sslEngine) {
+        pipeline.addLast(new SslHandler(sslEngine));
+        pipeline.addLast(new DelimiterBasedFrameDecoder(MAX_MESSAGE_LENGTH, Delimiters.lineDelimiter()));
+        pipeline.addLast(new StringDecoder());
+        pipeline.addLast(new StringEncoder());
+    }
+}
diff --git a/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SSLSocketChannelSender.java b/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SSLSocketChannelSender.java
index 70771f1..e2f05cc 100644
--- a/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SSLSocketChannelSender.java
+++ b/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SSLSocketChannelSender.java
@@ -68,9 +68,10 @@ public class SSLSocketChannelSender extends SocketChannelSender {
 
     @Override
     public void close() {
-        super.close();
-        IOUtils.closeQuietly(sslOutputStream);
+        // Close SSLSocketChannel before closing other resources
         IOUtils.closeQuietly(sslChannel);
+        IOUtils.closeQuietly(sslOutputStream);
+        super.close();
         sslChannel = null;
     }
 

Mime
View raw message