diff --git a/embeddings/openai/openai.go b/embeddings/openai/openai.go index 5ed7cd1e1..5c114d361 100644 --- a/embeddings/openai/openai.go +++ b/embeddings/openai/openai.go @@ -42,6 +42,12 @@ func (e OpenAI) EmbedDocuments(ctx context.Context, texts []string) ([][]float32 return nil, err } + // If the size of this batch is 1, don't average/combine the vectors. + if len(texts) == 1 { + emb = append(emb, curTextEmbeddings[0]) + continue + } + textLengths := make([]int, 0, len(texts)) for _, text := range texts { textLengths = append(textLengths, len(text)) diff --git a/embeddings/openai/openai_test.go b/embeddings/openai/openai_test.go index c43d9c7ef..0331c8c06 100644 --- a/embeddings/openai/openai_test.go +++ b/embeddings/openai/openai_test.go @@ -27,6 +27,30 @@ func TestOpenaiEmbeddings(t *testing.T) { assert.Len(t, embeddings, 3) } +func TestOpenaiEmbeddingsQueryVsDocuments(t *testing.T) { + // Verifies that we get the same embedding for the same string, regardless + // of which method we call. + t.Parallel() + + if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" { + t.Skip("OPENAI_API_KEY not set") + } + e, err := NewOpenAI() + require.NoError(t, err) + + text := "hi there" + + eq, err := e.EmbedQuery(context.Background(), text) + require.NoError(t, err) + + eb, err := e.EmbedDocuments(context.Background(), []string{text}) + require.NoError(t, err) + + // Using strict equality should be OK here because we expect the same values + // for the same string, deterministically. + assert.Equal(t, eq, eb[0]) +} + func TestOpenaiEmbeddingsWithOptions(t *testing.T) { t.Parallel()