Skip to content

Commit

Permalink
Fix #70 -filters doesn't apply correctly in AND clause and bug when u…
Browse files Browse the repository at this point in the history
…sing OR clause (#71)

* Add Kernel Memory support

* Prepare for publish

* Fix #70 - fix tag filter for OR and AND clauses
  • Loading branch information
kbeaugrand authored Jan 9, 2024
1 parent 4cf2535 commit 603466c
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 71 deletions.
128 changes: 126 additions & 2 deletions src/Connectors.Memory.SqlServer.Tests/SqlServerMemoryDbTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,14 @@ public async Task GetSimilarListShouldNotReturnExpectedWithFiltersAsync()

var filter = new MemoryFilter().ByTag("test", "record1");

_ = this._textEmbeddingGeneratorMock.Setup(x => x.GenerateEmbeddingAsync(It.IsAny<string>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(compareEmbedding);

// Act
double threshold = -1;
var topNResults = memoryDb.GetSimilarListAsync(
index: collection,
text: "Sample",
index: collection,
text: "Sample",
limit: 4,
minRelevance: threshold,
filters: new[] { filter })
Expand All @@ -501,4 +504,125 @@ public async Task GetSimilarListShouldNotReturnExpectedWithFiltersAsync()
Assert.NotNull(topNResults);
Assert.Empty(topNResults);
}

/// <summary>
/// Test that get similar list should return expected results.
/// </summary>
/// <returns></returns>
[Fact]
public async Task GetSimilarListShouldNotReturnExpectedWithFiltersWithANDClauseAsync()
{
// Arrange
var memoryDb = this.CreateMemoryDb();

var compareEmbedding = new ReadOnlyMemory<float>(new float[] { 1, 1, 1 });
string collection = "test_collection";
await memoryDb.CreateIndexAsync(collection, 1536);
int i = 0;

MemoryRecord testRecord = new MemoryRecord()
{
Id = "test" + i,
Vector = new float[] { 1, 1, 1 }
};

testRecord.Tags.Add("test", "record0");
testRecord.Tags.Add("test", "test");

_ = await memoryDb.UpsertAsync(collection, testRecord);

testRecord = new MemoryRecord()
{
Id = "test" + i,
Vector = new float[] { 1, 1, 1 }
};

testRecord.Tags.Add("test", "record1");
testRecord.Tags.Add("test", "test");

_ = await memoryDb.UpsertAsync(collection, testRecord);

_ = this._textEmbeddingGeneratorMock.Setup(x => x.GenerateEmbeddingAsync(It.IsAny<string>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(compareEmbedding);

// Act
double threshold = -1;
var topNResults = memoryDb.GetSimilarListAsync(
index: collection,
text: "Sample",
limit: 4,
minRelevance: threshold,
filters: new[] {
new MemoryFilter()
.ByTag("test", "record0")
.ByTag("test", "test")
})
.ToEnumerable()
.ToArray();

// Assert
Assert.NotNull(topNResults);
Assert.Single(topNResults);
}

/// <summary>
/// Test that get similar list should return expected results.
/// </summary>
/// <returns></returns>
[Fact]
public async Task GetSimilarListShouldNotReturnExpectedWithFiltersWithORClauseAsync()
{
// Arrange
var memoryDb = this.CreateMemoryDb();
var compareEmbedding = new ReadOnlyMemory<float>(new float[] { 1, 1, 1 });

string collection = "test_collection";
await memoryDb.CreateIndexAsync(collection, 1536);
int i = 0;

MemoryRecord testRecord = new MemoryRecord()
{
Id = "test" + i++,
Vector = new float[] { 1, 1, 1 }
};

testRecord.Tags.Add("test", "record0");
testRecord.Tags.Add("test", "test");

_ = await memoryDb.UpsertAsync(collection, testRecord);

testRecord = new MemoryRecord()
{
Id = "test" + i,
Vector = new float[] { 1, 1, 1 }
};

testRecord.Tags.Add("test", "record1");
testRecord.Tags.Add("test", "test");

_ = await memoryDb.UpsertAsync(collection, testRecord);

_ = this._textEmbeddingGeneratorMock.Setup(x => x.GenerateEmbeddingAsync(It.IsAny<string>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(compareEmbedding);

// Act
double threshold = -1;
var topNResults = memoryDb.GetSimilarListAsync(
index: collection,
text: "Sample",
limit: 4,
minRelevance: threshold,
filters: new[] {
new MemoryFilter()
.ByTag("test", "record0"),
new MemoryFilter()
.ByTag("test", "record1")
})
.ToEnumerable()
.ToArray();

// Assert
Assert.NotNull(topNResults);
Assert.Equal(2, topNResults.Length);
}
}
142 changes: 73 additions & 69 deletions src/KernelMemory.MemoryStorage.SqlServer/SqlServerMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -249,28 +250,10 @@ SELECT TOP (@limit)
{queryColumns}
FROM
{this.GetFullTableName(MemoryTableName)}
LEFT JOIN
{this.GetFullTableName($"{TagsTableName}_{index}")} ON {this.GetFullTableName(MemoryTableName)}.[id] = {this.GetFullTableName($"{TagsTableName}_{index}")}.[memory_id]
LEFT JOIN [filters] ON
{this.GetFullTableName($"{TagsTableName}_{index}")}.[name] = [filters].[name] AND {this.GetFullTableName($"{TagsTableName}_{index}")}.[value] = [filters].[value]
WHERE 1=1
AND {this.GetFullTableName(MemoryTableName)}.[collection] = @index";
WHERE 1=1
AND {this.GetFullTableName(MemoryTableName)}.[collection] = @index
{GenerateFilters(index, cmd.Parameters, filters)};";

if (filters is not null)
{
filters.ToList()
.ForEach(c => c.ToList().ForEach(x =>
{
tagFilters.Add(x.Key, x.Value);

cmd.CommandText += $@"
AND [filters].[name] = @filter_name_{x.Key}
AND [filters].[value] = @filter_value_{x.Key}";

cmd.Parameters.AddWithValue($"@filter_name_{x.Key}", x.Key);
cmd.Parameters.AddWithValue($"@filter_value_{x.Key}", JsonSerializer.Serialize(x.Value));
}));
}

cmd.Parameters.AddWithValue("@index", index);
cmd.Parameters.AddWithValue("@limit", limit);
Expand Down Expand Up @@ -304,14 +287,8 @@ await connection.OpenAsync(cancellationToken)
using SqlCommand cmd = connection.CreateCommand();

cmd.CommandText = $@"
WITH [filters] AS
(
SELECT
cast([filters].[key] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [name],
cast([filters].[value] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [value]
FROM openjson(@filters) [filters]
)
,[embedding] as
WITH
[embedding] as
(
SELECT
cast([key] AS INT) AS [vector_value_id],
Expand Down Expand Up @@ -339,54 +316,24 @@ GROUP BY
ORDER BY
cosine_similarity DESC
)
SELECT
SELECT DISTINCT
{this.GetFullTableName(MemoryTableName)}.[id],
{this.GetFullTableName(MemoryTableName)}.[key],
{this.GetFullTableName(MemoryTableName)}.[payload],
{this.GetFullTableName(MemoryTableName)}.[tags],
{this.GetFullTableName(MemoryTableName)}.[embedding],
(
SELECT
[vector_value]
FROM {this.GetFullTableName($"{EmbeddingsTableName}_{index}")}
WHERE {this.GetFullTableName(MemoryTableName)}.[id] = {this.GetFullTableName($"{EmbeddingsTableName}_{index}")}.[memory_id]
ORDER BY vector_value_id
FOR JSON AUTO
) AS [embeddings],
[similarity].[cosine_similarity]
FROM
[similarity]
INNER JOIN
{this.GetFullTableName(MemoryTableName)} ON [similarity].[memory_id] = {this.GetFullTableName(MemoryTableName)}.[id]
LEFT JOIN {this.GetFullTableName($"{TagsTableName}_{index}")} ON {this.GetFullTableName(MemoryTableName)}.[id] = {this.GetFullTableName($"{TagsTableName}_{index}")}.[memory_id]
LEFT JOIN [filters] ON {this.GetFullTableName($"{TagsTableName}_{index}")}.[name] = [filters].[name] AND {this.GetFullTableName($"{TagsTableName}_{index}")}.[value] = [filters].[value]
WHERE 1=1
AND cosine_similarity >= @min_relevance_score
";

var tagFilters = new TagCollection();

if (filters is not null)
{
filters.ToList()
.ForEach(c => c.ToList().ForEach(x =>
{
tagFilters.Add(x.Key, x.Value);

cmd.CommandText += $@"
AND [filters].[name] = @filter_name_{x.Key}
AND [filters].[value] = @filter_value_{x.Key}";

cmd.Parameters.AddWithValue($"@filter_name_{x.Key}", x.Key);
cmd.Parameters.AddWithValue($"@filter_value_{x.Key}", JsonSerializer.Serialize(x.Value));
}));
}
{GenerateFilters(index, cmd.Parameters, filters)}";

cmd.Parameters.AddWithValue("@vector", JsonSerializer.Serialize(embedding.Data.ToArray()));
cmd.Parameters.AddWithValue("@index", index);
cmd.Parameters.AddWithValue("@min_relevance_score", minRelevance);
cmd.Parameters.AddWithValue("@limit", limit);
cmd.Parameters.AddWithValue("@filters", JsonSerializer.Serialize(tagFilters));

using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);

Expand Down Expand Up @@ -438,26 +385,32 @@ WHEN NOT MATCHED THEN
[src].[vector_value_id],
[src].[vector_value] );
DELETE FROM [tgt]
FROM {this.GetFullTableName($"{TagsTableName}_{index}")} AS [tgt]
INNER JOIN {this.GetFullTableName(MemoryTableName)} ON [tgt].[memory_id] = {this.GetFullTableName(MemoryTableName)}.[id]
WHERE {this.GetFullTableName(MemoryTableName)}.[key] = @key
AND {this.GetFullTableName(MemoryTableName)}.[collection] = @index;
MERGE {this.GetFullTableName($"{TagsTableName}_{index}")} AS [tgt]
USING (
SELECT
{this.GetFullTableName(MemoryTableName)}.[id],
cast([tags].[key] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [tag_name],
cast([tags].[value] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [tag_value]
[tag_value].[value] AS [value]
FROM {this.GetFullTableName(MemoryTableName)}
CROSS APPLY
openjson(@tags) [tags]
CROSS APPLY openjson(@tags) [tags]
CROSS APPLY openjson(cast([tags].[value] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS) [tag_value]
WHERE {this.GetFullTableName(MemoryTableName)}.[key] = @key
AND {this.GetFullTableName(MemoryTableName)}.[collection] = @index
) AS [src]
ON [tgt].[memory_id] = [src].[id] AND [tgt].[name] = [src].[tag_name]
WHEN MATCHED THEN
UPDATE SET [tgt].[value] = [src].[tag_value]
UPDATE SET [tgt].[value] = [src].[value]
WHEN NOT MATCHED THEN
INSERT ([memory_id], [name], [value])
VALUES ([src].[id],
[src].[tag_name],
[src].[tag_value]);";
[src].[value]);";

cmd.Parameters.AddWithValue("@index", index);
cmd.Parameters.AddWithValue("@key", record.Id);

Check warning on line 416 in src/KernelMemory.MemoryStorage.SqlServer/SqlServerMemory.cs

View workflow job for this annotation

GitHub Actions / Publish (KernelMemory.MemoryStorage.SqlServer)

In externally visible method 'Task<string> SqlServerMemory.UpsertAsync(string index, MemoryRecord record, CancellationToken cancellationToken = default(CancellationToken))', validate parameter 'record' is non-null before using it. If appropriate, throw an 'ArgumentNullException' when the argument is 'null'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1062)
Expand Down Expand Up @@ -487,9 +440,9 @@ IF OBJECT_ID(N'{this.GetFullTableName(MemoryTableName)}', N'U') IS NULL
( [id] UNIQUEIDENTIFIER NOT NULL,
[key] NVARCHAR(256) NOT NULL,
[collection] NVARCHAR(256) NOT NULL,
[payload] TEXT,
[tags] TEXT,
[embedding] TEXT,
[payload] NVARCHAR(MAX),
[tags] NVARCHAR(MAX),
[embedding] NVARCHAR(MAX),
PRIMARY KEY ([id]),
FOREIGN KEY ([collection]) REFERENCES {this.GetFullTableName(MemoryCollectionTableName)}([id]) ON DELETE CASCADE,
CONSTRAINT UK_{MemoryTableName} UNIQUE([collection], [key])
Expand Down Expand Up @@ -538,6 +491,57 @@ private string GetFullTableName(string tableName)
return $"[{this._config.Schema}].[{tableName}]";
}

/// <summary>
/// Generates the filters as SQL commands and sets the SQL parameters
/// </summary>
/// <param name="index">The index name.</param>
/// <param name="parameters">The SQL parameters to populate.</param>
/// <param name="filters">The filters to apply</param>
/// <returns></returns>
private string GenerateFilters(string index, SqlParameterCollection parameters, ICollection<MemoryFilter> ? filters = null)
{
var filterBuilder = new StringBuilder();

if (filters is not null)
{
filterBuilder.Append($@"AND (
");

for (int i = 0; i < filters.Count; i++)
{
var filter = filters.ElementAt(i);

if (i > 0)
{
filterBuilder.Append(" OR ");
}

filterBuilder.Append($@"EXISTS (
SELECT
1
FROM (
SELECT
cast([filters].[key] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [name],
[tag_value].[value] AS[value]
FROM openjson(@filter_{i}) [filters]
CROSS APPLY openjson(cast([filters].[value] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS)[tag_value]
) AS [filter]
INNER JOIN {this.GetFullTableName($"{TagsTableName}_{index}")} AS [tags] ON [filter].[name] = [tags].[name] AND [filter].[value] = [tags].[value]
WHERE
[tags].[memory_id] = {this.GetFullTableName(MemoryTableName)}.[id]
)
");

parameters.AddWithValue($"@filter_{i}", JsonSerializer.Serialize(filter));
}

filterBuilder.Append(@"
)");
}

return filterBuilder.ToString();
}

private async Task<MemoryRecord> ReadEntryAsync(SqlDataReader dataReader, bool withEmbedding, CancellationToken cancellationToken = default)

Check warning on line 545 in src/KernelMemory.MemoryStorage.SqlServer/SqlServerMemory.cs

View workflow job for this annotation

GitHub Actions / Publish (KernelMemory.MemoryStorage.SqlServer)

Member 'ReadEntryAsync' does not access instance data and can be marked as static (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1822)
{
var entry = new MemoryRecord();
Expand Down

0 comments on commit 603466c

Please sign in to comment.