diff --git a/examples/python/output_parser.py b/examples/python/output_parser.py index c07e7d0..07cbda4 100644 --- a/examples/python/output_parser.py +++ b/examples/python/output_parser.py @@ -34,9 +34,10 @@ 'name': 'middle_name', }, { - 'type': 'str', - 'description': 'The last name of the person', + 'type': 'literal', + 'description': 'The last name of the person, the value can be either of Vishnu or Satis', 'name': 'last_name', + 'values': ['Vishnu', 'Satis'], }, ], } diff --git a/flo_ai/parsers/flo_json_parser.py b/flo_ai/parsers/flo_json_parser.py index c1c11a7..0733bbd 100644 --- a/flo_ai/parsers/flo_json_parser.py +++ b/flo_ai/parsers/flo_json_parser.py @@ -1,6 +1,6 @@ import json from flo_ai.parsers.flo_parser import FloParser -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, Literal from pydantic import BaseModel, Field, create_model from flo_ai.error.flo_exception import FloException from langchain_core.output_parsers import PydanticOutputParser @@ -19,14 +19,33 @@ def __init__(self, parse_contract: ParseContract): super().__init__() def __create_contract_from_json(self) -> BaseModel: - type_mapping = {'str': str, 'int': int, 'bool': bool, 'float': float} - pydantic_fields = { - field['name']: ( - type_mapping[field['type']], + type_mapping = { + 'str': str, + 'int': int, + 'bool': bool, + 'float': float, + 'literal': Literal, + } + pydantic_fields = {} + for field in self.contract.fields: + field_type = field['type'] + if field_type == 'literal': + literal_values = field.get('values', []) + if not literal_values: + raise ValueError( + f"Field '{field['name']}' of type 'literal' must specify 'values'." + ) + field_type_annotation = Literal[tuple(literal_values)] + else: + field_type_annotation = type_mapping.get(field_type) + if field_type_annotation is None: + raise ValueError(f'Unsupported type: {field_type}') + + pydantic_fields[field['name']] = ( + field_type_annotation, Field(..., description=field['description']), ) - for field in self.contract.fields - } + DynamicModel = create_model(self.contract.name, **pydantic_fields) return DynamicModel