Skip to content

Commit

Permalink
Avoid storing trainer in ModelCardCallback
Browse files Browse the repository at this point in the history
and SentenceTransformerModelCardData

This prevents a proper cleanup
  • Loading branch information
tomaarsen committed Dec 23, 2024
1 parent cfb883c commit 80fb41e
Show file tree
Hide file tree
Showing 15 changed files with 60 additions and 60 deletions.
2 changes: 1 addition & 1 deletion sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,7 +1252,7 @@ def _create_model_card(

# If we loaded a Sentence Transformer model from the Hub, and no training was done, then
# we don't generate a new model card, but reuse the old one instead.
if self._model_card_text and self.model_card_data.trainer is None:
if self._model_card_text and "generated_from_trainer" not in self.model_card_data.tags:
model_card = self._model_card_text
if self.model_card_data.model_id:
# If the original model card was saved without a model_id, we replace the model_id with the new model_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __call__(
self.primary_metric = f"{self.similarity_fn_names[0]}_ap"

metrics = self.prefix_name_to_metrics(metrics, self.name)
self.store_metrics_in_model_card_data(model, metrics)
self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
return metrics

def compute_metrices(self, model: SentenceTransformer) -> dict[str, dict[str, float]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def __call__(
self.primary_metric = f"spearman_{self.similarity_fn_names[0]}"

metrics = self.prefix_name_to_metrics(metrics, self.name)
self.store_metrics_in_model_card_data(model, metrics)
self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
return metrics

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def __call__(
for k, value in values.items()
}
metrics = self.prefix_name_to_metrics(metrics, self.name)
self.store_metrics_in_model_card_data(model, metrics)
self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
return metrics

def compute_metrices(
Expand Down
2 changes: 1 addition & 1 deletion sentence_transformers/evaluation/LabelAccuracyEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,5 @@ def __call__(

metrics = {"accuracy": accuracy}
metrics = self.prefix_name_to_metrics(metrics, self.name)
self.store_metrics_in_model_card_data(model, metrics)
self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
return metrics
2 changes: 1 addition & 1 deletion sentence_transformers/evaluation/MSEEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __call__(self, model: SentenceTransformer, output_path: str = None, epoch=-1
# Return negative score as SentenceTransformers maximizes the performance
metrics = {"negative_mse": -mse}
metrics = self.prefix_name_to_metrics(metrics, self.name)
self.store_metrics_in_model_card_data(model, metrics)
self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
return metrics

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __call__(
# Return negative score as SentenceTransformers maximizes the performance
metrics = {"negative_mse": -np.mean(mse_scores).item()}
metrics = self.prefix_name_to_metrics(metrics, self.name)
self.store_metrics_in_model_card_data(model, metrics)
self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
return metrics

@property
Expand Down
2 changes: 1 addition & 1 deletion sentence_transformers/evaluation/NanoBEIREvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def __call__(

# TODO: Ensure this primary_metric works as expected, also with bolding the right thing in the model card
agg_results = self.prefix_name_to_metrics(agg_results, self.name)
self.store_metrics_in_model_card_data(model, agg_results)
self.store_metrics_in_model_card_data(model, agg_results, epoch, steps)

per_dataset_results.update(agg_results)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def __call__(
"threshold": threshold,
}
metrics = self.prefix_name_to_metrics(metrics, self.name)
self.store_metrics_in_model_card_data(model, metrics)
self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
return metrics

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion sentence_transformers/evaluation/RerankingEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __call__(
f"ndcg@{self.at_k}": mean_ndcg,
}
metrics = self.prefix_name_to_metrics(metrics, self.name)
self.store_metrics_in_model_card_data(model, metrics)
self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
return metrics

def compute_metrices(self, model):
Expand Down
6 changes: 4 additions & 2 deletions sentence_transformers/evaluation/SentenceEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ def prefix_name_to_metrics(self, metrics: dict[str, float], name: str) -> dict[s
self.primary_metric = name + "_" + self.primary_metric
return metrics

def store_metrics_in_model_card_data(self, model: SentenceTransformer, metrics: dict[str, Any]) -> None:
model.model_card_data.set_evaluation_metrics(self, metrics)
def store_metrics_in_model_card_data(
self, model: SentenceTransformer, metrics: dict[str, Any], epoch: int = 0, step: int = 0
) -> None:
model.model_card_data.set_evaluation_metrics(self, metrics, epoch, step)

@property
def description(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion sentence_transformers/evaluation/TranslationEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,5 +183,5 @@ def __call__(
"mean_accuracy": (acc_src2trg + acc_trg2src) / 2,
}
metrics = self.prefix_name_to_metrics(metrics, self.name)
self.store_metrics_in_model_card_data(model, metrics)
self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
return metrics
2 changes: 1 addition & 1 deletion sentence_transformers/evaluation/TripletEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,5 +261,5 @@ def __call__(
self.primary_metric = f"{self.similarity_fn_names[0]}_accuracy"

metrics = self.prefix_name_to_metrics(metrics, self.name)
self.store_metrics_in_model_card_data(model, metrics)
self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
return metrics
86 changes: 42 additions & 44 deletions sentence_transformers/model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from sentence_transformers.util import fullname, is_accelerate_available, is_datasets_available

if is_datasets_available():
from datasets import Dataset, DatasetDict, IterableDataset, Value
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, Value

logger = logging.getLogger(__name__)

Expand All @@ -43,47 +43,45 @@


class ModelCardCallback(TrainerCallback):
def __init__(self, trainer: SentenceTransformerTrainer, default_args_dict: dict[str, Any]) -> None:
def __init__(self, default_args_dict: dict[str, Any]) -> None:
super().__init__()
self.trainer = trainer
self.default_args_dict = default_args_dict

callbacks = [
callback
for callback in self.trainer.callback_handler.callbacks
if isinstance(callback, CodeCarbonCallback)
]
if callbacks:
trainer.model.model_card_data.code_carbon_callback = callbacks[0]

trainer.model.model_card_data.trainer = trainer
trainer.model.model_card_data.add_tags("generated_from_trainer")

def on_init_end(
self,
args: SentenceTransformerTrainingArguments,
state: TrainerState,
control: TrainerControl,
model: SentenceTransformer,
trainer: SentenceTransformerTrainer,
**kwargs,
) -> None:
from sentence_transformers.losses import AdaptiveLayerLoss, Matryoshka2dLoss, MatryoshkaLoss

model.model_card_data.add_tags("generated_from_trainer")

# Try to set the code carbon callback if it exists
callbacks = [
callback for callback in trainer.callback_handler.callbacks if isinstance(callback, CodeCarbonCallback)
]
if callbacks:
model.model_card_data.code_carbon_callback = callbacks[0]

# Try to infer the dataset "name", "id" and "revision" from the dataset cache files
if self.trainer.train_dataset:
if trainer.train_dataset:
model.model_card_data.train_datasets = model.model_card_data.extract_dataset_metadata(
self.trainer.train_dataset, model.model_card_data.train_datasets, "train"
trainer.train_dataset, model.model_card_data.train_datasets, trainer.loss, "train"
)

if self.trainer.eval_dataset:
if trainer.eval_dataset:
model.model_card_data.eval_datasets = model.model_card_data.extract_dataset_metadata(
self.trainer.eval_dataset, model.model_card_data.eval_datasets, "eval"
trainer.eval_dataset, model.model_card_data.eval_datasets, trainer.loss, "eval"
)

if isinstance(self.trainer.loss, dict):
losses = list(self.trainer.loss.values())
if isinstance(trainer.loss, dict):
losses = list(trainer.loss.values())
else:
losses = [self.trainer.loss]
losses = [trainer.loss]
# Some losses are known to use other losses internally, e.g. MatryoshkaLoss, AdaptiveLayerLoss and Matryoshka2dLoss
# So, verify for `loss` attributes in the losses
loss_idx = 0
Expand All @@ -99,6 +97,10 @@ def on_init_end(

model.model_card_data.set_losses(losses)

# Extract some meaningful examples from the evaluation or training dataset to showcase the performance
if not model.model_card_data.widget and (dataset := trainer.eval_dataset or trainer.train_dataset):
model.model_card_data.set_widget_examples(dataset)

def on_train_begin(
self,
args: SentenceTransformerTrainingArguments,
Expand All @@ -107,7 +109,6 @@ def on_train_begin(
model: SentenceTransformer,
**kwargs,
) -> None:
# model.model_card_data.hyperparameters = extract_hyperparameters_from_trainer(self.trainer)
ignore_keys = {
"output_dir",
"logging_dir",
Expand Down Expand Up @@ -304,7 +305,6 @@ class SentenceTransformerModelCardData(CardData):
code_carbon_callback: CodeCarbonCallback | None = field(default=None, init=False)
citations: dict[str, str] = field(default_factory=dict, init=False)
best_model_step: int | None = field(default=None, init=False)
trainer: SentenceTransformerTrainer | None = field(default=None, init=False, repr=False)
datasets: list[str] = field(default_factory=list, init=False, repr=False)

# Utility fields
Expand Down Expand Up @@ -335,7 +335,9 @@ def __post_init__(self) -> None:
)
self.model_id = None

def validate_datasets(self, dataset_list, infer_languages: bool = True) -> None:
def validate_datasets(
self, dataset_list: list[dict[str, Any]], infer_languages: bool = True
) -> list[dict[str, Any]]:
output_dataset_list = []
for dataset in dataset_list:
if "name" not in dataset:
Expand Down Expand Up @@ -403,7 +405,7 @@ def set_best_model_step(self, step: int) -> None:
self.best_model_step = step

def set_widget_examples(self, dataset: Dataset | DatasetDict) -> None:
if isinstance(dataset, IterableDataset):
if isinstance(dataset, (IterableDataset, IterableDatasetDict)):
# We can't set widget examples from an IterableDataset without losing data
return

Expand All @@ -417,6 +419,10 @@ def set_widget_examples(self, dataset: Dataset | DatasetDict) -> None:
for dataset_name, num_samples in tqdm(
dataset_names.items(), desc="Computing widget examples", unit="example", leave=False
):
if isinstance(dataset[dataset_name], IterableDataset):
# We can't set widget examples from an IterableDataset without losing data
continue

# Sample 1000 examples from the dataset, sort them by length, and pick the shortest examples as the core
# examples for the widget
columns = [
Expand Down Expand Up @@ -472,7 +478,9 @@ def set_widget_examples(self, dataset: Dataset | DatasetDict) -> None:
)
self.predict_example = sentences[:3]

def set_evaluation_metrics(self, evaluator: SentenceEvaluator, metrics: dict[str, Any]) -> None:
def set_evaluation_metrics(
self, evaluator: SentenceEvaluator, metrics: dict[str, Any], epoch: int = 0, step: int = 0
) -> None:
from sentence_transformers.evaluation import SequentialEvaluator

self.eval_results_dict[evaluator] = copy(metrics)
Expand All @@ -484,12 +492,6 @@ def set_evaluation_metrics(self, evaluator: SentenceEvaluator, metrics: dict[str
elif isinstance(primary_metrics, str):
primary_metrics = [primary_metrics]

if self.trainer is None:
step = 0
epoch = 0
else:
step = self.trainer.state.global_step
epoch = self.trainer.state.epoch
training_log_metrics = {key: value for key, value in metrics.items() if key in primary_metrics}

if self.training_logs and self.training_logs[-1]["Step"] == step:
Expand Down Expand Up @@ -681,8 +683,12 @@ def to_html_list(data: dict):
return dataset_info

def extract_dataset_metadata(
self, dataset: Dataset | DatasetDict, dataset_metadata, dataset_type: Literal["train", "eval"]
) -> dict[str, Any]:
self,
dataset: Dataset | DatasetDict,
dataset_metadata: list[dict[str, Any]],
loss: nn.Module | dict[str, nn.Module],
dataset_type: Literal["train", "eval"],
) -> list[dict[str, Any]]:
if dataset:
if dataset_metadata and (
(isinstance(dataset, DatasetDict) and len(dataset_metadata) != len(dataset))
Expand All @@ -702,14 +708,14 @@ def extract_dataset_metadata(
self.compute_dataset_metrics(
dataset_value,
dataset_info,
self.trainer.loss[dataset_name] if isinstance(self.trainer.loss, dict) else self.trainer.loss,
loss[dataset_name] if isinstance(loss, dict) else loss,
)
for dataset_name, dataset_value, dataset_info in zip(
dataset.keys(), dataset.values(), dataset_metadata
)
]
else:
dataset_metadata = [self.compute_dataset_metrics(dataset, dataset_metadata[0], self.trainer.loss)]
dataset_metadata = [self.compute_dataset_metrics(dataset, dataset_metadata[0], loss)]

# Try to get the number of training samples
if dataset_type == "train":
Expand Down Expand Up @@ -939,14 +945,6 @@ def get_codecarbon_data(self) -> dict[Literal["co2_eq_emissions"], dict[str, Any
return results

def to_dict(self) -> dict[str, Any]:
# Extract some meaningful examples from the evaluation or training dataset to showcase the performance
if (
not self.widget
and self.trainer is not None
and (dataset := self.trainer.eval_dataset or self.trainer.train_dataset)
):
self.set_widget_examples(dataset)

# Try to set the base model
if self.first_save and not self.base_model:
try:
Expand Down
4 changes: 2 additions & 2 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,9 @@ def add_model_card_callback(self, default_args_dict: dict[str, Any]) -> None:
This method can be overriden by subclassing the trainer to remove/customize this callback in custom uses cases
"""

model_card_callback = ModelCardCallback(self, default_args_dict)
model_card_callback = ModelCardCallback(default_args_dict)
self.add_callback(model_card_callback)
model_card_callback.on_init_end(self.args, self.state, self.control, self.model)
model_card_callback.on_init_end(self.args, self.state, self.control, self.model, trainer=self)

def call_model_init(self, trial=None) -> SentenceTransformer:
model = super().call_model_init(trial=trial)
Expand Down

0 comments on commit 80fb41e

Please sign in to comment.