Skip to content

Commit

Permalink
Support similarity scores in Document API
Browse files Browse the repository at this point in the history
Document
* Introduced “score” attribute in Document API. It stores the similarity score.
* Consolidate “distance” metadata for Documents. It stores the distance measurement.
* Adopted prefix-less naming convention in Document.Builder and deprecated old methods.
* Deprecated the many overloaded Document constructors in favour of Document.Builder.

Vector Stores
* Every vector store implementation now configures a “score” attribute with the similarity score of the Document embedding. It also includes the “distance” metadata with the distance measurement.
* Fixed error in Elasticsearch where distance and similarity were mixed up.
* Added missing integration tests for SimpleVectorStore.
* The Azure Vector Store and HanaDB Vector Store do not include those measurements because the product documentation do not include information about how the similarity score is returned, and without access to the cloud products I could not verify that via debugging.
* Improved tests to actually assert the result of the similarity search based on the returned score.

Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
  • Loading branch information
ThomasVitale committed Nov 26, 2024
1 parent e7f37a0 commit efca3c7
Show file tree
Hide file tree
Showing 53 changed files with 868 additions and 397 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
Expand All @@ -31,6 +32,7 @@
import org.springframework.ai.document.id.RandomIdGenerator;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.MediaContent;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

Expand Down Expand Up @@ -61,7 +63,15 @@ public class Document implements MediaContent {
* Metadata for the document. It should not be nested and values should be restricted
* to string, int, float, boolean for simple use with Vector Dbs.
*/
private Map<String, Object> metadata;
private final Map<String, Object> metadata;

/**
* Measure of similarity between the document embedding and the query vector. The
* higher the score, the more they are similar. It's the opposite of the distance
* measure.
*/
@Nullable
private Double score;

/**
* Embedding of the document. Note: ephemeral field.
Expand All @@ -80,31 +90,61 @@ public Document(@JsonProperty("content") String content) {
this(content, new HashMap<>());
}

/**
* @deprecated Use builder instead: {@link Document#builder()}.
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Document(String content, Map<String, Object> metadata) {
this(content, metadata, new RandomIdGenerator());
}

/**
* @deprecated Use builder instead: {@link Document#builder()}.
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Document(String content, Collection<Media> media, Map<String, Object> metadata) {
this(new RandomIdGenerator().generateId(content, metadata), content, media, metadata);
}

/**
* @deprecated Use builder instead: {@link Document#builder()}.
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Document(String content, Map<String, Object> metadata, IdGenerator idGenerator) {
this(idGenerator.generateId(content, metadata), content, metadata);
}

/**
* @deprecated Use builder instead: {@link Document#builder()}.
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Document(String id, String content, Map<String, Object> metadata) {
this(id, content, List.of(), metadata);
}

/**
* @deprecated Use builder instead: {@link Document#builder()}.
*/
@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Document(String id, String content, Collection<Media> media, Map<String, Object> metadata) {
Assert.hasText(id, "id must not be null or empty");
Assert.notNull(content, "content must not be null");
Assert.notNull(metadata, "metadata must not be null");
this(id, content, media, metadata, null);
}

public Document(String id, String content, @Nullable Collection<Media> media,
@Nullable Map<String, Object> metadata, @Nullable Double score) {
Assert.hasText(id, "id cannot be null or empty");
Assert.notNull(content, "content cannot be null");
Assert.notNull(media, "media cannot be null");
Assert.noNullElements(media, "media cannot have null elements");
Assert.notNull(metadata, "metadata cannot be null");
Assert.noNullElements(metadata.keySet(), "metadata cannot have null keys");
Assert.noNullElements(metadata.values(), "metadata cannot have null values");

this.id = id;
this.content = content;
this.media = media;
this.metadata = metadata;
this.media = media != null ? media : List.of();
this.metadata = metadata != null ? metadata : new HashMap<>();
this.score = score;
}

public static Builder builder() {
Expand Down Expand Up @@ -149,6 +189,15 @@ public Map<String, Object> getMetadata() {
return this.metadata;
}

@Nullable
public Double getScore() {
return this.score;
}

public void setScore(@Nullable Double score) {
this.score = score;
}

/**
* Return the embedding that were calculated.
* @deprecated We are considering getting rid of this, please comment on
Expand Down Expand Up @@ -186,57 +235,24 @@ public void setContentFormatter(ContentFormatter contentFormatter) {

@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((this.id == null) ? 0 : this.id.hashCode());
result = prime * result + ((this.metadata == null) ? 0 : this.metadata.hashCode());
result = prime * result + ((this.content == null) ? 0 : this.content.hashCode());
return result;
return Objects.hash(id, content, media, metadata);
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
public boolean equals(Object o) {
if (this == o)
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
Document other = (Document) obj;
if (this.id == null) {
if (other.id != null) {
return false;
}
}
else if (!this.id.equals(other.id)) {
return false;
}
if (this.metadata == null) {
if (other.metadata != null) {
return false;
}
}
else if (!this.metadata.equals(other.metadata)) {
if (o == null || getClass() != o.getClass())
return false;
}
if (this.content == null) {
if (other.content != null) {
return false;
}
}
else if (!this.content.equals(other.content)) {
return false;
}
return true;
Document document = (Document) o;
return Objects.equals(id, document.id) && Objects.equals(content, document.content)
&& Objects.equals(media, document.media) && Objects.equals(metadata, document.metadata);
}

@Override
public String toString() {
return "Document{" + "id='" + this.id + '\'' + ", metadata=" + this.metadata + ", content='" + this.content
+ '\'' + ", media=" + this.media + '}';
return "Document{" + "id='" + id + '\'' + ", content='" + content + '\'' + ", media=" + media + ", metadata="
+ metadata + ", score=" + score + '}';
}

public static class Builder {
Expand All @@ -249,56 +265,102 @@ public static class Builder {

private Map<String, Object> metadata = new HashMap<>();

private float[] embedding = new float[0];

private Double score;

private IdGenerator idGenerator = new RandomIdGenerator();

public Builder withIdGenerator(IdGenerator idGenerator) {
Assert.notNull(idGenerator, "idGenerator must not be null");
public Builder idGenerator(IdGenerator idGenerator) {
Assert.notNull(idGenerator, "idGenerator cannot be null");
this.idGenerator = idGenerator;
return this;
}

public Builder withId(String id) {
Assert.hasText(id, "id must not be null or empty");
public Builder id(String id) {
Assert.hasText(id, "id cannot be null or empty");
this.id = id;
return this;
}

public Builder withContent(String content) {
Assert.notNull(content, "content must not be null");
public Builder content(String content) {
this.content = content;
return this;
}

public Builder withMedia(List<Media> media) {
Assert.notNull(media, "media must not be null");
public Builder media(List<Media> media) {
this.media = media;
return this;
}

public Builder withMedia(Media media) {
Assert.notNull(media, "media must not be null");
this.media.add(media);
public Builder media(Media... media) {
Assert.noNullElements(media, "media cannot contain null elements");
this.media.addAll(List.of(media));
return this;
}

public Builder withMetadata(Map<String, Object> metadata) {
Assert.notNull(metadata, "metadata must not be null");
public Builder metadata(Map<String, Object> metadata) {
this.metadata = metadata;
return this;
}

public Builder withMetadata(String key, Object value) {
Assert.notNull(key, "key must not be null");
Assert.notNull(value, "value must not be null");
public Builder metadata(String key, Object value) {
this.metadata.put(key, value);
return this;
}

public Builder embedding(float[] embedding) {
this.embedding = embedding;
return this;
}

public Builder score(Double score) {
this.score = score;
return this;
}

@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Builder withIdGenerator(IdGenerator idGenerator) {
return idGenerator(idGenerator);
}

@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Builder withId(String id) {
return id(id);
}

@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Builder withContent(String content) {
return content(content);
}

@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Builder withMedia(List<Media> media) {
return media(media);
}

@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Builder withMedia(Media media) {
return media(media);
}

@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Builder withMetadata(Map<String, Object> metadata) {
return metadata(metadata);
}

@Deprecated(since = "1.0.0-M5", forRemoval = true)
public Builder withMetadata(String key, Object value) {
return metadata(key, value);
}

public Document build() {
if (!StringUtils.hasText(this.id)) {
this.id = this.idGenerator.generateId(this.content, this.metadata);
}
return new Document(this.id, this.content, this.media, this.metadata);
var document = new Document(this.id, this.content, this.media, this.metadata, this.score);
document.setEmbedding(this.embedding);
return document;
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.document;

import org.springframework.ai.vectorstore.VectorStore;

/**
* Common set of metadata keys used in {@link Document}s by {@link DocumentReader}s and
* {@link VectorStore}s.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public enum DocumentMetadata {

// @formatter:off

/**
* Measure of distance between the document embedding and the query vector.
* The lower the distance, the more they are similar.
* It's the opposite of the similarity score.
*/
DISTANCE("distance");

private final String value;

DocumentMetadata(String value) {
this.value = value;
}
public String value() {
return this.value;
}

// @formatter:on

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

@NonNullApi
@NonNullFields
package org.springframework.ai.document;

import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;
Loading

0 comments on commit efca3c7

Please sign in to comment.