Skip to content

Commit

Permalink
embedding/openai: fix method-dependent embedding discrepancy (#357)
Browse files Browse the repository at this point in the history
Fix method-dependent embedding discrepancy for OpenAI

For #356
  • Loading branch information
eliben authored Nov 18, 2023
1 parent 09a09b3 commit 65725eb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
6 changes: 6 additions & 0 deletions embeddings/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ func (e OpenAI) EmbedDocuments(ctx context.Context, texts []string) ([][]float32
return nil, err
}

// If the size of this batch is 1, don't average/combine the vectors.
if len(texts) == 1 {
emb = append(emb, curTextEmbeddings[0])
continue
}

textLengths := make([]int, 0, len(texts))
for _, text := range texts {
textLengths = append(textLengths, len(text))
Expand Down
24 changes: 24 additions & 0 deletions embeddings/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,30 @@ func TestOpenaiEmbeddings(t *testing.T) {
assert.Len(t, embeddings, 3)
}

func TestOpenaiEmbeddingsQueryVsDocuments(t *testing.T) {
// Verifies that we get the same embedding for the same string, regardless
// of which method we call.
t.Parallel()

if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" {
t.Skip("OPENAI_API_KEY not set")
}
e, err := NewOpenAI()
require.NoError(t, err)

text := "hi there"

eq, err := e.EmbedQuery(context.Background(), text)
require.NoError(t, err)

eb, err := e.EmbedDocuments(context.Background(), []string{text})
require.NoError(t, err)

// Using strict equality should be OK here because we expect the same values
// for the same string, deterministically.
assert.Equal(t, eq, eb[0])
}

func TestOpenaiEmbeddingsWithOptions(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 65725eb

Please sign in to comment.