Skip to content

Commit

Permalink
Relates #101 - open connection when doing an action. (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
kbeaugrand authored Jan 31, 2024
1 parent fc10350 commit 6d288b2
Showing 1 changed file with 74 additions and 62 deletions.
136 changes: 74 additions & 62 deletions src/KernelMemory.MemoryStorage.SqlServer/SqlServerMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace KernelMemory.MemoryStorage.SqlServer;
/// </summary>
[System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities",
Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")]
public class SqlServerMemory : IMemoryDb, IDisposable
public class SqlServerMemory : IMemoryDb
{
/// <summary>
/// The SQL Server configuration.
Expand All @@ -40,12 +40,6 @@ public class SqlServerMemory : IMemoryDb, IDisposable
/// </summary>
private readonly ILogger<SqlServerMemory> _log;

/// <summary>
/// The SQL connection.
/// </summary>
SqlConnection _connection;


/// <summary>
/// Initializes a new instance of the <see cref="SqlServerMemory"/> class.
/// </summary>
Expand All @@ -62,9 +56,6 @@ public SqlServerMemory(

this._config = config;

this._connection = new SqlConnection(this._config.ConnectionString);
this._connection.Open();

if (this._embeddingGenerator == null)
{
throw new SqlServerMemoryException("Embedding generator not configured");
Expand All @@ -84,7 +75,10 @@ public async Task CreateIndexAsync(string index, int vectorSize, CancellationTok
return;
}

using (SqlCommand command = this._connection.CreateCommand())
using var connection = new SqlConnection(this._config.ConnectionString);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);

using (SqlCommand command = connection.CreateCommand())
{
command.CommandText = $@"
INSERT INTO {this.GetFullTableName(this._config.MemoryCollectionTableName)}([id])
Expand Down Expand Up @@ -131,9 +125,12 @@ public async Task DeleteAsync(string index, MemoryRecord record, CancellationTok
return;
}

using SqlCommand cmd = this._connection.CreateCommand();

cmd.CommandText = $@"
using var connection = new SqlConnection(this._config.ConnectionString);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);

using (SqlCommand command = connection.CreateCommand())
{
command.CommandText = $@"
DELETE [embeddings]
FROM {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")} [embeddings]
INNER JOIN {this.GetFullTableName(this._config.MemoryTableName)} ON [embeddings].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id]
Expand All @@ -149,12 +146,14 @@ DELETE [tags]
AND {this.GetFullTableName(this._config.MemoryTableName)}.[key]=@key;
DELETE FROM {this.GetFullTableName(this._config.MemoryTableName)} WHERE [collection] = @index AND [key]=@key;
";
";

command.Parameters.AddWithValue("@index", index);
command.Parameters.AddWithValue("@key", record.Id);

Check warning on line 152 in src/KernelMemory.MemoryStorage.SqlServer/SqlServerMemory.cs

View workflow job for this annotation

GitHub Actions / build

In externally visible method 'Task SqlServerMemory.DeleteAsync(string index, MemoryRecord record, CancellationToken cancellationToken = default(CancellationToken))', validate parameter 'record' is non-null before using it. If appropriate, throw an 'ArgumentNullException' when the argument is 'null'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1062)

Check warning on line 152 in src/KernelMemory.MemoryStorage.SqlServer/SqlServerMemory.cs

View workflow job for this annotation

GitHub Actions / Publish (KernelMemory.MemoryStorage.SqlServer)

In externally visible method 'Task SqlServerMemory.DeleteAsync(string index, MemoryRecord record, CancellationToken cancellationToken = default(CancellationToken))', validate parameter 'record' is non-null before using it. If appropriate, throw an 'ArgumentNullException' when the argument is 'null'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1062)

Check warning on line 152 in src/KernelMemory.MemoryStorage.SqlServer/SqlServerMemory.cs

View workflow job for this annotation

GitHub Actions / Publish (KernelMemory.MemoryStorage.SqlServer)

In externally visible method 'Task SqlServerMemory.DeleteAsync(string index, MemoryRecord record, CancellationToken cancellationToken = default(CancellationToken))', validate parameter 'record' is non-null before using it. If appropriate, throw an 'ArgumentNullException' when the argument is 'null'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1062)

cmd.Parameters.AddWithValue("@index", index);
cmd.Parameters.AddWithValue("@key", record.Id);
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);

await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
}

/// <inheritdoc/>
Expand All @@ -168,7 +167,10 @@ public async Task DeleteIndexAsync(string index, CancellationToken cancellationT
return;
}

using (SqlCommand command = this._connection.CreateCommand())
using var connection = new SqlConnection(this._config.ConnectionString);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);

using (SqlCommand command = connection.CreateCommand())
{
command.CommandText = $@"DELETE FROM {this.GetFullTableName(this._config.MemoryCollectionTableName)}
WHERE [id] = @index;
Expand All @@ -188,7 +190,10 @@ public async Task<IEnumerable<string>> GetIndexesAsync(CancellationToken cancell
{
List<string> indexes = new();

using (SqlCommand command = this._connection.CreateCommand())
using var connection = new SqlConnection(this._config.ConnectionString);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);

using (SqlCommand command = connection.CreateCommand())
{
command.CommandText = $"SELECT [id] FROM {this.GetFullTableName(this._config.MemoryCollectionTableName)}";

Expand Down Expand Up @@ -226,11 +231,14 @@ public async IAsyncEnumerable<MemoryRecord> GetListAsync(string index, ICollecti
limit = int.MaxValue;
}

using SqlCommand cmd = this._connection.CreateCommand();
using var connection = new SqlConnection(this._config.ConnectionString);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);

var tagFilters = new TagCollection();
using (SqlCommand command = connection.CreateCommand())
{
var tagFilters = new TagCollection();

cmd.CommandText = $@"
command.CommandText = $@"
WITH [filters] AS
(
SELECT
Expand All @@ -244,18 +252,20 @@ SELECT TOP (@limit)
{this.GetFullTableName(this._config.MemoryTableName)}
WHERE 1=1
AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
{GenerateFilters(index, cmd.Parameters, filters)};";
{GenerateFilters(index, command.Parameters, filters)};";


cmd.Parameters.AddWithValue("@index", index);
cmd.Parameters.AddWithValue("@limit", limit);
cmd.Parameters.AddWithValue("@filters", JsonSerializer.Serialize(tagFilters));
command.Parameters.AddWithValue("@index", index);
command.Parameters.AddWithValue("@limit", limit);
command.Parameters.AddWithValue("@filters", JsonSerializer.Serialize(tagFilters));

using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
using var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);

while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
{
yield return await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false);
}

while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
{
yield return await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false);
}
}

Expand All @@ -279,9 +289,12 @@ SELECT TOP (@limit)
queryColumns += ", [embedding]";
}

using SqlCommand cmd = this._connection.CreateCommand();
using var connection = new SqlConnection(this._config.ConnectionString);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);

cmd.CommandText = $@"
using (SqlCommand command = connection.CreateCommand())
{
command.CommandText = $@"
WITH
[embedding] as
(
Expand Down Expand Up @@ -323,20 +336,21 @@ INNER JOIN
{this.GetFullTableName(this._config.MemoryTableName)} ON [similarity].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id]
WHERE 1=1
AND [cosine_similarity] >= @min_relevance_score
{GenerateFilters(index, cmd.Parameters, filters)}
{GenerateFilters(index, command.Parameters, filters)}
ORDER BY [cosine_similarity] desc";

cmd.Parameters.AddWithValue("@vector", JsonSerializer.Serialize(embedding.Data.ToArray()));
cmd.Parameters.AddWithValue("@index", index);
cmd.Parameters.AddWithValue("@min_relevance_score", minRelevance);
cmd.Parameters.AddWithValue("@limit", limit);
command.Parameters.AddWithValue("@vector", JsonSerializer.Serialize(embedding.Data.ToArray()));
command.Parameters.AddWithValue("@index", index);
command.Parameters.AddWithValue("@min_relevance_score", minRelevance);
command.Parameters.AddWithValue("@limit", limit);

using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
using var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);

while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
{
double cosineSimilarity = dataReader.GetDouble(dataReader.GetOrdinal("cosine_similarity"));
yield return (await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false), cosineSimilarity);
while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
{
double cosineSimilarity = dataReader.GetDouble(dataReader.GetOrdinal("cosine_similarity"));
yield return (await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false), cosineSimilarity);
}
}
}

Expand All @@ -351,9 +365,12 @@ public async Task<string> UpsertAsync(string index, MemoryRecord record, Cancell
return string.Empty;
}

using SqlCommand cmd = this._connection.CreateCommand();
using var connection = new SqlConnection(this._config.ConnectionString);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);

cmd.CommandText = $@"
using (SqlCommand command = connection.CreateCommand())
{
command.CommandText = $@"
MERGE INTO {this.GetFullTableName(this._config.MemoryTableName)}
USING (SELECT @key) as [src]([key])
ON {this.GetFullTableName(this._config.MemoryTableName)}.[key] = [src].[key]
Expand Down Expand Up @@ -411,15 +428,16 @@ WHEN NOT MATCHED THEN
[src].[tag_name],
[src].[value]);";

cmd.Parameters.AddWithValue("@index", index);
cmd.Parameters.AddWithValue("@key", record.Id);
cmd.Parameters.AddWithValue("@payload", JsonSerializer.Serialize(record.Payload) ?? (object)DBNull.Value);
cmd.Parameters.AddWithValue("@tags", JsonSerializer.Serialize(record.Tags) ?? (object)DBNull.Value);
cmd.Parameters.AddWithValue("@embedding", JsonSerializer.Serialize(record.Vector.Data.ToArray()));
command.Parameters.AddWithValue("@index", index);
command.Parameters.AddWithValue("@key", record.Id);

Check warning on line 432 in src/KernelMemory.MemoryStorage.SqlServer/SqlServerMemory.cs

View workflow job for this annotation

GitHub Actions / build

In externally visible method 'Task<string> SqlServerMemory.UpsertAsync(string index, MemoryRecord record, CancellationToken cancellationToken = default(CancellationToken))', validate parameter 'record' is non-null before using it. If appropriate, throw an 'ArgumentNullException' when the argument is 'null'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1062)

Check warning on line 432 in src/KernelMemory.MemoryStorage.SqlServer/SqlServerMemory.cs

View workflow job for this annotation

GitHub Actions / Publish (KernelMemory.MemoryStorage.SqlServer)

In externally visible method 'Task<string> SqlServerMemory.UpsertAsync(string index, MemoryRecord record, CancellationToken cancellationToken = default(CancellationToken))', validate parameter 'record' is non-null before using it. If appropriate, throw an 'ArgumentNullException' when the argument is 'null'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1062)
command.Parameters.AddWithValue("@payload", JsonSerializer.Serialize(record.Payload) ?? (object)DBNull.Value);
command.Parameters.AddWithValue("@tags", JsonSerializer.Serialize(record.Tags) ?? (object)DBNull.Value);
command.Parameters.AddWithValue("@embedding", JsonSerializer.Serialize(record.Vector.Data.ToArray()));

await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);

return record.Id;
return record.Id;
}
}

/// <summary>
Expand Down Expand Up @@ -452,7 +470,10 @@ FOREIGN KEY ([collection]) REFERENCES {this.GetFullTableName(this._config.Memory
);
";

using (SqlCommand command = this._connection.CreateCommand())
using var connection = new SqlConnection(this._config.ConnectionString);
connection.Open();

using (SqlCommand command = connection.CreateCommand())
{
command.CommandText = sql;
command.ExecuteNonQuery();
Expand Down Expand Up @@ -593,13 +614,4 @@ private static string NormalizeIndexName(string index)

return index;
}

public void Dispose()
{
if (this._connection != null)
{
this._connection.Dispose();
this._connection = null!;
}
}
}

0 comments on commit 6d288b2

Please sign in to comment.