Skip to content

Commit

Permalink
Merge pull request #344 from markstur/embed_same_test
Browse files Browse the repository at this point in the history
Embedding add a test that would have helped
  • Loading branch information
evaline-ju authored Apr 4, 2024
2 parents d34987a + 55b07f8 commit c12cb82
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,3 +1054,43 @@ def test_encode_extensions(loaded_model):
BOOTSTRAPPED_MODEL._encode_with_retry(
"text here"
) # and no KeyError trying to remove non-existing keys


@pytest.mark.parametrize(
"truncate_input_tokens",
[0, 1, 2, 3, 4, 5, 99, 100, 101, 300, 510, 511, 512, 513, 1000, -1],
)
def test_same_same(loaded_model: EmbeddingModule, truncate_input_tokens):
"""Confirm that same text gives same results"""

inputs = ["What is generative ai?", "What is generative ai?", "different"]

# First ensuring that batch input vs loop over inputs is the same
separate_embeddings = [
loaded_model.run_embedding(text=i, truncate_input_tokens=truncate_input_tokens)
for i in inputs
]
combined_embeddings = loaded_model.run_embeddings(
texts=inputs, truncate_input_tokens=truncate_input_tokens
)

separate_vectors = [
e.to_dict()["result"]["data"]["values"] for e in separate_embeddings
]
combined_vectors = [
e["data"]["values"] for e in combined_embeddings.to_dict()["results"]["vectors"]
]

assert len(separate_vectors) == len(
combined_vectors
), "expected the same number separate and combined embeddings"

# test order by comparing value of individual embeddings in sequence
for i, e in enumerate(separate_vectors):
assert np.allclose(e, combined_vectors[i])

# Next ensuring that the two identical sentences yield identical results (and 3rd does not)
assert np.array_equal(combined_vectors[0], combined_vectors[1])
assert not np.array_equal(combined_vectors[1], combined_vectors[2])
assert np.array_equal(separate_vectors[0], separate_vectors[1])
assert not np.array_equal(separate_vectors[1], separate_vectors[2])

0 comments on commit c12cb82

Please sign in to comment.