Skip to content

Commit

Permalink
ignore failed embeddings when adding embeddings to buster, add option…
Browse files Browse the repository at this point in the history
… to save them
  • Loading branch information
jerpint committed Nov 9, 2023
1 parent 008e2ba commit b65fa49
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
36 changes: 30 additions & 6 deletions buster/documents_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,21 @@
logging.basicConfig(level=logging.INFO)


def split_df_by_nans(df: pd.DataFrame, column: str) -> (pd.DataFrame, pd.DataFrame):
"""
Splits a DataFrame into two DataFrames, one with rows containing NaN and the other without NaNs.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
(pd.DataFrame, pd.DataFrame): A tuple of two DataFrames, the first with NaN rows, the second without NaNs.
"""
df_with_nans = df[df[column].isna()]
df_without_nans = df.dropna()
return df_without_nans, df_with_nans


@dataclass
class DocumentsManager(ABC):
def __init__(self, required_columns: Optional[list[str]] = None):
Expand All @@ -41,7 +56,7 @@ def _checkpoint_csv(self, df, csv_filename: str, csv_overwrite: bool = True):

if csv_overwrite:
df.to_csv(csv_filename)
logger.info(f"Saved DataFrame with embeddings to {csv_filename}")
logger.debug(f"Saved DataFrame to {csv_filename}")

else:
if os.path.exists(csv_filename):
Expand All @@ -52,14 +67,15 @@ def _checkpoint_csv(self, df, csv_filename: str, csv_overwrite: bool = True):
# will create the new file
append_df = df.copy()
append_df.to_csv(csv_filename)
logger.info(f"Appending DataFrame embeddings to {csv_filename}")
logger.debug(f"Appending DataFrame to {csv_filename}")

def add(
self,
df: pd.DataFrame,
num_workers: int = 16,
embedding_fn: callable = get_openai_embedding,
csv_filename: Optional[str] = None,
csv_embeddings_filename: Optional[str] = None,
csv_errors_filename: Optional[str] = None,
csv_overwrite: bool = True,
**add_kwargs,
):
Expand All @@ -77,7 +93,7 @@ def add(
embedding_fn (callable, optional): A function that computes embeddings for a given input string.
Default is 'get_embedding_openai' which uses the text-embedding-ada-002 model.
csv_filename: (str, optional) = Path to save a copy of the dataframe with computed embeddings for later use.
csv_embeddings_filename: (str, optional) = Path to save a copy of the dataframe with computed embeddings for later use.
csv_overwrite: (bool, optional) = If csv_filename is specified, whether to overwrite the file with a new file.
**add_kwargs: Additional keyword arguments to be passed to the '_add_documents' method.
Expand All @@ -91,8 +107,16 @@ def add(
if "embedding" not in df.columns:
df["embedding"] = compute_embeddings_parallelized(df, embedding_fn=embedding_fn, num_workers=num_workers)

if csv_filename is not None:
self._checkpoint_csv(df, csv_filename=csv_filename, csv_overwrite=csv_overwrite)
# errors with embeddings computation will be NaNs, so we filter them out and the user can recompute them later on...
df, df_errors = split_df_by_nans(df, column="embedding")

if len(df_errors) > 0:
logger.warning(f"{len(df_errors)} errors have occured during embedding generation.")
if csv_errors_filename is not None:
self._checkpoint_csv(df_errors, csv_filename=csv_errors_filename, csv_overwrite=csv_overwrite)

if csv_embeddings_filename is not None:
self._checkpoint_csv(df, csv_filename=csv_embeddings_filename, csv_overwrite=csv_overwrite)

self._add_documents(df, **add_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion buster/examples/generate_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def main(csv):
dm = DeepLakeDocumentsManager(vector_store_path="deeplake_store", overwrite=True, required_columns=REQUIRED_COLUMNS)

# Generate the embeddings for our documents and store them to the deeplake store
dm.add(df, csv_filename="embeddings.csv")
dm.add(df, csv_embeddings_filename="embeddings.csv", csv_errors_filename="missing_embeddings.csv")

# Save it to a zip file
dm.to_zip()
Expand Down

0 comments on commit b65fa49

Please sign in to comment.