Skip to content

Commit

Permalink
fix: fix model json schema compat, fix fields inspected for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
joein committed Nov 12, 2024
1 parent e15b21e commit e5956a1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
9 changes: 9 additions & 0 deletions qdrant_client/_pydantic_compat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

from typing import Any, Dict, Type, TypeVar

from pydantic import BaseModel
Expand Down Expand Up @@ -44,3 +46,10 @@ def model_fields_set(model: BaseModel) -> set:
return model.model_fields_set
else:
return model.__fields_set__


def model_json_schema(model: BaseModel, *args: Any, **kwargs: Any) -> dict[str, Any]:
if PYDANTIC_V2:
return model.model_json_schema(*args, **kwargs)
else:
return json.loads(model.schema_json(*args, **kwargs))
5 changes: 3 additions & 2 deletions qdrant_client/embed/schema_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Type, Dict, Union, Any, Set, Optional

from pydantic import BaseModel
from pydantic.json_schema import model_json_schema

from qdrant_client.embed.utils import FieldPath, convert_paths

Expand Down Expand Up @@ -67,7 +68,7 @@ class ModelSchemaParser:
"""

CACHE_PATH = "_inspection_cache.py"
INFERENCE_OBJECT_NAMES = {"Document", "Image"}
INFERENCE_OBJECT_NAMES = {"Document", "Image", "InferenceObject"}

def __init__(self) -> None:
self._defs: Dict[str, Union[Dict[str, Any], List[Dict[str, Any]]]] = deepcopy(DEFS) # type: ignore[arg-type]
Expand Down Expand Up @@ -223,7 +224,7 @@ def parse_model(self, model: Type[BaseModel]) -> None:
if model_name in self._cache:
return None

schema = model.model_json_schema()
schema = model_json_schema(model)
self._defs.update(schema.get("$defs", {}))

defs = self._replace_refs(schema)
Expand Down

0 comments on commit e5956a1

Please sign in to comment.