From be7db3a07045e6a74b8748f613834902d882d9cf Mon Sep 17 00:00:00 2001 From: guoxu <13910754971@163.com> Date: Sun, 29 Sep 2024 17:50:39 +0800 Subject: [PATCH] [fix][dingo-executor] fix issues mysql client hanging and chinese character encoding. --- .../driver/mysql/command/MysqlCommands.java | 47 +++++++++++----- .../mysql/command/MysqlResponseHandler.java | 54 +++++++++---------- .../driver/mysql/netty/HandshakeHandler.java | 4 +- .../driver/mysql/packet/ERRPacket.java | 52 ++++++++++++++++-- .../driver/mysql/process/MessageProcess.java | 20 +++++-- 5 files changed, 124 insertions(+), 53 deletions(-) diff --git a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlCommands.java b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlCommands.java index 54aabfe85..eee0518ed 100644 --- a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlCommands.java +++ b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlCommands.java @@ -70,20 +70,23 @@ public class MysqlCommands { MysqlPacketFactory mysqlPacketFactory = MysqlPacketFactory.getInstance(); public static void executeShowFields(String table, AtomicLong packetId, MysqlConnection mysqlConnection) { + String connCharSet = null; try { + connCharSet = mysqlConnection.getConnection().getClientInfo(CONNECTION_CHARSET); ResultSet rs = mysqlConnection.getConnection().getMetaData().getColumns(null, null, table, null); MysqlResponseHandler.responseShowField(rs, packetId, mysqlConnection); } catch (SQLException e) { - MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, e); + MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, e, connCharSet); } } public void execute(QueryPacket queryPacket, MysqlConnection mysqlConnection) { String sql; + String characterSet = null; try { - String characterSet = mysqlConnection.getConnection().getClientInfo(CONNECTION_CHARSET); + characterSet = mysqlConnection.getConnection().getClientInfo(CONNECTION_CHARSET); characterSet = getCharacterSet(characterSet); sql = new String(queryPacket.message, characterSet); } catch (Exception e) { @@ -92,7 +95,8 @@ public void execute(QueryPacket queryPacket, AtomicLong packetId = new AtomicLong(queryPacket.packetId + 1); LogUtils.debug(log, "dingo connection:{}, receive sql:{}", mysqlConnection.getConnection().toString(), sql); if (mysqlConnection.passwordExpire && !doExpire(mysqlConnection, sql, packetId)) { - MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, ErrorCode.ER_PASSWORD_EXPIRE); + MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, ErrorCode.ER_PASSWORD_EXPIRE, + characterSet); return; } executeSingleQuery(sql, packetId, mysqlConnection); @@ -115,10 +119,14 @@ private static boolean doExpire(MysqlConnection mysqlConnection, String sql, Ato } } - if (sql.startsWith(setPwdSql1) || sql.startsWith(alterUserPwdSql1)) { - MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, - ErrorCode.ER_PASSWORD_EXPIRE); - return true; + try { + if (sql.startsWith(setPwdSql1) || sql.startsWith(alterUserPwdSql1)) { + MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, + ErrorCode.ER_PASSWORD_EXPIRE, mysqlConnection.getConnection().getClientInfo(CONNECTION_CHARSET)); + return true; + } + } catch(SQLException e) { + LogUtils.info(log, e.getMessage(), e); } return false; } @@ -126,9 +134,12 @@ private static boolean doExpire(MysqlConnection mysqlConnection, String sql, Ato public void prepare(MysqlConnection mysqlConnection, String sql) { DingoConnection connection = (DingoConnection) mysqlConnection.getConnection(); AtomicLong packetId = new AtomicLong(2); + String connCharSet = null; + try { DingoPreparedStatement preparedStatement = (DingoPreparedStatement) connection .prepareStatement(sql); + connCharSet = mysqlConnection.getConnection().getClientInfo(CONNECTION_CHARSET); Meta.StatementHandle statementHandle = preparedStatement.handle; String placeholder = "?"; int i = 0; @@ -173,7 +184,7 @@ public void prepare(MysqlConnection mysqlConnection, String sql) { MysqlResponseHandler.responsePrepare(preparePacket, mysqlConnection.channel); } catch (SQLException e) { LogUtils.info(log, e.getMessage(), e); - MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, e); + MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, e, connCharSet); } } @@ -181,8 +192,11 @@ public void executeSingleQuery(String sql, AtomicLong packetId, MysqlConnection mysqlConnection) { Statement statement = null; boolean hasResults; + String connCharSet = null; + try { statement = mysqlConnection.getConnection().createStatement(); + connCharSet = mysqlConnection.getConnection().getClientInfo(CONNECTION_CHARSET); hasResults = statement.execute(sql); if (hasResults) { // select @@ -212,10 +226,9 @@ public void executeSingleQuery(String sql, AtomicLong packetId, } catch (SQLException sqlException) { LogUtils.error(log, "sql exception sqlstate:" + sqlException.getSQLState() + ", code:" + sqlException.getErrorCode() + ", message:" + sqlException.getMessage()); - MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, sqlException); + MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, sqlException, connCharSet); } catch (Exception e) { LogUtils.error(log, e.getMessage(), e); - //MysqlResponseHandler.responseError(bakPacketId, mysqlConnection.channel, ErrorCode.ER_UNKNOWN_ERROR, ""); throw e; } finally { try { @@ -242,9 +255,13 @@ public void executeStatement(ExecuteStatementPacket statementPacket, AtomicLong packetId, MysqlConnection mysqlConnection ) { + String connectionCharSet = null; try { + connectionCharSet = mysqlConnection.getConnection().getClientInfo(CONNECTION_CHARSET); statementPacket.paramValMap.forEach((k, v) -> { + String connCharSet = null; try { + connCharSet = mysqlConnection.getConnection().getClientInfo(CONNECTION_CHARSET); switch (v.getType()) { case MysqlType.FIELD_TYPE_TINY: byte byteVal = v.getValue()[0]; @@ -323,15 +340,17 @@ public void executeStatement(ExecuteStatementPacket statementPacket, preparedStatement.setObject(k, charVal); } } catch (SQLException e) { - MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, e); + MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, e, connCharSet); } }); if (statementType == Meta.StatementType.SELECT) { + String connCharSet = null; try (ResultSet resultSet = preparedStatement.executeQuery()) { + connCharSet = mysqlConnection.getConnection().getClientInfo(CONNECTION_CHARSET); MysqlResponseHandler.responsePrepareExecute(resultSet, packetId, mysqlConnection); } catch (SQLException e) { LogUtils.error(log, e.getMessage(), e); - MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, e); + MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, e, connCharSet); } } else if (statementType == Meta.StatementType.OTHER_DDL) { OKPacket okPacket = mysqlPacketFactory.getOkPacket(0, packetId, null); @@ -353,9 +372,9 @@ public void executeStatement(ExecuteStatementPacket statementPacket, MysqlResponseHandler.responseOk(okPacket, mysqlConnection.channel); } } catch (SQLException e) { - MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, e); + MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, e, connectionCharSet); } catch (Exception e) { - MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, ErrorCode.ER_UNKNOWN_ERROR, ""); + MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, ErrorCode.ER_UNKNOWN_ERROR, connectionCharSet); LogUtils.error(log, e.getMessage(), e); } } diff --git a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlResponseHandler.java b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlResponseHandler.java index e8930263f..b0653b419 100644 --- a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlResponseHandler.java +++ b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/command/MysqlResponseHandler.java @@ -64,7 +64,9 @@ public static void responseShowField(ResultSet resultSet, MysqlConnection mysqlConnection) { // 1. column packet // 2. ok packet + String connCharSet = null; try { + connCharSet = mysqlConnection.getConnection().getClientInfo(CONNECTION_CHARSET); List columnPackets = factory.getColumnPackets(packetId, resultSet, true); ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); for (ColumnPacket columnPacket : columnPackets) { @@ -74,7 +76,7 @@ public static void responseShowField(ResultSet resultSet, okPacket.write(buffer); mysqlConnection.channel.writeAndFlush(buffer); } catch (SQLException e) { - responseError(packetId, mysqlConnection.channel, e); + responseError(packetId, mysqlConnection.channel, e, connCharSet); } } @@ -94,7 +96,9 @@ public static void responseResultSet(ResultSet resultSet, // 5. eof packet boolean deprecateEof = (mysqlConnection.authPacket.extendClientFlags & ExtendedClientCapabilities.CLIENT_DEPRECATE_EOF) != 0; + String connCharSet = null; try { + connCharSet = mysqlConnection.getConnection().getClientInfo(CONNECTION_CHARSET); ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); ResultSetMetaData metaData = resultSet.getMetaData(); ColumnsNumberPacket columnsNumberPacket = new ColumnsNumberPacket(); @@ -127,7 +131,7 @@ public static void responseResultSet(ResultSet resultSet, mysqlConnection.channel.writeAndFlush(buffer); } catch (SQLException e) { - responseError(packetId, mysqlConnection.channel, e); + responseError(packetId, mysqlConnection.channel, e, connCharSet); } } @@ -204,19 +208,23 @@ private static void handlerPrepareRowPacket(ResultSet resultSet, public static void responseError(AtomicLong packetId, SocketChannel channel, - io.dingodb.common.mysql.constant.ErrorCode errorCode) { - responseError(packetId, channel, errorCode, errorCode.message); + io.dingodb.common.mysql.constant.ErrorCode errorCode, + String characterSet) { + responseError(packetId, channel, errorCode, errorCode.message, characterSet); } public static void responseError(AtomicLong packetId, SocketChannel channel, - io.dingodb.common.mysql.constant.ErrorCode errorCode, String message) { + io.dingodb.common.mysql.constant.ErrorCode errorCode, + String message, + String characterSet) { ERRPacket ep = new ERRPacket(); ep.packetId = (byte) packetId.getAndIncrement(); ep.capabilities = MysqlServer.getServerCapabilities(); ep.errorCode = errorCode.code; ep.sqlState = errorCode.sqlState; ep.errorMessage = message; + ep.characterSet = characterSet; ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); ep.write(buffer); channel.writeAndFlush(buffer); @@ -224,13 +232,16 @@ public static void responseError(AtomicLong packetId, public static void responseError(AtomicLong packetId, SocketChannel channel, - SQLException exception) { + SQLException exception, + String characterSet) { exception = errorDingo2Mysql(exception); ERRPacket ep = new ERRPacket(); ep.packetId = (byte) packetId.getAndIncrement(); ep.capabilities = MysqlServer.getServerCapabilities(); ep.errorCode = exception.getErrorCode(); ep.sqlState = exception.getSQLState(); + ep.characterSet = characterSet; + if (exception.getMessage() == null) { ep.errorMessage = ""; } else if (exception.getMessage().startsWith("Encountered")) { @@ -243,23 +254,6 @@ public static void responseError(AtomicLong packetId, channel.writeAndFlush(buffer); } - public static String encodeResponseError(String src) { - if(src == null) { - return null; - } - - StringBuilder builder = new StringBuilder(); - for(char c : src.toCharArray()) { - if(c > 0x7f) { - //deal with chinese character. - builder.append("\\u").append(String.format("%04x", (int)c)); - } else { - builder.append(c); - } - } - return builder.toString(); - } - public static SQLException errorDingo2Mysql(SQLException e) { if (e.getMessage() != null && e.getMessage().contains("Duplicate data")) { return new SQLException("Duplicate data for key 'PRIMARY'", "23000", 1062); @@ -268,11 +262,11 @@ public static SQLException errorDingo2Mysql(SQLException e) { return new SQLException("Query execution was interrupted", "70100", 1317); } else if (e.getMessage() != null && e.getMessage().contains("Statement canceled")) { LogUtils.info(log, e.getMessage(), e); - return new SQLException(encodeResponseError(e.getMessage()), "HY000", 1105); + return new SQLException(e.getMessage(), "HY000", 1105); } else if (e.getErrorCode() == 9001 && e.getSQLState().equals("45000")) { int code = 1105; String state = "HY000"; - String reason = encodeResponseError(e.getMessage()); + String reason = e.getMessage(); if (reason.contains("Duplicate entry")) { code = 1062; state = "23000"; @@ -282,15 +276,15 @@ public static SQLException errorDingo2Mysql(SQLException e) { } return new SQLException(reason, state, code); } else if (e.getErrorCode() == 1054 && e.getSQLState().equals("42S22")) { - return new SQLException(encodeResponseError(e.getMessage()), "HY000", 1105); + return new SQLException(e.getMessage(), "HY000", 1105); } else if (e.getErrorCode() == 5001 && e.getSQLState().equals("45000")) { if (e.getMessage().contains("Syntax Error")) { return new SQLException( - encodeResponseError(e.getMessage()), + e.getMessage(), e.getSQLState(), e.getErrorCode()); } else if (e.getMessage().contains("io.dingodb.store.api.transaction.exception.DuplicateEntryException:")) { - return new SQLException(encodeResponseError(e.getMessage()) + return new SQLException(e.getMessage() .replace("io.dingodb.store.api.transaction.exception.DuplicateEntryException:", ""), "23000", 1062); } @@ -328,7 +322,9 @@ public static void responsePrepareExecute(ResultSet resultSet, // 5. eof packet boolean deprecateEof = (mysqlConnection.authPacket.extendClientFlags & ExtendedClientCapabilities.CLIENT_DEPRECATE_EOF) != 0; + String connCharSet = null; try { + connCharSet = mysqlConnection.getConnection().getClientInfo(CONNECTION_CHARSET); ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); ResultSetMetaData metaData = resultSet.getMetaData(); ColumnsNumberPacket columnsNumberPacket = new ColumnsNumberPacket(); @@ -361,7 +357,7 @@ public static void responsePrepareExecute(ResultSet resultSet, mysqlConnection.channel.writeAndFlush(buffer); } catch (SQLException e) { - responseError(packetId, mysqlConnection.channel, e); + responseError(packetId, mysqlConnection.channel, e, connCharSet); } } } diff --git a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/netty/HandshakeHandler.java b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/netty/HandshakeHandler.java index a6fe6f430..da744b099 100644 --- a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/netty/HandshakeHandler.java +++ b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/netty/HandshakeHandler.java @@ -160,7 +160,7 @@ protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) { String error = String.format(ErrorCode.ER_ACCESS_DENIED_ERROR.message, user, ip, "YES"); MysqlResponseHandler.responseError(packetId, - mysqlConnection.channel, ErrorCode.ER_ACCESS_DENIED_ERROR, error); + mysqlConnection.channel, ErrorCode.ER_ACCESS_DENIED_ERROR, error, null); if (mysqlConnection.channel.isActive()) { mysqlConnection.channel.close(); } @@ -217,7 +217,7 @@ protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) { String error = String.format(ErrorCode.ER_ACCESS_DENIED_ERROR.message, user, ip, "YES"); MysqlResponseHandler.responseError(packetId, - mysqlConnection.channel, ErrorCode.ER_ACCESS_DENIED_ERROR, error); + mysqlConnection.channel, ErrorCode.ER_ACCESS_DENIED_ERROR, error, null); if (mysqlConnection.channel.isActive()) { mysqlConnection.channel.close(); } diff --git a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/packet/ERRPacket.java b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/packet/ERRPacket.java index 18880ecb9..1737ec7ca 100644 --- a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/packet/ERRPacket.java +++ b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/packet/ERRPacket.java @@ -22,8 +22,12 @@ import io.dingodb.driver.mysql.NativeConstants; import io.dingodb.driver.mysql.util.BufferUtil; import io.netty.buffer.ByteBuf; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; +import java.io.UnsupportedEncodingException; + +@Slf4j public class ERRPacket extends MysqlPacket { public int header; @@ -38,6 +42,10 @@ public class ERRPacket extends MysqlPacket { public int capabilities; + public String characterSet; + + private byte[] errMsgBytes; + @Override public void read(byte[] data) { MysqlMessage message = new MysqlMessage(data); @@ -52,8 +60,41 @@ public void read(byte[] data) { errorMessage = message.readString(); } + public static String encodeResponseError(String src) { + if(src == null) { + return null; + } + + StringBuilder builder = new StringBuilder(); + for(char c : src.toCharArray()) { + if(c > 0x7f) { + //deal with chinese character. + builder.append("\\u").append(String.format("%04x", (int)c)); + } else { + builder.append(c); + } + } + return builder.toString(); + } + @Override public void write(ByteBuf buffer) { + byte[] stringBytes = null; + + if (errorMessage != null) { + if(characterSet != null) { + try { + errMsgBytes = errorMessage.getBytes(characterSet); + stringBytes = errMsgBytes; + } catch (UnsupportedEncodingException e) { + log.error("Error message encoding Exception:{}", e.toString()); + stringBytes = encodeResponseError(errorMessage).getBytes(); + } + } else { + stringBytes = errorMessage.getBytes(); + } + } + BufferUtil.writeUB3(buffer, calcPacketSize()); buffer.writeByte(packetId); buffer.writeByte(NativeConstants.TYPE_ID_ERROR); @@ -65,8 +106,9 @@ public void write(ByteBuf buffer) { } buffer.writeBytes(sqlState.getBytes()); } - if (errorMessage != null) { - buffer.writeBytes(errorMessage.getBytes()); + + if(stringBytes != null) { + buffer.writeBytes(stringBytes); } } @@ -74,7 +116,11 @@ public void write(ByteBuf buffer) { public int calcPacketSize() { int size = 9; if (errorMessage != null) { - size += errorMessage.length(); + if(characterSet != null && errMsgBytes != null) { + size += errMsgBytes.length; + } else { + size += errorMessage.length(); + } } return size; } diff --git a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/process/MessageProcess.java b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/process/MessageProcess.java index 811415351..5ba9fc757 100644 --- a/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/process/MessageProcess.java +++ b/dingo-driver/mysql-service/src/main/java/io/dingodb/driver/mysql/process/MessageProcess.java @@ -41,6 +41,8 @@ import java.sql.SQLException; import java.util.concurrent.atomic.AtomicLong; +import static io.dingodb.calcite.operation.SetOptionOperation.CONNECTION_CHARSET; + @Slf4j public final class MessageProcess { @@ -57,8 +59,16 @@ public static void process(ByteBuf msg, MysqlConnection mysqlConnection) { byte packetIdByte = array[0]; AtomicLong packetId = new AtomicLong(packetIdByte); packetId.incrementAndGet(); + String connCharSet = null; + + try { + connCharSet = mysqlConnection.getConnection().getClientInfo(CONNECTION_CHARSET); + } catch(SQLException e) { + log.error("Fail to get connection characterSet, {}", e.toString()); + } + if (flg != NativeConstants.COM_QUIT && flg != NativeConstants.COM_QUERY && mysqlConnection.passwordExpire) { - MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, ErrorCode.ER_PASSWORD_EXPIRE); + MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, ErrorCode.ER_PASSWORD_EXPIRE, connCharSet); return; } switch (flg) { @@ -82,7 +92,7 @@ public static void process(ByteBuf msg, MysqlConnection mysqlConnection) { String error = String.format(ErrorCode.ER_ACCESS_DB_DENIED_ERROR.message, user, host, usedSchema); MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, - ErrorCode.ER_ACCESS_DB_DENIED_ERROR, error); + ErrorCode.ER_ACCESS_DB_DENIED_ERROR, error, connCharSet); return; } // todo: current version, ignore name case @@ -94,7 +104,7 @@ public static void process(ByteBuf msg, MysqlConnection mysqlConnection) { MysqlResponseHandler.responseOk(okPacket, mysqlConnection.channel); } else { MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, - ErrorCode.ER_NO_DATABASE_ERROR); + ErrorCode.ER_NO_DATABASE_ERROR, connCharSet); } break; case NativeConstants.COM_QUERY: @@ -224,7 +234,7 @@ public static void process(ByteBuf msg, MysqlConnection mysqlConnection) { } catch (NoSuchStatementException e) { throw new RuntimeException(e); } catch (SQLException e) { - MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, e); + MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, e, connCharSet); } break; case NativeConstants.COM_STMT_RESET: @@ -240,7 +250,7 @@ public static void process(ByteBuf msg, MysqlConnection mysqlConnection) { } catch (NoSuchStatementException e) { throw new RuntimeException(e); } catch (SQLException e) { - MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, e); + MysqlResponseHandler.responseError(packetId, mysqlConnection.channel, e, connCharSet); } okPacket = MysqlPacketFactory.getInstance().getOkPacket(0, packetId, null); MysqlResponseHandler.responseOk(okPacket, mysqlConnection.channel);