Skip to content

Commit

Permalink
new: make model name mandatory in inference structures
Browse files Browse the repository at this point in the history
  • Loading branch information
joein committed Nov 12, 2024
1 parent e15b21e commit cdc1f8d
Show file tree
Hide file tree
Showing 7 changed files with 302 additions and 320 deletions.
5 changes: 1 addition & 4 deletions qdrant_client/async_qdrant_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ def _resolve_query(
Optional[models.Query]: query as it was, models.Query(nearest=query) or None
Raises:
ValueError: if query is not of supported type or query is models.Document without `model` field
ValueError: if query is not of supported type
"""
if isinstance(query, get_args(types.Query)) or isinstance(query, grpc.Query):
return query
Expand All @@ -807,9 +807,6 @@ def _resolve_query(
)
return models.NearestQuery(nearest=query)
if isinstance(query, INFERENCE_OBJECT_TYPES):
model_name = query.model
if model_name is None:
raise ValueError(f"`model` field has to be set explicitly in the {type(query)}")
return models.NearestQuery(nearest=query)
if query is None:
return None
Expand Down
6 changes: 3 additions & 3 deletions qdrant_client/conversions/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,23 +1036,23 @@ def convert_multi_dense_vector(cls, model: grpc.MultiDenseVector) -> List[List[f
def convert_document(cls, model: grpc.Document) -> rest.Document:
return rest.Document(
text=model.text,
model=model.model if model.HasField("model") else None,
model=model.model,
options=grpc_to_payload(model.options),
)

@classmethod
def convert_image(cls, model: grpc.Image) -> rest.Image:
return rest.Image(
image=value_to_json(model.image),
model=model.model if model.HasField("model") else None,
model=model.model,
options=grpc_to_payload(model.options),
)

@classmethod
def convert_inference_object(cls, model: grpc.InferenceObject) -> rest.InferenceObject:
return rest.InferenceObject(
object=value_to_json(model.object),
model=model.model if model.HasField("model") else None,
model=model.model,
options=grpc_to_payload(model.options),
)

Expand Down
576 changes: 288 additions & 288 deletions qdrant_client/grpc/points_pb2.py

Large diffs are not rendered by default.

15 changes: 6 additions & 9 deletions qdrant_client/http/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,9 +607,8 @@ class Document(BaseModel, extra="forbid"):
"""

text: str = Field(..., description="Text of the document This field will be used as input for the embedding model")
model: Optional[str] = Field(
default=None,
description="Name of the model used to generate the vector List of available models depends on a provider",
model: str = Field(
..., description="Name of the model used to generate the vector List of available models depends on a provider"
)
options: Optional[Dict[str, Any]] = Field(
default=None, description="Parameters for the model Values of the parameters are model-specific"
Expand Down Expand Up @@ -873,9 +872,8 @@ class Image(BaseModel, extra="forbid"):
"""

image: Any = Field(..., description="Image data: base64 encoded image or an URL")
model: Optional[str] = Field(
default=None,
description="Name of the model used to generate the vector List of available models depends on a provider",
model: str = Field(
..., description="Name of the model used to generate the vector List of available models depends on a provider"
)
options: Optional[Dict[str, Any]] = Field(
default=None, description="Parameters for the model Values of the parameters are model-specific"
Expand Down Expand Up @@ -921,9 +919,8 @@ class InferenceObject(BaseModel, extra="forbid"):
...,
description="Arbitrary data, used as input for the embedding model Used if the model requires more than one input or a custom input",
)
model: Optional[str] = Field(
default=None,
description="Name of the model used to generate the vector List of available models depends on a provider",
model: str = Field(
..., description="Name of the model used to generate the vector List of available models depends on a provider"
)
options: Optional[Dict[str, Any]] = Field(
default=None, description="Parameters for the model Values of the parameters are model-specific"
Expand Down
6 changes: 3 additions & 3 deletions qdrant_client/proto/points.proto
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ message SparseIndices {

message Document {
string text = 1; // Text of the document
optional string model = 3; // Model name
string model = 3; // Model name
map<string, Value> options = 4; // Model options
}

message Image {
Value image = 1; // Image data, either base64 encoded or URL
optional string model = 2; // Model name
string model = 2; // Model name
map<string, Value> options = 3; // Model options
}

message InferenceObject {
Value object = 1; // Object to infer
optional string model = 2; // Model name
string model = 2; // Model name
map<string, Value> options = 3; // Model options
}

Expand Down
5 changes: 1 addition & 4 deletions qdrant_client/qdrant_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ def _resolve_query(
Optional[models.Query]: query as it was, models.Query(nearest=query) or None
Raises:
ValueError: if query is not of supported type or query is models.Document without `model` field
ValueError: if query is not of supported type
"""
if isinstance(query, get_args(types.Query)) or isinstance(query, grpc.Query):
return query
Expand All @@ -892,9 +892,6 @@ def _resolve_query(
return models.NearestQuery(nearest=query)

if isinstance(query, INFERENCE_OBJECT_TYPES):
model_name = query.model
if model_name is None:
raise ValueError(f"`model` field has to be set explicitly in the {type(query)}")
return models.NearestQuery(nearest=query)

if query is None:
Expand Down
9 changes: 0 additions & 9 deletions tests/conversions/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,21 +312,18 @@
text="random text", model="bert", options=payload_to_grpc({"a": 2, "b": [1, 2], "c": "useful"})
)
document_without_options = grpc.Document(text="random text", model="bert")
document_only_text = grpc.Document(text="random text")
image_with_options = grpc.Image(
image=json_to_value("path_to_image"),
model="resnet",
options=payload_to_grpc({"a": 2, "b": [1, 2], "c": "useful"}),
)
image_without_options = grpc.Image(image=json_to_value("path_to_image"), model="resnet")
image_only_image = grpc.Image(image=json_to_value("path_to_image"))
inference_object_with_options = grpc.InferenceObject(
object=json_to_value("path_to_image"),
model="bert",
options=payload_to_grpc({"a": 2, "b": [1, 2], "c": "useful"}),
)
inference_object_without_options = grpc.InferenceObject(object=json_to_value("text"), model="bert")
inference_object_only_object = grpc.InferenceObject(object=json_to_value("text"))
order_value_int = grpc.OrderValue(int=42)
order_value_float = grpc.OrderValue(float=42.0)
single_vector_output = grpc.VectorsOutput(vector=grpc.VectorOutput(data=[1.0, 2.0, 3.0, 4.0]))
Expand Down Expand Up @@ -1254,13 +1251,10 @@
vector_input_multi = grpc.VectorInput(multi_dense=multi_dense_vector)
vector_input_doc_with_options = grpc.VectorInput(document=document_with_options)
vector_input_doc_without_options = grpc.VectorInput(document=document_without_options)
vector_input_doc_only_text = grpc.VectorInput(document=document_only_text)
vector_input_image_with_options = grpc.VectorInput(image=image_with_options)
vector_input_image_without_options = grpc.VectorInput(image=image_without_options)
vector_input_image_only_image = grpc.VectorInput(image=image_only_image)
vector_input_inference_with_options = grpc.VectorInput(object=inference_object_with_options)
vector_input_inference_without_options = grpc.VectorInput(object=inference_object_without_options)
vector_input_inference_only_object = grpc.VectorInput(object=inference_object_only_object)

recommend_input = grpc.RecommendInput(
positive=[
Expand All @@ -1272,15 +1266,12 @@
positive=[
vector_input_doc_with_options,
vector_input_doc_without_options,
vector_input_doc_only_text,
],
negative=[
vector_input_image_with_options,
vector_input_image_without_options,
vector_input_image_only_image,
vector_input_inference_with_options,
vector_input_inference_without_options,
vector_input_inference_only_object,
],
)
recommend_input_strategy = grpc.RecommendInput(
Expand Down

0 comments on commit cdc1f8d

Please sign in to comment.