Skip to content

Commit

Permalink
Alter image structure (#845)
Browse files Browse the repository at this point in the history
* new: assign image.image Any type in grpc

* new: update image structure, rename embed utils Path to FieldPath
  • Loading branch information
joein authored Nov 11, 2024
1 parent 6d29cc6 commit e15b21e
Show file tree
Hide file tree
Showing 15 changed files with 359 additions and 357 deletions.
20 changes: 10 additions & 10 deletions qdrant_client/async_qdrant_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
#
# ****** WARNING: THIS FILE IS AUTOGENERATED ******

import base64
import io
import uuid
import warnings
from itertools import tee
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, Set, get_args
from copy import deepcopy
from pathlib import Path
import numpy as np
from pydantic import BaseModel
from qdrant_client.async_client_base import AsyncQdrantBase
Expand All @@ -25,7 +24,7 @@
from qdrant_client.embed.embed_inspector import InspectorEmbed
from qdrant_client.embed.models import NumericVector, NumericVectorStruct
from qdrant_client.embed.schema_parser import ModelSchemaParser
from qdrant_client.embed.utils import Path
from qdrant_client.embed.utils import FieldPath
from qdrant_client.fastembed_common import QueryResponse
from qdrant_client.http import models
from qdrant_client.hybrid.fusion import reciprocal_rank_fusion
Expand Down Expand Up @@ -843,13 +842,13 @@ def _resolve_query_batch_request(
return [self._resolve_query_request(query) for query in requests]

def _embed_models(
self, model: BaseModel, paths: Optional[List[Path]] = None, is_query: bool = False
self, model: BaseModel, paths: Optional[List[FieldPath]] = None, is_query: bool = False
) -> Union[BaseModel, NumericVector]:
"""Embed model's fields requiring inference
Args:
model: Qdrant model containing fields to embed
paths: Path to fields to embed. E.g. [Path(current="recommend", tail=[Path(current="negative", tail=None)])]
model: Qdrant http model containing fields to embed
paths: Path to fields to embed. E.g. [FieldPath(current="recommend", tail=[FieldPath(current="negative", tail=None)])]
is_query: Flag to determine which embed method to use. Defaults to False.
Returns:
Expand Down Expand Up @@ -999,9 +998,10 @@ def _embed_image(self, image: models.Image) -> NumericVector:
embedding_model_inst = self._get_or_init_image_model(
model_name=model_name, **image.options or {}
)
image_data = base64.b64decode(image.image)
with io.BytesIO(image_data) as buffer:
with PilImage.open(buffer) as image:
embedding = list(embedding_model_inst.embed(images=[image]))[0].tolist()
if not isinstance(image.image, (str, Path, PilImage.Image)):
raise ValueError(
f"Unsupported image type: {type(image.image)}. Image: {image.image}"
)
embedding = list(embedding_model_inst.embed(images=[image.image]))[0].tolist()
return embedding
raise ValueError(f"{model_name} is not among supported models")
12 changes: 8 additions & 4 deletions qdrant_client/conversions/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,15 +1043,15 @@ def convert_document(cls, model: grpc.Document) -> rest.Document:
@classmethod
def convert_image(cls, model: grpc.Image) -> rest.Image:
return rest.Image(
image=model.image,
image=value_to_json(model.image),
model=model.model if model.HasField("model") else None,
options=grpc_to_payload(model.options),
)

@classmethod
def convert_inference_object(cls, model: grpc.InferenceObject) -> rest.InferenceObject:
return rest.InferenceObject(
object=model.object,
object=value_to_json(model.object),
model=model.model if model.HasField("model") else None,
options=grpc_to_payload(model.options),
)
Expand Down Expand Up @@ -2931,13 +2931,17 @@ def convert_document(cls, model: rest.Document) -> grpc.Document:
@classmethod
def convert_image(cls, model: rest.Image) -> grpc.Image:
return grpc.Image(
image=model.image, model=model.model, options=payload_to_grpc(model.options)
image=json_to_value(model.image),
model=model.model,
options=payload_to_grpc(model.options),
)

@classmethod
def convert_inference_object(cls, model: rest.InferenceObject) -> grpc.InferenceObject:
return grpc.InferenceObject(
object=model.object, model=model.model, options=payload_to_grpc(model.options)
object=json_to_value(model.object),
model=model.model,
options=payload_to_grpc(model.options),
)

@classmethod
Expand Down
12 changes: 6 additions & 6 deletions qdrant_client/embed/embed_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
from qdrant_client.embed.schema_parser import ModelSchemaParser

from qdrant_client.embed.utils import convert_paths, Path
from qdrant_client.embed.utils import convert_paths, FieldPath
from qdrant_client.http import models


Expand All @@ -21,14 +21,14 @@ class InspectorEmbed:
def __init__(self, parser: Optional[ModelSchemaParser] = None) -> None:
self.parser = ModelSchemaParser() if parser is None else parser

def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> List[Path]:
def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> List[FieldPath]:
"""Looks for all the paths to objects requiring inference in the received models
Args:
points: models to inspect
Returns:
list of Path objects
list of FieldPath objects
"""
paths = []
if isinstance(points, BaseModel):
Expand All @@ -45,7 +45,7 @@ def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> List[Path]:
return convert_paths(paths)

def _inspect_model(
self, mod: BaseModel, paths: Optional[List[Path]] = None, accum: Optional[str] = None
self, mod: BaseModel, paths: Optional[List[FieldPath]] = None, accum: Optional[str] = None
) -> List[str]:
"""Looks for all the paths to objects requiring inference in the received model
Expand All @@ -72,15 +72,15 @@ def _inspect_inner_models(
self,
original_model: BaseModel,
current_path: str,
tail: List[Path],
tail: List[FieldPath],
accum: Optional[str] = None,
) -> List[str]:
"""Looks for all the paths to objects requiring inference in the received model
Args:
original_model: model to inspect
current_path: the field to inspect on the current iteration
tail: list of Path objects to the fields possibly containing objects for inference
tail: list of FieldPath objects to the fields possibly containing objects for inference
accum: accumulator for the path. Path is a dot separated string of field names which we assemble recursively
Returns:
Expand Down
20 changes: 10 additions & 10 deletions qdrant_client/embed/schema_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel

from qdrant_client.embed.utils import Path, convert_paths
from qdrant_client.embed.utils import FieldPath, convert_paths


try:
Expand Down Expand Up @@ -39,20 +39,20 @@ class ModelSchemaParser:
{"Prefetch"}
_cache: cache of string paths for models containing objects for inference, e.g.:
{"Prefetch": ['prefetch.query', 'prefetch.query.context.negative', ...]}
path_cache: cache of Path objects for models containing objects for inference, e.g.:
path_cache: cache of FieldPath objects for models containing objects for inference, e.g.:
{
"Prefetch": [
Path(
FieldPath(
current="prefetch",
tail=[
Path(
FieldPath(
current="query",
tail=[
Path(
FieldPath(
current="recommend",
tail=[
Path(current="negative", tail=None),
Path(current="positive", tail=None),
FieldPath(current="negative", tail=None),
FieldPath(current="positive", tail=None),
],
),
...,
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(self) -> None:
self.name_recursive_ref_mapping: Dict[str, str] = {
k: v for k, v in NAME_RECURSIVE_REF_MAPPING.items()
}
self.path_cache: Dict[str, List[Path]] = {
self.path_cache: Dict[str, List[FieldPath]] = {
model: convert_paths(paths) for model, paths in self._cache.items()
}

Expand Down Expand Up @@ -238,10 +238,10 @@ def parse_model(self, model: Type[BaseModel]) -> None:
else:
self._excluded_recursive_refs.add(ref)

# convert str paths to Path objects which group path parts and reduce the time of the traversal
# convert str paths to FieldPath objects which group path parts and reduce the time of the traversal
self.path_cache = {model: convert_paths(paths) for model, paths in self._cache.items()}

def _persist(self, output_path: Union[Path, str] = CACHE_PATH) -> None:
def _persist(self, output_path: Union[FieldPath, str] = CACHE_PATH) -> None:
"""Persist the parser state to a file
Args:
Expand Down
6 changes: 3 additions & 3 deletions qdrant_client/embed/type_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
from qdrant_client.embed.schema_parser import ModelSchemaParser
from qdrant_client.embed.utils import Path
from qdrant_client.embed.utils import FieldPath
from qdrant_client.http import models


Expand Down Expand Up @@ -41,7 +41,7 @@ def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> bool:
return True
return False

def _inspect_model(self, model: BaseModel, paths: Optional[List[Path]] = None) -> bool:
def _inspect_model(self, model: BaseModel, paths: Optional[List[FieldPath]] = None) -> bool:
if isinstance(model, INFERENCE_OBJECT_TYPES):
return True

Expand All @@ -58,7 +58,7 @@ def _inspect_model(self, model: BaseModel, paths: Optional[List[Path]] = None) -
return False

def _inspect_inner_models(
self, original_model: BaseModel, current_path: str, tail: List[Path]
self, original_model: BaseModel, current_path: str, tail: List[FieldPath]
) -> bool:
def inspect_recursive(member: BaseModel) -> bool:
recursive_paths = []
Expand Down
18 changes: 9 additions & 9 deletions qdrant_client/embed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
from pydantic import BaseModel, Field


class Path(BaseModel):
class FieldPath(BaseModel):
current: str
tail: Optional[List["Path"]] = Field(default=None)
tail: Optional[List["FieldPath"]] = Field(default=None)

def as_str_list(self) -> List[str]:
"""
>>> Path(current='a', tail=[Path(current='b', tail=[Path(current='c'), Path(current='d')])]).as_str_list()
>>> FieldPath(current='a', tail=[FieldPath(current='b', tail=[FieldPath(current='c'), FieldPath(current='d')])]).as_str_list()
['a.b.c', 'a.b.d']
"""

# Recursive function to collect all paths
def collect_paths(path: Path, prefix: str = "") -> List[str]:
def collect_paths(path: FieldPath, prefix: str = "") -> List[str]:
current_path = prefix + path.current
if not path.tail:
return [current_path]
Expand All @@ -28,16 +28,16 @@ def collect_paths(path: Path, prefix: str = "") -> List[str]:
return collect_paths(self)


def convert_paths(paths: List[str]) -> List[Path]:
"""Convert string paths into Path objects
def convert_paths(paths: List[str]) -> List[FieldPath]:
"""Convert string paths into FieldPath objects
Paths which share the same root are grouped together.
Args:
paths: List[str]: List of str paths containing "." as separator
Returns:
List[Path]: List of Path objects
List[FieldPath]: List of FieldPath objects
"""
sorted_paths = sorted(paths)
prev_root = None
Expand All @@ -46,7 +46,7 @@ def convert_paths(paths: List[str]) -> List[Path]:
parts = path.split(".")
root = parts[0]
if root != prev_root:
converted_paths.append(Path(current=root))
converted_paths.append(FieldPath(current=root))
prev_root = root
current = converted_paths[-1]
for part in parts[1:]:
Expand All @@ -59,7 +59,7 @@ def convert_paths(paths: List[str]) -> List[Path]:
found = True
break
if not found:
new_tail = Path(current=part)
new_tail = FieldPath(current=part)
assert current.tail is not None
current.tail.append(new_tail)
current = new_tail
Expand Down
560 changes: 280 additions & 280 deletions qdrant_client/grpc/points_pb2.py

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions qdrant_client/http/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,12 @@ class Document(BaseModel, extra="forbid"):
"""
WARN: Work-in-progress, unimplemented Text document for embedding. Requires inference infrastructure, unimplemented.
"""
text: str = Field(..., description="Text document to be embedded by FastEmbed or Cloud inference server")
model: Optional[str] = Field(default=None, description="Model name to be used for embedding computation")

text: str = Field(..., description="Text of the document This field will be used as input for the embedding model")
model: Optional[str] = Field(
default=None,
description="Name of the model used to generate the vector List of available models depends on a provider",
)
options: Optional[Dict[str, Any]] = Field(
default=None, description="Parameters for the model Values of the parameters are model-specific"
)
Expand Down Expand Up @@ -868,7 +872,7 @@ class Image(BaseModel, extra="forbid"):
WARN: Work-in-progress, unimplemented Image object for embedding. Requires inference infrastructure, unimplemented.
"""

image: str = Field(..., description="Image data: base64 encoded image or an URL")
image: Any = Field(..., description="Image data: base64 encoded image or an URL")
model: Optional[str] = Field(
default=None,
description="Name of the model used to generate the vector List of available models depends on a provider",
Expand Down
2 changes: 1 addition & 1 deletion qdrant_client/proto/points.proto
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ message Document {
}

message Image {
string image = 1; // Image data, either base64 encoded or URL
Value image = 1; // Image data, either base64 encoded or URL
optional string model = 2; // Model name
map<string, Value> options = 3; // Model options
}
Expand Down
23 changes: 12 additions & 11 deletions qdrant_client/qdrant_fastembed.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import base64
import io
import uuid
import warnings
from itertools import tee
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, Set, get_args
from copy import deepcopy

from pathlib import Path

import numpy as np

Expand All @@ -18,7 +16,7 @@
from qdrant_client.embed.embed_inspector import InspectorEmbed
from qdrant_client.embed.models import NumericVector, NumericVectorStruct
from qdrant_client.embed.schema_parser import ModelSchemaParser
from qdrant_client.embed.utils import Path
from qdrant_client.embed.utils import FieldPath
from qdrant_client.fastembed_common import QueryResponse
from qdrant_client.http import models
from qdrant_client.hybrid.fusion import reciprocal_rank_fusion
Expand Down Expand Up @@ -933,14 +931,14 @@ def _resolve_query_batch_request(
def _embed_models(
self,
model: BaseModel,
paths: Optional[List[Path]] = None,
paths: Optional[List[FieldPath]] = None,
is_query: bool = False,
) -> Union[BaseModel, NumericVector]:
"""Embed model's fields requiring inference
Args:
model: Qdrant model containing fields to embed
paths: Path to fields to embed. E.g. [Path(current="recommend", tail=[Path(current="negative", tail=None)])]
model: Qdrant http model containing fields to embed
paths: Path to fields to embed. E.g. [FieldPath(current="recommend", tail=[FieldPath(current="negative", tail=None)])]
is_query: Flag to determine which embed method to use. Defaults to False.
Returns:
Expand Down Expand Up @@ -1097,10 +1095,13 @@ def _embed_image(self, image: models.Image) -> NumericVector:
embedding_model_inst = self._get_or_init_image_model(
model_name=model_name, **(image.options or {})
)
image_data = base64.b64decode(image.image)
with io.BytesIO(image_data) as buffer:
with PilImage.open(buffer) as image:
embedding = list(embedding_model_inst.embed(images=[image]))[0].tolist()
if not isinstance(image.image, (str, Path, PilImage.Image)): # type: ignore
# PilImage is None if PIL is not installed,
# but we'll fail earlier if it's not installed.
raise ValueError(
f"Unsupported image type: {type(image.image)}. Image: {image.image}"
)
embedding = list(embedding_model_inst.embed(images=[image.image]))[0].tolist()
return embedding

raise ValueError(f"{model_name} is not among supported models")
Loading

0 comments on commit e15b21e

Please sign in to comment.