-
Notifications
You must be signed in to change notification settings - Fork 26
/
VectorTextIndex.cs
152 lines (129 loc) · 5.78 KB
/
VectorTextIndex.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
// Copyright (c) Microsoft. All rights reserved.
using Microsoft.TypeChat.Embeddings;
namespace Microsoft.TypeChat;
/// <summary>
/// VectorTextIndex is an in-memory vector text index that automatically vectorizes given items using a model
/// All embeddings are normalized automatically for performance.
/// Each item T has an associated text description. It is this description that is indexed using embeddings.
///
/// The VectorTextIndex is also a TextRequestRouter that uses embeddings to route text requests
/// </summary>
/// <typeparam name="T"></typeparam>
public class VectorTextIndex<T> : ITextRequestRouter<T>
{
TextEmbeddingModel _model;
VectorizedList<T> _list;
/// <summary>
/// Create a new VectorTextIndex
/// </summary>
/// <param name="model">embedding model</param>
public VectorTextIndex(TextEmbeddingModel model)
: this(model, new VectorizedList<T>())
{
}
/// <summary>
/// Create a new VectorTextIndex
/// </summary>
/// <param name="model">model to use</param>
/// <param name="list">vector list to use</param>
public VectorTextIndex(TextEmbeddingModel model, VectorizedList<T> list)
{
ArgumentVerify.ThrowIfNull(model, nameof(model));
ArgumentVerify.ThrowIfNull(list, nameof(list));
_model = model;
_list = list;
}
/// <summary>
/// Items in this index
/// </summary>
public VectorizedList<T> Items => _list;
/// <summary>
/// Route the given request to the semantically nearest T
/// Does so by comparing the embedding of request to that of all registered T
/// </summary>
/// <param name="request">tequest</param>
/// <param name="cancelToken">optional cancel token</param>
/// <returns></returns>
/// <exception cref="NotImplementedException"></exception>
public Task<T> RouteRequestAsync(string request, CancellationToken cancelToken = default)
{
return NearestAsync(request, cancelToken);
}
/// <summary>
/// Add an item to the collection. Its associated textKey will be vectorized into an embedding
/// </summary>
/// <param name="item">item to add to the index</param>
/// <param name="textRepresentation">The text representation of the item; its transformed into an embedding</param>
/// <param name="cancelToken">cancel token</param>
public async Task AddAsync(T item, string textRepresentation, CancellationToken cancelToken = default)
{
ArgumentVerify.ThrowIfNullOrEmpty(textRepresentation, nameof(textRepresentation));
var embedding = await GetNormalizedEmbeddingAsync(textRepresentation, cancelToken).ConfigureAwait(false);
_list.Add(item, embedding);
}
/// <summary>
/// A multiple items to the collection.
/// If the associated embedding model supports batching, this can be much faster
/// </summary>
/// <param name="items">items to add to the collection</param>
/// <param name="textRepresentations">the text representations of these items</param>
/// <param name="cancelToken">optional cancel token</param>
/// <exception cref="ArgumentException"></exception>
/// <exception cref="InvalidOperationException"></exception>
public async Task AddAsync(T[] items, string[] textRepresentations, CancellationToken cancelToken = default)
{
ArgumentVerify.ThrowIfNull(items, nameof(items));
ArgumentVerify.ThrowIfNull(textRepresentations, nameof(textRepresentations));
if (items.Length != textRepresentations.Length)
{
throw new ArgumentException("items and their representations must of the same length");
}
Embedding[] embeddings = await GetNormalizedEmbeddingAsync(textRepresentations, cancelToken).ConfigureAwait(false);
if (embeddings.Length != items.Length)
{
throw new InvalidOperationException($"Embedding length {embeddings.Length} does not match items length {items.Length}");
}
for (int i = 0; i < items.Length; ++i)
{
_list.Add(items[i], embeddings[i]);
}
}
/// <summary>
/// Find nearest match to the given text
/// </summary>
/// <param name="text"></param>
/// <param name="cancelToken">optional cancel token</param>
/// <returns>nearest item</returns>
public async Task<T> NearestAsync(string text, CancellationToken cancelToken = default)
{
var embedding = await GetNormalizedEmbeddingAsync(text, cancelToken).ConfigureAwait(false);
return _list.Nearest(embedding, EmbeddingDistance.Dot);
}
/// <summary>
/// Return topN text from the collection closest to the given text
/// </summary>
/// <param name="text">text to search for</param>
/// <param name="maxMatches">max matches</param>
/// <param name="cancelToken">optional cancel token</param>
/// <returns>list of matches</returns>
public async Task<List<T>> NearestAsync(string text, int maxMatches, CancellationToken cancelToken = default)
{
var embedding = await GetNormalizedEmbeddingAsync(text, cancelToken).ConfigureAwait(false);
return _list.Nearest(embedding, maxMatches, EmbeddingDistance.Dot).ToList();
}
async Task<Embedding> GetNormalizedEmbeddingAsync(string text, CancellationToken cancelToken)
{
var embedding = await _model.GenerateEmbeddingAsync(text, cancelToken).ConfigureAwait(false);
embedding.NormalizeInPlace();
return embedding;
}
async Task<Embedding[]> GetNormalizedEmbeddingAsync(string[] texts, CancellationToken cancelToken)
{
var embeddings = await _model.GenerateEmbeddingsAsync(texts, cancelToken).ConfigureAwait(false);
for (int i = 0; i < embeddings.Length; ++i)
{
embeddings[i].NormalizeInPlace();
}
return embeddings;
}
}