Skip to content

Commit

Permalink
Fix filtering and ordering indexes search results (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis authored Sep 21, 2023
1 parent 96e602d commit aa668f9
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 48 deletions.
4 changes: 2 additions & 2 deletions examples/embeddings/simpleVector/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func main() {
}
}

query := "Describe within a paragraph what is the purpose of the NATO Alliance."
query := "What is the purpose of the NATO Alliance?"
similarities, err := docsVectorIndex.Query(
context.Background(),
query,
Expand Down Expand Up @@ -83,7 +83,7 @@ func ingestData(docsVectorIndex *simplevectorindex.Index, openaiEmbedder index.E
return err
}

textSplitter := textsplitter.NewRecursiveCharacterTextSplitter(2000, 100)
textSplitter := textsplitter.NewRecursiveCharacterTextSplitter(1000, 20)

documentChunks := textSplitter.SplitDocuments(documents)

Expand Down
15 changes: 0 additions & 15 deletions index/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package index
import (
"context"
"errors"
"sort"

"github.com/henomis/lingoose/document"
"github.com/henomis/lingoose/embedder"
Expand Down Expand Up @@ -58,20 +57,6 @@ type Embedder interface {
Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error)
}

func FilterSearchResults(searchResults SearchResults, topK int) SearchResults {
//sort by similarity score
sort.Slice(searchResults, func(i, j int) bool {
return searchResults[i].Score > searchResults[j].Score
})

maxTopK := topK
if maxTopK > len(searchResults) {
maxTopK = len(searchResults)
}

return searchResults[:maxTopK]
}

func DeepCopyMetadata(metadata types.Meta) types.Meta {
metadataCopy := make(types.Meta)
for k, v := range metadata {
Expand Down
12 changes: 6 additions & 6 deletions index/pinecone/pinecone.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Index struct {
namespace string
embedder index.Embedder
includeContent bool
includeValues bool
batchUpsertSize int

createIndex *CreateIndexOptions
Expand All @@ -44,6 +45,7 @@ type Options struct {
IndexName string
Namespace string
IncludeContent bool
IncludeValues bool
BatchUpsertSize *int
CreateIndex *CreateIndexOptions
}
Expand All @@ -65,6 +67,7 @@ func New(options Options, embedder index.Embedder) *Index {
embedder: embedder,
namespace: options.Namespace,
includeContent: options.IncludeContent,
includeValues: options.IncludeValues,
batchUpsertSize: batchUpsertSize,
createIndex: options.CreateIndex,
}
Expand Down Expand Up @@ -165,9 +168,7 @@ func (p *Index) Search(ctx context.Context, values []float64, opts ...option.Opt
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

searchResults := buildSearchResultsFromPineconeMatches(matches, p.includeContent)

return index.FilterSearchResults(searchResults, pineconeOptions.TopK), nil
return buildSearchResultsFromPineconeMatches(matches, p.includeContent), nil
}

func (p *Index) Query(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {
Expand All @@ -188,9 +189,7 @@ func (p *Index) Query(ctx context.Context, query string, opts ...option.Option)
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

searchResults := buildSearchResultsFromPineconeMatches(matches, p.includeContent)

return index.FilterSearchResults(searchResults, pineconeOptions.TopK), nil
return buildSearchResultsFromPineconeMatches(matches, p.includeContent), nil
}

func (p *Index) query(ctx context.Context, query string, opts *option.Options) ([]pineconeresponse.QueryMatch, error) {
Expand Down Expand Up @@ -222,6 +221,7 @@ func (p *Index) similaritySearch(
TopK: int32(opts.TopK),
Vector: values,
IncludeMetadata: &includeMetadata,
IncludeValues: &p.includeValues,
Namespace: &p.namespace,
Filter: opts.Filter.(map[string]string),
},
Expand Down
12 changes: 6 additions & 6 deletions index/qdrant/qdrant.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type Index struct {
collectionName string
embedder index.Embedder
includeContent bool
includeValues bool
batchUpsertSize int

createCollection *CreateCollectionOptions
Expand All @@ -47,6 +48,7 @@ type CreateCollectionOptions struct {
type Options struct {
CollectionName string
IncludeContent bool
IncludeValues bool
BatchUpsertSize *int
CreateCollection *CreateCollectionOptions
}
Expand All @@ -67,6 +69,7 @@ func New(options Options, embedder index.Embedder) *Index {
collectionName: options.CollectionName,
embedder: embedder,
includeContent: options.IncludeContent,
includeValues: options.IncludeValues,
batchUpsertSize: batchUpsertSize,
createCollection: options.CreateCollection,
}
Expand Down Expand Up @@ -150,9 +153,7 @@ func (q *Index) Search(ctx context.Context, values []float64, opts ...option.Opt
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

searchResults := buildSearchResultsFromQdrantMatches(matches, q.includeContent)

return index.FilterSearchResults(searchResults, qdrantOptions.TopK), nil
return buildSearchResultsFromQdrantMatches(matches, q.includeContent), nil
}

func (q *Index) Query(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {
Expand All @@ -169,9 +170,7 @@ func (q *Index) Query(ctx context.Context, query string, opts ...option.Option)
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

searchResults := buildSearchResultsFromQdrantMatches(matches, q.includeContent)

return index.FilterSearchResults(searchResults, qdrantOptions.TopK), nil
return buildSearchResultsFromQdrantMatches(matches, q.includeContent), nil
}

func (q *Index) similaritySearch(
Expand All @@ -192,6 +191,7 @@ func (q *Index) similaritySearch(
Limit: opts.TopK,
Vector: values,
WithPayload: &includeMetadata,
WithVector: &q.includeValues,
Filter: opts.Filter.(qdrantrequest.Filter),
},
res,
Expand Down
74 changes: 55 additions & 19 deletions index/simpleVectorIndex/simpleVectorIndex.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package simplevectorindex
import (
"context"
"encoding/json"
"errors"
"fmt"
"math"
"os"
"sort"
"strings"

"github.com/google/uuid"
Expand Down Expand Up @@ -211,7 +213,10 @@ func (s *Index) similaritySearch(
opts *option.Options,
) (index.SearchResults, error) {
_ = ctx
scores := s.cosineSimilarityBatch(embedding)
scores, err := s.cosineSimilarityBatch(embedding)
if err != nil {
return nil, fmt.Errorf("%w: %w", index.ErrInternal, err)
}

searchResults := make([]index.SearchResult, len(scores))

Expand All @@ -230,33 +235,64 @@ func (s *Index) similaritySearch(
searchResults = opts.Filter.(FilterFn)(searchResults)
}

return index.FilterSearchResults(searchResults, opts.TopK), nil
return filterSearchResults(searchResults, opts.TopK), nil
}

func (s *Index) cosineSimilarity(a embedder.Embedding, b embedder.Embedding) float64 {
dotProduct := float64(0.0)
normA := float64(0.0)
normB := float64(0.0)

for i := 0; i < len(a); i++ {
dotProduct += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
func (s *Index) cosineSimilarity(a []float64, b []float64) (cosine float64, err error) {
var count int
lengthA := len(a)
lengthB := len(b)
if lengthA > lengthB {
count = lengthA
} else {
count = lengthB
}

if normA == 0 || normB == 0 {
return float64(0.0)
sumA := 0.0
s1 := 0.0
s2 := 0.0
for k := 0; k < count; k++ {
if k >= lengthA {
s2 += math.Pow(b[k], 2)
continue
}
if k >= lengthB {
s1 += math.Pow(a[k], 2)
continue
}
sumA += a[k] * b[k]
s1 += math.Pow(a[k], 2)
s2 += math.Pow(b[k], 2)
}

return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
if s1 == 0 || s2 == 0 {
return 0.0, errors.New("vectors should not be null (all zeros)")
}
return sumA / (math.Sqrt(s1) * math.Sqrt(s2)), nil
}

func (s *Index) cosineSimilarityBatch(a embedder.Embedding) []float64 {
func (s *Index) cosineSimilarityBatch(a embedder.Embedding) ([]float64, error) {
var err error
scores := make([]float64, len(s.data))

for i := range s.data {
scores[i] = s.cosineSimilarity(a, s.data[i].Values)
scores[i], err = s.cosineSimilarity(a, s.data[i].Values)
if err != nil {
return nil, err
}
}

return scores, nil
}

func filterSearchResults(searchResults index.SearchResults, topK int) index.SearchResults {
//sort by similarity score
sort.Slice(searchResults, func(i, j int) bool {
return (1 - searchResults[i].Score) < (1 - searchResults[j].Score)
})

maxTopK := topK
if maxTopK > len(searchResults) {
maxTopK = len(searchResults)
}

return scores
return searchResults[:maxTopK]
}

0 comments on commit aa668f9

Please sign in to comment.