Skip to content

Commit

Permalink
ES-1867 | force_one_shard_attribute_value param (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna authored Jan 5, 2024
1 parent 5e93203 commit e94750e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
15 changes: 13 additions & 2 deletions arango/aql.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def execute(
fill_block_cache: Optional[bool] = None,
allow_dirty_read: bool = False,
allow_retry: bool = False,
force_one_shard_attribute_value: Optional[str] = None,
) -> Result[Cursor]:
"""Execute the query and return the result cursor.
Expand Down Expand Up @@ -373,6 +374,16 @@ def execute(
:param allow_retry: Make it possible to retry fetching the latest batch
from a cursor.
:type allow_retry: bool
:param force_one_shard_attribute_value: (Enterprise Only) Explicitly set
a shard key value that will be used during query snippet distribution
to limit the query to a specific server in the cluster. This query option
can be used in complex queries in case the query optimizer cannot
automatically detect that the query can be limited to only a single
server (e.g. in a disjoint smart graph case). If the option is set
incorrectly, i.e. to a wrong shard key value, then the query may be
shipped to a wrong DB server and may not return results
(i.e. empty result set). Use at your own risk.
:param force_one_shard_attribute_value: str | None
:return: Result cursor.
:rtype: arango.cursor.Cursor
:raise arango.exceptions.AQLQueryExecuteError: If execute fails.
Expand Down Expand Up @@ -418,10 +429,10 @@ def execute(
options["skipInaccessibleCollections"] = skip_inaccessible_cols
if max_runtime is not None:
options["maxRuntime"] = max_runtime

# New in 3.11
if allow_retry is not None:
options["allowRetry"] = allow_retry
if force_one_shard_attribute_value is not None:
options["forceOneShardAttributeValue"] = force_one_shard_attribute_value

if options:
data["options"] = options
Expand Down
32 changes: 31 additions & 1 deletion tests/test_aql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
AQLQueryTrackingSetError,
AQLQueryValidateError,
)
from tests.helpers import assert_raises, extract
from tests.helpers import assert_raises, extract, generate_col_name


def test_aql_attributes(db, username):
Expand Down Expand Up @@ -246,6 +246,36 @@ def test_aql_query_management(db_version, db, bad_db, col, docs):
assert err.value.error_code in {11, 1228}


def test_aql_query_force_one_shard_attribute_value(db, db_version, enterprise, cluster):
if db_version < version.parse("3.10") or not enterprise or not cluster:
return

name = generate_col_name()
col = db.create_collection(name, shard_fields=["foo"], shard_count=3)

doc = {"foo": "bar"}
col.insert(doc)

cursor = db.aql.execute(
"FOR d IN @@c RETURN d",
bind_vars={"@c": name},
force_one_shard_attribute_value="bar",
)

results = [doc for doc in cursor]
assert len(results) == 1
assert results[0]["foo"] == "bar"

cursor = db.aql.execute(
"FOR d IN @@c RETURN d",
bind_vars={"@c": name},
force_one_shard_attribute_value="ooo",
)

results = [doc for doc in cursor]
assert len(results) == 0


def test_aql_function_management(db, bad_db):
fn_group = "functions::temperature"
fn_name_1 = "functions::temperature::celsius_to_fahrenheit"
Expand Down

0 comments on commit e94750e

Please sign in to comment.