diff --git a/src/MySqlConnector/Core/IServerCapabilities.cs b/src/MySqlConnector/Core/IServerCapabilities.cs new file mode 100644 index 000000000..233e63cd0 --- /dev/null +++ b/src/MySqlConnector/Core/IServerCapabilities.cs @@ -0,0 +1,7 @@ +namespace MySqlConnector.Core; + +internal interface IServerCapabilities +{ + bool SupportsDeprecateEof { get; } + bool SupportsSessionTrack { get; } +} diff --git a/src/MySqlConnector/Core/ResultSet.cs b/src/MySqlConnector/Core/ResultSet.cs index a0887aa2c..2b93879ca 100644 --- a/src/MySqlConnector/Core/ResultSet.cs +++ b/src/MySqlConnector/Core/ResultSet.cs @@ -38,7 +38,7 @@ public async Task ReadResultSetHeaderAsync(IOBehavior ioBehavior) var firstByte = payload.HeaderByte; if (firstByte == OkPayload.Signature) { - var ok = OkPayload.Create(payload.Span, Session.SupportsDeprecateEof, Session.SupportsSessionTrack); + var ok = OkPayload.Create(payload.Span, Session); // if we've read a result set header then this is a SELECT statement, so we shouldn't overwrite RecordsAffected // (which should be -1 for SELECT) unless the server reports a non-zero count @@ -252,9 +252,9 @@ public async Task ReadAsync(IOBehavior ioBehavior, CancellationToken cance if (payload.HeaderByte == EofPayload.Signature) { - if (Session.SupportsDeprecateEof && OkPayload.IsOk(payload.Span, Session.SupportsDeprecateEof)) + if (Session.SupportsDeprecateEof && OkPayload.IsOk(payload.Span, Session)) { - var ok = OkPayload.Create(payload.Span, Session.SupportsDeprecateEof, Session.SupportsSessionTrack); + var ok = OkPayload.Create(payload.Span, Session); BufferState = (ok.ServerStatus & ServerStatus.MoreResultsExist) == 0 ? ResultSetState.NoMoreData : ResultSetState.HasMoreData; return null; } diff --git a/src/MySqlConnector/Core/ServerSession.cs b/src/MySqlConnector/Core/ServerSession.cs index dec7faea3..dcaf1e69f 100644 --- a/src/MySqlConnector/Core/ServerSession.cs +++ b/src/MySqlConnector/Core/ServerSession.cs @@ -23,7 +23,7 @@ namespace MySqlConnector.Core; #pragma warning disable CA1001 // Types that own disposable fields should be disposable -internal sealed partial class ServerSession +internal sealed partial class ServerSession : IServerCapabilities { public ServerSession(ILogger logger) : this(logger, null, 0, Interlocked.Increment(ref s_lastId)) @@ -320,7 +320,7 @@ public void FinishQuerying() SendAsync(payload, IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); payload = ReceiveReplyAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); #pragma warning restore CA2012 - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + OkPayload.Verify(payload.Span, this); } lock (m_lock) @@ -532,19 +532,30 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella payload = await SwitchAuthenticationAsync(cs, password, payload, ioBehavior, cancellationToken).ConfigureAwait(false); } - var ok = OkPayload.Create(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + var ok = OkPayload.Create(payload.Span, this); var statusInfo = ok.StatusInfo; if (m_useCompression) m_payloadHandler = new CompressedPayloadHandler(m_payloadHandler.ByteHandler); - // set 'collation_connection' to the server default - await SendAsync(m_setNamesPayload, ioBehavior, cancellationToken).ConfigureAwait(false); - payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + // send 'SET NAMES' to set the character set and collation unless the server reports that it's already using the desired character set (e.g., MariaDB >= 11.5) + if (ok.NewCharacterSet != (ServerVersion.Version >= ServerVersions.SupportsUtf8Mb4 ? CharacterSet.Utf8Mb4Binary : CharacterSet.Utf8Mb3Binary)) + { + // set 'collation_connection' to the server default + await SendAsync(m_setNamesPayload, ioBehavior, cancellationToken).ConfigureAwait(false); + payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); + OkPayload.Verify(payload.Span, this); + } if (ShouldGetRealServerDetails(cs)) + { await GetRealServerDetailsAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); + } + else if (ok.NewConnectionId is int newConnectionId && newConnectionId != ConnectionId) + { + Log.ChangingConnectionId(m_logger, Id, ConnectionId, newConnectionId, ServerVersion.OriginalString, ServerVersion.OriginalString); + ConnectionId = newConnectionId; + } m_payloadHandler.ByteHandler.RemainingTimeout = Constants.InfiniteTimeout; return statusInfo; @@ -584,10 +595,10 @@ public async Task TryResetConnectionAsync(ConnectionSettings cs, MySqlConn // read two OK replies payload = await ReceiveReplyAsync(1, ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + OkPayload.Verify(payload.Span, this); payload = await ReceiveReplyAsync(1, ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + OkPayload.Verify(payload.Span, this); return true; } @@ -595,7 +606,7 @@ public async Task TryResetConnectionAsync(ConnectionSettings cs, MySqlConn Log.SendingResetConnectionRequest(m_logger, Id, ServerVersion.OriginalString); await SendAsync(ResetConnectionPayload.Instance, ioBehavior, cancellationToken).ConfigureAwait(false); payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + OkPayload.Verify(payload.Span, this); } else { @@ -619,13 +630,13 @@ public async Task TryResetConnectionAsync(ConnectionSettings cs, MySqlConn Log.OptimisticReauthenticationFailed(m_logger, Id); payload = await SwitchAuthenticationAsync(cs, password, payload, ioBehavior, cancellationToken).ConfigureAwait(false); } - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + OkPayload.Verify(payload.Span, this); } // set 'collation_connection' to the server default await SendAsync(m_setNamesPayload, ioBehavior, cancellationToken).ConfigureAwait(false); payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + OkPayload.Verify(payload.Span, this); return true; } @@ -684,7 +695,7 @@ private async Task SwitchAuthenticationAsync(ConnectionSettings cs, payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); // OK payload can be sent immediately (e.g., if password is empty) bypassing even the fast authentication path - if (OkPayload.IsOk(payload.Span, SupportsDeprecateEof)) + if (OkPayload.IsOk(payload.Span, this)) return payload; var cachingSha2ServerResponsePayload = CachingSha2ServerResponsePayload.Create(payload.Span); @@ -824,7 +835,7 @@ public async ValueTask TryPingAsync(bool logInfo, IOBehavior ioBehavior, C Log.PingingServer(m_logger, Id); await SendAsync(PingPayload.Instance, ioBehavior, cancellationToken).ConfigureAwait(false); var payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + OkPayload.Verify(payload.Span, this); Log.SuccessfullyPingedServer(m_logger, logInfo ? LogLevel.Information : LogLevel.Trace, Id); return true; } @@ -1662,8 +1673,8 @@ static void ReadRow(ReadOnlySpan span, out int? connectionId, out ServerVe // OK/EOF payload payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); - if (OkPayload.IsOk(payload.Span, SupportsDeprecateEof)) - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + if (OkPayload.IsOk(payload.Span, this)) + OkPayload.Verify(payload.Span, this); else EofPayload.Create(payload.Span); diff --git a/src/MySqlConnector/MySqlConnection.cs b/src/MySqlConnector/MySqlConnection.cs index e9b21042e..7a835d202 100644 --- a/src/MySqlConnector/MySqlConnection.cs +++ b/src/MySqlConnector/MySqlConnection.cs @@ -161,10 +161,10 @@ private async ValueTask BeginTransactionAsync(IsolationLevel i // read the two OK replies var payload = await m_session.ReceiveReplyAsync(1, ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, m_session.SupportsDeprecateEof, m_session.SupportsSessionTrack); + OkPayload.Verify(payload.Span, m_session); payload = await m_session.ReceiveReplyAsync(1, ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, m_session.SupportsDeprecateEof, m_session.SupportsSessionTrack); + OkPayload.Verify(payload.Span, m_session); } else { @@ -172,12 +172,12 @@ private async ValueTask BeginTransactionAsync(IsolationLevel i await m_session.SendAsync(new Protocol.PayloadData(startTransactionPayload.Slice(4, startTransactionPayload.Span[0])), ioBehavior, cancellationToken).ConfigureAwait(false); var payload = await m_session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, m_session.SupportsDeprecateEof, m_session.SupportsSessionTrack); + OkPayload.Verify(payload.Span, m_session); await m_session.SendAsync(new Protocol.PayloadData(startTransactionPayload.Slice(8 + startTransactionPayload.Span[0], startTransactionPayload.Span[startTransactionPayload.Span[0] + 4])), ioBehavior, cancellationToken).ConfigureAwait(false); payload = await m_session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, m_session.SupportsDeprecateEof, m_session.SupportsSessionTrack); + OkPayload.Verify(payload.Span, m_session); } var transaction = new MySqlTransaction(this, isolationLevel, m_transactionLogger); @@ -487,7 +487,9 @@ private async Task ChangeDatabaseAsync(IOBehavior ioBehavior, string databaseNam using (var initDatabasePayload = InitDatabasePayload.Create(databaseName)) await m_session!.SendAsync(initDatabasePayload, ioBehavior, cancellationToken).ConfigureAwait(false); var payload = await m_session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, m_session.SupportsDeprecateEof, m_session.SupportsSessionTrack); + OkPayload.Verify(payload.Span, m_session); + + // for non session tracking servers m_session.DatabaseOverride = databaseName; } @@ -603,7 +605,7 @@ public async ValueTask ResetConnectionAsync(CancellationToken cancellationToken Log.ResettingConnection(m_logger, session.Id); await session.SendAsync(ResetConnectionPayload.Instance, AsyncIOBehavior, cancellationToken).ConfigureAwait(false); var payload = await session.ReceiveReplyAsync(AsyncIOBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, session.SupportsDeprecateEof, session.SupportsSessionTrack); + OkPayload.Verify(payload.Span, session); } [AllowNull] diff --git a/src/MySqlConnector/Protocol/Payloads/OkPayload.cs b/src/MySqlConnector/Protocol/Payloads/OkPayload.cs index d5d4a1aa2..50be1db79 100644 --- a/src/MySqlConnector/Protocol/Payloads/OkPayload.cs +++ b/src/MySqlConnector/Protocol/Payloads/OkPayload.cs @@ -1,4 +1,6 @@ +using System.Buffers.Text; using System.Text; +using MySqlConnector.Core; using MySqlConnector.Protocol.Serialization; using MySqlConnector.Utilities; @@ -12,6 +14,8 @@ internal sealed class OkPayload public int WarningCount { get; } public string? StatusInfo { get; } public string? NewSchema { get; } + public CharacterSet? NewCharacterSet { get; } + public int? NewConnectionId { get; } public const byte Signature = 0x00; @@ -20,56 +24,56 @@ internal sealed class OkPayload * https://mariadb.com/kb/en/the-mariadb-library/resultset/ * https://github.com/MariaDB/mariadb-connector-j/blob/5fa814ac6e1b4c9cb6d141bd221cbd5fc45c8a78/src/main/java/org/mariadb/jdbc/internal/com/read/resultset/SelectResultSet.java#L443-L444 */ - public static bool IsOk(ReadOnlySpan span, bool deprecateEof) => + public static bool IsOk(ReadOnlySpan span, IServerCapabilities serverCapabilities) => span.Length > 0 && (span.Length > 6 && span[0] == Signature || - deprecateEof && span.Length < 0xFF_FFFF && span[0] == EofPayload.Signature); + serverCapabilities.SupportsDeprecateEof && span.Length < 0xFF_FFFF && span[0] == EofPayload.Signature); /// /// Creates an from the given , or throws /// if the bytes do not represent a valid . /// /// The bytes from which to read an OK packet. - /// Whether the flag was set on the connection. - /// Whether flag was set on the connection. + /// The server capabilities. /// A with the contents of the OK packet. /// Thrown when the bytes are not a valid OK packet. - public static OkPayload Create(ReadOnlySpan span, bool deprecateEof, bool clientSessionTrack) => - Read(span, deprecateEof, clientSessionTrack, true)!; + public static OkPayload Create(ReadOnlySpan span, IServerCapabilities serverCapabilities) => + Read(span, serverCapabilities, true)!; /// /// Verifies that the bytes in the given form a valid , or throws /// if they do not. /// /// The bytes from which to read an OK packet. - /// Whether the flag was set on the connection. - /// Whether flag was set on the connection. + /// The server capabilities. /// Thrown when the bytes are not a valid OK packet. - public static void Verify(ReadOnlySpan span, bool deprecateEof, bool clientSessionTrack) => - Read(span, deprecateEof, clientSessionTrack, createPayload: false); + public static void Verify(ReadOnlySpan span, IServerCapabilities serverCapabilities) => + Read(span, serverCapabilities, createPayload: false); - private static OkPayload? Read(ReadOnlySpan span, bool deprecateEof, bool clientSessionTrack, bool createPayload) + private static OkPayload? Read(ReadOnlySpan span, IServerCapabilities serverCapabilities, bool createPayload) { var reader = new ByteArrayReader(span); var signature = reader.ReadByte(); - if (signature != Signature && (!deprecateEof || signature != EofPayload.Signature)) + if (signature != Signature && (!serverCapabilities.SupportsDeprecateEof || signature != EofPayload.Signature)) throw new FormatException($"Expected to read 0x00 or 0xFE but got 0x{signature:X2}"); var affectedRowCount = reader.ReadLengthEncodedInteger(); var lastInsertId = reader.ReadLengthEncodedInteger(); var serverStatus = (ServerStatus) reader.ReadUInt16(); var warningCount = (int) reader.ReadUInt16(); string? newSchema = null; + CharacterSet clientCharacterSet = default; + CharacterSet connectionCharacterSet = default; + CharacterSet resultsCharacterSet = default; + int? connectionId = null; ReadOnlySpan statusBytes; - if (clientSessionTrack) + if (serverCapabilities.SupportsSessionTrack) { if (reader.BytesRemaining > 0) { statusBytes = reader.ReadLengthEncodedByteString(); // human-readable info - - if ((serverStatus & ServerStatus.SessionStateChanged) == ServerStatus.SessionStateChanged && reader.BytesRemaining > 0) + while (reader.BytesRemaining > 0) { - // implies ProtocolCapabilities.SessionTrack var sessionStateChangeDataLength = checked((int) reader.ReadLengthEncodedInteger()); var endOffset = reader.Offset + sessionStateChangeDataLength; while (reader.Offset < endOffset) @@ -82,6 +86,38 @@ public static void Verify(ReadOnlySpan span, bool deprecateEof, bool clien newSchema = Encoding.UTF8.GetString(reader.ReadLengthEncodedByteString()); break; + case SessionTrackKind.SystemVariables: + var systemVariablesEndOffset = reader.Offset + dataLength; + do + { + var systemVariableName = reader.ReadLengthEncodedByteString(); + var systemVariableValueLength = reader.ReadLengthEncodedIntegerOrNull(); + var systemVariableValue = systemVariableValueLength == -1 ? default : reader.ReadByteString(systemVariableValueLength); + if (systemVariableName.SequenceEqual("character_set_client"u8) && systemVariableValueLength != 0) + { + clientCharacterSet = systemVariableValue.SequenceEqual("utf8mb4"u8) ? CharacterSet.Utf8Mb4Binary : + systemVariableValue.SequenceEqual("utf8"u8) ? CharacterSet.Utf8Mb3Binary : + CharacterSet.None; + } + else if (systemVariableName.SequenceEqual("character_set_connection"u8) && systemVariableValueLength != 0) + { + connectionCharacterSet = systemVariableValue.SequenceEqual("utf8mb4"u8) ? CharacterSet.Utf8Mb4Binary : + systemVariableValue.SequenceEqual("utf8"u8) ? CharacterSet.Utf8Mb3Binary : + CharacterSet.None; + } + else if (systemVariableName.SequenceEqual("character_set_results"u8) && systemVariableValueLength != 0) + { + resultsCharacterSet = systemVariableValue.SequenceEqual("utf8mb4"u8) ? CharacterSet.Utf8Mb4Binary : + systemVariableValue.SequenceEqual("utf8"u8) ? CharacterSet.Utf8Mb3Binary : + CharacterSet.None; + } + else if (systemVariableName.SequenceEqual("connection_id"u8)) + { + connectionId = Utf8Parser.TryParse(systemVariableValue, out int parsedConnectionId, out var bytesConsumed) && bytesConsumed == systemVariableValue.Length ? parsedConnectionId : default(int?); + } + } while (reader.Offset < systemVariablesEndOffset); + break; + default: reader.Offset += dataLength; break; @@ -109,7 +145,12 @@ public static void Verify(ReadOnlySpan span, bool deprecateEof, bool clien { var statusInfo = statusBytes.Length == 0 ? null : Encoding.UTF8.GetString(statusBytes); - if (affectedRowCount == 0 && lastInsertId == 0 && warningCount == 0 && statusInfo is null && newSchema is null) + // detect the connection character set as utf8mb4 (or utf8) if all three system variables are set to the same value + var characterSet = clientCharacterSet == CharacterSet.Utf8Mb4Binary && connectionCharacterSet == CharacterSet.Utf8Mb4Binary && resultsCharacterSet == CharacterSet.Utf8Mb4Binary ? CharacterSet.Utf8Mb4Binary : + clientCharacterSet == CharacterSet.Utf8Mb3Binary && connectionCharacterSet == CharacterSet.Utf8Mb3Binary && resultsCharacterSet == CharacterSet.Utf8Mb3Binary ? CharacterSet.Utf8Mb3Binary : + CharacterSet.None; + + if (affectedRowCount == 0 && lastInsertId == 0 && warningCount == 0 && statusInfo is null && newSchema is null && clientCharacterSet is CharacterSet.None && connectionId is null) { if (serverStatus == ServerStatus.AutoCommit) return s_autoCommitOk; @@ -117,7 +158,7 @@ public static void Verify(ReadOnlySpan span, bool deprecateEof, bool clien return s_autoCommitSessionStateChangedOk; } - return new OkPayload(affectedRowCount, lastInsertId, serverStatus, warningCount, statusInfo, newSchema); + return new OkPayload(affectedRowCount, lastInsertId, serverStatus, warningCount, statusInfo, newSchema, characterSet, connectionId); } else { @@ -125,7 +166,7 @@ public static void Verify(ReadOnlySpan span, bool deprecateEof, bool clien } } - private OkPayload(ulong affectedRowCount, ulong lastInsertId, ServerStatus serverStatus, int warningCount, string? statusInfo, string? newSchema) + private OkPayload(ulong affectedRowCount, ulong lastInsertId, ServerStatus serverStatus, int warningCount, string? statusInfo, string? newSchema, CharacterSet newCharacterSet, int? connectionId) { AffectedRowCount = affectedRowCount; LastInsertId = lastInsertId; @@ -133,8 +174,10 @@ private OkPayload(ulong affectedRowCount, ulong lastInsertId, ServerStatus serve WarningCount = warningCount; StatusInfo = statusInfo; NewSchema = newSchema; + NewCharacterSet = newCharacterSet; + NewConnectionId = connectionId; } - private static readonly OkPayload s_autoCommitOk = new(0, 0, ServerStatus.AutoCommit, 0, null, null); - private static readonly OkPayload s_autoCommitSessionStateChangedOk = new(0, 0, ServerStatus.AutoCommit | ServerStatus.SessionStateChanged, 0, null, null); + private static readonly OkPayload s_autoCommitOk = new(0, 0, ServerStatus.AutoCommit, 0, default, default, default, default); + private static readonly OkPayload s_autoCommitSessionStateChangedOk = new(0, 0, ServerStatus.AutoCommit | ServerStatus.SessionStateChanged, 0, default, default, default, default); }