Skip to content

Commit

Permalink
Add function to directly retrieve embedding from data dictionary
Browse files Browse the repository at this point in the history
closes #32
  • Loading branch information
tiadams committed Oct 17, 2024
1 parent 2d6bcb3 commit 7c0673f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 10 deletions.
61 changes: 51 additions & 10 deletions datastew/process/parsing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from abc import ABC
from typing import Dict
from embedding import EmbeddingModel, MPNetAdapter

import pandas as pd
import numpy as np
Expand Down Expand Up @@ -52,16 +54,34 @@ def to_dataframe(self) -> pd.DataFrame:


class DataDictionarySource(Source):
"""
Contains mapping of variable -> description
"""

def __init__(self, file_path: str, variable_field: str, description_field: str):
self.file_path = file_path
self.variable_field = variable_field
self.description_field = description_field

def to_dataframe(self) -> pd.DataFrame:
"""
Initialize the DataDictionarySource with the path to the data dictionary file
and the fields that represent the variables and their descriptions.
:param file_path: Path to the data dictionary file.
:param variable_field: The column that contains the variable names.
:param description_field: The column that contains the variable descriptions.
"""
self.file_path: str = file_path
self.variable_field: str = variable_field
self.description_field: str = description_field

def to_dataframe(self, dropna: bool = True) -> pd.DataFrame:
"""
Load the data dictionary file into a pandas DataFrame, select the variable and
description fields, and ensure they exist. Optionally remove rows with missing
variables or descriptions based on the 'dropna' parameter.
:param dropna: If True, rows with missing 'variable' or 'description' values are
dropped. Defaults to True.
:return: A DataFrame containing two columns:
- 'variable': The variable names from the data dictionary.
- 'description': The descriptions corresponding to each variable.
:raises ValueError: If either the variable field or the description field is not
found in the data dictionary file.
"""
df = super().to_dataframe()
# sanity check
if self.variable_field not in df.columns:
Expand All @@ -70,9 +90,30 @@ def to_dataframe(self) -> pd.DataFrame:
raise ValueError(f"Description field {self.description_field} not found in {self.file_path}")
df = df[[self.variable_field, self.description_field]]
df = df.rename(columns={self.variable_field: "variable", self.description_field: "description"})
df.dropna(subset=["variable", "description"], inplace=True)
if dropna:
df.dropna(subset=["variable", "description"], inplace=True)
return df


def get_embeddings(self, embedding_model: EmbeddingModel = MPNetAdapter) -> Dict[str, list]:
"""
Compute embedding vectors for each description in the data dictionary. The
resulting vectors are mapped to their respective variables and returned as a
dictionary.
:param embedding_model: The embedding model used to compute embeddings for the descriptions.
Defaults to MPNetAdapter.
:return: A dictionary where each key is a variable name and the value is the
embedding vector for the corresponding description.
:rtype: Dict[str, list]
"""
# Compute vectors for all descriptions
df: pd.DataFrame = self.to_dataframe()
descriptions: list[str] = df["description"].tolist()
embeddings: list = embedding_model.encode(descriptions)
# variable identfy descriptins -> variable to embedding
variable_to_embedding: Dict[str, list] = dict(zip(df["variable"], embeddings))
return variable_to_embedding


class EmbeddingSource:
def __init__(self, source_path: str):
Expand Down
5 changes: 5 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@ def test_parse_data_dict_excel(self):
mapping_table.add_descriptions(data_dictionary_source)
mappings = mapping_table.get_mappings()
self.assertEqual(11, len(mappings))

def test_get_embeddings(self):
vectors = self.data_dictionary_source.get_embeddings()
self.assertEqual(len(vectors), 11)
self.assertIn("Q_8", vectors)

0 comments on commit 7c0673f

Please sign in to comment.