diff --git a/src/Connectors.Memory.SqlServer.Tests/SqlServerMemoryDbTests.cs b/src/Connectors.Memory.SqlServer.Tests/SqlServerMemoryDbTests.cs index 0783a9f..29b0d2b 100644 --- a/src/Connectors.Memory.SqlServer.Tests/SqlServerMemoryDbTests.cs +++ b/src/Connectors.Memory.SqlServer.Tests/SqlServerMemoryDbTests.cs @@ -486,11 +486,14 @@ public async Task GetSimilarListShouldNotReturnExpectedWithFiltersAsync() var filter = new MemoryFilter().ByTag("test", "record1"); + _ = this._textEmbeddingGeneratorMock.Setup(x => x.GenerateEmbeddingAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(compareEmbedding); + // Act double threshold = -1; var topNResults = memoryDb.GetSimilarListAsync( - index: collection, - text: "Sample", + index: collection, + text: "Sample", limit: 4, minRelevance: threshold, filters: new[] { filter }) @@ -501,4 +504,125 @@ public async Task GetSimilarListShouldNotReturnExpectedWithFiltersAsync() Assert.NotNull(topNResults); Assert.Empty(topNResults); } + + /// + /// Test that get similar list should return expected results. + /// + /// + [Fact] + public async Task GetSimilarListShouldNotReturnExpectedWithFiltersWithANDClauseAsync() + { + // Arrange + var memoryDb = this.CreateMemoryDb(); + + var compareEmbedding = new ReadOnlyMemory(new float[] { 1, 1, 1 }); + string collection = "test_collection"; + await memoryDb.CreateIndexAsync(collection, 1536); + int i = 0; + + MemoryRecord testRecord = new MemoryRecord() + { + Id = "test" + i, + Vector = new float[] { 1, 1, 1 } + }; + + testRecord.Tags.Add("test", "record0"); + testRecord.Tags.Add("test", "test"); + + _ = await memoryDb.UpsertAsync(collection, testRecord); + + testRecord = new MemoryRecord() + { + Id = "test" + i, + Vector = new float[] { 1, 1, 1 } + }; + + testRecord.Tags.Add("test", "record1"); + testRecord.Tags.Add("test", "test"); + + _ = await memoryDb.UpsertAsync(collection, testRecord); + + _ = this._textEmbeddingGeneratorMock.Setup(x => x.GenerateEmbeddingAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(compareEmbedding); + + // Act + double threshold = -1; + var topNResults = memoryDb.GetSimilarListAsync( + index: collection, + text: "Sample", + limit: 4, + minRelevance: threshold, + filters: new[] { + new MemoryFilter() + .ByTag("test", "record0") + .ByTag("test", "test") + }) + .ToEnumerable() + .ToArray(); + + // Assert + Assert.NotNull(topNResults); + Assert.Single(topNResults); + } + + /// + /// Test that get similar list should return expected results. + /// + /// + [Fact] + public async Task GetSimilarListShouldNotReturnExpectedWithFiltersWithORClauseAsync() + { + // Arrange + var memoryDb = this.CreateMemoryDb(); + var compareEmbedding = new ReadOnlyMemory(new float[] { 1, 1, 1 }); + + string collection = "test_collection"; + await memoryDb.CreateIndexAsync(collection, 1536); + int i = 0; + + MemoryRecord testRecord = new MemoryRecord() + { + Id = "test" + i++, + Vector = new float[] { 1, 1, 1 } + }; + + testRecord.Tags.Add("test", "record0"); + testRecord.Tags.Add("test", "test"); + + _ = await memoryDb.UpsertAsync(collection, testRecord); + + testRecord = new MemoryRecord() + { + Id = "test" + i, + Vector = new float[] { 1, 1, 1 } + }; + + testRecord.Tags.Add("test", "record1"); + testRecord.Tags.Add("test", "test"); + + _ = await memoryDb.UpsertAsync(collection, testRecord); + + _ = this._textEmbeddingGeneratorMock.Setup(x => x.GenerateEmbeddingAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(compareEmbedding); + + // Act + double threshold = -1; + var topNResults = memoryDb.GetSimilarListAsync( + index: collection, + text: "Sample", + limit: 4, + minRelevance: threshold, + filters: new[] { + new MemoryFilter() + .ByTag("test", "record0"), + new MemoryFilter() + .ByTag("test", "record1") + }) + .ToEnumerable() + .ToArray(); + + // Assert + Assert.NotNull(topNResults); + Assert.Equal(2, topNResults.Length); + } } diff --git a/src/KernelMemory.MemoryStorage.SqlServer/SqlServerMemory.cs b/src/KernelMemory.MemoryStorage.SqlServer/SqlServerMemory.cs index c35f33d..a1256b5 100644 --- a/src/KernelMemory.MemoryStorage.SqlServer/SqlServerMemory.cs +++ b/src/KernelMemory.MemoryStorage.SqlServer/SqlServerMemory.cs @@ -10,6 +10,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq; +using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -249,28 +250,10 @@ SELECT TOP (@limit) {queryColumns} FROM {this.GetFullTableName(MemoryTableName)} - LEFT JOIN - {this.GetFullTableName($"{TagsTableName}_{index}")} ON {this.GetFullTableName(MemoryTableName)}.[id] = {this.GetFullTableName($"{TagsTableName}_{index}")}.[memory_id] - LEFT JOIN [filters] ON - {this.GetFullTableName($"{TagsTableName}_{index}")}.[name] = [filters].[name] AND {this.GetFullTableName($"{TagsTableName}_{index}")}.[value] = [filters].[value] - WHERE 1=1 - AND {this.GetFullTableName(MemoryTableName)}.[collection] = @index"; + WHERE 1=1 + AND {this.GetFullTableName(MemoryTableName)}.[collection] = @index + {GenerateFilters(index, cmd.Parameters, filters)};"; - if (filters is not null) - { - filters.ToList() - .ForEach(c => c.ToList().ForEach(x => - { - tagFilters.Add(x.Key, x.Value); - - cmd.CommandText += $@" - AND [filters].[name] = @filter_name_{x.Key} - AND [filters].[value] = @filter_value_{x.Key}"; - - cmd.Parameters.AddWithValue($"@filter_name_{x.Key}", x.Key); - cmd.Parameters.AddWithValue($"@filter_value_{x.Key}", JsonSerializer.Serialize(x.Value)); - })); - } cmd.Parameters.AddWithValue("@index", index); cmd.Parameters.AddWithValue("@limit", limit); @@ -304,14 +287,8 @@ await connection.OpenAsync(cancellationToken) using SqlCommand cmd = connection.CreateCommand(); cmd.CommandText = $@" - WITH [filters] AS - ( - SELECT - cast([filters].[key] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [name], - cast([filters].[value] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [value] - FROM openjson(@filters) [filters] - ) - ,[embedding] as + WITH + [embedding] as ( SELECT cast([key] AS INT) AS [vector_value_id], @@ -339,54 +316,24 @@ GROUP BY ORDER BY cosine_similarity DESC ) - SELECT + SELECT DISTINCT {this.GetFullTableName(MemoryTableName)}.[id], {this.GetFullTableName(MemoryTableName)}.[key], {this.GetFullTableName(MemoryTableName)}.[payload], {this.GetFullTableName(MemoryTableName)}.[tags], - {this.GetFullTableName(MemoryTableName)}.[embedding], - ( - SELECT - [vector_value] - FROM {this.GetFullTableName($"{EmbeddingsTableName}_{index}")} - WHERE {this.GetFullTableName(MemoryTableName)}.[id] = {this.GetFullTableName($"{EmbeddingsTableName}_{index}")}.[memory_id] - ORDER BY vector_value_id - FOR JSON AUTO - ) AS [embeddings], [similarity].[cosine_similarity] FROM [similarity] INNER JOIN {this.GetFullTableName(MemoryTableName)} ON [similarity].[memory_id] = {this.GetFullTableName(MemoryTableName)}.[id] - LEFT JOIN {this.GetFullTableName($"{TagsTableName}_{index}")} ON {this.GetFullTableName(MemoryTableName)}.[id] = {this.GetFullTableName($"{TagsTableName}_{index}")}.[memory_id] - LEFT JOIN [filters] ON {this.GetFullTableName($"{TagsTableName}_{index}")}.[name] = [filters].[name] AND {this.GetFullTableName($"{TagsTableName}_{index}")}.[value] = [filters].[value] WHERE 1=1 AND cosine_similarity >= @min_relevance_score - "; - - var tagFilters = new TagCollection(); - - if (filters is not null) - { - filters.ToList() - .ForEach(c => c.ToList().ForEach(x => - { - tagFilters.Add(x.Key, x.Value); - - cmd.CommandText += $@" - AND [filters].[name] = @filter_name_{x.Key} - AND [filters].[value] = @filter_value_{x.Key}"; - - cmd.Parameters.AddWithValue($"@filter_name_{x.Key}", x.Key); - cmd.Parameters.AddWithValue($"@filter_value_{x.Key}", JsonSerializer.Serialize(x.Value)); - })); - } + {GenerateFilters(index, cmd.Parameters, filters)}"; cmd.Parameters.AddWithValue("@vector", JsonSerializer.Serialize(embedding.Data.ToArray())); cmd.Parameters.AddWithValue("@index", index); cmd.Parameters.AddWithValue("@min_relevance_score", minRelevance); cmd.Parameters.AddWithValue("@limit", limit); - cmd.Parameters.AddWithValue("@filters", JsonSerializer.Serialize(tagFilters)); using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); @@ -438,26 +385,32 @@ WHEN NOT MATCHED THEN [src].[vector_value_id], [src].[vector_value] ); + DELETE FROM [tgt] + FROM {this.GetFullTableName($"{TagsTableName}_{index}")} AS [tgt] + INNER JOIN {this.GetFullTableName(MemoryTableName)} ON [tgt].[memory_id] = {this.GetFullTableName(MemoryTableName)}.[id] + WHERE {this.GetFullTableName(MemoryTableName)}.[key] = @key + AND {this.GetFullTableName(MemoryTableName)}.[collection] = @index; + MERGE {this.GetFullTableName($"{TagsTableName}_{index}")} AS [tgt] USING ( SELECT {this.GetFullTableName(MemoryTableName)}.[id], cast([tags].[key] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [tag_name], - cast([tags].[value] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [tag_value] + [tag_value].[value] AS [value] FROM {this.GetFullTableName(MemoryTableName)} - CROSS APPLY - openjson(@tags) [tags] + CROSS APPLY openjson(@tags) [tags] + CROSS APPLY openjson(cast([tags].[value] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS) [tag_value] WHERE {this.GetFullTableName(MemoryTableName)}.[key] = @key AND {this.GetFullTableName(MemoryTableName)}.[collection] = @index ) AS [src] ON [tgt].[memory_id] = [src].[id] AND [tgt].[name] = [src].[tag_name] WHEN MATCHED THEN - UPDATE SET [tgt].[value] = [src].[tag_value] + UPDATE SET [tgt].[value] = [src].[value] WHEN NOT MATCHED THEN INSERT ([memory_id], [name], [value]) VALUES ([src].[id], [src].[tag_name], - [src].[tag_value]);"; + [src].[value]);"; cmd.Parameters.AddWithValue("@index", index); cmd.Parameters.AddWithValue("@key", record.Id); @@ -487,9 +440,9 @@ IF OBJECT_ID(N'{this.GetFullTableName(MemoryTableName)}', N'U') IS NULL ( [id] UNIQUEIDENTIFIER NOT NULL, [key] NVARCHAR(256) NOT NULL, [collection] NVARCHAR(256) NOT NULL, - [payload] TEXT, - [tags] TEXT, - [embedding] TEXT, + [payload] NVARCHAR(MAX), + [tags] NVARCHAR(MAX), + [embedding] NVARCHAR(MAX), PRIMARY KEY ([id]), FOREIGN KEY ([collection]) REFERENCES {this.GetFullTableName(MemoryCollectionTableName)}([id]) ON DELETE CASCADE, CONSTRAINT UK_{MemoryTableName} UNIQUE([collection], [key]) @@ -538,6 +491,57 @@ private string GetFullTableName(string tableName) return $"[{this._config.Schema}].[{tableName}]"; } + /// + /// Generates the filters as SQL commands and sets the SQL parameters + /// + /// The index name. + /// The SQL parameters to populate. + /// The filters to apply + /// + private string GenerateFilters(string index, SqlParameterCollection parameters, ICollection ? filters = null) + { + var filterBuilder = new StringBuilder(); + + if (filters is not null) + { + filterBuilder.Append($@"AND ( + "); + + for (int i = 0; i < filters.Count; i++) + { + var filter = filters.ElementAt(i); + + if (i > 0) + { + filterBuilder.Append(" OR "); + } + + filterBuilder.Append($@"EXISTS ( + SELECT + 1 + FROM ( + SELECT + cast([filters].[key] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [name], + [tag_value].[value] AS[value] + FROM openjson(@filter_{i}) [filters] + CROSS APPLY openjson(cast([filters].[value] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS)[tag_value] + ) AS [filter] + INNER JOIN {this.GetFullTableName($"{TagsTableName}_{index}")} AS [tags] ON [filter].[name] = [tags].[name] AND [filter].[value] = [tags].[value] + WHERE + [tags].[memory_id] = {this.GetFullTableName(MemoryTableName)}.[id] + ) + "); + + parameters.AddWithValue($"@filter_{i}", JsonSerializer.Serialize(filter)); + } + + filterBuilder.Append(@" + )"); + } + + return filterBuilder.ToString(); + } + private async Task ReadEntryAsync(SqlDataReader dataReader, bool withEmbedding, CancellationToken cancellationToken = default) { var entry = new MemoryRecord();