From 7c0673f22b57de07b8c2cb25491e730733ee59fd Mon Sep 17 00:00:00 2001 From: Tim Adams Date: Thu, 17 Oct 2024 12:52:28 +0200 Subject: [PATCH] Add function to directly retrieve embedding from data dictionary closes #32 --- datastew/process/parsing.py | 61 +++++++++++++++++++++++++++++++------ tests/test_parser.py | 5 +++ 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/datastew/process/parsing.py b/datastew/process/parsing.py index 4d7f4b4..576ed46 100644 --- a/datastew/process/parsing.py +++ b/datastew/process/parsing.py @@ -1,4 +1,6 @@ from abc import ABC +from typing import Dict +from embedding import EmbeddingModel, MPNetAdapter import pandas as pd import numpy as np @@ -52,16 +54,34 @@ def to_dataframe(self) -> pd.DataFrame: class DataDictionarySource(Source): - """ - Contains mapping of variable -> description - """ def __init__(self, file_path: str, variable_field: str, description_field: str): - self.file_path = file_path - self.variable_field = variable_field - self.description_field = description_field - - def to_dataframe(self) -> pd.DataFrame: + """ + Initialize the DataDictionarySource with the path to the data dictionary file + and the fields that represent the variables and their descriptions. + + :param file_path: Path to the data dictionary file. + :param variable_field: The column that contains the variable names. + :param description_field: The column that contains the variable descriptions. + """ + self.file_path: str = file_path + self.variable_field: str = variable_field + self.description_field: str = description_field + + def to_dataframe(self, dropna: bool = True) -> pd.DataFrame: + """ + Load the data dictionary file into a pandas DataFrame, select the variable and + description fields, and ensure they exist. Optionally remove rows with missing + variables or descriptions based on the 'dropna' parameter. + + :param dropna: If True, rows with missing 'variable' or 'description' values are + dropped. Defaults to True. + :return: A DataFrame containing two columns: + - 'variable': The variable names from the data dictionary. + - 'description': The descriptions corresponding to each variable. + :raises ValueError: If either the variable field or the description field is not + found in the data dictionary file. + """ df = super().to_dataframe() # sanity check if self.variable_field not in df.columns: @@ -70,9 +90,30 @@ def to_dataframe(self) -> pd.DataFrame: raise ValueError(f"Description field {self.description_field} not found in {self.file_path}") df = df[[self.variable_field, self.description_field]] df = df.rename(columns={self.variable_field: "variable", self.description_field: "description"}) - df.dropna(subset=["variable", "description"], inplace=True) + if dropna: + df.dropna(subset=["variable", "description"], inplace=True) return df - + + def get_embeddings(self, embedding_model: EmbeddingModel = MPNetAdapter) -> Dict[str, list]: + """ + Compute embedding vectors for each description in the data dictionary. The + resulting vectors are mapped to their respective variables and returned as a + dictionary. + + :param embedding_model: The embedding model used to compute embeddings for the descriptions. + Defaults to MPNetAdapter. + :return: A dictionary where each key is a variable name and the value is the + embedding vector for the corresponding description. + :rtype: Dict[str, list] + """ + # Compute vectors for all descriptions + df: pd.DataFrame = self.to_dataframe() + descriptions: list[str] = df["description"].tolist() + embeddings: list = embedding_model.encode(descriptions) + # variable identfy descriptins -> variable to embedding + variable_to_embedding: Dict[str, list] = dict(zip(df["variable"], embeddings)) + return variable_to_embedding + class EmbeddingSource: def __init__(self, source_path: str): diff --git a/tests/test_parser.py b/tests/test_parser.py index 9159e45..ea8ce58 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -39,3 +39,8 @@ def test_parse_data_dict_excel(self): mapping_table.add_descriptions(data_dictionary_source) mappings = mapping_table.get_mappings() self.assertEqual(11, len(mappings)) + + def test_get_embeddings(self): + vectors = self.data_dictionary_source.get_embeddings() + self.assertEqual(len(vectors), 11) + self.assertIn("Q_8", vectors)