diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java b/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java index 4d19c87ba5..45c2a5602f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/Document.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -31,6 +32,7 @@ import org.springframework.ai.document.id.RandomIdGenerator; import org.springframework.ai.model.Media; import org.springframework.ai.model.MediaContent; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -61,7 +63,15 @@ public class Document implements MediaContent { * Metadata for the document. It should not be nested and values should be restricted * to string, int, float, boolean for simple use with Vector Dbs. */ - private Map metadata; + private final Map metadata; + + /** + * Measure of similarity between the document embedding and the query vector. The + * higher the score, the more they are similar. It's the opposite of the distance + * measure. + */ + @Nullable + private Double score; /** * Embedding of the document. Note: ephemeral field. @@ -80,31 +90,61 @@ public Document(@JsonProperty("content") String content) { this(content, new HashMap<>()); } + /** + * @deprecated Use builder instead: {@link Document#builder()}. + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Document(String content, Map metadata) { this(content, metadata, new RandomIdGenerator()); } + /** + * @deprecated Use builder instead: {@link Document#builder()}. + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Document(String content, Collection media, Map metadata) { this(new RandomIdGenerator().generateId(content, metadata), content, media, metadata); } + /** + * @deprecated Use builder instead: {@link Document#builder()}. + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Document(String content, Map metadata, IdGenerator idGenerator) { this(idGenerator.generateId(content, metadata), content, metadata); } + /** + * @deprecated Use builder instead: {@link Document#builder()}. + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Document(String id, String content, Map metadata) { this(id, content, List.of(), metadata); } + /** + * @deprecated Use builder instead: {@link Document#builder()}. + */ + @Deprecated(since = "1.0.0-M5", forRemoval = true) public Document(String id, String content, Collection media, Map metadata) { - Assert.hasText(id, "id must not be null or empty"); - Assert.notNull(content, "content must not be null"); - Assert.notNull(metadata, "metadata must not be null"); + this(id, content, media, metadata, null); + } + + public Document(String id, String content, @Nullable Collection media, + @Nullable Map metadata, @Nullable Double score) { + Assert.hasText(id, "id cannot be null or empty"); + Assert.notNull(content, "content cannot be null"); + Assert.notNull(media, "media cannot be null"); + Assert.noNullElements(media, "media cannot have null elements"); + Assert.notNull(metadata, "metadata cannot be null"); + Assert.noNullElements(metadata.keySet(), "metadata cannot have null keys"); + Assert.noNullElements(metadata.values(), "metadata cannot have null values"); this.id = id; this.content = content; - this.media = media; - this.metadata = metadata; + this.media = media != null ? media : List.of(); + this.metadata = metadata != null ? metadata : new HashMap<>(); + this.score = score; } public static Builder builder() { @@ -149,6 +189,15 @@ public Map getMetadata() { return this.metadata; } + @Nullable + public Double getScore() { + return this.score; + } + + public void setScore(@Nullable Double score) { + this.score = score; + } + /** * Return the embedding that were calculated. * @deprecated We are considering getting rid of this, please comment on @@ -186,57 +235,24 @@ public void setContentFormatter(ContentFormatter contentFormatter) { @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((this.id == null) ? 0 : this.id.hashCode()); - result = prime * result + ((this.metadata == null) ? 0 : this.metadata.hashCode()); - result = prime * result + ((this.content == null) ? 0 : this.content.hashCode()); - return result; + return Objects.hash(id, content, media, metadata); } @Override - public boolean equals(Object obj) { - if (this == obj) { + public boolean equals(Object o) { + if (this == o) return true; - } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { - return false; - } - Document other = (Document) obj; - if (this.id == null) { - if (other.id != null) { - return false; - } - } - else if (!this.id.equals(other.id)) { - return false; - } - if (this.metadata == null) { - if (other.metadata != null) { - return false; - } - } - else if (!this.metadata.equals(other.metadata)) { + if (o == null || getClass() != o.getClass()) return false; - } - if (this.content == null) { - if (other.content != null) { - return false; - } - } - else if (!this.content.equals(other.content)) { - return false; - } - return true; + Document document = (Document) o; + return Objects.equals(id, document.id) && Objects.equals(content, document.content) + && Objects.equals(media, document.media) && Objects.equals(metadata, document.metadata); } @Override public String toString() { - return "Document{" + "id='" + this.id + '\'' + ", metadata=" + this.metadata + ", content='" + this.content - + '\'' + ", media=" + this.media + '}'; + return "Document{" + "id='" + id + '\'' + ", content='" + content + '\'' + ", media=" + media + ", metadata=" + + metadata + ", score=" + score + '}'; } public static class Builder { @@ -249,56 +265,102 @@ public static class Builder { private Map metadata = new HashMap<>(); + private float[] embedding = new float[0]; + + private Double score; + private IdGenerator idGenerator = new RandomIdGenerator(); - public Builder withIdGenerator(IdGenerator idGenerator) { - Assert.notNull(idGenerator, "idGenerator must not be null"); + public Builder idGenerator(IdGenerator idGenerator) { + Assert.notNull(idGenerator, "idGenerator cannot be null"); this.idGenerator = idGenerator; return this; } - public Builder withId(String id) { - Assert.hasText(id, "id must not be null or empty"); + public Builder id(String id) { + Assert.hasText(id, "id cannot be null or empty"); this.id = id; return this; } - public Builder withContent(String content) { - Assert.notNull(content, "content must not be null"); + public Builder content(String content) { this.content = content; return this; } - public Builder withMedia(List media) { - Assert.notNull(media, "media must not be null"); + public Builder media(List media) { this.media = media; return this; } - public Builder withMedia(Media media) { - Assert.notNull(media, "media must not be null"); - this.media.add(media); + public Builder media(Media... media) { + Assert.noNullElements(media, "media cannot contain null elements"); + this.media.addAll(List.of(media)); return this; } - public Builder withMetadata(Map metadata) { - Assert.notNull(metadata, "metadata must not be null"); + public Builder metadata(Map metadata) { this.metadata = metadata; return this; } - public Builder withMetadata(String key, Object value) { - Assert.notNull(key, "key must not be null"); - Assert.notNull(value, "value must not be null"); + public Builder metadata(String key, Object value) { this.metadata.put(key, value); return this; } + public Builder embedding(float[] embedding) { + this.embedding = embedding; + return this; + } + + public Builder score(Double score) { + this.score = score; + return this; + } + + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public Builder withIdGenerator(IdGenerator idGenerator) { + return idGenerator(idGenerator); + } + + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public Builder withId(String id) { + return id(id); + } + + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public Builder withContent(String content) { + return content(content); + } + + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public Builder withMedia(List media) { + return media(media); + } + + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public Builder withMedia(Media media) { + return media(media); + } + + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public Builder withMetadata(Map metadata) { + return metadata(metadata); + } + + @Deprecated(since = "1.0.0-M5", forRemoval = true) + public Builder withMetadata(String key, Object value) { + return metadata(key, value); + } + public Document build() { if (!StringUtils.hasText(this.id)) { this.id = this.idGenerator.generateId(this.content, this.metadata); } - return new Document(this.id, this.content, this.media, this.metadata); + var document = new Document(this.id, this.content, this.media, this.metadata, this.score); + document.setEmbedding(this.embedding); + return document; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentMetadata.java new file mode 100644 index 0000000000..0d2ee89555 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentMetadata.java @@ -0,0 +1,50 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.document; + +import org.springframework.ai.vectorstore.VectorStore; + +/** + * Common set of metadata keys used in {@link Document}s by {@link DocumentReader}s and + * {@link VectorStore}s. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public enum DocumentMetadata { + +// @formatter:off + + /** + * Measure of distance between the document embedding and the query vector. + * The lower the distance, the more they are similar. + * It's the opposite of the similarity score. + */ + DISTANCE("distance"); + + private final String value; + + DocumentMetadata(String value) { + this.value = value; + } + public String value() { + return this.value; + } + +// @formatter:on + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/document/package-info.java new file mode 100644 index 0000000000..fdd9362647 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.document; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java index a009e56f78..fcf72b0e63 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java @@ -22,6 +22,7 @@ import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -30,6 +31,7 @@ * instance and then apply the 'with' methods to alter the default values. * * @author Christian Tzolov + * @author Thomas Vitale */ public final class SearchRequest { @@ -51,6 +53,7 @@ public final class SearchRequest { private double similarityThreshold = SIMILARITY_THRESHOLD_ACCEPT_ALL; + @Nullable private Filter.Expression filterExpression; private SearchRequest(String query) { @@ -186,7 +189,7 @@ public SearchRequest withSimilarityThresholdAll() { * filter criteria. The 'null' value stands for no expression filters. * @return this builder. */ - public SearchRequest withFilterExpression(Filter.Expression expression) { + public SearchRequest withFilterExpression(@Nullable Filter.Expression expression) { this.filterExpression = expression; return this; } @@ -225,7 +228,7 @@ public SearchRequest withFilterExpression(Filter.Expression expression) { * 'null' value stands for no expression filters. * @return this.builder */ - public SearchRequest withFilterExpression(String textExpression) { + public SearchRequest withFilterExpression(@Nullable String textExpression) { this.filterExpression = (textExpression != null) ? new FilterExpressionTextParser().parse(textExpression) : null; return this; @@ -243,6 +246,7 @@ public double getSimilarityThreshold() { return this.similarityThreshold; } + @Nullable public Filter.Expression getFilterExpression() { return this.filterExpression; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java index c941439208..28cbfc2eb5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java @@ -43,6 +43,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; @@ -68,6 +69,7 @@ * @author Christian Tzolov * @author Sebastien Deleuze * @author Ilayaperumal Gopinathan + * @author Thomas Vitale */ public class SimpleVectorStore extends AbstractObservationVectorStore { @@ -127,12 +129,11 @@ public List doSimilaritySearch(SearchRequest request) { float[] userQueryEmbedding = getUserQueryEmbedding(request.getQuery()); return this.store.values() .stream() - .map(entry -> new Similarity(entry, - EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding()))) - .filter(s -> s.score >= request.getSimilarityThreshold()) - .sorted(Comparator.comparingDouble(s -> s.score).reversed()) + .map(content -> content + .toDocument(EmbeddingMath.cosineSimilarity(userQueryEmbedding, content.getEmbedding()))) + .filter(document -> document.getScore() >= request.getSimilarityThreshold()) + .sorted(Comparator.comparing(Document::getScore).reversed()) .limit(request.getTopK()) - .map(s -> s.getDocument()) .toList(); } @@ -235,28 +236,7 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str .withSimilarityMetric(VectorStoreSimilarityMetric.COSINE.value()); } - public static class Similarity { - - private SimpleVectorStoreContent content; - - private double score; - - public Similarity(SimpleVectorStoreContent content, double score) { - this.content = content; - this.score = score; - } - - Document getDocument() { - return Document.builder() - .withId(this.content.getId()) - .withContent(this.content.getContent()) - .withMetadata(this.content.getMetadata()) - .build(); - } - - } - - public final class EmbeddingMath { + public static final class EmbeddingMath { private EmbeddingMath() { throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStoreContent.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStoreContent.java index c54a96d003..41aa8e569c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStoreContent.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStoreContent.java @@ -25,6 +25,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.document.id.IdGenerator; import org.springframework.ai.document.id.RandomIdGenerator; import org.springframework.ai.model.Content; @@ -135,6 +137,12 @@ public float[] getEmbedding() { return Arrays.copyOf(this.embedding, this.embedding.length); } + public Document toDocument(Double score) { + var metadata = new HashMap<>(this.metadata); + metadata.put(DocumentMetadata.DISTANCE.value(), 1.0 - score); + return Document.builder().id(this.id).content(this.content).metadata(metadata).score(score).build(); + } + @Override public boolean equals(Object o) { if (this == o) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/package-info.java new file mode 100644 index 0000000000..3edee23fc8 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.vectorstore; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java index ebaeef3890..2408aebf8b 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/document/DocumentBuilderTests.java @@ -42,13 +42,11 @@ private static List getMediaList() { URL mediaUrl2 = new URL("http://type2"); Media media1 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl1); Media media2 = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl2); - List mediaList = List.of(media1, media2); - return mediaList; + return List.of(media1, media2); } catch (MalformedURLException e) { throw new RuntimeException(e); } - } @BeforeEach @@ -58,32 +56,26 @@ void setUp() { @Test void testWithIdGenerator() { - IdGenerator mockGenerator = new IdGenerator() { - - @Override - public String generateId(Object... contents) { - return "mockedId"; - } - }; + IdGenerator mockGenerator = contents -> "mockedId"; - Document.Builder result = this.builder.withIdGenerator(mockGenerator); + Document.Builder result = this.builder.idGenerator(mockGenerator); assertThat(result).isSameAs(this.builder); - Document document = result.withContent("Test content").withMetadata("key", "value").build(); + Document document = result.content("Test content").metadata("key", "value").build(); assertThat(document.getId()).isEqualTo("mockedId"); } @Test void testWithIdGeneratorNull() { - assertThatThrownBy(() -> this.builder.withIdGenerator(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("idGenerator must not be null"); + assertThatThrownBy(() -> this.builder.idGenerator(null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("idGenerator cannot be null"); } @Test void testWithId() { - Document.Builder result = this.builder.withId("testId"); + Document.Builder result = this.builder.id("testId"); assertThat(result).isSameAs(this.builder); assertThat(result.build().getId()).isEqualTo("testId"); @@ -91,16 +83,16 @@ void testWithId() { @Test void testWithIdNullOrEmpty() { - assertThatThrownBy(() -> this.builder.withId(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("id must not be null or empty"); + assertThatThrownBy(() -> this.builder.id(null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("id cannot be null or empty"); - assertThatThrownBy(() -> this.builder.withId("")).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("id must not be null or empty"); + assertThatThrownBy(() -> this.builder.id("").build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("id cannot be null or empty"); } @Test void testWithContent() { - Document.Builder result = this.builder.withContent("Test content"); + Document.Builder result = this.builder.content("Test content"); assertThat(result).isSameAs(this.builder); assertThat(result.build().getContent()).isEqualTo("Test content"); @@ -108,14 +100,14 @@ void testWithContent() { @Test void testWithContentNull() { - assertThatThrownBy(() -> this.builder.withContent(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("content must not be null"); + assertThatThrownBy(() -> this.builder.content(null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("content cannot be null"); } @Test void testWithMediaList() { List mediaList = getMediaList(); - Document.Builder result = this.builder.withMedia(mediaList); + Document.Builder result = this.builder.media(mediaList); assertThat(result).isSameAs(this.builder); assertThat(result.build().getMedia()).isEqualTo(mediaList); @@ -123,9 +115,9 @@ void testWithMediaList() { @Test void testWithMediaListNull() { - assertThatThrownBy(() -> this.builder.withMedia((List) null)) + assertThatThrownBy(() -> this.builder.media((List) null).build()) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("media must not be null"); + .hasMessageContaining("media cannot be null"); } @Test @@ -133,7 +125,7 @@ void testWithMediaSingle() throws MalformedURLException { URL mediaUrl = new URL("http://test"); Media media = new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl); - Document.Builder result = this.builder.withMedia(media); + Document.Builder result = this.builder.media(media); assertThat(result).isSameAs(this.builder); assertThat(result.build().getMedia()).contains(media); @@ -141,8 +133,8 @@ void testWithMediaSingle() throws MalformedURLException { @Test void testWithMediaSingleNull() { - assertThatThrownBy(() -> this.builder.withMedia((Media) null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("media must not be null"); + assertThatThrownBy(() -> this.builder.media((Media) null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("media cannot contain null elements"); } @Test @@ -150,7 +142,7 @@ void testWithMetadataMap() { Map metadata = new HashMap<>(); metadata.put("key1", "value1"); metadata.put("key2", 2); - Document.Builder result = this.builder.withMetadata(metadata); + Document.Builder result = this.builder.metadata(metadata); assertThat(result).isSameAs(this.builder); assertThat(result.build().getMetadata()).isEqualTo(metadata); @@ -158,47 +150,51 @@ void testWithMetadataMap() { @Test void testWithMetadataMapNull() { - assertThatThrownBy(() -> this.builder.withMetadata((Map) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("metadata must not be null"); + assertThatThrownBy(() -> this.builder.metadata(null).build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("metadata cannot be null"); } @Test void testWithMetadataKeyValue() { - Document.Builder result = this.builder.withMetadata("key", "value"); + Document.Builder result = this.builder.metadata("key", "value"); assertThat(result).isSameAs(this.builder); assertThat(result.build().getMetadata()).containsEntry("key", "value"); } @Test - void testWithMetadataKeyValueNull() { - assertThatThrownBy(() -> this.builder.withMetadata(null, "value")).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("key must not be null"); + void testWithMetadataKeyNull() { + assertThatThrownBy(() -> this.builder.metadata(null, "value").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("metadata cannot have null keys"); + } - assertThatThrownBy(() -> this.builder.withMetadata("key", null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("value must not be null"); + @Test + void testWithMetadataValueNull() { + assertThatThrownBy(() -> this.builder.metadata("key", null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("metadata cannot have null values"); } @Test void testBuildWithoutId() { - Document document = this.builder.withContent("Test content").build(); + Document document = this.builder.content("Test content").build(); assertThat(document.getId()).isNotNull().isNotEmpty(); assertThat(document.getContent()).isEqualTo("Test content"); } @Test - void testBuildWithAllProperties() throws MalformedURLException { + void testBuildWithAllProperties() { List mediaList = getMediaList(); Map metadata = new HashMap<>(); metadata.put("key", "value"); - Document document = this.builder.withId("customId") - .withContent("Test content") - .withMedia(mediaList) - .withMetadata(metadata) + Document document = this.builder.id("customId") + .content("Test content") + .media(mediaList) + .metadata(metadata) .build(); assertThat(document.getId()).isEqualTo("customId"); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreSimilarityTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreSimilarityTests.java index 3447e5f7cc..499d0269af 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreSimilarityTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreSimilarityTests.java @@ -27,6 +27,7 @@ /** * @author Ilayaperumal Gopinathan + * @author Thomas Vitale */ public class SimpleVectorStoreSimilarityTests { @@ -38,8 +39,7 @@ public void testSimilarity() { SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent("1", "hello, how are you?", metadata, testEmbedding); - SimpleVectorStore.Similarity similarity = new SimpleVectorStore.Similarity(storeContent, 0.6d); - Document document = similarity.getDocument(); + Document document = storeContent.toDocument(0.6); assertThat(document).isNotNull(); assertThat(document.getId()).isEqualTo("1"); assertThat(document.getContent()).isEqualTo("hello, how are you?"); diff --git a/spring-ai-integration-tests/pom.xml b/spring-ai-integration-tests/pom.xml index 3ca4c4ace4..b3cbb4cde7 100644 --- a/spring-ai-integration-tests/pom.xml +++ b/spring-ai-integration-tests/pom.xml @@ -54,7 +54,6 @@ test - org.springframework.ai spring-ai-openai-spring-boot-starter @@ -76,6 +75,13 @@ test + + org.springframework.ai + spring-ai-test + ${project.parent.version} + test + + org.springframework.boot spring-boot-testcontainers diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestApplication.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestApplication.java index 7b3f01292b..df17af1fd4 100644 --- a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestApplication.java +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestApplication.java @@ -16,7 +16,10 @@ package org.springframework.ai.integration.tests; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.SimpleVectorStore; import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; /** @@ -28,4 +31,9 @@ @Import(TestcontainersConfiguration.class) public class TestApplication { + @Bean + SimpleVectorStore simpleVectorStore(EmbeddingModel embeddingModel) { + return new SimpleVectorStore(embeddingModel); + } + } diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/vectorstore/SimpleVectorStoreIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/vectorstore/SimpleVectorStoreIT.java new file mode 100644 index 0000000000..c6bb774db2 --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/vectorstore/SimpleVectorStoreIT.java @@ -0,0 +1,122 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.integration.tests.vectorstore; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; +import org.springframework.ai.integration.tests.TestApplication; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.SimpleVectorStore; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.io.DefaultResourceLoader; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link SimpleVectorStore}. + * + * @author Thomas Vitale + */ +@SpringBootTest(classes = TestApplication.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +public class SimpleVectorStoreIT { + + @Autowired + private SimpleVectorStore vectorStore; + + List documents = List.of( + Document.builder() + .id("471a8c78-549a-4b2c-bce5-ef3ae6579be3") + .content(getText("classpath:/test/data/spring.ai.txt")) + .metadata(Map.of("meta1", "meta1")) + .build(), + Document.builder() + .id("bc51d7f7-627b-4ba6-adf4-f0bcd1998f8f") + .content(getText("classpath:/test/data/time.shelter.txt")) + .metadata(Map.of()) + .build(), + Document.builder() + .id("d0237682-1150-44ff-b4d2-1be9b1731ee5") + .content(getText("classpath:/test/data/great.depression.txt")) + .metadata(Map.of("meta2", "meta2")) + .build()); + + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @AfterEach + void setUp() { + vectorStore.delete(this.documents.stream().map(Document::getId).toList()); + } + + @Test + public void searchWithThreshold() { + Document document = Document.builder() + .id(UUID.randomUUID().toString()) + .content("Spring AI rocks!!") + .metadata("meta1", "meta1") + .build(); + + vectorStore.add(List.of(document)); + + List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); + assertThat(resultDoc.getMetadata()).containsKey("meta1"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + + Document sameIdDocument = Document.builder() + .id(document.getId()) + .content("The World is Big and Salvation Lurks Around the Corner") + .metadata("meta2", "meta2") + .build(); + + vectorStore.add(List.of(sameIdDocument)); + + results = vectorStore.similaritySearch(SearchRequest.query("FooBar").withTopK(5)); + + assertThat(results).hasSize(1); + resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); + assertThat(resultDoc.getMetadata()).containsKey("meta2"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + + vectorStore.delete(List.of(document.getId())); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java index a40804a650..bf9f2e9cfd 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreProperties.java @@ -18,6 +18,7 @@ import java.time.Duration; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.vectorstore.PineconeVectorStore; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -25,6 +26,7 @@ * Configuration properties for Pinecone Vector Store. * * @author Christian Tzolov + * @author Thomas Vitale */ @ConfigurationProperties(PineconeVectorStoreProperties.CONFIG_PREFIX) public class PineconeVectorStoreProperties { @@ -43,7 +45,7 @@ public class PineconeVectorStoreProperties { private String contentFieldName = PineconeVectorStore.CONTENT_FIELD_NAME; - private String distanceMetadataFieldName = PineconeVectorStore.DISTANCE_METADATA_FIELD_NAME; + private String distanceMetadataFieldName = DocumentMetadata.DISTANCE.value(); private Duration serverSideTimeout = Duration.ofSeconds(20); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java index ce450bac02..3083442f87 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStorePropertiesTests.java @@ -20,12 +20,14 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.vectorstore.PineconeVectorStore; import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov + * @author Thomas Vitale */ public class PineconeVectorStorePropertiesTests { @@ -39,7 +41,7 @@ public void defaultValues() { assertThat(props.getIndexName()).isNull(); assertThat(props.getServerSideTimeout()).isEqualTo(Duration.ofSeconds(20)); assertThat(props.getContentFieldName()).isEqualTo(PineconeVectorStore.CONTENT_FIELD_NAME); - assertThat(props.getDistanceMetadataFieldName()).isEqualTo(PineconeVectorStore.DISTANCE_METADATA_FIELD_NAME); + assertThat(props.getDistanceMetadataFieldName()).isEqualTo(DocumentMetadata.DISTANCE.value()); } @Test diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java index c0ad68e360..11dc700673 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java @@ -74,6 +74,7 @@ * * @author Theo van Kraay * @author Soby Chacko + * @author Thomas Vitale * @since 1.0.0 */ public class CosmosDBVectorStore extends AbstractObservationVectorStore implements AutoCloseable { @@ -338,7 +339,7 @@ public List doSimilaritySearch(SearchRequest request) { .block(); // Convert JsonNode to Document List docs = documents.stream() - .map(doc -> new Document(doc.get("id").asText(), doc.get("content").asText(), new HashMap<>())) + .map(doc -> Document.builder().id(doc.get("id").asText()).content(doc.get("content").asText()).build()) .collect(Collectors.toList()); return docs != null ? docs : List.of(); diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java index 97ac8081c7..fd8b044c0f 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java @@ -42,6 +42,7 @@ /** * @author Theo van Kraay + * @author Thomas Vitale * @since 1.0.0 */ @EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+") diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDbImage.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDbImage.java new file mode 100644 index 0000000000..6bbbdc71cd --- /dev/null +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDbImage.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import org.testcontainers.utility.DockerImageName; + +/** + * @author Thomas Vitale + */ +public final class CosmosDbImage { + + // It must always be "latest" or else Azure locks the image after a while. See: + // https://github.com/Azure/azure-cosmos-db-emulator-docker/issues/60 + public static final DockerImageName DEFAULT_IMAGE = DockerImageName + .parse("mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:latest"); + + private CosmosDbImage() { + + } + +} diff --git a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java index 6200738b92..f1893bd20f 100644 --- a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java +++ b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java @@ -47,6 +47,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -96,8 +97,6 @@ public class AzureVectorStore extends AbstractObservationVectorStore implements private static final String METADATA_FIELD_NAME = "metadata"; - private static final String DISTANCE_METADATA_FIELD_NAME = "distance"; - private static final int DEFAULT_TOP_K = 4; private static final Double DEFAULT_SIMILARITY_THRESHOLD = 0.0; @@ -321,13 +320,15 @@ public List doSimilaritySearch(SearchRequest request) { }) : Map.of(); - metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - (float) result.getScore()); - - final Document doc = new Document(entry.id(), entry.content(), metadata); - doc.setEmbedding(EmbeddingUtils.toPrimitive(entry.embedding())); - - return doc; + metadata.put(DocumentMetadata.DISTANCE.value(), 1.0 - result.getScore()); + return Document.builder() + .id(entry.id()) + .content(entry.content) + .metadata(metadata) + .score(result.getScore()) + .embedding(EmbeddingUtils.toPrimitive(entry.embedding)) + .build(); }) .collect(Collectors.toList()); } diff --git a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java index dc87de60fe..50911ad10c 100644 --- a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java +++ b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java @@ -35,6 +35,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; import org.springframework.ai.vectorstore.SearchRequest; @@ -52,6 +53,7 @@ /** * @author Christian Tzolov + * @author Thomas Vitale */ @EnabledIfEnvironmentVariable(named = "AZURE_AI_SEARCH_API_KEY", matches = ".+") @EnabledIfEnvironmentVariable(named = "AZURE_AI_SEARCH_ENDPOINT", matches = ".+") @@ -103,7 +105,7 @@ public void addAndSearchTest() { assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -224,7 +226,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -245,7 +247,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(List.of(document.getId())); @@ -271,21 +273,22 @@ public void searchThresholdTest() { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java index 3ad7f5a916..af105efc03 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java @@ -45,6 +45,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -106,8 +107,6 @@ */ public class CassandraVectorStore extends AbstractObservationVectorStore implements AutoCloseable { - public static final String SIMILARITY_FIELD_NAME = "similarity_score"; - public static final String DRIVER_PROFILE_UPDATES = "spring-ai-updates"; public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search"; @@ -252,14 +251,19 @@ public List doSimilaritySearch(SearchRequest request) { break; } Map docFields = new HashMap<>(); - docFields.put(SIMILARITY_FIELD_NAME, score); + docFields.put(DocumentMetadata.DISTANCE.value(), 1 - score); for (var metadata : this.conf.schema.metadataColumns()) { var value = row.get(metadata.name(), metadata.javaType()); if (null != value) { docFields.put(metadata.name(), value); } } - Document doc = new Document(getDocumentId(row), row.getString(this.conf.schema.content()), docFields); + Document doc = Document.builder() + .id(getDocumentId(row)) + .content(row.getString(this.conf.schema.content())) + .metadata(docFields) + .score((double) score) + .build(); if (this.conf.returnEmbeddings) { doc.setEmbedding(EmbeddingUtils diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java index e3a3ee65fe..4a31c8dbe1 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java @@ -37,6 +37,7 @@ import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.containers.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -226,8 +227,7 @@ void addAndSearch() { assertThat(resultDoc.getMetadata()).hasSize(3); - assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", - CassandraVectorStore.SIMILARITY_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", DocumentMetadata.DISTANCE.value()); // Remove all documents from the createStore store.delete(documents.stream().map(doc -> doc.getId()).toList()); @@ -494,8 +494,7 @@ void documentUpdate() { assertThat(resultDoc.getId()).isNotEqualTo(sameIdDocument.getId()); assertThat(resultDoc.getContent()).doesNotContain(newContent); - assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", - CassandraVectorStore.SIMILARITY_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", DocumentMetadata.DISTANCE.value()); } }); } @@ -509,16 +508,15 @@ void searchWithThreshold() { List fullResult = store .similaritySearch(SearchRequest.query(URANUS_ORBIT_QUERY).withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream() - .map(doc -> (Float) doc.getMetadata().get(CassandraVectorStore.SIMILARITY_FIELD_NAME)) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = store.similaritySearch( - SearchRequest.query(URANUS_ORBIT_QUERY).withTopK(5).withSimilarityThreshold(threshold)); + List results = store.similaritySearch(SearchRequest.query(URANUS_ORBIT_QUERY) + .withTopK(5) + .withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -526,8 +524,8 @@ void searchWithThreshold() { assertThat(resultDoc.getContent()).contains(URANUS_ORBIT_QUERY); - assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", - CassandraVectorStore.SIMILARITY_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKeys("id", "revision", DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); } }); } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java index e17091d5ea..13f7e7d726 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java @@ -30,6 +30,7 @@ import com.datastax.oss.driver.api.core.type.DataTypes; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.containers.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -138,7 +139,7 @@ void addAndSearch() { "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", CassandraVectorStore.SIMILARITY_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store store.delete(documents().stream().map(doc -> doc.getId()).toList()); @@ -174,7 +175,7 @@ void addAndSearchReturnEmbeddings() { "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(1); - assertThat(resultDoc.getMetadata()).containsKey(CassandraVectorStore.SIMILARITY_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store store.delete(documents().stream().map(doc -> doc.getId()).toList()); @@ -359,7 +360,7 @@ void documentUpdate() { resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", CassandraVectorStore.SIMILARITY_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); store.delete(List.of(document.getId())); } @@ -375,16 +376,14 @@ void searchWithThreshold() { List fullResult = store .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream() - .map(doc -> (Float) doc.getMetadata().get(CassandraVectorStore.SIMILARITY_FIELD_NAME)) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = store - .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(threshold)); + List results = store.similaritySearch( + SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -393,7 +392,8 @@ void searchWithThreshold() { assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", CassandraVectorStore.SIMILARITY_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); } }); } diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java index 49deb147d6..eafe1e37ad 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java @@ -32,6 +32,7 @@ import org.springframework.ai.chroma.ChromaApi.DeleteEmbeddingsRequest; import org.springframework.ai.chroma.ChromaApi.Embedding; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -58,11 +59,10 @@ * @author Fu Cheng * @author Sebastien Deleuze * @author Soby Chacko + * @author Thomas Vitale */ public class ChromaVectorStore extends AbstractObservationVectorStore implements InitializingBean { - public static final String DISTANCE_FIELD_NAME = "distance"; - public static final String DEFAULT_COLLECTION_NAME = "SpringAiCollection"; private final EmbeddingModel embeddingModel; @@ -192,9 +192,14 @@ public Optional doDelete(@NonNull List idList) { if (metadata == null) { metadata = new HashMap<>(); } - metadata.put(DISTANCE_FIELD_NAME, distance); - Document document = new Document(id, content, metadata); - document.setEmbedding(chromaEmbedding.embedding()); + metadata.put(DocumentMetadata.DISTANCE.value(), distance); + Document document = Document.builder() + .id(id) + .content(content) + .metadata(metadata) + .embedding(chromaEmbedding.embedding()) + .score(1.0 - distance) + .build(); responseDocuments.add(document); } } @@ -244,8 +249,7 @@ public void afterPropertiesSet() throws Exception { @NonNull String operationName) { return VectorStoreObservationContext.builder(VectorStoreProvider.CHROMA.value(), operationName) .withDimensions(this.embeddingModel.dimensions()) - .withCollectionName(this.collectionName + ":" + this.collectionId) - .withFieldName(this.initializeSchema ? DISTANCE_FIELD_NAME : null); + .withCollectionName(this.collectionName + ":" + this.collectionId); } public static class Builder { diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java index a2f6f00009..c308202848 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java @@ -24,6 +24,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.chromadb.ChromaDBContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -81,7 +82,7 @@ public void addAndSearch() { assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store assertThat(vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList())) @@ -179,7 +180,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -194,7 +195,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(List.of(document.getId())); @@ -213,13 +214,13 @@ public void searchThresholdTest() { var request = SearchRequest.query("Great").withTopK(5); List fullResult = vectorStore.similaritySearch(request.withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore.similaritySearch(request.withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch(request.withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -227,7 +228,8 @@ public void searchThresholdTest() { assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java index cd9e973b54..72fc9238e0 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreObservationIT.java @@ -107,7 +107,6 @@ void observationVectorStoreAddAndQueryOperations() { .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_COLLECTION_NAME.asString(), "TestCollection:" + vectorStore.getCollectionId()) .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_NAMESPACE.asString()) - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_FIELD_NAME.asString(), "distance") .doesNotHaveHighCardinalityKeyValueWithKey( HighCardinalityKeyNames.DB_SEARCH_SIMILARITY_METRIC.asString()) .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_VECTOR_QUERY_TOP_K.asString()) @@ -141,7 +140,6 @@ void observationVectorStoreAddAndQueryOperations() { .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_COLLECTION_NAME.asString(), "TestCollection:" + vectorStore.getCollectionId()) .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_NAMESPACE.asString()) - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_FIELD_NAME.asString(), "distance") .doesNotHaveHighCardinalityKeyValueWithKey( HighCardinalityKeyNames.DB_SEARCH_SIMILARITY_METRIC.asString()) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_QUERY_TOP_K.asString(), "1") diff --git a/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java b/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java index e9c64e7946..29ee2bfdff 100644 --- a/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java +++ b/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/CoherenceVectorStore.java @@ -37,6 +37,7 @@ import com.tangosol.util.Filter; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.beans.factory.InitializingBean; @@ -62,6 +63,7 @@ * * * @author Aleks Seovic + * @author Thomas Vitale * @since 1.0.0 */ public class CoherenceVectorStore implements VectorStore, InitializingBean { @@ -211,8 +213,13 @@ public List similaritySearch(SearchRequest request) { if (this.distanceType != DistanceType.COSINE || (1 - r.getDistance()) >= request.getSimilarityThreshold()) { DocumentChunk.Id id = r.getKey(); DocumentChunk chunk = r.getValue(); - chunk.metadata().put("distance", r.getDistance()); - documents.add(new Document(id.docId(), chunk.text(), chunk.metadata())); + chunk.metadata().put(DocumentMetadata.DISTANCE.value(), r.getDistance()); + documents.add(Document.builder() + .id(id.docId()) + .content(chunk.text()) + .metadata(chunk.metadata()) + .score(1 - r.getDistance()) + .build()); } } return documents; diff --git a/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/CoherenceVectorStoreIT.java b/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/CoherenceVectorStoreIT.java index 8b1a0c52bf..776d449b51 100644 --- a/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/CoherenceVectorStoreIT.java +++ b/vector-stores/spring-ai-coherence-store/src/test/java/org/springframework/ai/vectorstore/CoherenceVectorStoreIT.java @@ -47,6 +47,7 @@ import org.junit.jupiter.params.provider.MethodSource; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser; @@ -125,7 +126,7 @@ public void addAndSearch(CoherenceVectorStore.DistanceType distanceType, Coheren assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -223,7 +224,7 @@ public void documentUpdate() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -236,7 +237,7 @@ public void documentUpdate() { resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); truncateMap(context, ((CoherenceVectorStore) vectorStore).getMapName()); }); @@ -257,18 +258,18 @@ public void searchWithThreshold() { assertThat(isSortedByDistance(fullResult)).isTrue(); - List distances = fullResult.stream() - .map(doc -> (Double) doc.getMetadata().get("distance")) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - double threshold = 1d - (distances.get(0) + distances.get(1)) / 2f; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId()); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); truncateMap(context, ((CoherenceVectorStore) vectorStore).getMapName()); }); @@ -276,7 +277,7 @@ public void searchWithThreshold() { private static boolean isSortedByDistance(final List documents) { final List distances = documents.stream() - .map(doc -> (Double) doc.getMetadata().get("distance")) + .map(doc -> (Double) doc.getMetadata().get(DocumentMetadata.DISTANCE.value())) .toList(); if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index 9c058936b4..895c1f924c 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -40,6 +40,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -210,20 +211,23 @@ private String getElasticsearchQueryString(Filter.Expression filterExpression) { private Document toDocument(Hit hit) { Document document = hit.source(); - document.getMetadata().put("distance", calculateDistance(hit.score().floatValue())); + if (hit.score() != null) { + document.getMetadata().put(DocumentMetadata.DISTANCE.value(), 1 - normalizeSimilarityScore(hit.score())); + document.setScore(normalizeSimilarityScore(hit.score())); + } return document; } // more info on score/distance calculation // https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#knn-similarity-search - private float calculateDistance(Float score) { + private double normalizeSimilarityScore(double score) { switch (this.options.getSimilarity()) { case l2_norm: // the returned value of l2_norm is the opposite of the other functions // (closest to zero means more accurate), so to make it consistent // with the other functions the reverse is returned applying a "1-" // to the standard transformation - return (float) (1 - (java.lang.Math.sqrt((1 / score) - 1))); + return (1 - (java.lang.Math.sqrt((1 / score) - 1))); // cosine and dot_product default: return (2 * score) - 1; diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java index c262c9a4af..2372a5d9c3 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreIT.java @@ -42,6 +42,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.elasticsearch.ElasticsearchContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -165,7 +166,7 @@ public void addAndSearchTest(String similarityFunction) { assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(Document::getId).toList()); @@ -299,7 +300,7 @@ public void documentUpdateTest(String similarityFunction) { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", Map.of("meta2", "meta2")); @@ -318,7 +319,7 @@ public void documentUpdateTest(String similarityFunction) { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(List.of(document.getId())); @@ -343,21 +344,22 @@ public void searchThresholdTest(String similarityFunction) { List fullResult = vectorStore.similaritySearch(query); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float thresholdResult = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; List results = vectorStore.similaritySearch( - SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(thresholdResult)); + SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(Document::getId).toList()); diff --git a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java index 753eb1e7bd..24e47aedf3 100644 --- a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java +++ b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java @@ -30,6 +30,7 @@ import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.document.DocumentMetadata; import reactor.util.annotation.NonNull; import org.springframework.ai.document.Document; @@ -172,8 +173,6 @@ public String[] getFields() { // Query Defaults private static final String QUERY = "/query"; - private static final String DISTANCE_METADATA_FIELD_NAME = "distance"; - /** * Initializes the GemFireVectorStore after properties are set. This method is called * after all bean properties have been set and allows the bean to perform any @@ -271,9 +270,9 @@ public List doSimilaritySearch(SearchRequest request) { metadata = new HashMap<>(); metadata.put(DOCUMENT_FIELD, "--Deleted--"); } - metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - r.score); + metadata.put(DocumentMetadata.DISTANCE.value(), 1 - r.score); String content = (String) metadata.remove(DOCUMENT_FIELD); - return new Document(r.key, content, metadata); + return Document.builder().id(r.key).content(content).metadata(metadata).score((double) r.score).build(); }) .collectList() .onErrorMap(WebClientException.class, this::handleHttpClientException) diff --git a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java index fd93afe058..1430520fe2 100644 --- a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java +++ b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java @@ -34,6 +34,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; import org.springframework.boot.SpringBootConfiguration; @@ -134,7 +135,7 @@ public void addAndSearchTest() { assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939)" + " was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); }); } @@ -156,7 +157,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks " + "Around the Corner", @@ -171,7 +172,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation" + " Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); }); } @@ -191,12 +192,12 @@ public void searchThresholdTest() { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); - assertThat(distances).hasSize(3); + List scores = fullResult.stream().map(Document::getScore).toList(); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(1 - threshold)); + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; + List results = vectorStore.similaritySearch( + SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); @@ -204,7 +205,8 @@ public void searchThresholdTest() { assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression " + "(1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); }); } diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java index 9ed8d8b015..18c5f26ea0 100644 --- a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java @@ -54,6 +54,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -97,7 +98,7 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements public static final String EMBEDDING_FIELD_NAME = "embedding"; // Metadata, automatically assigned by Milvus. - public static final String DISTANCE_FIELD_NAME = "distance"; + private static final String DISTANCE_FIELD_NAME = "distance"; private static final Logger logger = LoggerFactory.getLogger(MilvusVectorStore.class); @@ -258,13 +259,18 @@ public List doSimilaritySearch(SearchRequest request) { try { metadata = (JSONObject) rowRecord.get(this.config.metadataFieldName); // inject the distance into the metadata. - metadata.put(DISTANCE_FIELD_NAME, 1 - getResultSimilarity(rowRecord)); + metadata.put(DocumentMetadata.DISTANCE.value(), 1 - getResultSimilarity(rowRecord)); } catch (ParamException e) { // skip the ParamException if metadata doesn't exist for the custom // collection } - return new Document(docId, content, (metadata != null) ? metadata.getInnerMap() : Map.of()); + return Document.builder() + .id(docId) + .content(content) + .metadata((metadata != null) ? metadata.getInnerMap() : Map.of()) + .score((double) getResultSimilarity(rowRecord)) + .build(); }) .toList(); } diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java index 98a88e7b7d..757b5c1b15 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java @@ -30,6 +30,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.milvus.MilvusContainer; @@ -106,7 +107,7 @@ public void addAndSearch(String metricType) { assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -200,7 +201,7 @@ public void documentUpdate(String metricType) { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -215,7 +216,7 @@ public void documentUpdate(String metricType) { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); vectorStore.delete(List.of(document.getId())); @@ -238,24 +239,22 @@ public void searchWithThreshold(String metricType) { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream() - .map(doc -> (Float) doc.getMetadata().get("distance")) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); - + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); }); } diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java index 2e40db851c..3ab60f4cc6 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java @@ -26,6 +26,7 @@ import io.micrometer.observation.ObservationRegistry; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -172,12 +173,17 @@ private org.bson.Document createSearchIndexDefinition() { private Document mapMongoDocument(org.bson.Document mongoDocument, float[] queryEmbedding) { String id = mongoDocument.getString(ID_FIELD_NAME); String content = mongoDocument.getString(CONTENT_FIELD_NAME); + double score = mongoDocument.getDouble(SCORE_FIELD_NAME); Map metadata = mongoDocument.get(METADATA_FIELD_NAME, org.bson.Document.class); - - Document document = new Document(id, content, metadata); - document.setEmbedding(queryEmbedding); - - return document; + metadata.put(DocumentMetadata.DISTANCE.value(), 1 - score); + + return Document.builder() + .id(id) + .content(content) + .metadata(metadata) + .score(score) + .embedding(queryEmbedding) + .build(); } @Override diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java index d908e84364..4eced845ed 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStoreIT.java @@ -16,6 +16,8 @@ package org.springframework.ai.vectorstore; +import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -27,6 +29,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.document.DocumentMetadata; +import org.springframework.core.io.DefaultResourceLoader; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; @@ -198,6 +202,55 @@ void searchWithFilters() { }); } + @Test + public void searchWithThreshold() { + this.contextRunner.run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + + var documents = List.of( + new Document("471a8c78-549a-4b2c-bce5-ef3ae6579be3", getText("classpath:/test/data/spring.ai.txt"), + Map.of("meta1", "meta1")), + new Document("bc51d7f7-627b-4ba6-adf4-f0bcd1998f8f", + getText("classpath:/test/data/time.shelter.txt"), Map.of()), + new Document("d0237682-1150-44ff-b4d2-1be9b1731ee5", + getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + vectorStore.add(documents); + Thread.sleep(5000); // Await a second for the document to be indexed + + List fullResult = vectorStore + .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); + assertThat(fullResult).hasSize(3); + + List scores = fullResult.stream().map(Document::getScore).toList(); + + assertThat(scores).hasSize(3); + + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; + + List results = vectorStore.similaritySearch( + SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getContent()).contains( + "Spring AI provides abstractions that serve as the foundation for developing AI applications."); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); + + }); + } + + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + @SpringBootConfiguration @EnableAutoConfiguration public static class TestApplication { diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java index 0493d893ca..97a74a175a 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java @@ -29,6 +29,7 @@ import org.neo4j.driver.Values; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -224,15 +225,19 @@ private Document recordToDocument(org.neo4j.driver.Record neoRecord) { var node = neoRecord.get("node").asNode(); var score = neoRecord.get("score").asFloat(); var metaData = new HashMap(); - metaData.put("distance", 1 - score); + metaData.put(DocumentMetadata.DISTANCE.value(), 1 - score); node.keys().forEach(key -> { if (key.startsWith("metadata.")) { metaData.put(key.substring(key.indexOf(".") + 1), node.get(key).asObject()); } }); - return new Document(node.get(this.config.idProperty).asString(), node.get("text").asString(), - Map.copyOf(metaData)); + return Document.builder() + .id(node.get(this.config.idProperty).asString()) + .content(node.get("text").asString()) + .metadata(Map.copyOf(metaData)) + .score((double) score) + .build(); } @Override diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java index a707a30622..b5bc62ef9f 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java @@ -28,6 +28,7 @@ import org.neo4j.driver.AuthTokens; import org.neo4j.driver.Driver; import org.neo4j.driver.GraphDatabase; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.containers.Neo4jContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -91,7 +92,7 @@ void addAndSearchTest() { assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(Document::getId).toList()); @@ -203,7 +204,7 @@ void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -218,7 +219,7 @@ void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); }); } @@ -235,14 +236,14 @@ void searchThresholdTest() { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Great").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Great").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Great").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -250,8 +251,8 @@ void searchThresholdTest() { assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); - + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); }); } diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java index 653a78df78..6199040f85 100644 --- a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java @@ -40,6 +40,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -231,7 +232,10 @@ private List similaritySearch(org.opensearch.client.opensearch.core.Se private Document toDocument(Hit hit) { Document document = hit.source(); - document.getMetadata().put("distance", 1 - hit.score().floatValue()); + if (hit.score() != null) { + document.setScore(hit.score()); + document.getMetadata().put(DocumentMetadata.DISTANCE.value(), 1 - hit.score().floatValue()); + } return document; } diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java index 6652ecd788..5a4a5b20ac 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java @@ -39,6 +39,7 @@ import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; import org.opensearch.testcontainers.OpensearchContainer; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -144,7 +145,7 @@ public void addAndSearchTest(String similarityFunction) { assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(Document::getId).toList()); @@ -281,7 +282,7 @@ public void documentUpdateTest(String similarityFunction) { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", Map.of("meta2", "meta2")); @@ -300,7 +301,7 @@ public void documentUpdateTest(String similarityFunction) { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(List.of(document.getId())); @@ -330,21 +331,22 @@ public void searchThresholdTest(String similarityFunction) { List fullResult = vectorStore.similaritySearch(query); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; List results = vectorStore.similaritySearch( - SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(1 - threshold)); + SearchRequest.query("Great Depression").withTopK(50).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(Document::getId).toList()); diff --git a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java index a9e3b63eb7..f519024dcc 100644 --- a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java +++ b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java @@ -39,6 +39,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -80,6 +81,7 @@ * @author Loïc Lefèvre * @author Christian Tzolov * @author Soby Chacko + * @author Thomas Vitale */ public class OracleVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -649,12 +651,16 @@ private static class DocumentRowMapper implements RowMapper { @Override public Document mapRow(ResultSet rs, int rowNum) throws SQLException { final Map metadata = getMap(rs.getObject(3, OracleJsonValue.class)); - metadata.put("distance", rs.getDouble(5)); + metadata.put(DocumentMetadata.DISTANCE.value(), rs.getDouble(5)); - final Document document = new Document(rs.getString(1), rs.getString(2), metadata); final float[] embedding = rs.getObject(4, float[].class); - document.setEmbedding(embedding); - return document; + return Document.builder() + .id(rs.getString(1)) + .content(rs.getString(2)) + .metadata(metadata) + .score(1 - rs.getDouble(5)) + .embedding(embedding) + .build(); } private Map getMap(OracleJsonValue value) { diff --git a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java index 35727822d1..2c3fcee85c 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java +++ b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java @@ -32,6 +32,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.oracle.OracleContainer; @@ -94,21 +95,19 @@ private static void dropTable(ApplicationContext context, String tableName) { jdbcTemplate.execute("DROP TABLE IF EXISTS " + tableName + " PURGE"); } - private static boolean isSortedByDistance(final List documents) { - final List distances = documents.stream() - .map(doc -> (Double) doc.getMetadata().get("distance")) - .toList(); + private static boolean isSortedBySimilarity(final List documents) { + final List scores = documents.stream().map(Document::getScore).toList(); - if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { + if (CollectionUtils.isEmpty(scores) || scores.size() == 1) { return true; } - Iterator iter = distances.iterator(); + Iterator iter = scores.iterator(); Double current; Double previous = iter.next(); while (iter.hasNext()) { current = iter.next(); - if (previous > current) { + if (previous < current) { return false; } previous = current; @@ -134,7 +133,7 @@ public void addAndSearch(String distanceType) { assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -243,7 +242,7 @@ public void documentUpdate(String distanceType) { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -256,7 +255,7 @@ public void documentUpdate(String distanceType) { resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); dropTable(context, ((OracleVectorStore) vectorStore).getTableName()); }); @@ -279,20 +278,19 @@ public void searchWithThreshold(String distanceType) { assertThat(fullResult).hasSize(3); - assertThat(isSortedByDistance(fullResult)).isTrue(); + assertThat(isSortedBySimilarity(fullResult)).isTrue(); - List distances = fullResult.stream() - .map(doc -> (Double) doc.getMetadata().get("distance")) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - double threshold = (distances.get(0) + distances.get(1)) / 2d; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2d; List results = vectorStore.similaritySearch( - SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(1d - threshold)); + SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); dropTable(context, ((OracleVectorStore) vectorStore).getTableName()); }); diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index 69cd79b41c..f74e3fe94b 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -35,6 +35,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -502,12 +503,15 @@ public Document mapRow(ResultSet rs, int rowNum) throws SQLException { Float distance = rs.getFloat(COLUMN_DISTANCE); Map metadata = toMap(pgMetadata); - metadata.put(COLUMN_DISTANCE, distance); - - Document document = new Document(id, content, metadata); - document.setEmbedding(toFloatArray(embedding)); - - return document; + metadata.put(DocumentMetadata.DISTANCE.value(), distance); + + return Document.builder() + .id(id) + .content(content) + .metadata(metadata) + .score(1.0 - distance) + .embedding(toFloatArray(embedding)) + .build(); } private float[] toFloatArray(PGobject embedding) throws SQLException { diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java index 3366c88296..83c2e49e78 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java @@ -34,6 +34,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.containers.PostgreSQLContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -113,20 +114,20 @@ static Stream provideFilters() { ); } - private static boolean isSortedByDistance(List docs) { + private static boolean isSortedBySimilarity(List docs) { - List distances = docs.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = docs.stream().map(Document::getScore).toList(); - if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { + if (CollectionUtils.isEmpty(scores) || scores.size() == 1) { return true; } - Iterator iter = distances.iterator(); - Float current; - Float previous = iter.next(); + Iterator iter = scores.iterator(); + Double current; + Double previous = iter.next(); while (iter.hasNext()) { current = iter.next(); - if (previous > current) { + if (previous < current) { return false; } previous = current; @@ -150,7 +151,7 @@ public void addAndSearch(String distanceType) { assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -289,7 +290,7 @@ public void documentUpdate(String distanceType) { Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -303,7 +304,7 @@ public void documentUpdate(String distanceType) { resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); dropTable(context); }); @@ -326,20 +327,19 @@ public void searchWithThreshold(String distanceType) { assertThat(fullResult).hasSize(3); - assertThat(isSortedByDistance(fullResult)).isTrue(); + assertThat(isSortedBySimilarity(fullResult)).isTrue(); - List distances = fullResult.stream() - .map(doc -> (Float) doc.getMetadata().get("distance")) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; List results = vectorStore.similaritySearch( - SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(1 - threshold)); + SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); dropTable(context); }); diff --git a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java index 1093656370..d9ff49721a 100644 --- a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java +++ b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java @@ -38,6 +38,7 @@ import io.pinecone.proto.Vector; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -60,13 +61,12 @@ * @author Christian Tzolov * @author Adam Bchouti * @author Soby Chacko + * @author Thomas Vitale */ public class PineconeVectorStore extends AbstractObservationVectorStore { public static final String CONTENT_FIELD_NAME = "document_content"; - public static final String DISTANCE_METADATA_FIELD_NAME = "distance"; - public final FilterExpressionConverter filterExpressionConverter = new PineconeFilterExpressionConverter(); private final EmbeddingModel embeddingModel; @@ -236,7 +236,12 @@ public List similaritySearch(SearchRequest request, String namespace) var content = metadataStruct.getFieldsOrThrow(this.pineconeContentFieldName).getStringValue(); Map metadata = extractMetadata(metadataStruct); metadata.put(this.pineconeDistanceMetadataFieldName, 1 - scoredVector.getScore()); - return new Document(id, content, metadata); + return Document.builder() + .id(id) + .content(content) + .metadata(metadata) + .score((double) scoredVector.getScore()) + .build(); }) .toList(); } @@ -298,6 +303,8 @@ public static final class PineconeVectorStoreConfig { private final String contentFieldName; + // TODO: Why is this field configurable? Can we remove this after standardizing + // the key? private final String distanceMetadataFieldName; private final PineconeConnectionConfig connectionConfig; @@ -357,7 +364,7 @@ public static final class Builder { private String contentFieldName = CONTENT_FIELD_NAME; - private String distanceMetadataFieldName = DISTANCE_METADATA_FIELD_NAME; + private String distanceMetadataFieldName = DocumentMetadata.DISTANCE.value(); /** * Optional server-side timeout in seconds for all operations. Default: 20 diff --git a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java index 6e1eddc5f3..712efa19a6 100644 --- a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java +++ b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java @@ -31,6 +31,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; import org.springframework.ai.vectorstore.PineconeVectorStore.PineconeVectorStoreConfig; @@ -46,6 +47,7 @@ /** * @author Christian Tzolov + * @author Thomas Vitale */ @EnabledIfEnvironmentVariable(named = "PINECONE_API_KEY", matches = ".+") public class PineconeVectorStoreIT { @@ -109,7 +111,7 @@ public void addAndSearchTest() { assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -193,7 +195,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -214,7 +216,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(List.of(document.getId())); @@ -240,21 +242,22 @@ public void searchThresholdTest() { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Depression").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java index 825c50f7b5..05e6e69c4a 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java @@ -35,6 +35,7 @@ import io.qdrant.client.grpc.Points.UpdateStatus; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -57,6 +58,7 @@ * @author Eddú Meléndez * @author Josh Long * @author Soby Chacko + * @author Thomas Vitale * @since 0.8.1 */ public class QdrantVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -65,8 +67,6 @@ public class QdrantVectorStore extends AbstractObservationVectorStore implements private static final String CONTENT_FIELD_NAME = "doc_content"; - private static final String DISTANCE_FIELD_NAME = "distance"; - private final EmbeddingModel embeddingModel; private final QdrantClient qdrantClient; @@ -200,20 +200,25 @@ public List doSimilaritySearch(SearchRequest request) { } /** - * Extracts metadata from a Protobuf Struct. - * @param metadataStruct The Protobuf Struct containing metadata. + * Extracts metadata from a ScoredPoint. + * @param point The ScoredPoint. * @return The metadata as a map. */ private Document toDocument(ScoredPoint point) { try { var id = point.getId().getUuid(); - var payload = QdrantObjectFactory.toObjectMap(point.getPayloadMap()); - payload.put(DISTANCE_FIELD_NAME, 1 - point.getScore()); + var metadata = QdrantObjectFactory.toObjectMap(point.getPayloadMap()); + metadata.put(DocumentMetadata.DISTANCE.value(), 1 - point.getScore()); - var content = (String) payload.remove(CONTENT_FIELD_NAME); + var content = (String) metadata.remove(CONTENT_FIELD_NAME); - return new Document(id, content, payload); + return Document.builder() + .id(id) + .content(content) + .metadata(metadata) + .score((double) point.getScore()) + .build(); } catch (Exception e) { throw new RuntimeException(e); diff --git a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java index 995cd8cc3a..0cb14fde45 100644 --- a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java +++ b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStoreIT.java @@ -30,6 +30,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.qdrant.QdrantContainer; @@ -107,7 +108,7 @@ public void addAndSearch() { assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -185,7 +186,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -200,7 +201,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); vectorStore.delete(List.of(document.getId())); }); @@ -218,13 +219,13 @@ public void searchThresholdTest() { var request = SearchRequest.query("Great").withTopK(5); List fullResult = vectorStore.similaritySearch(request.withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore.similaritySearch(request.withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch(request.withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); @@ -232,7 +233,8 @@ public void searchThresholdTest() { assertThat(resultDoc.getContent()).isEqualTo( "Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java index df27706716..8759437926 100644 --- a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java @@ -30,6 +30,7 @@ import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.document.DocumentMetadata; import redis.clients.jedis.JedisPooled; import redis.clients.jedis.Pipeline; import redis.clients.jedis.json.Path2; @@ -240,14 +241,21 @@ public List doSimilaritySearch(SearchRequest request) { private Document toDocument(redis.clients.jedis.search.Document doc) { var id = doc.getId().substring(this.config.prefix.length()); - var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName) - : null; + var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName) : ""; Map metadata = this.config.metadataFields.stream() .map(MetadataField::name) .filter(doc::hasProperty) .collect(Collectors.toMap(Function.identity(), doc::getString)); + // TODO: this seems wrong. The key is named "vector_store", but the value is the + // distance. Can we remove this after standardizing the metadata? metadata.put(DISTANCE_FIELD_NAME, 1 - similarityScore(doc)); - return new Document(id, content, metadata); + metadata.put(DocumentMetadata.DISTANCE.value(), 1 - similarityScore(doc)); + return Document.builder() + .id(id) + .content(content) + .metadata(metadata) + .score((double) similarityScore(doc)) + .build(); } private float similarityScore(redis.clients.jedis.search.Document doc) { diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java index 4153a82ea6..6c389b672d 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java @@ -26,6 +26,7 @@ import com.redis.testcontainers.RedisStackContainer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import redis.clients.jedis.JedisPooled; @@ -50,6 +51,7 @@ /** * @author Julien Ruaux * @author Eddú Meléndez + * @author Thomas Vitale */ @Testcontainers class RedisVectorStoreIT { @@ -105,8 +107,9 @@ void addAndSearch() { assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).hasSize(2); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME); + assertThat(resultDoc.getMetadata()).hasSize(3); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME, + DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -190,6 +193,7 @@ void documentUpdate() { assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); assertThat(resultDoc.getMetadata()).containsKey(RedisVectorStore.DISTANCE_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -205,6 +209,7 @@ void documentUpdate() { assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey(RedisVectorStore.DISTANCE_FIELD_NAME); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); vectorStore.delete(List.of(document.getId())); @@ -223,24 +228,23 @@ void searchWithThreshold() { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream() - .map(doc -> (Float) doc.getMetadata().get(RedisVectorStore.DISTANCE_FIELD_NAME)) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME); - + assertThat(resultDoc.getMetadata()).containsKeys("meta1", RedisVectorStore.DISTANCE_FIELD_NAME, + DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); }); } diff --git a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java index b0fc72f136..b805be6212 100644 --- a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java +++ b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseVectorStore.java @@ -26,6 +26,7 @@ import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.document.DocumentMetadata; import org.typesense.api.Client; import org.typesense.api.FieldTypes; import org.typesense.model.CollectionResponse; @@ -57,6 +58,7 @@ * @author Pablo Sanchidrian Herrera * @author Soby Chacko * @author Christian Tzolov + * @author Thomas Vitale */ public class TypesenseVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -212,8 +214,13 @@ public List doSimilaritySearch(SearchRequest request) { String content = rawDocument.get(CONTENT_FIELD_NAME).toString(); Map metadata = rawDocument.get(METADATA_FIELD_NAME) instanceof Map ? (Map) rawDocument.get(METADATA_FIELD_NAME) : Map.of(); - metadata.put("distance", hit.getVectorDistance()); - return new Document(docId, content, metadata); + metadata.put(DocumentMetadata.DISTANCE.value(), hit.getVectorDistance()); + return Document.builder() + .id(docId) + .content(content) + .metadata(metadata) + .score(1.0 - hit.getVectorDistance()) + .build(); })) .toList(); diff --git a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java index 23cd9f03bb..6d05a17fab 100644 --- a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java +++ b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseVectorStoreIT.java @@ -26,6 +26,7 @@ import java.util.UUID; import org.junit.jupiter.api.Test; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.containers.GenericContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -96,7 +97,7 @@ void documentUpdate() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -114,7 +115,7 @@ void documentUpdate() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); vectorStore.delete(List.of(document.getId())); @@ -211,21 +212,22 @@ void searchWithThreshold() { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); ((TypesenseVectorStore) vectorStore).dropCollection(); diff --git a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java index 9ce01c9f05..049d4a4f1f 100644 --- a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java +++ b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java @@ -45,6 +45,7 @@ import io.weaviate.client.v1.graphql.query.fields.Fields; import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; @@ -73,11 +74,10 @@ * @author Eddú Meléndez * @author Josh Long * @author Soby Chacko + * @author Thomas Vitale */ public class WeaviateVectorStore extends AbstractObservationVectorStore { - public static final String DOCUMENT_METADATA_DISTANCE_KEY_NAME = "distance"; - private static final String METADATA_FIELD_PREFIX = "meta_"; private static final String CONTENT_FIELD_NAME = "content"; @@ -367,7 +367,7 @@ private Document toDocument(Map item) { // Metadata Map metadata = new HashMap<>(); - metadata.put(DOCUMENT_METADATA_DISTANCE_KEY_NAME, 1 - certainty); + metadata.put(DocumentMetadata.DISTANCE.value(), 1 - certainty); try { String metadataJson = (String) item.get(METADATA_FIELD_NAME); @@ -382,10 +382,13 @@ private Document toDocument(Map item) { // Content String content = (String) item.get(CONTENT_FIELD_NAME); - var document = new Document(id, content, metadata); - document.setEmbedding(EmbeddingUtils.toPrimitive(EmbeddingUtils.doubleToFloat(embedding))); - - return document; + return Document.builder() + .id(id) + .content(content) + .metadata(metadata) + .embedding(EmbeddingUtils.toPrimitive(EmbeddingUtils.doubleToFloat(embedding))) + .score(certainty) + .build(); } @Override diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java index b474cdaede..7239b319b6 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java @@ -26,6 +26,7 @@ import io.weaviate.client.Config; import io.weaviate.client.WeaviateClient; import org.junit.jupiter.api.Test; +import org.springframework.ai.document.DocumentMetadata; import org.testcontainers.containers.wait.strategy.Wait; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -101,7 +102,7 @@ public void addAndSearch() { assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); assertThat(resultDoc.getMetadata()).hasSize(2); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -186,7 +187,7 @@ public void documentUpdate() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); assertThat(resultDoc.getMetadata()).containsKey("meta1"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -201,7 +202,7 @@ public void documentUpdate() { assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); - assertThat(resultDoc.getMetadata()).containsKey("distance"); + assertThat(resultDoc.getMetadata()).containsKey(DocumentMetadata.DISTANCE.value()); vectorStore.delete(List.of(document.getId())); @@ -222,23 +223,22 @@ public void searchWithThreshold() { List fullResult = vectorStore .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); - List distances = fullResult.stream() - .map(doc -> (Double) doc.getMetadata().get("distance")) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - assertThat(distances).hasSize(3); + assertThat(scores).hasSize(3); - double threshold = (distances.get(0) + distances.get(1)) / 2; + double similarityThreshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore - .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); + List results = vectorStore.similaritySearch( + SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(similarityThreshold)); assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(0).getId()); assertThat(resultDoc.getContent()).contains( "Spring AI provides abstractions that serve as the foundation for developing AI applications."); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", DocumentMetadata.DISTANCE.value()); + assertThat(resultDoc.getScore()).isGreaterThanOrEqualTo(similarityThreshold); }); }