diff --git a/modules/core/src/main/java/org/apache/ignite/internal/util/nio/GridNioServer.java b/modules/core/src/main/java/org/apache/ignite/internal/util/nio/GridNioServer.java index 695eaed5228b3..fc23516c8afb5 100644 --- a/modules/core/src/main/java/org/apache/ignite/internal/util/nio/GridNioServer.java +++ b/modules/core/src/main/java/org/apache/ignite/internal/util/nio/GridNioServer.java @@ -1425,7 +1425,8 @@ private void processWriteSsl(SelectionKey key) throws IOException { try { boolean writeFinished = writeSslSystem(ses, sockCh); - if (!handshakeFinished) { + // If post-handshake message is not written fully (possible on JDK 17), we should retry. + if (!handshakeFinished || !writeFinished) { if (writeFinished) stopPollingForWrite(key, ses); diff --git a/modules/core/src/main/java/org/apache/ignite/internal/util/nio/ssl/BlockingSslHandler.java b/modules/core/src/main/java/org/apache/ignite/internal/util/nio/ssl/BlockingSslHandler.java index 40ded3c314ba5..1da7c8ab62ae3 100644 --- a/modules/core/src/main/java/org/apache/ignite/internal/util/nio/ssl/BlockingSslHandler.java +++ b/modules/core/src/main/java/org/apache/ignite/internal/util/nio/ssl/BlockingSslHandler.java @@ -85,8 +85,8 @@ public BlockingSslHandler(SSLEngine sslEngine, SocketChannel ch, boolean directBuf, ByteOrder order, - IgniteLogger log) - throws SSLException { + IgniteLogger log + ) { this.ch = ch; this.log = log; this.sslEngine = sslEngine; diff --git a/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/GridNioServerWrapper.java b/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/GridNioServerWrapper.java index 8302b1c950885..187b7fc3edb21 100644 --- a/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/GridNioServerWrapper.java +++ b/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/GridNioServerWrapper.java @@ -1186,12 +1186,6 @@ private long safeTcpHandshake( try { return tcpHandshakeExecutor.tcpHandshake(ch, rmtNodeId, sslMeta, msg); } - catch (IOException e) { - if (log.isDebugEnabled()) - log.debug("Failed to read from channel: " + e); - - throw new IgniteCheckedException("Failed to read from channel.", e); - } finally { if (!timeoutObject.cancel()) throw handshakeTimeoutException(); diff --git a/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/TcpHandshakeExecutor.java b/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/TcpHandshakeExecutor.java index cb5e256503ec7..387ceed0fdb8c 100644 --- a/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/TcpHandshakeExecutor.java +++ b/modules/core/src/main/java/org/apache/ignite/spi/communication/tcp/internal/TcpHandshakeExecutor.java @@ -22,6 +22,7 @@ import java.nio.ByteOrder; import java.nio.channels.SocketChannel; import java.util.UUID; +import javax.net.ssl.SSLException; import org.apache.ignite.IgniteCheckedException; import org.apache.ignite.IgniteLogger; import org.apache.ignite.internal.util.nio.ssl.BlockingSslHandler; @@ -67,228 +68,296 @@ public TcpHandshakeExecutor(IgniteLogger log, ClusterStateProvider stateProvider * @param ch Socket channel which using for handshake. * @param rmtNodeId Expected remote node. * @param sslMeta Required data for ssl. - * @param msg Handshake message which should be send during handshake. + * @param msg Handshake message which should be sent during handshake. * @return Handshake response from predefined variants from {@link RecoveryLastReceivedMessage}. - * @throws IgniteCheckedException If not related to IO exception happened. - * @throws IOException If reading or writing to socket is failed. + * @throws IgniteCheckedException If handshake failed. */ public long tcpHandshake( SocketChannel ch, UUID rmtNodeId, GridSslMeta sslMeta, HandshakeMessage msg - ) throws IgniteCheckedException, IOException { - long rcvCnt; + ) throws IgniteCheckedException { + BlockingTransport transport = stateProvider.isSslEnabled() ? + new SslTransport(sslMeta, ch, directBuffer, log) : new TcpTransport(ch); - BlockingSslHandler sslHnd = null; + ByteBuffer buf = transport.recieveNodeId(); - ByteBuffer buf; + if (buf == null) + return NEED_WAIT; - // Step 1. Get remote node response with the remote nodeId value. - if (stateProvider.isSslEnabled()) { - assert sslMeta != null; - - sslHnd = new BlockingSslHandler(sslMeta.sslEngine(), ch, directBuffer, ByteOrder.LITTLE_ENDIAN, log); - - if (!sslHnd.handshake()) - throw new HandshakeException("SSL handshake is not completed."); - - ByteBuffer handBuff = sslHnd.applicationBuffer(); - - if (handBuff.remaining() >= DIRECT_TYPE_SIZE) { - short msgType = makeMessageType(handBuff.get(0), handBuff.get(1)); - - if (msgType == HANDSHAKE_WAIT_MSG_TYPE) - return NEED_WAIT; - } - - if (handBuff.remaining() < NodeIdMessage.MESSAGE_FULL_SIZE) { - ByteBuffer readBuf = ByteBuffer.allocate(1000); - - while (handBuff.remaining() < NodeIdMessage.MESSAGE_FULL_SIZE) { - int read = ch.read(readBuf); - - if (read == -1) - throw new HandshakeException("Failed to read remote node ID (connection closed)."); + UUID rmtNodeId0 = U.bytesToUuid(buf.array(), DIRECT_TYPE_SIZE); - readBuf.flip(); + if (!rmtNodeId.equals(rmtNodeId0)) + throw new HandshakeException("Remote node ID is not as expected [expected=" + rmtNodeId + ", rcvd=" + rmtNodeId0 + ']'); + else if (log.isDebugEnabled()) + log.debug("Received remote node ID: " + rmtNodeId0); - sslHnd.decode(readBuf); + if (log.isDebugEnabled()) + log.debug("Writing handshake message [rmtNode=" + rmtNodeId + ", msg=" + msg + ']'); - if (handBuff.remaining() >= DIRECT_TYPE_SIZE) { - break; - } + transport.sendHandshake(msg); - readBuf.flip(); - } + buf = transport.recieveAcknowledge(); - buf = handBuff; + long rcvCnt = buf.getLong(DIRECT_TYPE_SIZE); - if (handBuff.remaining() >= DIRECT_TYPE_SIZE) { - short msgType = makeMessageType(handBuff.get(0), handBuff.get(1)); + if (log.isDebugEnabled()) + log.debug("Received handshake message [rmtNode=" + rmtNodeId + ", rcvCnt=" + rcvCnt + ']'); - if (msgType == HANDSHAKE_WAIT_MSG_TYPE) - return NEED_WAIT; - } - } - else - buf = handBuff; + if (rcvCnt == -1) { + if (log.isDebugEnabled()) + log.debug("Connection rejected, will retry client creation [rmtNode=" + rmtNodeId + ']'); } - else { - buf = ByteBuffer.allocate(NodeIdMessage.MESSAGE_FULL_SIZE); - for (int i = 0; i < NodeIdMessage.MESSAGE_FULL_SIZE; ) { - int read = ch.read(buf); + transport.onHandshakeFinished(sslMeta); + + return rcvCnt; + } - if (read == -1) + /** + * Encapsulates handshake logic. + */ + private abstract static class BlockingTransport { + /** + * Receive {@link NodeIdMessage}. + * + * @return Buffer with {@link NodeIdMessage}. + * @throws IgniteCheckedException If failed. + */ + ByteBuffer recieveNodeId() throws IgniteCheckedException { + ByteBuffer buf = ByteBuffer.allocate(NodeIdMessage.MESSAGE_FULL_SIZE) + .order(ByteOrder.LITTLE_ENDIAN); + + for (int totalBytes = 0; totalBytes < NodeIdMessage.MESSAGE_FULL_SIZE; ) { + int readBytes = read(buf); + + if (readBytes == -1) throw new HandshakeException("Failed to read remote node ID (connection closed)."); - if (read >= DIRECT_TYPE_SIZE) { + if (readBytes >= DIRECT_TYPE_SIZE) { short msgType = makeMessageType(buf.get(0), buf.get(1)); if (msgType == HANDSHAKE_WAIT_MSG_TYPE) - return NEED_WAIT; + return null; } - i += read; + totalBytes += readBytes; } - } - - UUID rmtNodeId0 = U.bytesToUuid(buf.array(), DIRECT_TYPE_SIZE); - if (!rmtNodeId.equals(rmtNodeId0)) - throw new HandshakeException("Remote node ID is not as expected [expected=" + rmtNodeId + - ", rcvd=" + rmtNodeId0 + ']'); - else if (log.isDebugEnabled()) - log.debug("Received remote node ID: " + rmtNodeId0); - - if (stateProvider.isSslEnabled()) { - assert sslHnd != null; + return buf; + } - U.writeFully(ch, sslHnd.encrypt(ByteBuffer.wrap(U.IGNITE_HEADER))); + /** + * Send {@link HandshakeMessage} to remote node. + * + * @param msg Handshake message. + * @throws IgniteCheckedException If failed. + */ + void sendHandshake(HandshakeMessage msg) throws IgniteCheckedException { + ByteBuffer buf = ByteBuffer.allocate(msg.getMessageSize() + U.IGNITE_HEADER.length) + .order(ByteOrder.LITTLE_ENDIAN) + .put(U.IGNITE_HEADER); + + msg.writeTo(buf, null); + buf.flip(); + + write(buf); } - else - U.writeFully(ch, ByteBuffer.wrap(U.IGNITE_HEADER)); - // Step 2. Prepare Handshake message to send to the remote node. - if (log.isDebugEnabled()) - log.debug("Writing handshake message [rmtNode=" + rmtNodeId + ", msg=" + msg + ']'); + /** + * Receive {@link RecoveryLastReceivedMessage} acknowledge message. + * + * @return Buffer with message. + * @throws IgniteCheckedException If failed. + */ + ByteBuffer recieveAcknowledge() throws IgniteCheckedException { + ByteBuffer buf = ByteBuffer.allocate(RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE) + .order(ByteOrder.LITTLE_ENDIAN); - buf = ByteBuffer.allocate(msg.getMessageSize()); + for (int totalBytes = 0; totalBytes < RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE; ) { + int readBytes = read(buf); - buf.order(ByteOrder.LITTLE_ENDIAN); + if (readBytes == -1) + throw new HandshakeException("Failed to read remote node recovery handshake " + + "(connection closed)."); - boolean written = msg.writeTo(buf, null); + totalBytes += readBytes; + } - assert written; + return buf; + } - buf.flip(); + /** + * Read data from media. + * + * @param buf Buffer to read into. + * @return Bytes read. + * @throws IgniteCheckedException If failed. + */ + abstract int read(ByteBuffer buf) throws IgniteCheckedException; + + /** + * Write data fully. + * @param buf Buffer to write. + * @throws IgniteCheckedException If failed. + */ + abstract void write(ByteBuffer buf) throws IgniteCheckedException; + + /** + * Do some post-handshake job if needed. + * + * @param sslMeta Ssl meta. + */ + void onHandshakeFinished(GridSslMeta sslMeta) { + // No-op. + } + } - if (stateProvider.isSslEnabled()) { - assert sslHnd != null; + /** + * Tcp plaintext transport. + */ + private static class TcpTransport extends BlockingTransport { + /** */ + private final SocketChannel ch; - U.writeFully(ch, sslHnd.encrypt(buf)); + /** */ + TcpTransport(SocketChannel ch) { + this.ch = ch; } - else - U.writeFully(ch, buf); - if (log.isDebugEnabled()) - log.debug("Waiting for handshake [rmtNode=" + rmtNodeId + ']'); - - // Step 3. Waiting for response from the remote node with their receive count message. - if (stateProvider.isSslEnabled()) { - assert sslHnd != null; + /** {@inheritDoc} */ + @Override int read(ByteBuffer buf) throws IgniteCheckedException { + try { + return ch.read(buf); + } + catch (IOException e) { + throw new IgniteCheckedException("Failed to read from channel", e); + } + } - buf = ByteBuffer.allocate(1000); - buf.order(ByteOrder.LITTLE_ENDIAN); + /** {@inheritDoc} */ + @Override void write(ByteBuffer buf) throws IgniteCheckedException { + try { + U.writeFully(ch, buf); + } + catch (IOException e) { + throw new IgniteCheckedException("Failed to write to channel", e); + } + } + } - ByteBuffer decode = ByteBuffer.allocate(2 * buf.capacity()); - decode.order(ByteOrder.LITTLE_ENDIAN); + /** Ssl transport */ + private static class SslTransport extends BlockingTransport { + /** */ + private static final int READ_BUFFER_CAPACITY = 1024; - for (int i = 0; i < RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE; ) { - int read = ch.read(buf); + /** */ + private final BlockingSslHandler handler; - if (read == -1) - throw new HandshakeException("Failed to read remote node recovery handshake " + - "(connection closed)."); + /** */ + private final SocketChannel ch; - buf.flip(); + /** */ + private final ByteBuffer readBuf; - ByteBuffer decode0 = sslHnd.decode(buf); + /** */ + SslTransport(GridSslMeta meta, SocketChannel ch, boolean directBuf, IgniteLogger log) throws IgniteCheckedException { + try { + this.ch = ch; + handler = new BlockingSslHandler(meta.sslEngine(), ch, directBuf, ByteOrder.LITTLE_ENDIAN, log); - i += decode0.remaining(); + if (!handler.handshake()) + throw new HandshakeException("SSL handshake is not completed."); - decode = appendAndResizeIfNeeded(decode, decode0); + readBuf = directBuf ? ByteBuffer.allocateDirect(READ_BUFFER_CAPACITY) : ByteBuffer.allocate(READ_BUFFER_CAPACITY); - buf.clear(); + readBuf.order(ByteOrder.LITTLE_ENDIAN); + } + catch (SSLException e) { + throw new IgniteCheckedException("SSL handhshake failed", e); } + } - decode.flip(); + /** {@inheritDoc} */ + @Override int read(ByteBuffer buf) throws IgniteCheckedException { + ByteBuffer appBuff = handler.applicationBuffer(); - rcvCnt = decode.getLong(DIRECT_TYPE_SIZE); + int read = copy(appBuff, buf); - if (decode.limit() > RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE) { - decode.position(RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE); + if (read > 0) + return read; - sslMeta.decodedBuffer(decode); - } + try { + while (read == 0) { + readBuf.clear(); - ByteBuffer inBuf = sslHnd.inputBuffer(); + if (ch.read(readBuf) < 0) + return -1; - if (inBuf.position() > 0) - sslMeta.encodedBuffer(inBuf); - } - else { - buf = ByteBuffer.allocate(RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE); + readBuf.flip(); - buf.order(ByteOrder.LITTLE_ENDIAN); + handler.decode(readBuf); - for (int i = 0; i < RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE; ) { - int read = ch.read(buf); + read = copy(appBuff, buf); + } + } + catch (SSLException e) { + throw new IgniteCheckedException("Failed to decrypt data", e); + } + catch (IOException e) { + throw new IgniteCheckedException("Failed to read from channel", e); + } - if (read == -1) - throw new HandshakeException("Failed to read remote node recovery handshake " + - "(connection closed)."); + return read; + } - i += read; + /** {@inheritDoc} */ + @Override void write(ByteBuffer buf) throws IgniteCheckedException { + try { + U.writeFully(ch, handler.encrypt(buf)); + } + catch (SSLException e) { + throw new IgniteCheckedException("Failed to encrypt data", e); + } + catch (IOException e) { + throw new IgniteCheckedException("Failed to write to channel", e); } - - rcvCnt = buf.getLong(DIRECT_TYPE_SIZE); } - if (log.isDebugEnabled()) - log.debug("Received handshake message [rmtNode=" + rmtNodeId + ", rcvCnt=" + rcvCnt + ']'); + /** {@inheritDoc} */ + @Override void onHandshakeFinished(GridSslMeta sslMeta) { + ByteBuffer appBuff = handler.applicationBuffer(); + if (appBuff.hasRemaining()) + sslMeta.decodedBuffer(appBuff); - if (rcvCnt == -1) { - if (log.isDebugEnabled()) - log.debug("Connection rejected, will retry client creation [rmtNode=" + rmtNodeId + ']'); + ByteBuffer inBuf = handler.inputBuffer(); + + if (inBuf.position() > 0) + sslMeta.encodedBuffer(inBuf); } - return rcvCnt; - } + /** + * @param src Source buffer. + * @param dst Destination buffer. + * @return Bytes copied. + */ + private int copy(ByteBuffer src, ByteBuffer dst) { + int remaining = Math.min(src.remaining(), dst.remaining()); - /** - * @param target Target buffer to append to. - * @param src Source buffer to get data. - * @return Original or expanded buffer. - */ - private ByteBuffer appendAndResizeIfNeeded(ByteBuffer target, ByteBuffer src) { - if (target.remaining() < src.remaining()) { - int newSize = Math.max(target.capacity() * 2, target.capacity() + src.remaining()); + if (remaining > 0) { + int oldLimit = src.limit(); - ByteBuffer tmp = ByteBuffer.allocate(newSize); + src.limit(src.position() + remaining); - tmp.order(target.order()); + dst.put(src); - target.flip(); + src.limit(oldLimit); + } - tmp.put(target); + src.compact(); - target = tmp; + return remaining; } - - target.put(src); - - return target; } } diff --git a/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/GridTcpCommunicationSpiAbstractTest.java b/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/GridTcpCommunicationSpiAbstractTest.java index a70078d395120..5752f7909571f 100644 --- a/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/GridTcpCommunicationSpiAbstractTest.java +++ b/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/GridTcpCommunicationSpiAbstractTest.java @@ -76,7 +76,7 @@ abstract class GridTcpCommunicationSpiAbstractTest extends GridAbstractCommunica // Test idle clients remove. for (CommunicationSpi spi : spis.values()) { - ConcurrentMap clients = GridTestUtils.getFieldValue(spi, "clientPool", "clients"); + ConcurrentMap clients = GridTestUtils.getFieldValue(spi, "clientPool", "clients"); assertEquals(getSpiCount() - 1, clients.size()); diff --git a/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/TcpCommunicationHandshakeTimeoutTest.java b/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/TcpCommunicationHandshakeTimeoutTest.java index 819dce188dfc5..c8eba14075cc4 100644 --- a/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/TcpCommunicationHandshakeTimeoutTest.java +++ b/modules/core/src/test/java/org/apache/ignite/spi/communication/tcp/TcpCommunicationHandshakeTimeoutTest.java @@ -17,7 +17,6 @@ package org.apache.ignite.spi.communication.tcp; -import java.io.IOException; import java.nio.channels.SocketChannel; import java.util.Arrays; import java.util.HashSet; @@ -137,7 +136,7 @@ public DelaydTcpHandshakeExecutor(TcpHandshakeExecutor delegate, AtomicBoolean n /** {@inheritDoc} */ @Override public long tcpHandshake(SocketChannel ch, UUID rmtNodeId, GridSslMeta sslMeta, - HandshakeMessage msg) throws IgniteCheckedException, IOException { + HandshakeMessage msg) throws IgniteCheckedException { if (needToDelayd.get()) { needToDelayd.set(false);