Skip to content

Commit

Permalink
Make 'Context' immutable.
Browse files Browse the repository at this point in the history
Move mutable properties back to ServerSession.

Signed-off-by: Bradley Grainger <bgrainger@gmail.com>
  • Loading branch information
bgrainger committed Jul 17, 2024
1 parent ef40088 commit 77ad191
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 41 deletions.
2 changes: 1 addition & 1 deletion src/MySqlConnector/Core/ConnectionPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public async ValueTask<ServerSession> GetSessionAsync(MySqlConnection connection
}
else
{
if (ConnectionSettings.ConnectionReset || !session.Context.IsInitialDatabase())
if (ConnectionSettings.ConnectionReset || session.DatabaseOverride is not null)
{
if (timeoutMilliseconds != 0)
session.SetTimeout(Math.Max(1, timeoutMilliseconds - Utility.GetElapsedMilliseconds(startingTimestamp)));
Expand Down
12 changes: 1 addition & 11 deletions src/MySqlConnector/Core/Context.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,16 @@ namespace MySqlConnector.Core;

internal sealed class Context
{
public Context(ProtocolCapabilities protocolCapabilities, string? database, int connectionId)
public Context(ProtocolCapabilities protocolCapabilities)
{
SupportsDeprecateEof = (protocolCapabilities & ProtocolCapabilities.DeprecateEof) != 0;
SupportsCachedPreparedMetadata = (protocolCapabilities & ProtocolCapabilities.MariaDbCacheMetadata) != 0;
SupportsQueryAttributes = (protocolCapabilities & ProtocolCapabilities.QueryAttributes) != 0;
SupportsSessionTrack = (protocolCapabilities & ProtocolCapabilities.SessionTrack) != 0;
ConnectionId = connectionId;
Database = database;
m_initialDatabase = database;
}

public bool SupportsDeprecateEof { get; }
public bool SupportsQueryAttributes { get; }
public bool SupportsSessionTrack { get; }
public bool SupportsCachedPreparedMetadata { get; }
public string? ClientCharset { get; set; }

public string? Database { get; set; }
private readonly string? m_initialDatabase;
public bool IsInitialDatabase() => string.Equals(m_initialDatabase, Database, StringComparison.Ordinal);

public int ConnectionId { get; set; }
}
2 changes: 2 additions & 0 deletions src/MySqlConnector/Core/ResultSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ public async Task ReadResultSetHeaderAsync(IOBehavior ioBehavior)
if (ok.LastInsertId != 0)
Command?.SetLastInsertedId((long) ok.LastInsertId);
WarningCount = ok.WarningCount;
if (ok.NewSchema is not null)
Connection.Session.DatabaseOverride = ok.NewSchema;
m_columnDefinitions = default;
State = (ok.ServerStatus & ServerStatus.MoreResultsExist) == 0
? ResultSetState.NoMoreData
Expand Down
44 changes: 26 additions & 18 deletions src/MySqlConnector/Core/ServerSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,22 @@ public ServerSession(ILogger logger, ConnectionPool? pool, int poolGeneration, i
m_activityTags = [];
DataReader = new();
Log.CreatedNewSession(m_logger, Id);
Context = new Context(0, null, 0);
Context = new Context(default);
}

public string Id { get; }
public ServerVersion ServerVersion { get; set; }
public bool SupportsPerQueryVariables => ServerVersion.IsMariaDb && ServerVersion.Version >= ServerVersions.MariaDbSupportsPerQueryVariables;
public int ActiveCommandId { get; private set; }
public int CancellationTimeout { get; private set; }
public int ConnectionId { get; set; }
public byte[]? AuthPluginData { get; set; }
public long CreatedTimestamp { get; }
public ConnectionPool? Pool { get; }
public int PoolGeneration { get; }
public long LastLeasedTimestamp { get; set; }
public long LastReturnedTimestamp { get; private set; }
public string? DatabaseOverride { get; set; }

public string HostName { get; private set; }
public IPEndPoint? IPEndPoint => m_tcpClient?.Client.RemoteEndPoint as IPEndPoint;
Expand Down Expand Up @@ -338,8 +340,8 @@ public void FinishQuerying()
var activity = ActivitySourceHelper.StartActivity(name, m_activityTags);
if (activity is { IsAllDataRequested: true })
{
if (!Context.IsInitialDatabase())
activity.SetTag(ActivitySourceHelper.DatabaseNameTagName, Context.Database);
if (DatabaseOverride is not null)
activity.SetTag(ActivitySourceHelper.DatabaseNameTagName, DatabaseOverride);
if (tagName1 is not null)
activity.SetTag(tagName1, tagValue1);
}
Expand Down Expand Up @@ -452,15 +454,16 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
}

ServerVersion = new(initialHandshake.ServerVersion);
Context = new Context(initialHandshake.ProtocolCapabilities, cs.Database, initialHandshake.ConnectionId);
ConnectionId = initialHandshake.ConnectionId;
Context = new Context(initialHandshake.ProtocolCapabilities);
AuthPluginData = initialHandshake.AuthPluginData;
m_useCompression = cs.UseCompression && (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.Compress) != 0;
CancellationTimeout = cs.CancellationTimeout;
UserID = cs.UserID;

// set activity tags
{
var connectionId = Context.ConnectionId.ToString(CultureInfo.InvariantCulture);
var connectionId = ConnectionId.ToString(CultureInfo.InvariantCulture);
m_activityTags[ActivitySourceHelper.DatabaseConnectionIdTagName] = connectionId;
if (activity is { IsAllDataRequested: true })
activity.SetTag(ActivitySourceHelper.DatabaseConnectionIdTagName, connectionId);
Expand Down Expand Up @@ -499,7 +502,7 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
}
}

Log.SessionMadeConnection(m_logger, Id, ServerVersion.OriginalString, Context.ConnectionId, m_useCompression, m_supportsConnectionAttributes, Context.SupportsDeprecateEof, Context.SupportsCachedPreparedMetadata, serverSupportsSsl, Context.SupportsSessionTrack, m_supportsPipelining, Context.SupportsQueryAttributes);
Log.SessionMadeConnection(m_logger, Id, ServerVersion.OriginalString, ConnectionId, m_useCompression, m_supportsConnectionAttributes, Context.SupportsDeprecateEof, Context.SupportsCachedPreparedMetadata, serverSupportsSsl, Context.SupportsSessionTrack, m_supportsPipelining, Context.SupportsQueryAttributes);

if (cs.SslMode != MySqlSslMode.None && (cs.SslMode != MySqlSslMode.Preferred || serverSupportsSsl))
{
Expand Down Expand Up @@ -532,18 +535,23 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
if (m_useCompression)
m_payloadHandler = new CompressedPayloadHandler(m_payloadHandler.ByteHandler);

// set 'collation_connection' to the server default
if (Context.ClientCharset == null || ServerVersion.Version >= ServerVersions.SupportsUtf8Mb4
? !string.Equals(Context.ClientCharset, "utf8mb4", StringComparison.Ordinal)
: !string.Equals(Context.ClientCharset, "utf8", StringComparison.Ordinal))
if (ok.ClientCharacterSet != (ServerVersion.Version >= ServerVersions.SupportsUtf8Mb4 ? "utf8mb4" : "utf8"))
{
// set 'collation_connection' to the server default
await SendAsync(m_setNamesPayload, ioBehavior, cancellationToken).ConfigureAwait(false);
payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
OkPayload.Verify(payload.Span, Context);
}

if (ShouldGetRealServerDetails(cs))
{
await GetRealServerDetailsAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
}
else if (ok.ConnectionId is int newConnectionId && newConnectionId != ConnectionId)
{
Log.ChangingConnectionId(m_logger, Id, ConnectionId, newConnectionId, ServerVersion.OriginalString, ServerVersion.OriginalString);
ConnectionId = newConnectionId;
}

m_payloadHandler.ByteHandler.RemainingTimeout = Constants.InfiniteTimeout;
return statusInfo;
Expand All @@ -570,9 +578,9 @@ public async Task<bool> TryResetConnectionAsync(ConnectionSettings cs, MySqlConn
ClearPreparedStatements();

PayloadData payload;
if (Context.IsInitialDatabase() &&
((!ServerVersion.IsMariaDb && ServerVersion.Version.CompareTo(ServerVersions.SupportsResetConnection) >= 0) ||
(ServerVersion.IsMariaDb && ServerVersion.Version.CompareTo(ServerVersions.MariaDbSupportsResetConnection) >= 0)))
if (DatabaseOverride is null &&
((!ServerVersion.IsMariaDb && ServerVersion.Version.CompareTo(ServerVersions.SupportsResetConnection) >= 0) ||
(ServerVersion.IsMariaDb && ServerVersion.Version.CompareTo(ServerVersions.MariaDbSupportsResetConnection) >= 0)))
{
if (m_supportsPipelining)
{
Expand All @@ -599,14 +607,14 @@ public async Task<bool> TryResetConnectionAsync(ConnectionSettings cs, MySqlConn
else
{
// optimistically hash the password with the challenge from the initial handshake (supported by MariaDB; doesn't appear to be supported by MySQL)
if (Context.IsInitialDatabase())
if (DatabaseOverride is null)
{
Log.SendingChangeUserRequest(m_logger, Id, ServerVersion.OriginalString);
}
else
{
Log.SendingChangeUserRequestDueToChangedDatabase(m_logger, Id, Context.Database!);
Context.Database = cs.Database;
Log.SendingChangeUserRequestDueToChangedDatabase(m_logger, Id, DatabaseOverride);
DatabaseOverride = null;
}
var password = GetPassword(cs, connection);
var hashedPassword = AuthenticationUtility.CreateAuthenticationResponse(AuthPluginData!, password);
Expand Down Expand Up @@ -1668,8 +1676,8 @@ static void ReadRow(ReadOnlySpan<byte> span, out int? connectionId, out ServerVe

if (connectionId is int newConnectionId && serverVersion is not null)
{
Log.ChangingConnectionId(m_logger, Id, Context.ConnectionId, newConnectionId, ServerVersion.OriginalString, serverVersion.OriginalString);
Context.ConnectionId = newConnectionId;
Log.ChangingConnectionId(m_logger, Id, ConnectionId, newConnectionId, ServerVersion.OriginalString, serverVersion.OriginalString);
ConnectionId = newConnectionId;
ServerVersion = serverVersion;
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/MySqlConnector/MySqlConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ private async Task ChangeDatabaseAsync(IOBehavior ioBehavior, string databaseNam
OkPayload.Verify(payload.Span, m_session.Context);

// for non session tracking servers
m_session.Context.Database = databaseName;
m_session.DatabaseOverride = databaseName;
}

public new MySqlCommand CreateCommand() => (MySqlCommand) base.CreateCommand();
Expand Down Expand Up @@ -628,7 +628,7 @@ public override string ConnectionString
}
}

public override string Database => m_session?.Context.Database ?? GetConnectionSettings().Database;
public override string Database => m_session?.DatabaseOverride ?? GetConnectionSettings().Database;

public override ConnectionState State => m_connectionState;

Expand All @@ -639,7 +639,7 @@ public override string ConnectionString
/// <summary>
/// The connection ID from MySQL Server.
/// </summary>
public int ServerThread => Session.Context.ConnectionId;
public int ServerThread => Session.ConnectionId;

/// <summary>
/// Gets or sets the delegate used to provide client certificates for connecting to a server.
Expand Down
25 changes: 17 additions & 8 deletions src/MySqlConnector/Protocol/Payloads/OkPayload.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ internal sealed class OkPayload
public ServerStatus ServerStatus { get; }
public int WarningCount { get; }
public string? StatusInfo { get; }
public string? NewSchema { get; }
public string? ClientCharacterSet { get; }
public int? ConnectionId { get; }

public const byte Signature = 0x00;

Expand Down Expand Up @@ -57,6 +60,9 @@ public static void Verify(ReadOnlySpan<byte> span, Context context) =>
var lastInsertId = reader.ReadLengthEncodedInteger();
var serverStatus = (ServerStatus) reader.ReadUInt16();
var warningCount = (int) reader.ReadUInt16();
string? newSchema = null;
string? clientCharacterSet = null;
int? connectionId = null;
ReadOnlySpan<byte> statusBytes;

if (context.SupportsSessionTrack)
Expand All @@ -75,7 +81,7 @@ public static void Verify(ReadOnlySpan<byte> span, Context context) =>
switch (kind)
{
case SessionTrackKind.Schema:
context.Database = Encoding.UTF8.GetString(reader.ReadLengthEncodedByteString());
newSchema = Encoding.UTF8.GetString(reader.ReadLengthEncodedByteString());
break;

case SessionTrackKind.SystemVariables:
Expand All @@ -90,10 +96,10 @@ public static void Verify(ReadOnlySpan<byte> span, Context context) =>
switch (variableSv)
{
case "character_set_client":
context.ClientCharset = valueSv;
clientCharacterSet = valueSv;
break;
case "connection_id":
context.ConnectionId = Convert.ToInt32(valueSv, CultureInfo.InvariantCulture);
connectionId = Convert.ToInt32(valueSv, CultureInfo.InvariantCulture);
break;
}
} while (reader.Offset < systemVariableOffset);
Expand Down Expand Up @@ -126,31 +132,34 @@ public static void Verify(ReadOnlySpan<byte> span, Context context) =>
{
var statusInfo = statusBytes.Length == 0 ? null : Encoding.UTF8.GetString(statusBytes);

if (affectedRowCount == 0 && lastInsertId == 0 && warningCount == 0 && statusInfo is null)
if (affectedRowCount == 0 && lastInsertId == 0 && warningCount == 0 && statusInfo is null && newSchema is null && clientCharacterSet is null && connectionId is null)
{
if (serverStatus == ServerStatus.AutoCommit)
return s_autoCommitOk;
if (serverStatus == (ServerStatus.AutoCommit | ServerStatus.SessionStateChanged))
return s_autoCommitSessionStateChangedOk;
}

return new OkPayload(affectedRowCount, lastInsertId, serverStatus, warningCount, statusInfo);
return new OkPayload(affectedRowCount, lastInsertId, serverStatus, warningCount, statusInfo, newSchema, clientCharacterSet, connectionId);
}
else
{
return null;
}
}

private OkPayload(ulong affectedRowCount, ulong lastInsertId, ServerStatus serverStatus, int warningCount, string? statusInfo)
private OkPayload(ulong affectedRowCount, ulong lastInsertId, ServerStatus serverStatus, int warningCount, string? statusInfo, string? newSchema, string? clientCharacterSet, int? connectionId)
{
AffectedRowCount = affectedRowCount;
LastInsertId = lastInsertId;
ServerStatus = serverStatus;
WarningCount = warningCount;
StatusInfo = statusInfo;
NewSchema = newSchema;
ClientCharacterSet = clientCharacterSet;
ConnectionId = connectionId;
}

private static readonly OkPayload s_autoCommitOk = new(0, 0, ServerStatus.AutoCommit, 0, null);
private static readonly OkPayload s_autoCommitSessionStateChangedOk = new(0, 0, ServerStatus.AutoCommit | ServerStatus.SessionStateChanged, 0, null);
private static readonly OkPayload s_autoCommitOk = new(0, 0, ServerStatus.AutoCommit, 0, default, default, default, default);
private static readonly OkPayload s_autoCommitSessionStateChangedOk = new(0, 0, ServerStatus.AutoCommit | ServerStatus.SessionStateChanged, 0, default, default, default, default);
}

0 comments on commit 77ad191

Please sign in to comment.