Skip to content

Releases: UKPLab/sentence-transformers

v3.0.1 - Patch introducing new Trainer features, model card improvements and evaluator fixes

07 Jun 13:01
Compare
Choose a tag to compare

This patch release introduces some improvements for the SentenceTransformerTrainer, as well as some updates for the automatic model card generation. It also patches some minor evaluator bugs and a bug with MatryoshkaLoss. Lastly, every single Sentence Transformer model can now be saved and loaded with the safer model.safetensors files.

Install this version with

# Full installation:
pip install sentence-transformers[train]==3.0.1

# Inference only:
pip install sentence-transformers==3.0.1

SentenceTransformerTrainer improvements

  • Implement gradient checkpointing for lower memory usage during training (#2717)
  • Implement support for push_to_hub=True Training Argument, also implement trainer.push_to_hub(...) (#2718)

Model Cards

This patch release improves on the automatically generated model cards in several ways:

  • Your training datasets are now automatically linked if they're on Hugging Face (#2711)
  • A new generated_from_trainer tag is now also added (#2710)
  • The automatically included widget examples are now improved, especially for question-answering. Previously, the widget could give examples of comparing two questions with eachother (#2713)
  • If you save a model locally, then load it again and upload it, it would previously still show
...
# Download from the πŸ€— Hub
model = SentenceTransformer("sentence_transformers_model_id")
...

This now gets replaced with your new model ID on Hugging Face (#2714)

  • The exact training dataset size is now included in the model metadata, rather than as a bucket of e.g. 1K<n<10K (#2728)

Evaluators fixes

  • The primary metric of evaluators in SequentialEvaluator would be ignored in the scores calculation (#2700)
  • Fix confusing print statement in TranslationEvaluator when using print_wrong_matches=True (#1894)
  • Fix bug that prevents you from customizing the primary_metric in InformationRetrievalEvaluator (#2701)
  • Allow passing a list of evaluators to the STTrainer rather than a SequentialEvaluator (#2717)

Losses fixes

  • Fix MatryoshkaLoss crash if the first dimension is not the biggest (#2719)

Security

  • Integrate safetensors with all modules, including Dense, LSTM, CNN, etc. to prevent needing pickled pytorch_model.bin anymore (#2722)

All changes

  • updating to evaluation_strategy by @higorsilvaa in #2686
  • fix loss link by @Samoed in #2690
  • Fix bug that restricts users from specifying custom primary_function in InformationRetrievalEvaluator by @hetulvp in #2701
  • Fix a bug in SequentialEvaluator to use primary_metric if defined in evaluator. by @hetulvp in #2700
  • [fix] Always override the originally saved version in the ST config by @tomaarsen in #2709
  • [model cards] Also include HF datasets in the model card metadata by @tomaarsen in #2711
  • Add "generated_from_trainer" tag to auto-generated model cards by @tomaarsen in #2710
  • Fix confusing print statement in TranslationEvaluator by @NathanS-Git in #1894
  • [model cards] Improve the widget example selection: not based on embeddings, better for QA by @tomaarsen in #2713
  • [model cards] Replace 'sentence_transformers_model_id' from reused model if possible by @tomaarsen in #2714
  • [feat] Allow passing a list of evaluators to the Trainer by @tomaarsen in #2716
  • [fix] Fix gradient checkpointing to allow for much lower memory usage by @tomaarsen in #2717
  • [fix] Implement create_model_card on the Trainer, allowing args.push_to_hub=True by @tomaarsen in #2718
  • [fix] Fix MatryoshkaLoss crash if the first dimension is not the biggest by @tomaarsen in #2719
  • Update models_en_sentence_embeddings.html by @saikartheekb in #2720
  • [typing] Improve typing for many functions & add py.typed to satisfy mypy by @tomaarsen in #2724
  • [fix] Fix edge case with evaluator being None by @tomaarsen in #2726
  • [simplify] Set can_return_loss=True globally, instead of via the data collator by @tomaarsen in #2727
  • [feat] Integrate safetensors with Dense, etc. modules too. by @tomaarsen in #2722
  • [model cards] Specify the exact dataset size as a tag, will be bucketized by HF by @tomaarsen in #2728

New Contributors

Full Changelog: v3.0.0...v3.0.1

v3.0.0 - Sentence Transformer Training Refactor; new similarity methods; hyperparameter optimization; 50+ datasets release

28 May 11:54
Compare
Choose a tag to compare

This release consists of a major refactor that overhauls the training approach (introducing multi-gpu training, bf16, loss logging, callbacks, and much more), adds convenient similarity and similarity_pairwise methods, adds extra keyword arguments, introduces Hyperparameter Optimization, and includes a massive reformatting and release of 50+ datasets for training embedding models. In total, this is the largest Sentence Transformers update since the project was first created.

Install this version with

# Full installation:
pip install sentence-transformers[train]==3.0.0

# Inference only:
pip install sentence-transformers==3.0.0

Sentence Transformer training refactor (#2449)

The v3.0 release centers around this huge modernization of the training approach for SentenceTransformer models. Whereas training before v3.0 used to be all about InputExample, DataLoader and model.fit, the new training approach relies on 5 new components. You can learn more about these components in our Training and Finetuning Embedding Models with Sentence Transformers v3 blogpost. Additionally, you can read the new Training Overview, check out the Training Examples, or read this summary:

  1. Dataset
    A training Dataset or DatasetDict. This class is much more suited for sharing & efficient modifications than lists/DataLoaders of InputExample instances. A Dataset can contain multiple text columns that will be fed in order to the corresponding loss function. So, if the loss expects (anchor, positive, negative) triplets, then your dataset should also have 3 columns. The names of these columns are irrelevant. If there is a "label" or "score" column, it is treated separately, and used as the labels during training.
    A DatasetDict can be used to train with multiple datasets at once, e.g.:
    DatasetDict({
        multi_nli: Dataset({
            features: ['premise', 'hypothesis', 'label'],
            num_rows: 392702
        })
        snli: Dataset({
            features: ['snli_premise', 'hypothesis', 'label'],
            num_rows: 549367
        })
        stsb: Dataset({
            features: ['sentence1', 'sentence2', 'label'],
            num_rows: 5749
        })
    })
    When a DatasetDict is used, the loss parameter to the SentenceTransformerTrainer must also be a dictionary with these dataset keys, e.g.:
    {
        'multi_nli': SoftmaxLoss(...),
        'snli': SoftmaxLoss(...),
        'stsb': CosineSimilarityLoss(...),
    }
  2. Loss Function
    A loss function, or a dictionary of loss functions like described above. These loss functions do not require changes compared to before this PR.
  3. Training Arguments
    A SentenceTransformerTrainingArguments instance, subclass of a TrainingArguments instance. This powerful class controls the specific details of the training.
  4. Evaluator
    An optional SentenceEvaluator instance. Unlike before, models can now be evaluated both on an evaluation dataset with some loss function and/or a SentenceEvaluator instance.
  5. Trainer
    The new SentenceTransformersTrainer instance based on the transformers Trainer. This instance is provided with a SentenceTransformer model, a SentenceTransformerTrainingArguments class, a SentenceEvaluator, a training and evaluation Dataset/DatasetDict and a loss function/dict of loss functions. Most of these parameters are optional. Once provided, all you have to do is call trainer.train().

Some of the major features that are now implemented include:

  • MultiGPU Training (Data Parallelism (DP) and Distributed Data Parallelism (DDP))
  • bf16 training support
  • Loss logging
  • Evaluation datasets + evaluation loss
  • Improved callback support (built-in via Weights and Biases, TensorBoard, CodeCarbon, etc., as well as custom callbacks)
  • Gradient checkpointing
  • Gradient accumulation
  • Improved model card generation
  • Warmup ratio
  • Pushing to the Hugging Face Hub on every model checkpoint
  • Resuming from a training checkpoint
  • Hyperparameter Optimization

This script is a minimal example (no evaluator, no training arguments) of training mpnet-base on a part of the all-nli dataset using MultipleNegativesRankingLoss:

from datasets import load_dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
from sentence_transformers.losses import MultipleNegativesRankingLoss

# 1. Load a model to finetune
model = SentenceTransformer("microsoft/mpnet-base")

# 2. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/all-nli", "triplet")
train_dataset = dataset["train"].select(range(10_000))
eval_dataset = dataset["dev"].select(range(1_000))

# 3. Define a loss function
loss = MultipleNegativesRankingLoss(model)

# 4. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
)
trainer.train()

# 5. Save the trained model
model.save_pretrained("models/mpnet-base-all-nli")

Additionally, trained models now automatically produce extensive model cards. Each of the following models were trained using some script from the Training Examples, and the model cards were not edited manually whatsoever:

Prior to the Sentence Transformer v3 release, all models would be trained using the SentenceTransformer.fit method. Rather than deprecating this method, starting from v3.0, this method will use the SentenceTransformerTrainer behind the scenes. This means that your old training code should still work, and should even be upgraded with the new features such as multi-gpu training, loss logging, etc. That said, the new training approach is much more powerful, so it is recommended to write new training scripts using the new approach.

Many of the old training scripts were updated to use the new Trainer-based approach, but not all have been updated yet. We accept help via Pull Requests to assist in updating the scripts.

Similarity Score (#2615, #2490)

Sentence Transformers v3.0 introduces two new useful methods:

and one property:

These can be used to calculate the similarity between embeddings, and to specify which similarity function should be used, for example:

>>> from sentence_transformers import SentenceTransformer
>>> model = SentenceTransformer("all-mpnet-base-v2")
>>> sentences = [
...     "The weather is so nice!",
...     "It's so sunny outside.",
...     "He's driving to the movie theater.",
...     "She's going to the cinema.",
... ]
>>> embeddings = model.encode(sentences, normalize_embeddings=True)
>>> model.similarity(embeddings, embeddings)
tensor([[1.0000, 0.7235, 0.0290, 0.1309],
        [0.7235, 1.0000, 0.0613, 0.1129],
        [0.0290, 0.0613, 1.0000, 0.5027],
        [0.1309, 0.1129, 0.5027, 1.0000]])
>>> model.similarity_fn_name
"cosine"
>>> model.similarity_fn_name = "euclidean"
>>> model.similarity(embedd...
Read more

v2.7.0 - CachedGISTEmbedLoss, easy Matryoshka inference & evaluation, CrossEncoder, Intel Gaudi2

17 Apr 13:16
Compare
Choose a tag to compare

This release introduces a new promising loss function, easier inference for Matryoshka models, new functionality for CrossEncoders and Inference on Intel Gaudi2, along much more.

Install this version with

pip install sentence-transformers==2.7.0

New loss function: CachedGISTEmbedLoss (#2592)

For a number of years, MultipleNegativesRankingLoss (also known as SimCSE, InfoNCE, in-batch negatives loss) has been the state of the art in embedding model training. Notably, this loss function performs better with a larger batch size.

Recently, various improvements have been introduced:

  1. CachedMultipleNegativesRankingLoss was introduced, which allows you to pick much higher batch sizes (e.g. 65536) with constant memory.
  2. GISTEmbedLoss takes a guide model to guide the in-batch negative sample selection. This prevents false negatives, resulting in a stronger training signal.

Now, @JacksonCakes has combined these two approaches to produce the best of both worlds: CachedGISTEmbedLoss. This loss function allows for high batch sizes with constant memory usage, while also using a guide model to assist with the in-batch negative sample selection.

As can be seen in our Loss Overview, this model should be used with (anchor, positive) pairs or (anchor, positive, negative) triplets, much like MultipleNegativesRankingLoss, CachedMultipleNegativesRankingLoss, and GISTEmbedLoss. In short, any example using those loss functions can be updated to use CachedGISTEmbedLoss! Feel free to experiment, e.g. with this training script.

Automatic Matryoshka model truncation (#2573)

Sentence Transformers v2.4.0 introduced Matryoshka models: models whose embeddings are still useful after truncation. Since then, many useful Matryoshka models have been trained.

As of this release, the truncation for these Matryoshka embedding models can be done automatically via a new truncate_dim constructor argument:

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

matryoshka_dim = 64
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True, truncate_dim=matryoshka_dim)

embeddings = model.encode(
    [
        "search_query: What is TSNE?",
        "search_document: t-distributed stochastic neighbor embedding (t-SNE) is a statistical method for visualizing high-dimensional data by giving each datapoint a location in a two or three-dimensional map.",
        "search_document: Amelia Mary Earhart was an American aviation pioneer and writer.",
    ]
)
print(embeddings.shape)
# => [3, 64]

similarities = cos_sim(embeddings[0], embeddings[1:])
# => tensor([[0.7839, 0.4933]])

Extra information:

Model truncation in all evaluators (#2582)

Alongside easier inference with Matryoshka models, evaluating them is now also much easier. You can also pass truncate_dim to any Evaluator. This way you can easily check the performance of any Sentence Transformer model at various truncated dimensions (even if the model was not trained with MatryoshkaLoss!)

from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers import SentenceTransformer
import datasets

model = SentenceTransformer("tomaarsen/mpnet-base-nli-matryoshka")

stsb = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")

for dim in [768, 512, 256, 128, 64, 32, 16, 8, 4]:
    evaluator = EmbeddingSimilarityEvaluator(
        stsb["sentence1"],
        stsb["sentence2"],
        [score / 5 for score in stsb["score"]],
        name=f"sts-test-{dim}",
        truncate_dim=dim,
    )
    print(f"dim={dim:<3}: {evaluator(model) * 100:.2f} Spearman Correlation")
dim=768: 86.81 Spearman Correlation
dim=512: 86.76 Spearman Correlation
dim=256: 86.66 Spearman Correlation
dim=128: 86.20 Spearman Correlation
dim=64 : 85.40 Spearman Correlation
dim=32 : 82.42 Spearman Correlation
dim=16 : 79.31 Spearman Correlation
dim=8  : 72.82 Spearman Correlation
dim=4  : 63.44 Spearman Correlation

Here are some example training scripts that use this new truncate_dim option to assist with training Matryoshka models:

CrossEncoder improvements

This release improves the support for CrossEncoder reranker models.

push_to_hub (#2524)

You can now push trained CrossEncoder models to the πŸ€— Hugging Face Hub!

from sentence_transformers import CrossEncoder

...

model = CrossEncoder("distilroberta-base")

# Train the model
model.fit(
    train_dataloader=train_dataloader,
    evaluator=evaluator,
    epochs=num_epochs,
    warmup_steps=warmup_steps,
)

model.push_to_hub("tomaarsen/distilroberta-base-stsb-cross-encoder")

trust_remote_code for custom models (#2595)

You can now load custom models from the Hugging Face Hub, i.e. models that have custom modelling code that require trust_remote_code to load.

from sentence_transformers import CrossEncoder

# Note: this model does not require `trust_remote_code=True` - there are currently no models that require it yet.
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", trust_remote_code=True)

# We want to compute the similarity between the query sentence
query = "A man is eating pasta."

# With all sentences in the corpus
corpus = [
    "A man is eating food.",
    "A man is eating a piece of bread.",
    "The girl is carrying a baby.",
    "A man is riding a horse.",
    "A woman is playing violin.",
    "Two men pushed carts through the woods.",
    "A man is riding a white horse on an enclosed ground.",
    "A monkey is playing drums.",
    "A cheetah is running behind its prey.",
]

# We rank all sentences in the corpus for the query
ranks = model.rank(query, corpus)

# Print the scores
print("Query:", query)
for rank in ranks:
    print(f"{rank['score']:.2f}\t{corpus[rank['corpus_id']]}")

Inference on Intel Gaudi2 (#2557)

From this release onwards, you will be able to perform inference on Intel Gaudi2 accelerators. No modifications are needed, as the library will automatically detect the hpu device and configure the model accordingly. Thanks to Intel Habana for the support here.

All changes

  • [docs] Add simple Makefile for building docs by @tomaarsen in #2566
  • [examples] Add Matryoshka evaluation plot by @kddubey in #2564
  • Adding push_to_hub to CrossEncoder by @imvladikon in #2524
  • Fix semantic_search_usearch() for single query by @karmi in #2572
  • [requirements] Set minimum transformers version to 4.34.0 for is_nltk_available by @tomaarsen in #2574
  • [docs] Update link: retrieve_rerank_simple_wikipedia.py -> .ipynb by @tomaarsen in #2580
  • Document dev reqs, add ruff pre-commit by @kddubey in #2576
  • Enable Sentence Transformer Inference with Intel Gaudi2 GPU Supported ( 'hpu' ) by @ZhengHongming888 in #2557
  • [feat] Add truncation support by @kddubey in #2573
  • [examples] Add model upload for training_nli_v3 with GISTEmbedLoss by @tomaarsen in #2584
  • Add truncation support in evaluators by @kddubey in #2582
  • Add ST annotation to evaluators by @kddubey in #2586
  • [fix] Matryoshka training always patch original forward, and check matryoshka_dims by @kddubey in #2593
  • corrected comment from kmeans to...
Read more

v2.6.1 - Fix Quantized Semantic Search rescoring

26 Mar 09:01
Compare
Choose a tag to compare

This is a patch release to fix a bug in semantic_search_faiss and semantic_search_usearch that caused the scores to not correspond to the returned corpus indices. Additionally, you can now evaluate embedding models after quantizing their embeddings.

Precision support in EmbeddingSimilarityEvaluator

You can now pass precision to the EmbeddingSimilarityEvaluator to evaluate the performance after quantization:

from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction
import datasets

model = SentenceTransformer("all-mpnet-base-v2")

stsb = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")

print("Spearman correlation based on Cosine Similarity on the STS Benchmark test set:")
for precision in ["float32", "uint8", "int8", "ubinary", "binary"]:
    evaluator = EmbeddingSimilarityEvaluator(
        stsb["sentence1"],
        stsb["sentence2"],
        [score / 5 for score in stsb["score"]],
        main_similarity=SimilarityFunction.COSINE,
        name="sts-test",
        precision=precision,
    )
    print(precision, evaluator(model))
Spearman correlation based on Cosine Similarity on the STS Benchmark test set:
float32 0.8342190421330611
uint8 0.8260094846238505
int8 0.8312754408857808
ubinary 0.8244338431442343
binary 0.8244338431442343

All changes

  • Add 'precision' support to the EmbeddingSimilarityEvaluator by @tomaarsen in #2559
  • [hotfix] Quantization patch; fix semantic_search_faiss/semantic_search_usearch rescoring by @tomaarsen in #2558
  • Fix a typo in a docstring in CosineSimilarityLoss.py by @bryant1410 in #2553

Full Changelog: v2.6.0...v2.6.1

v2.6.0 - Embedding Quantization, GISTEmbedLoss

22 Mar 15:29
Compare
Choose a tag to compare

This release brings embedding quantization: a way to heavily speed up retrieval & other tasks, and a new powerful loss function: GISTEmbedLoss.

Install this version with

pip install sentence-transformers==2.6.0

Embedding Quantization

Embeddings may be challenging to scale up, which leads to expensive solutions and high latencies. However, there is a new approach to counter this problem; it entails reducing the size of each of the individual values in the embedding: Quantization. Experiments on quantization have shown that we can maintain a large amount of performance while significantly speeding up computation and saving on memory, storage, and costs.

To be specific, using binary quantization may result in retaining 96% of the retrieval performance, while speeding up retrieval by 25x and saving on memory & disk space with 32x. Do not underestimate this approach! Read more about Embedding Quantization in our extensive blogpost.

Binary and Scalar Quantization

Two forms of quantization exist at this time: binary and scalar (int8). These quantize embedding values from float32 into binary and int8, respectively. For Binary quantization, you can use the following snippet:

from sentence_transformers import SentenceTransformer
from sentence_transformers.quantization import quantize_embeddings

# 1. Load an embedding model
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")

# 2a. Encode some text using "binary" quantization
binary_embeddings = model.encode(
    ["I am driving to the lake.", "It is a beautiful day."],
    precision="binary",
)

# 2b. or, encode some text without quantization & apply quantization afterwards
embeddings = model.encode(["I am driving to the lake.", "It is a beautiful day."])
binary_embeddings = quantize_embeddings(embeddings, precision="binary")

References:

GISTEmbedLoss

GISTEmbedLoss, as introduced in Solatorio (2024), is a guided variant of the more standard in-batch negatives (MultipleNegativesRankingLoss) loss. Both loss functions are provided with a list of (anchor, positive) pairs, but while MultipleNegativesRankingLoss uses anchor_i and positive_i as positive pair and all positive_j with i != j as negative pairs, GISTEmbedLoss uses a second model to guide the in-batch negative sample selection.

This can be very useful, because it is plausible that anchor_i and positive_j are actually quite semantically similar. In this case, GISTEmbedLoss would not consider them a negative pair, while MultipleNegativesRankingLoss would. When finetuning MPNet-base on the AllNLI dataset, these are the Spearman correlation based on cosine similarity using the STS Benchmark dev set (higher is better):

312039399-ef5d4042-a739-41f6-a6ca-eddc7f901411
The blue line is MultipleNegativesRankingLoss, whereas the grey line is GISTEmbedLoss with the small all-MiniLM-L6-v2 as the guide model. Note that all-MiniLM-L6-v2 by itself does not reach 88 Spearman correlation on this dataset, so this is really the effect of two models (mpnet-base and all-MiniLM-L6-v2) reaching a performance that they could not reach separately.

Soft save_to_hub Deprecation

Most codebases that allow for pushing models to the Hugging Face Hub adopt a push_to_hub method instead of a save_to_hub method, and now Sentence Transformers will follow that convention. The push_to_hub method will now be the recommended approach, although save_to_hub will continue to exist for the time being: it will simply call push_to_hub internally.

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-mpnet-base-v2")

...

# Train the model
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=dev_evaluator,
    epochs=num_epochs,
    evaluation_steps=1000,
    warmup_steps=warmup_steps,
)

# Push the model to Hugging Face
model.push_to_hub("tomaarsen/mpnet-base-nli-stsb")

All changes

  • Add GISTEmbedLoss by @avsolatorio in #2535
  • [feat] Add 'get_config_dict' method to GISTEmbedLoss for better model cards by @tomaarsen in #2543
  • Enable saving modules as pytorch_model.bin by @CKeibel in #2542
  • [deprecation] Deprecate save_to_hub in favor of push_to_hub; add safe_serialization support to push_to_hub by @tomaarsen in #2544
  • Fix SentenceTransformer encode documentation return type default (numpy vectors) by @CKeibel in #2546
  • [docs] Update return docstring of encode_multi_process by @tomaarsen in #2548
  • [feat] Add binary & scalar embedding quantization support to Sentence Transformers by @tomaarsen in #2549

New Contributors

Full Changelog: v2.5.1...v2.6.0

v2.5.1 - fix CrossEncoder.rank bug with default top_k

01 Mar 08:18
Compare
Choose a tag to compare

This is a patch release to fix a bug in CrossEncoder.rank that caused the last value to be discarded when using the default top_k=-1.

CrossEncoder.rank patch:

from sentence_transformers.cross_encoder import CrossEncoder

# Pre-trained cross encoder
model = CrossEncoder("cross-encoder/stsb-distilroberta-base")

# We want to compute the similarity between the query sentence
query = "A man is eating pasta."

# With all sentences in the corpus
corpus = [
    "A man is eating food.",
    "A man is eating a piece of bread.",
    "The girl is carrying a baby.",
    "A man is riding a horse.",
    "A woman is playing violin.",
    "Two men pushed carts through the woods.",
    "A man is riding a white horse on an enclosed ground.",
    "A monkey is playing drums.",
    "A cheetah is running behind its prey.",
]

# We rank all sentences in the corpus for the query
ranks = model.rank(query, corpus)

# Print the scores
print("Query:", query)
for rank in ranks:
    print(f"{rank['score']:.2f}\t{corpus[rank['corpus_id']]}")
Query: A man is eating pasta.
0.67    A man is eating food.
0.34    A man is eating a piece of bread.
0.08    A man is riding a horse.
0.07    A man is riding a white horse on an enclosed ground.
0.01    The girl is carrying a baby.
0.01    Two men pushed carts through the woods.
0.01    A monkey is playing drums.
0.01    A woman is playing violin.
0.01    A cheetah is running behind its prey.

Previously, the lowest score document would be removed from the output.

All changes

  • [examples] Update model repo_id in 2dMatryoshka example by @tomaarsen in #2515
  • [feat] Add get_config_dict to new Matryoshka2dLoss & AdaptiveLayerLoss by @tomaarsen in #2516
  • [chore] Update to ruff 0.3.0; update ruff.toml by @tomaarsen in #2517
  • [example] Don't always normalize the embeddings in clustering example by @tomaarsen in #2520
  • Fix CrossEncoder.rank default value for top_k by @xenova in #2518

New Contributors

Full Changelog: v2.5.0...v2.5.1

v2.5.0 - 2D Matryoshka & Adaptive Layer models, CrossEncoder (re)ranking

29 Feb 14:07
Compare
Choose a tag to compare

This release brings two new loss functions, a new way to (re)rank with CrossEncoder models, and more fixes

Install this version with

pip install sentence-transformers==2.5.0

2D Matryoshka & Adaptive Layer models (#2506)

Embedding models are often encoder models with numerous layers, such as 12 (e.g. all-mpnet-base-v2) or 6 (e.g. all-MiniLM-L6-v2). To get embeddings, every single one of these layers must be traversed. 2D Matryoshka Sentence Embeddings (2DMSE) revisits this concept by proposing an approach to train embedding models that will perform well when only using a selection of all layers. This results in faster inference speeds at relatively low performance costs.

For example, using Sentence Transformers, you can train an Adaptive Layer model that can be sped up by 2x at a 15% reduction in performance, or 5x on GPU & 10x on CPU for a 20% reduction in performance. The 2DMSE paper highlights scenarios where this is superior to using a smaller model.

Training

Training with Adaptive Layer support is quite elementary: rather than applying some loss function on only the last layer, we also apply that same loss function on the pooled embeddings from previous layers. Additionally, we employ a KL-divergence loss that aims to make the embeddings of the non-last layers match that of the last layer. This can be seen as a fascinating approach of knowledge distillation, but with the last layer as the teacher model and the prior layers as the student models.

For example, with the 12-layer microsoft/mpnet-base, it will now be trained such that the model produces meaningful embeddings after each of the 12 layers.

from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import CoSENTLoss, AdaptiveLayerLoss

model = SentenceTransformer("microsoft/mpnet-base")

base_loss = CoSENTLoss(model=model)
loss = AdaptiveLayerLoss(model=model, loss=base_loss)

Additionally, this can be combined with the MatryoshkaLoss such that the resulting model can be reduced both in the number of layers, but also in the size of the output dimensions. See also the Matryoshka Embeddings for more information on reducing output dimensions. In Sentence Transformers, the combination of these two losses is called Matryoshka2dLoss, and a shorthand is provided for simpler training.

from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import CoSENTLoss, Matryoshka2dLoss

model = SentenceTransformer("microsoft/mpnet-base")

base_loss = CoSENTLoss(model=model)
loss = Matryoshka2dLoss(model=model, loss=base_loss, matryoshka_dims=[768, 512, 256, 128, 64])
Performance Results

Results

Let's look at the performance that we may be able to expect from an Adaptive Layer embedding model versus a regular embedding model. For this experiment, I have trained two models:

Both of these models were trained on the AllNLI dataset, which is a concatenation of the SNLI and MultiNLI datasets. I have evaluated these models on the STSBenchmark test set using multiple different embedding dimensions. The results are plotted in the following figure:

adaptive_layer_results

The first figure shows that the Adaptive Layer model stays much more performant when reducing the number of layers in the model. This is also clearly shown in the second figure, which displays that 80% of the performance is preserved when the number of layers is reduced all the way to 1.

Lastly, the third figure shows the expected speedup ratio for GPU & CPU devices in my tests. As you can see, removing half of the layers results in roughly a 2x speedup, at a cost of ~15% performance on STSB (~86 -> ~75 Spearman correlation). When removing even more layers, the performance benefit gets larger for CPUs, and between 5x and 10x speedups are very feasible with a 20% loss in performance.

Inference

Inference

After a model has been trained using the Adaptive Layer loss, you can then truncate the model layers to your desired layer count. Note that this requires doing a bit of surgery on the model itself, and each model is structured a bit differently, so the steps are slightly different depending on the model.

First of all, we will load the model & access the underlying transformers model like so:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("tomaarsen/mpnet-base-nli-adaptive-layer")

# We can access the underlying model with `model[0].auto_model`
print(model[0].auto_model)
MPNetModel(
  (embeddings): MPNetEmbeddings(
    (word_embeddings): Embedding(30527, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): MPNetEncoder(
    (layer): ModuleList(
      (0-11): 12 x MPNetLayer(
        (attention): MPNetAttention(
          (attn): MPNetSelfAttention(
            (q): Linear(in_features=768, out_features=768, bias=True)
            (k): Linear(in_features=768, out_features=768, bias=True)
            (v): Linear(in_features=768, out_features=768, bias=True)
            (o): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (intermediate): MPNetIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): MPNetOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (relative_attention_bias): Embedding(32, 12)
  )
  (pooler): MPNetPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)

This output will differ depending on the model. We will look for the repeated layers in the encoder. For this MPNet model, this is stored under model[0].auto_model.encoder.layer. Then we can slice the model to only keep the first few layers to speed up the model:

new_num_layers = 3
model[0].auto_model.encoder.layer = model[0].auto_model.encoder.layer[:new_num_layers]

Then we can run inference with it using SentenceTransformers.encode.

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

model = SentenceTransformer("tomaarsen/mpnet-base-nli-adaptive-layer")
new_num_layers = 3
model[0].auto_model.encoder.layer = model[0].auto_model.encoder.layer[:new_num_layers]

embeddings = model.encode(
    [
        "The weather is so nice!",
        "It's so sunny outside!",
        "He drove to the stadium.",
    ]
)
# Similarity of the first sentence with the other two
similarities = cos_sim(embeddings[0], embeddings[1:])
# => tensor([[0.7761, 0.1655]])
# compared to tensor([[ 0.7547, -0.0162]]) for the full model

As you can see, the similarity between the related sentences is much higher than the unrelated sentence, despite only using 3 layers. Feel free to copy this script locally, modify the new_num_layers, and observe the difference in similarities.

Extra information:

Example training scripts:

Read more

v2.4.0 - Matryoshka models, SOTA loss functions, prompt templates, INSTRUCTOR support

23 Feb 13:12
Compare
Choose a tag to compare

This release introduces numerous notable features that are well worth learning about!

Install this version with

pip install sentence-transformers==2.4.0

MatryoshkaLoss (#2485)

Dense embedding models typically produce embeddings with a fixed size, such as 768 or 1024. All further computations (clustering, classification, semantic search, retrieval, reranking, etc.) must then be done on these full embeddings. Matryoshka Representation Learning revisits this idea, and proposes a solution to train embedding models whose embeddings are still useful after truncation to much smaller sizes. This allows for considerably faster (bulk) processing.

Training

Training using Matryoshka Representation Learning (MRL) is quite elementary: rather than applying some loss function on only the full-size embeddings, we also apply that same loss function on truncated portions of the embeddings. For example, if a model has an embedding dimension of 768 by default, it can now be trained on 768, 512, 256, 128, 64 and 32. Each of these losses will be added together, optionally with some weight:

from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import CoSENTLoss, MatryoshkaLoss

model = SentenceTransformer("microsoft/mpnet-base")

base_loss = CoSENTLoss(model=model)
loss = MatryoshkaLoss(model=model, loss=base_loss, matryoshka_dims=[768, 512, 256, 128, 64])
Inference

Inference

After a model has been trained using a Matryoshka loss, you can then run inference with it using SentenceTransformers.encode. You must then truncate the resulting embeddings, and it is recommended to renormalize the embeddings.

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch.nn.functional as F

model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)

matryoshka_dim = 64
embeddings = model.encode(
    [
        "search_query: What is TSNE?",
        "search_document: t-distributed stochastic neighbor embedding (t-SNE) is a statistical method for visualizing high-dimensional data by giving each datapoint a location in a two or three-dimensional map.",
        "search_document: Amelia Mary Earhart was an American aviation pioneer and writer.",
    ]
)
embeddings = embeddings[..., :matryoshka_dim]  # Shrink the embedding dimensions

similarities = cos_sim(embeddings[0], embeddings[1:])
# => tensor([[0.7839, 0.4933]])

As you can see, the similarity between the search query and the correct document is much higher than that of an unrelated document, despite the very small matryoshka dimension applied. Feel free to copy this script locally, modify the matryoshka_dim, and observe the difference in similarities.

Note: Despite the embeddings being smaller, training and inference of a Matryoshka model is not faster, not more memory-efficient, and not smaller. Only the processing and storage of the resulting embeddings will be faster and cheaper.

Extra information:

Example training scripts:

CoSENTLoss (#2454)

CoSENTLoss was introduced by Jianlin Su, 2022 as a drop-in replacement of CosineSimilarityLoss. Experiments have shown that it produces a stronger learning signal than CosineSimilarityLoss.

from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.readers import InputExample

model = SentenceTransformer('bert-base-uncased')
train_examples = [
    InputExample(texts=['My first sentence', 'My second sentence'], label=1.0),
    InputExample(texts=['My third sentence', 'Unrelated sentence'], label=0.3)
]

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CoSENTLoss(model=model)

You can update training_stsbenchmark.py by replacing CosineSimilarityLoss with CoSENTLoss & you can observe the improved performance.

AnglELoss (#2471)

AnglELoss was introduced in Li and Li, 2023. It is an adaptation of the CoSENTLoss, and also acts as a strong drop-in replacement of CosineSimilarityLoss. Compared to CoSENTLoss, AnglELoss uses a different similarity function which aims to avoid vanishing gradients.

Like with CoSENTLoss, you can use it just like CosineSimilarityLoss.

from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.readers import InputExample

model = SentenceTransformer('bert-base-uncased')
train_examples = [
    InputExample(texts=['My first sentence', 'My second sentence'], label=1.0),
    InputExample(texts=['My third sentence', 'Unrelated sentence'], label=0.3)
]

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.AnglELoss(model=model)

You can update training_stsbenchmark.py by replacing CosineSimilarityLoss with AnglELoss & you can observe the improved performance.

Prompt Templates (#2477)

Some models require using specific text prompts to achieve optimal performance. For example, with intfloat/multilingual-e5-large you should prefix all queries with query: and all passages with passage: . Another example is BAAI/bge-large-en-v1.5, which performs best for retrieval when the input texts are prefixed with Represent this sentence for searching relevant passages: .

Sentence Transformer models can now be initialized with prompts and default_prompt_name parameters:

  • prompts is an optional argument that accepts a dictionary of prompts with prompt names to prompt texts. The prompt will be prepended to the input text during inference. For example,
    model = SentenceTransformer(
        "intfloat/multilingual-e5-large",
        prompts={
            "classification": "Classify the following text: ",
            "retrieval": "Retrieve semantically similar text: ",
            "clustering": "Identify the topic or theme based on the text: ",
        },
    )
    # or
    model.prompts = {
        "classification": "Classify the following text: ",
        "retrieval": "Retrieve semantically similar text: ",
        "clustering": "Identify the topic or theme based on the text: ",
    }
  • default_prompt_name is an optional argument that determines the default prompt to be used. It has to correspond with a prompt name from prompts. If None, then no prompt is used by default. For example,
    model = SentenceTransformer(
        "intfloat/multilingual-e5-large",
        prompts={
            "classification": "Classify the following text: ",
            "retrieval": "Retrieve semantically similar text: ",
            "clustering": "Identify the topic or theme based on the text: ",
        },
        default_prompt_name="retrieval",
    )
    # or
    model.default_prompt_name="retrieval"

Both of these parameters can also be specified in the config_sentence_transformers.json file of a saved model. That way, you won't have to specify these options manually when loading. When you save a Sentence Transformer model, these options will be automatically saved as well.

During inference, prompts can be applied in a few different ways. All of these scenarios result in identical texts being embedded:

  1. Explicitly using the prompt option in SentenceTransformer.encode:
    embeddings = model.encode("How to bake a strawberry cake", prompt="Retrieve semantically similar text: ")
  2. Explicitly using the prompt_name option in SentenceTransformer.encode by relying on the prompts loaded from a) initialization or b) the model config.
    embeddings = model.encode("How to bake a strawberry cake", prompt_name="retrieval")
  3. If prompt nor prompt_name are specified in SentenceTransformer.encode, then the prompt specified by default_prompt_name will be applied. If it is None, then no prompt will be applied.
    embeddings = model.encode("How to bake a strawberry cake")

Instructor support (#2477)

Some INSTRUCTOR models, such as [hkunlp/instructor-large](ht...

Read more

v2.3.1 - Patch for local models with Normalize modules

30 Jan 19:43
Compare
Choose a tag to compare

This releases patches a niche bug when loading a Sentence Transformer model which:

  1. is local
  2. uses a Normalize module as specified in modules.json
  3. does not contain the directory specified in the model configuration

This only occurs when a model with Normalize is downloaded from the Hugging Face hub and then later used locally.
See #2458 and #2459 for more details.

Release highlights

Full Changelog: v2.3.0...v2.3.1

v2.3.0 - Bug fixes, improved model loading & Cached MNRL

29 Jan 08:32
Compare
Choose a tag to compare

This release focuses on various bug fixes & improvements to keep up with adjacent works like transformers and huggingface_hub. These are the key changes in the release:

Pushing models to the Hugging Face Hub (#2376)

Prior to Sentence Transformers v2.3.0, saving models to the Hugging Face Hub may have resulted in various errors depending on the versions of the dependencies. Sentence Transformers v2.3.0 introduces a refactor to save_to_hub to resolve these issues.

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
...
model.save_to_hub("tomaarsen/all-MiniLM-L6-v2-quora")
pytorch_model.bin: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 90.9M/90.9M [00:06<00:00, 13.7MB/s]
Upload 1 LFS files: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:07<00:00,  7.11s/it]

Model Loading

Efficient model loading (#2345)

Recently, transformers has shifted towards using safetensors files as their primary model file formats. Additionally, various other file formats are commonly used, such as PyTorch (pytorch_model.bin), Rust (rust_model.ot), Tensorflow (tf_model.h5) and ONNX (model.onnx).

Prior to Sentence Transformers v2.3.0, almost all files of a repository would be downloaded, even if theye are not strictly required. Since v2.3.0, only the strictly required files will be downloaded. For example, when loading sentence-transformers/all-MiniLM-L6-v2 which has its model weights in three formats (pytorch_model.bin, rust_model.ot, tf_model.h5), only pytorch_model.bin will be downloaded. Additionally, when downloading intfloat/multilingual-e5-small with two formats (model.safetensors, pytorch_model.bin), only model.safetensors will be downloaded.

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
Downloading modules.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 349/349 [00:00<?, ?B/s]
Downloading (…)ce_transformers.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 116/116 [00:00<?, ?B/s]
Downloading README.md: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 10.6k/10.6k [00:00<?, ?B/s]
Downloading (…)nce_bert_config.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 53.0/53.0 [00:00<?, ?B/s]
Downloading config.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 612/612 [00:00<?, ?B/s]
Downloading pytorch_model.bin: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 90.9M/90.9M [00:06<00:00, 15.0MB/s]
Downloading tokenizer_config.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 350/350 [00:00<?, ?B/s]
Downloading vocab.txt: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 232k/232k [00:00<00:00, 1.37MB/s]
Downloading tokenizer.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 466k/466k [00:00<00:00, 4.61MB/s]
Downloading (…)cial_tokens_map.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 112/112 [00:00<?, ?B/s]
Downloading 1_Pooling/config.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 190/190 [00:00<?, ?B/s]

Note

This release updates the default cache location from ~/.cache/torch/sentence_transformers to the default cache location of transformers, i.e. ~/.cache/huggingface. You can still specify custom cache locations via the SENTENCE_TRANSFORMERS_HOME environment variable or the cache_folder argument.
Additionally, by supporting newer versions of various dependencies (e.g. huggingface_hub), the cache format changed. A consequence is that the old cached models cannot be used in v2.3.0 onwards, and those models need to be redownloaded. Once redownloaded, an airgapped machine can load the model like normal despite having no internet access.

Loading custom models (#2398)

This release brings models with custom code to Sentence Transformers through trust_remote_code, such as jinaai/jina-embeddings-v2-base-en.

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

model = SentenceTransformer("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True)
embeddings = model.encode(['How is the weather today?', 'What is the current weather like today?'])

print(cos_sim(embeddings[0], embeddings[1]))
# => tensor([[0.9341]])

Loading specific revisions (#2419)

If an embedding model is ever updated, it would invalidate all of the embeddings that you have created with the prior version of that model. We promise to never update the weights of any sentence-transformers/... model, but we cannot offer this guarantee for models by the community.

That is why this version introduces a revision keyword, allowing you to specify exactly which revision or branch you'd like to load:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("BAAI/bge-small-en-v1.5", revision="982532469af0dff5df8e70b38075b0940e863662")
# or a branch:
model = SentenceTransformer("BAAI/bge-small-en-v1.5", revision="main")

Soft deprecation of use_auth_token, use token instead (#2376)

Following updates from transformers & huggingface_hub, Sentence Transformers now recommends that you use the token argument to provide your Hugging Face authentication token to download private models.

from sentence_transformers import SentenceTransformer

# new:
model = SentenceTransformer("tomaarsen/all-mpnet-base-v2", token="hf_...")
# old, still works, but throws a warning to upgrade to "token"
model = SentenceTransformer("tomaarsen/all-mpnet-base-v2", use_auth_token="hf_...")

Note

The recommended way to include your Hugging Face authentication token is to run huggingface-cli login & paste your User Access Token from your Hugging Face Settings. See these docs for more information. Then, you don't have to include the token argument at all; it'll be automatically read from your filesystem.

Device patch (#2351)

Prior to this release, SentenceTransformers.device would not always correspond to the device on which embeddings were computed, or on which a model gets trained. This release brings a few fixes:

  • SentenceTransformers.device now always corresponds to the device that the model is on, and on which it will do its computations.
  • Models are now immediately moved to their specified device, rather than lazily whenever the model is being used.
  • SentenceTransformers.to(...), SentenceTransformers.cpu(), SentenceTransformers.cuda(), etc. will now work as expected, rather than being ignored.

Cached Multiple Negatives Ranking Loss (CMNRL) (#1759)

MultipleNegativesRankingLoss (MNRL) is a powerful loss function that is commonly applied to train embedding models. It uses in-batch negative sampling to produce a large number of negative pairs, allowing the model to receive a training signal to push the embeddings of this pair apart. It is commonly shown that a larger batch size results in better performing models (Qu et al., 2021, Li et al., 2023), but a larger batch size requires more VRAM in practice.

To counteract that, @kwang2049 has implemented a slightly modified GradCache technique that is able to separate the batch computation into mini-batches without any reduction in training quality. This allows the common practitioner to train with competitive batch sizes, e.g. 65536!
The downside is that training with Cached MNRL (CMNRL) is roughly 2 to 2.4 times slower than using normal MNRL.

CachedMultipleNegativesRankingLoss is a drop-in replacement for MultipleNegativesRankingLoss, but with a new mini_batch_size argument. I recommend trying out CMNRL with a large batch size and a fairly small mini_batch_size - the larger mini batch size that will fit into memory.

from sentence_transformers import SentenceTransformer, losses, InputExample
from torch.utils.data import DataLoader

model = SentenceTransformer("distilbert-base-uncased")
train_examples = [
    InputExample(texts=['Anchor 1', 'Positive 1']),
    InputExample(texts=['Anchor 2', 'Positive 2']),
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=1024)  # Here we can try much larger batch sizes!
train_loss = losses.CachedMultipleNegativesRankingLoss(model=model, mini_batch_size = 32)

model.fit([(t...
Read more