Skip to content

Commit

Permalink
chore: revert formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tellet-q committed Nov 26, 2024
1 parent 279d6fb commit b6d8dd9
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 30 deletions.
2 changes: 1 addition & 1 deletion qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
super().__init__(parser=self._inference_inspector.parser, **kwargs)
self._init_options = {
key: value
for key, value in locals().items()
for (key, value) in locals().items()
if key not in ("self", "__class__", "kwargs")
}
self._init_options.update(deepcopy(kwargs))
Expand Down
13 changes: 7 additions & 6 deletions qdrant_client/async_qdrant_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _embed_documents(
parallel: Optional[int] = None,
) -> Iterable[tuple[str, list[float]]]:
embedding_model = self._get_or_init_model(model_name=embedding_model_name)
documents_a, documents_b = tee(documents, 2)
(documents_a, documents_b) = tee(documents, 2)
if embed_type == "passage":
vectors_iter = embedding_model.passage_embed(
documents_a, batch_size=batch_size, parallel=parallel
Expand Down Expand Up @@ -456,7 +456,7 @@ def _points_iterator(
yield models.PointStruct(id=idx, payload=payload, vector=point_vector)

def _validate_collection_info(self, collection_info: models.CollectionInfo) -> None:
embeddings_size, distance = self._get_model_params(model_name=self.embedding_model_name)
(embeddings_size, distance) = self._get_model_params(model_name=self.embedding_model_name)
vector_field_name = self.get_vector_field_name()
assert isinstance(
collection_info.config.params.vectors, dict
Expand Down Expand Up @@ -502,7 +502,7 @@ def get_fastembed_vector_params(
Configuration for `vectors_config` argument in `create_collection` method.
"""
vector_field_name = self.get_vector_field_name()
embeddings_size, distance = self._get_model_params(model_name=self.embedding_model_name)
(embeddings_size, distance) = self._get_model_params(model_name=self.embedding_model_name)
return {
vector_field_name: models.VectorParams(
size=embeddings_size,
Expand Down Expand Up @@ -687,7 +687,7 @@ async def query(
with_payload=True,
**kwargs,
)
dense_request_response, sparse_request_response = await self.search_batch(
(dense_request_response, sparse_request_response) = await self.search_batch(
collection_name=collection_name, requests=[dense_request, sparse_request]
)
return self._scored_points_to_query_responses(
Expand Down Expand Up @@ -764,7 +764,7 @@ async def query_batch(
sparse_responses = responses[len(query_texts) :]
responses = [
reciprocal_rank_fusion([dense_response, sparse_response], limit=limit)
for dense_response, sparse_response in zip(dense_responses, sparse_responses)
for (dense_response, sparse_response) in zip(dense_responses, sparse_responses)
]
return [self._scored_points_to_query_responses(response) for response in responses]

Expand Down Expand Up @@ -925,7 +925,8 @@ def _embed_raw_data(
return self._embed_image(data)
elif isinstance(data, dict):
return {
key: self._embed_raw_data(value, is_query=is_query) for key, value in data.items()
key: self._embed_raw_data(value, is_query=is_query)
for (key, value) in data.items()
}
elif isinstance(data, list):
if data and isinstance(data[0], float):
Expand Down
24 changes: 12 additions & 12 deletions qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
if url.startswith("localhost"):
url = f"//{url}"
parsed_url: Url = parse_url(url)
self._host, self._port = (parsed_url.host, parsed_url.port)
(self._host, self._port) = (parsed_url.host, parsed_url.port)
if parsed_url.scheme:
self._https = parsed_url.scheme == "https"
self._scheme = parsed_url.scheme
Expand Down Expand Up @@ -201,7 +201,7 @@ async def close(self, grpc_grace: Optional[float] = None, **kwargs: Any) -> None
@staticmethod
def _parse_url(url: str) -> tuple[Optional[str], str, Optional[int], Optional[str]]:
parse_result: Url = parse_url(url)
scheme, host, port, prefix = (
(scheme, host, port, prefix) = (
parse_result.scheme,
parse_result.host,
parse_result.port,
Expand Down Expand Up @@ -1711,7 +1711,7 @@ async def delete_vectors(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -1733,7 +1733,7 @@ async def delete_vectors(
assert grpc_result is not None, "Delete vectors returned None result"
return GrpcToRest.convert_update_result(grpc_result)
else:
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
return (
await self.openapi_client.points_api.delete_vectors(
collection_name=collection_name,
Expand Down Expand Up @@ -1928,7 +1928,7 @@ async def delete(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(
points_selector
)
shard_key_selector = shard_key_selector or opt_shard_key_selector
Expand Down Expand Up @@ -1977,7 +1977,7 @@ async def set_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -2000,7 +2000,7 @@ async def set_payload(
).result
)
else:
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
result: Optional[types.UpdateResult] = (
await self.openapi_client.points_api.set_payload(
collection_name=collection_name,
Expand Down Expand Up @@ -2029,7 +2029,7 @@ async def overwrite_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -2051,7 +2051,7 @@ async def overwrite_payload(
).result
)
else:
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
result: Optional[types.UpdateResult] = (
await self.openapi_client.points_api.overwrite_payload(
collection_name=collection_name,
Expand Down Expand Up @@ -2079,7 +2079,7 @@ async def delete_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -2101,7 +2101,7 @@ async def delete_payload(
).result
)
else:
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
result: Optional[types.UpdateResult] = (
await self.openapi_client.points_api.delete_payload(
collection_name=collection_name,
Expand All @@ -2125,7 +2125,7 @@ async def clear_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(
points_selector
)
shard_key_selector = shard_key_selector or opt_shard_key_selector
Expand Down
22 changes: 11 additions & 11 deletions qdrant_client/local/async_qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _save(self) -> None:
{
"collections": {
collection_name: to_dict(collection.config)
for collection_name, collection in self.collections.items()
for (collection_name, collection) in self.collections.items()
},
"aliases": self.aliases,
}
Expand Down Expand Up @@ -377,7 +377,7 @@ def _resolve_prefetch_input(
if prefetch.query is None:
return prefetch
prefetch = deepcopy(prefetch)
query, mentioned_ids = self._resolve_query_input(
(query, mentioned_ids) = self._resolve_query_input(
collection_name, prefetch.query, prefetch.using, prefetch.lookup_from
)
prefetch.query = query
Expand All @@ -403,7 +403,7 @@ async def query_points(
) -> types.QueryResponse:
collection = self._get_collection(collection_name)
if query is not None:
query, mentioned_ids = self._resolve_query_input(
(query, mentioned_ids) = self._resolve_query_input(
collection_name, query, using, lookup_from
)
query_filter = ignore_mentioned_ids_filter(query_filter, list(mentioned_ids))
Expand Down Expand Up @@ -476,7 +476,7 @@ async def query_points_groups(
) -> types.GroupsResult:
collection = self._get_collection(collection_name)
if query is not None:
query, mentioned_ids = self._resolve_query_input(
(query, mentioned_ids) = self._resolve_query_input(
collection_name, query, using, lookup_from
)
query_filter = ignore_mentioned_ids_filter(query_filter, list(mentioned_ids))
Expand Down Expand Up @@ -836,7 +836,7 @@ async def get_collection_aliases(
return types.CollectionsAliasesResponse(
aliases=[
rest_models.AliasDescription(alias_name=alias_name, collection_name=name)
for alias_name, name in self.aliases.items()
for (alias_name, name) in self.aliases.items()
if name == collection_name
]
)
Expand All @@ -847,7 +847,7 @@ async def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
return types.CollectionsAliasesResponse(
aliases=[
rest_models.AliasDescription(alias_name=alias_name, collection_name=name)
for alias_name, name in self.aliases.items()
for (alias_name, name) in self.aliases.items()
]
)

Expand All @@ -857,7 +857,7 @@ async def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
return types.CollectionsResponse(
collections=[
rest_models.CollectionDescription(name=name)
for name, _ in self.collections.items()
for (name, _) in self.collections.items()
]
)

Expand Down Expand Up @@ -898,7 +898,7 @@ async def delete_collection(self, collection_name: str, **kwargs: Any) -> bool:
del _collection
self.aliases = {
alias_name: name
for alias_name, name in self.aliases.items()
for (alias_name, name) in self.aliases.items()
if name != collection_name
}
collection_path = self._collection_path(collection_name)
Expand Down Expand Up @@ -939,12 +939,12 @@ async def create_collection(
self.collections[collection_name] = collection
if src_collection and from_collection_name:
batch_size = 100
records, next_offset = await self.scroll(
(records, next_offset) = await self.scroll(
from_collection_name, limit=2, with_vectors=True
)
self.upload_records(collection_name, records)
while next_offset is not None:
records, next_offset = await self.scroll(
(records, next_offset) = await self.scroll(
from_collection_name, offset=next_offset, limit=batch_size, with_vectors=True
)
self.upload_records(collection_name, records)
Expand Down Expand Up @@ -1020,7 +1020,7 @@ def uuid_generator() -> Generator[str, None, None]:
vector=(vector.tolist() if isinstance(vector, np.ndarray) else vector) or {},
payload=payload or {},
)
for point_id, vector, payload in zip(
for (point_id, vector, payload) in zip(
ids or uuid_generator(), iter(vectors), payload or itertools.cycle([{}])
)
]
Expand Down

0 comments on commit b6d8dd9

Please sign in to comment.