Skip to content

Commit

Permalink
Implemented the mechanism to use literals in output parser
Browse files Browse the repository at this point in the history
  • Loading branch information
vizsatiz committed Dec 16, 2024
1 parent d3c1873 commit 92964d4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
5 changes: 3 additions & 2 deletions examples/python/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@
'name': 'middle_name',
},
{
'type': 'str',
'description': 'The last name of the person',
'type': 'literal',
'description': 'The last name of the person, the value can be either of Vishnu or Satis',
'name': 'last_name',
'values': ['Vishnu', 'Satis'],
},
],
}
Expand Down
33 changes: 26 additions & 7 deletions flo_ai/parsers/flo_json_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from flo_ai.parsers.flo_parser import FloParser
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Literal
from pydantic import BaseModel, Field, create_model
from flo_ai.error.flo_exception import FloException
from langchain_core.output_parsers import PydanticOutputParser
Expand All @@ -19,14 +19,33 @@ def __init__(self, parse_contract: ParseContract):
super().__init__()

def __create_contract_from_json(self) -> BaseModel:
type_mapping = {'str': str, 'int': int, 'bool': bool, 'float': float}
pydantic_fields = {
field['name']: (
type_mapping[field['type']],
type_mapping = {
'str': str,
'int': int,
'bool': bool,
'float': float,
'literal': Literal,
}
pydantic_fields = {}
for field in self.contract.fields:
field_type = field['type']
if field_type == 'literal':
literal_values = field.get('values', [])
if not literal_values:
raise ValueError(
f"Field '{field['name']}' of type 'literal' must specify 'values'."
)
field_type_annotation = Literal[tuple(literal_values)]
else:
field_type_annotation = type_mapping.get(field_type)
if field_type_annotation is None:
raise ValueError(f'Unsupported type: {field_type}')

pydantic_fields[field['name']] = (
field_type_annotation,
Field(..., description=field['description']),
)
for field in self.contract.fields
}

DynamicModel = create_model(self.contract.name, **pydantic_fields)
return DynamicModel

Expand Down

0 comments on commit 92964d4

Please sign in to comment.