generated from Tauffer-Consulting/domino_pieces_repository_template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
new file information extraction piece
- Loading branch information
Showing
3 changed files
with
187 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
{ | ||
"name": "FileInformationExtractionPiece", | ||
"description": "Extracts user-defined information from the input text.", | ||
"dependency": { | ||
"dockerfile": "Dockerfile_01" | ||
}, | ||
"tags": [ | ||
"text", | ||
"information extraction", | ||
"openai" | ||
], | ||
"style": { | ||
"node_label": "File Information Extraction", | ||
"icon_class_name": "fa-solid:align-right" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from pydantic import BaseModel, Field, ConfigDict | ||
from enum import Enum | ||
from typing import List | ||
from domino.models import OutputModifierModel, OutputModifierItemType | ||
from typing import Optional | ||
|
||
class LLMModelType(str, Enum): | ||
""" | ||
OpenAI model type | ||
""" | ||
gpt_3_5_turbo = "gpt-3.5-turbo-1106" | ||
gpt_4 = "gpt-4" | ||
|
||
|
||
class InputModel(BaseModel): | ||
""" | ||
InformationExtractionPiece Input model | ||
""" | ||
input_file_path: str = Field( | ||
description='Source text from where information should be extracted.', | ||
json_schema_extra={"from_upstream": "always"} | ||
) | ||
additional_information: Optional[str] = Field( | ||
default=None, | ||
description='Additional useful information to help with the extraction.', | ||
) | ||
openai_model: LLMModelType = Field( | ||
default=LLMModelType.gpt_3_5_turbo, | ||
description="OpenAI model name to use for information extraction.", | ||
) | ||
extract_items: List[OutputModifierModel] = Field( | ||
default=[ | ||
OutputModifierModel(name="name", type=OutputModifierItemType.string, description="Name of the person."), | ||
OutputModifierModel(name="age", type=OutputModifierItemType.integer, description="Age of the person."), | ||
], | ||
description='Information items to be extracted from source text.', | ||
json_schema_extra={"from_upstream": "never"} | ||
) | ||
|
||
|
||
class OutputModel(BaseModel): | ||
""" | ||
InformationExtractionPiece Output Model | ||
""" | ||
output_data: List[dict] = Field(description="Extracted information as JSON.") | ||
# output_file_path: str = Field(description="Extracted information as json file.") | ||
|
||
class SecretsModel(BaseModel): | ||
""" | ||
InformationExtractionPiece Secrets model | ||
""" | ||
OPENAI_API_KEY: str = Field(description="Your OpenAI API key.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from domino.base_piece import BasePiece | ||
from .models import InputModel, OutputModel, SecretsModel | ||
from openai import OpenAI | ||
import json | ||
from typing import Union, Optional | ||
import pickle | ||
|
||
|
||
class FileInformationExtractionPiece(BasePiece): | ||
|
||
def piece_function(self, input_data: InputModel, secrets_data: SecretsModel): | ||
# OpenAI settings | ||
if secrets_data.OPENAI_API_KEY is None: | ||
raise Exception("OPENAI_API_KEY not found in ENV vars. Please add it to the secrets section of the Piece.") | ||
|
||
self.client = OpenAI(api_key=secrets_data.OPENAI_API_KEY) | ||
|
||
with open(input_data.input_file_path, "rb") as file: | ||
file_content = pickle.load(file) | ||
|
||
result = [] | ||
|
||
if isinstance(file_content, dict): | ||
output_json = self.extract_json_infos(content=file_content, extract_items=input_data.extract_items, model=input_data.openai_model, additional_info=input_data.additional_information) | ||
result.append(output_json) | ||
elif isinstance(file_content, list): | ||
for i, item_content in enumerate(file_content): | ||
self.logger.info(f"Extract item i:{i}") | ||
output_json = self.extract_json_infos(content=item_content, extract_items=input_data.extract_items, model=input_data.openai_model,additional_info=input_data.additional_information) | ||
result.append(output_json) | ||
|
||
self.logger.info(result) | ||
|
||
# Return extracted information | ||
self.logger.info("Returning extracted information") | ||
|
||
self.format_display_result_table(input_data, result) | ||
return OutputModel(output_data=result) | ||
|
||
def extract_json_infos(self, content: dict, extract_items: dict, model: str, additional_info: Optional[str] = None): | ||
|
||
optional_additional_info = "" if not additional_info else f"You can use the following additional information to help filling the request items to be extracted. Additional information: {additional_info}." | ||
|
||
prompt = f"""Extract the following information from the text below as JSON. | ||
The output can be a simple json or a list of jsons but never a nested json. | ||
{optional_additional_info} | ||
Use the items to be extract as information to identify the right information to be extract: | ||
--- | ||
Input text: {content} | ||
Items to be extracted: | ||
{extract_items} | ||
""" | ||
|
||
response = self.client.chat.completions.create( | ||
response_format={ | ||
"type": "json_object" | ||
}, | ||
temperature=0, | ||
model=model, | ||
messages=[ | ||
{"role": "user", "content": prompt} | ||
], | ||
) | ||
if not response.choices: | ||
raise Exception("No response from OpenAI") | ||
|
||
if response.choices[0].message.content is None: | ||
raise Exception("No response from OpenAI") | ||
|
||
return json.loads(response.choices[0].message.content) | ||
|
||
def format_display_result_object(self, input_data: InputModel, result: dict): | ||
md_text = """## Extracted Information\n""" | ||
|
||
for item in input_data.extract_items: | ||
md_text += f"""### {item.name}:\n{result.get(item.name)}\n""" | ||
file_path = f"{self.results_path}/display_result.md" | ||
with open(file_path, "w") as f: | ||
f.write(md_text) | ||
self.display_result = { | ||
"file_type": "md", | ||
"file_path": file_path | ||
} | ||
|
||
def format_display_result_table(self, input_data: InputModel, result: Union[dict, list]): | ||
# Headers from extract_items | ||
headers = [item.name for item in input_data.extract_items] | ||
md_text = "## Extracted Information\n\n" | ||
|
||
if isinstance(result, list): | ||
# Generate table headers | ||
md_text += "| " + " | ".join(headers) + " |\n" | ||
md_text += "|---" * len(headers) + "|\n" | ||
|
||
# Populate table rows | ||
for res in result: | ||
row = [ | ||
str(res.get(item.name)) | ||
for item in input_data.extract_items | ||
] | ||
md_text += "| " + " | ".join(row) + " |\n" | ||
else: | ||
# Single object case | ||
md_text += "| " + " | ".join(headers) + " |\n" | ||
md_text += "|---" * len(headers) + "|\n" | ||
row = [ | ||
str(result.get(item.name)) | ||
for item in input_data.extract_items | ||
] | ||
md_text += "| " + " | ".join(row) + " |\n" | ||
|
||
file_path = f"{self.results_path}/display_result.md" | ||
with open(file_path, "w") as f: | ||
f.write(md_text) | ||
self.display_result = { | ||
"file_type": "md", | ||
"file_path": file_path | ||
} |