diff --git a/buster/documents_manager/base.py b/buster/documents_manager/base.py index 13cf19b..a1f7dd7 100644 --- a/buster/documents_manager/base.py +++ b/buster/documents_manager/base.py @@ -93,7 +93,8 @@ def add( embedding_fn (callable, optional): A function that computes embeddings for a given input string. Default is 'get_embedding_openai' which uses the text-embedding-ada-002 model. - csv_embeddings_filename: (str, optional) = Path to save a copy of the dataframe with computed embeddings for later use. + csv_embeddings_filename: (str, optional) = Path to save a copy of the dataframe with computed embeddings (for later use). + csv_errors_filename: (str, optional) = Path to save a copy of the dataframe for which embeddings failed to compute. User can decide how to handle it later. csv_overwrite: (bool, optional) = If csv_filename is specified, whether to overwrite the file with a new file. **add_kwargs: Additional keyword arguments to be passed to the '_add_documents' method. @@ -107,9 +108,11 @@ def add( if "embedding" not in df.columns: df["embedding"] = compute_embeddings_parallelized(df, embedding_fn=embedding_fn, num_workers=num_workers) - # errors with embeddings computation will be NaNs, so we filter them out and the user can recompute them later on... + # errors with embeddings computation will be NaNs df, df_errors = split_df_by_nans(df, column="embedding") + # If errors were detected, and the csv_errors_filename was specified, the errors will be added to that file + # This file will be the same as the original so a user can just reuse the script again on the new .csv file if len(df_errors) > 0: logger.warning(f"{len(df_errors)} errors have occured during embedding generation.") if csv_errors_filename is not None: