From 5474bf26a8aa7c522c63f8f7830e3a79fe91da99 Mon Sep 17 00:00:00 2001 From: jchrys Date: Thu, 8 Feb 2024 13:18:05 +0900 Subject: [PATCH] Fixed NPE with `SslMode.TUNNEL` Usage (#225) Motivation: A NPE was identified when utilizing `SslMode.TUNNEL`. The issue arises when `ConnectionContext#isMariaDb` is invoked from `SslBridgeHandler#isTls13Enabled`, leading to an NPE due to the `ConnectionContext` not being initialized at that time. Modification: Do not invoke `ConnectionContext#isMariaDb` when it is not initialized. Result: This change addresses the NPE issue, ensuring stability when `SslMode.TUNNEL` is selected. It resolves the problem reported in https://github.com/GoogleCloudPlatform/cloud-sql-jdbc-socket-factory/pull/1828 --------- Signed-off-by: jchrys Co-authored-by: Mirro Mutth --- pom.xml | 12 + .../r2dbc/mysql/ConnectionContext.java | 1 + .../r2dbc/mysql/client/SslBridgeHandler.java | 6 +- .../r2dbc/mysql/SslTunnelIntegrationTest.java | 282 ++++++++++++++++++ 4 files changed, 299 insertions(+), 2 deletions(-) create mode 100644 src/test/java/io/asyncer/r2dbc/mysql/SslTunnelIntegrationTest.java diff --git a/pom.xml b/pom.xml index 6e697487b..1f009e813 100644 --- a/pom.xml +++ b/pom.xml @@ -79,6 +79,7 @@ 0.3.0.RELEASE 3.0.2 24.1.0 + 1.77 @@ -117,6 +118,12 @@ ${java-annotations.version} provided + + org.bouncycastle + bcpkix-jdk18on + ${bouncy-castle.version} + test + @@ -240,6 +247,11 @@ jackson-annotations test + + org.bouncycastle + bcpkix-jdk18on + test + diff --git a/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java b/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java index 3039ea19b..f8f594fdc 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java @@ -57,6 +57,7 @@ public final class ConnectionContext implements CodecContext { */ private volatile short serverStatuses = ServerStatuses.AUTO_COMMIT; + @Nullable private volatile Capability capability = null; ConnectionContext(ZeroDateOption zeroDateOption, @Nullable Path localInfilePath, diff --git a/src/main/java/io/asyncer/r2dbc/mysql/client/SslBridgeHandler.java b/src/main/java/io/asyncer/r2dbc/mysql/client/SslBridgeHandler.java index 823a59187..952be3917 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/client/SslBridgeHandler.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/client/SslBridgeHandler.java @@ -220,8 +220,10 @@ static MySqlSslContextSpec forClient(MySqlSslConfiguration ssl, ConnectionContex .applicationProtocolConfig(null); String[] tlsProtocols = ssl.getTlsVersion(); - if (tlsProtocols.length > 0) { - builder.protocols(tlsProtocols); + if (tlsProtocols.length > 0 || ssl.getSslMode() == SslMode.TUNNEL) { + if (tlsProtocols.length > 0) { + builder.protocols(tlsProtocols); + } } else if (isTls13Enabled(context)) { builder.protocols(TLS_PROTOCOLS); } else { diff --git a/src/test/java/io/asyncer/r2dbc/mysql/SslTunnelIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/SslTunnelIntegrationTest.java new file mode 100644 index 000000000..9261f2514 --- /dev/null +++ b/src/test/java/io/asyncer/r2dbc/mysql/SslTunnelIntegrationTest.java @@ -0,0 +1,282 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + + +import io.asyncer.r2dbc.mysql.constant.SslMode; +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import javax.net.ssl.SSLException; +import java.net.InetSocketAddress; +import java.security.cert.CertificateException; +import java.time.Duration; + +import static org.assertj.core.api.Assertions.assertThat; + +public class SslTunnelIntegrationTest { + + private SelfSignedCertificate server; + + private SelfSignedCertificate client; + + private SslTunnelServer sslTunnelServer; + + @BeforeEach + void setUp() throws CertificateException, SSLException, InterruptedException { + server = new SelfSignedCertificate(); + client = new SelfSignedCertificate(); + final SslContext sslContext = SslContextBuilder.forServer(server.key(), server.cert()).build(); + sslTunnelServer = new SslTunnelServer("localhost", 3306, sslContext); + sslTunnelServer.setUp(); + } + + @AfterEach + void tearDown() throws InterruptedException { + server.delete(); + client.delete(); + sslTunnelServer.tearDown(); + } + + @Test + void sslTunnelConnectionTest() { + final String password = System.getProperty("test.mysql.password"); + assertThat(password).withFailMessage("Property test.mysql.password must exists and not be empty") + .isNotNull() + .isNotEmpty(); + + final MySqlConnectionConfiguration configuration = MySqlConnectionConfiguration + .builder() + .host("localhost") + .port(sslTunnelServer.getLocalPort()) + .connectTimeout(Duration.ofSeconds(3)) + .user("root") + .password(password) + .database("r2dbc") + .createDatabaseIfNotExist(true) + .sslMode(SslMode.TUNNEL) + .sslKey(client.privateKey().getAbsolutePath()) + .sslCert(client.certificate().getAbsolutePath()) + .sslCa(server.certificate().getAbsolutePath()) + .build(); + + final MySqlConnectionFactory connectionFactory = MySqlConnectionFactory.from(configuration); + + final MySqlConnection connection = connectionFactory.create().block(); + assert null != connection; + connection.createStatement("SELECT 3").execute() + .flatMap(it -> it.map((row, rowMetadata) -> row.get(0, Long.class))) + .doOnNext(it -> assertThat(it).isEqualTo(3L)) + .blockLast(); + + connection.close().block(); + } + + private static class SslTunnelServer { + + private final String remoteHost; + + private final int remotePort; + + private final SslContext sslContext; + + private volatile ChannelFuture channelFuture; + + + private SslTunnelServer(String remoteHost, int remotePort, SslContext sslContext) { + this.remoteHost = remoteHost; + this.remotePort = remotePort; + this.sslContext = sslContext; + } + + void setUp() throws InterruptedException { + // Configure the server. + ServerBootstrap b = new ServerBootstrap(); + b.localAddress(0) + .group(new NioEventLoopGroup()) + .channel(NioServerSocketChannel.class) + .childHandler(new ProxyInitializer(remoteHost, remotePort, sslContext)) + .childOption(ChannelOption.AUTO_READ, false); + + // Start the server. + channelFuture = b.bind().sync(); + } + + void tearDown() throws InterruptedException { + channelFuture.channel().close().sync(); + } + + int getLocalPort() { + return ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); + } + + } + + + private static class ProxyInitializer extends ChannelInitializer { + + private final String remoteHost; + + private final int remotePort; + + private final SslContext sslContext; + + ProxyInitializer(String remoteHost, int remotePort, SslContext sslContext) { + this.remoteHost = remoteHost; + this.remotePort = remotePort; + this.sslContext = sslContext; + } + + @Override + public void initChannel(SocketChannel ch) { + ch.pipeline().addLast(sslContext.newHandler(ch.alloc())); + ch.pipeline().addLast(new ProxyFrontendHandler(remoteHost, remotePort)); + } + } + + private static class ProxyFrontendHandler extends ChannelInboundHandlerAdapter { + + private final String remoteHost; + private final int remotePort; + + // As we use inboundChannel.eventLoop() when building the Bootstrap this does not need to be volatile as + // the outboundChannel will use the same EventLoop (and therefore Thread) as the inboundChannel. + private Channel outboundChannel; + + private ProxyFrontendHandler(String remoteHost, int remotePort) { + this.remoteHost = remoteHost; + this.remotePort = remotePort; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + final Channel inboundChannel = ctx.channel(); + + // Start the connection attempt. + Bootstrap b = new Bootstrap(); + b.group(inboundChannel.eventLoop()) + .channel(ctx.channel().getClass()) + .handler(new ProxyBackendHandler(inboundChannel)) + .option(ChannelOption.AUTO_READ, false); + ChannelFuture f = b.connect(remoteHost, remotePort); + outboundChannel = f.channel(); + f.addListener((ChannelFutureListener) future -> { + if (future.isSuccess()) { + // connection complete start to read first data + inboundChannel.read(); + } else { + // Close the connection if the connection attempt has failed. + inboundChannel.close(); + } + }); + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, Object msg) { + if (outboundChannel.isActive()) { + outboundChannel.writeAndFlush(msg).addListener((ChannelFutureListener) future -> { + if (future.isSuccess()) { + // was able to flush out data, start to read the next chunk + ctx.channel().read(); + } else { + future.channel().close(); + } + }); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + if (outboundChannel != null) { + closeOnFlush(outboundChannel); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + closeOnFlush(ctx.channel()); + } + + /** + * Closes the specified channel after all queued write requests are flushed. + */ + static void closeOnFlush(Channel ch) { + if (ch.isActive()) { + ch.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE); + } + } + } + + private static class ProxyBackendHandler extends ChannelInboundHandlerAdapter { + + private final Channel inboundChannel; + + private ProxyBackendHandler(Channel inboundChannel) { + this.inboundChannel = inboundChannel; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + if (!inboundChannel.isActive()) { + ProxyFrontendHandler.closeOnFlush(ctx.channel()); + } else { + ctx.read(); + } + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, Object msg) { + inboundChannel.writeAndFlush(msg).addListener((ChannelFutureListener) future -> { + if (future.isSuccess()) { + ctx.channel().read(); + } else { + future.channel().close(); + } + }); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + ProxyFrontendHandler.closeOnFlush(inboundChannel); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + ProxyFrontendHandler.closeOnFlush(ctx.channel()); + } + } + +}