From 1355fa10466fde4e1d3d30b0fed015065bd78f73 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 25 Nov 2024 10:12:16 +0100 Subject: [PATCH 1/4] Add support for parallel runtime --- neomodel/async_/core.py | 12 ++++++++++++ neomodel/async_/match.py | 36 +++++++++++++++++++++++++++++++---- neomodel/sync_/core.py | 9 +++++++++ neomodel/sync_/match.py | 28 ++++++++++++++++++++++++++- test/async_/test_match_api.py | 33 +++++++++++++++++++++++++++++++- test/sync_/test_match_api.py | 30 ++++++++++++++++++++++++++++- 6 files changed, 141 insertions(+), 7 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 5773da12..c569eadf 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -598,6 +598,18 @@ async def edition_is_enterprise(self) -> bool: edition = await self.database_edition return edition == "enterprise" + @ensure_connection + async def parallel_runtime_available(self) -> bool: + """Returns true if the database supports parallel runtime + + Returns: + bool: True if the database supports parallel runtime + """ + return ( + await self.version_is_higher_than("5.13") + and await self.edition_is_enterprise() + ) + async def change_neo4j_password(self, user, new_password): await self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'") diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index a3718d46..82eaf67b 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -1,6 +1,7 @@ import inspect import re import string +import warnings from dataclasses import dataclass from typing import Any, Dict, List from typing import Optional as TOptional @@ -396,6 +397,7 @@ def __init__( lookup: TOptional[str] = None, additional_return: TOptional[List[str]] = None, is_count: TOptional[bool] = False, + use_parallel_runtime: TOptional[bool] = False, ) -> None: self.match = match if match else [] self.optional_match = optional_match if optional_match else [] @@ -409,6 +411,7 @@ def __init__( self.lookup = lookup self.additional_return = additional_return if additional_return else [] self.is_count = is_count + self.use_parallel_runtime = use_parallel_runtime self.subgraph: Dict = {} @@ -432,6 +435,19 @@ async def build_ast(self) -> "AsyncQueryBuilder": self._ast.skip = self.node_set.skip if hasattr(self.node_set, "limit"): self._ast.limit = self.node_set.limit + if hasattr(self.node_set, "use_parallel_runtime"): + if ( + self.node_set.use_parallel_runtime + and not await adb.parallel_runtime_available() + ): + warnings.warn( + "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " + "Reverting to default runtime.", + UserWarning, + ) + self.node_set.use_parallel_runtime = False + else: + self._ast.use_parallel_runtime = self.node_set.use_parallel_runtime return self @@ -589,9 +605,11 @@ def build_traversal_from_path( } else: existing_rhs_name = subgraph[part][ - "rel_variable_name" - if relation.get("relation_filtering") - else "variable_name" + ( + "rel_variable_name" + if relation.get("relation_filtering") + else "variable_name" + ) ] if relation["include_in_return"] and not already_present: self._additional_return(rel_ident) @@ -812,6 +830,8 @@ def lookup_query_variable( def build_query(self) -> str: query: str = "" + if self._ast.use_parallel_runtime: + query += "CYPHER runtime=parallel " if self._ast.lookup: query += self._ast.lookup @@ -973,7 +993,9 @@ async def _execute(self, lazy: bool = False, dict_output: bool = False): ] query = self.build_query() results, prop_names = await adb.cypher_query( - query, self._query_params, resolve_objects=True + query, + self._query_params, + resolve_objects=True, ) if dict_output: for item in results: @@ -1236,6 +1258,8 @@ def __init__(self, source) -> None: self._subqueries: list[Tuple[str, list[str]]] = [] self._intermediate_transforms: list = [] + self.use_parallel_runtime = False + def __await__(self): return self.all().__await__() @@ -1564,6 +1588,10 @@ def intermediate_transform( ) return self + def parallel_runtime(self) -> "AsyncNodeSet": + self.use_parallel_runtime = True + return self + class AsyncTraversal(AsyncBaseSet): """ diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 6c72908a..2175fa02 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -596,6 +596,15 @@ def edition_is_enterprise(self) -> bool: edition = self.database_edition return edition == "enterprise" + @ensure_connection + def parallel_runtime_available(self) -> bool: + """Returns true if the database supports parallel runtime + + Returns: + bool: True if the database supports parallel runtime + """ + return self.version_is_higher_than("5.13") and self.edition_is_enterprise() + def change_neo4j_password(self, user, new_password): self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'") diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index cd9a7f43..c2c9539d 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1,6 +1,7 @@ import inspect import re import string +import warnings from dataclasses import dataclass from typing import Any, Dict, List from typing import Optional as TOptional @@ -396,6 +397,7 @@ def __init__( lookup: TOptional[str] = None, additional_return: TOptional[List[str]] = None, is_count: TOptional[bool] = False, + use_parallel_runtime: TOptional[bool] = False, ) -> None: self.match = match if match else [] self.optional_match = optional_match if optional_match else [] @@ -409,6 +411,7 @@ def __init__( self.lookup = lookup self.additional_return = additional_return if additional_return else [] self.is_count = is_count + self.use_parallel_runtime = use_parallel_runtime self.subgraph: Dict = {} @@ -432,6 +435,19 @@ def build_ast(self) -> "QueryBuilder": self._ast.skip = self.node_set.skip if hasattr(self.node_set, "limit"): self._ast.limit = self.node_set.limit + if hasattr(self.node_set, "use_parallel_runtime"): + if ( + self.node_set.use_parallel_runtime + and not db.parallel_runtime_available() + ): + warnings.warn( + "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " + "Reverting to default runtime.", + UserWarning, + ) + self.node_set.use_parallel_runtime = False + else: + self._ast.use_parallel_runtime = self.node_set.use_parallel_runtime return self @@ -814,6 +830,8 @@ def lookup_query_variable( def build_query(self) -> str: query: str = "" + if self._ast.use_parallel_runtime: + query += "CYPHER runtime=parallel " if self._ast.lookup: query += self._ast.lookup @@ -973,7 +991,9 @@ def _execute(self, lazy: bool = False, dict_output: bool = False): ] query = self.build_query() results, prop_names = db.cypher_query( - query, self._query_params, resolve_objects=True + query, + self._query_params, + resolve_objects=True, ) if dict_output: for item in results: @@ -1236,6 +1256,8 @@ def __init__(self, source) -> None: self._subqueries: list[Tuple[str, list[str]]] = [] self._intermediate_transforms: list = [] + self.use_parallel_runtime = False + def __await__(self): return self.all().__await__() @@ -1562,6 +1584,10 @@ def intermediate_transform( ) return self + def parallel_runtime(self) -> "NodeSet": + self.use_parallel_runtime = True + return self + class Traversal(BaseSet): """ diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index c83d826f..7df6f7d7 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -1,8 +1,9 @@ import re +import warnings from datetime import datetime from test._async_compat import mark_async_test -from pytest import raises +from pytest import raises, warns from neomodel import ( INCOMING, @@ -1113,3 +1114,33 @@ async def test_async_iterator(): # assert that generator runs loop above assert counter == n + + +@mark_async_test +async def test_parallel_runtime(): + await Coffee(name="Java", price=99).save() + + node_set = AsyncNodeSet(Coffee).parallel_runtime() + + assert node_set.use_parallel_runtime + + if ( + not await adb.version_is_higher_than("5.13") + or not await adb.edition_is_enterprise() + ): + assert not await adb.parallel_runtime_available() + with warns( + UserWarning, + match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", + ): + qb = await AsyncQueryBuilder(node_set).build_ast() + assert not qb._ast.use_parallel_runtime + assert not qb.build_query().startswith("CYPHER runtime=parallel") + else: + assert await adb.parallel_runtime_available() + qb = await AsyncQueryBuilder(node_set).build_ast() + assert qb._ast.use_parallel_runtime + assert qb.build_query().startswith("CYPHER runtime=parallel") + + results = [node async for node in qb._execute()] + assert len(results) == 1 diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 4a5684ea..2b148601 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -1,8 +1,9 @@ import re +import warnings from datetime import datetime from test._async_compat import mark_sync_test -from pytest import raises +from pytest import raises, warns from neomodel import ( INCOMING, @@ -1097,3 +1098,30 @@ def test_async_iterator(): # assert that generator runs loop above assert counter == n + + +@mark_sync_test +def test_parallel_runtime(): + Coffee(name="Java", price=99).save() + + node_set = NodeSet(Coffee).parallel_runtime() + + assert node_set.use_parallel_runtime + + if not db.version_is_higher_than("5.13") or not db.edition_is_enterprise(): + assert not db.parallel_runtime_available() + with warns( + UserWarning, + match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", + ): + qb = QueryBuilder(node_set).build_ast() + assert not qb._ast.use_parallel_runtime + assert not qb.build_query().startswith("CYPHER runtime=parallel") + else: + assert db.parallel_runtime_available() + qb = QueryBuilder(node_set).build_ast() + assert qb._ast.use_parallel_runtime + assert qb.build_query().startswith("CYPHER runtime=parallel") + + results = [node for node in qb._execute()] + assert len(results) == 1 From 990815afdabdf80a5d63c11502433325a7c46f2b Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 25 Nov 2024 15:43:37 +0100 Subject: [PATCH 2/4] Use context manager instead. Add doc --- doc/source/transactions.rst | 24 +++++++++-- neomodel/async_/core.py | 23 ++++++++++- neomodel/async_/match.py | 23 ----------- neomodel/sync_/core.py | 23 ++++++++++- neomodel/sync_/match.py | 23 ----------- pyproject.toml | 1 + requirements-dev.txt | 1 + test/async_/test_match_api.py | 76 ++++++++++++++++++++++++----------- test/sync_/test_match_api.py | 76 ++++++++++++++++++++++++----------- 9 files changed, 169 insertions(+), 101 deletions(-) diff --git a/doc/source/transactions.rst b/doc/source/transactions.rst index dfa97ee6..92f5b37e 100644 --- a/doc/source/transactions.rst +++ b/doc/source/transactions.rst @@ -51,7 +51,7 @@ Explicit Transactions Neomodel also supports `explicit transactions `_ that are pre-designated as either *read* or *write*. -This is vital when using neomodel over a `Neo4J causal cluster `_ because internally, queries will be rerouted to different servers depending on their designation. @@ -168,7 +168,7 @@ Impersonation *Neo4j Enterprise feature* -Impersonation (`see Neo4j driver documentation ``) +Impersonation (`see Neo4j driver documentation `_) can be enabled via a context manager:: from neomodel import db @@ -197,4 +197,22 @@ This can be mixed with other context manager like transactions:: @db.transaction() def func2(): - ... \ No newline at end of file + ... + + +Parallel runtime +---------------- + +As of version 5.13, Neo4j *Enterprise Edition* supports parallel runtime for read transactions. + +To use it, you can simply use the `parallel_read_transaction` context manager:: + + from neomodel import db + + with db.parallel_read_transaction: + # It works for both neomodel-generated and custom Cypher queries + parallel_count_1 = len(Coffee.nodes) + parallel_count_2 = db.cypher_query("MATCH (n:Coffee) RETURN count(n)") + +It is worth noting that the parallel runtime is only available for read transactions and that it is not enabled by default, because it is not always the fastest option. It is recommended to test it in your specific use case to see if it improves performance, and read the general considerations in the `Neo4j official documentation `_. + diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index c569eadf..e28895ba 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -104,6 +104,7 @@ def __init__(self): self._database_version = None self._database_edition = None self.impersonated_user = None + self._parallel_runtime = False async def set_connection(self, url: str = None, driver: AsyncDriver = None): """ @@ -239,6 +240,10 @@ def write_transaction(self): def read_transaction(self): return AsyncTransactionProxy(self, access_mode="READ") + @property + def parallel_read_transaction(self): + return AsyncTransactionProxy(self, access_mode="READ", parallel_runtime=True) + async def impersonate(self, user: str) -> "ImpersonationHandler": """All queries executed within this context manager will be executed as impersonated user @@ -454,7 +459,6 @@ async def cypher_query( :return: A tuple containing a list of results and a tuple of headers. """ - if self._active_transaction: # Use current session is a transaction is currently active results, meta = await self._run_cypher_query( @@ -493,6 +497,8 @@ async def _run_cypher_query( try: # Retrieve the data start = time.time() + if self._parallel_runtime: + query = "CYPHER runtime=parallel " + query response: AsyncResult = await session.run(query, params) results, meta = [list(r.values()) async for r in response], response.keys() end = time.time() @@ -1180,17 +1186,30 @@ async def install_all_labels(stdout=None): class AsyncTransactionProxy: bookmarks: Optional[Bookmarks] = None - def __init__(self, db: AsyncDatabase, access_mode=None): + def __init__( + self, db: AsyncDatabase, access_mode: str = None, parallel_runtime: bool = False + ): self.db = db self.access_mode = access_mode + self.parallel_runtime = parallel_runtime @ensure_connection async def __aenter__(self): + if self.parallel_runtime: + if not await self.db.parallel_runtime_available(): + warnings.warn( + "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " + "Reverting to default runtime.", + UserWarning, + ) + self.parallel_runtime = False + self.db._parallel_runtime = self.parallel_runtime await self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) self.bookmarks = None return self async def __aexit__(self, exc_type, exc_value, traceback): + self.db._parallel_runtime = False if exc_value: await self.db.rollback() diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 82eaf67b..99d08a16 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -397,7 +397,6 @@ def __init__( lookup: TOptional[str] = None, additional_return: TOptional[List[str]] = None, is_count: TOptional[bool] = False, - use_parallel_runtime: TOptional[bool] = False, ) -> None: self.match = match if match else [] self.optional_match = optional_match if optional_match else [] @@ -411,7 +410,6 @@ def __init__( self.lookup = lookup self.additional_return = additional_return if additional_return else [] self.is_count = is_count - self.use_parallel_runtime = use_parallel_runtime self.subgraph: Dict = {} @@ -435,19 +433,6 @@ async def build_ast(self) -> "AsyncQueryBuilder": self._ast.skip = self.node_set.skip if hasattr(self.node_set, "limit"): self._ast.limit = self.node_set.limit - if hasattr(self.node_set, "use_parallel_runtime"): - if ( - self.node_set.use_parallel_runtime - and not await adb.parallel_runtime_available() - ): - warnings.warn( - "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " - "Reverting to default runtime.", - UserWarning, - ) - self.node_set.use_parallel_runtime = False - else: - self._ast.use_parallel_runtime = self.node_set.use_parallel_runtime return self @@ -830,8 +815,6 @@ def lookup_query_variable( def build_query(self) -> str: query: str = "" - if self._ast.use_parallel_runtime: - query += "CYPHER runtime=parallel " if self._ast.lookup: query += self._ast.lookup @@ -1258,8 +1241,6 @@ def __init__(self, source) -> None: self._subqueries: list[Tuple[str, list[str]]] = [] self._intermediate_transforms: list = [] - self.use_parallel_runtime = False - def __await__(self): return self.all().__await__() @@ -1588,10 +1569,6 @@ def intermediate_transform( ) return self - def parallel_runtime(self) -> "AsyncNodeSet": - self.use_parallel_runtime = True - return self - class AsyncTraversal(AsyncBaseSet): """ diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 2175fa02..2b693908 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -104,6 +104,7 @@ def __init__(self): self._database_version = None self._database_edition = None self.impersonated_user = None + self._parallel_runtime = False def set_connection(self, url: str = None, driver: Driver = None): """ @@ -239,6 +240,10 @@ def write_transaction(self): def read_transaction(self): return TransactionProxy(self, access_mode="READ") + @property + def parallel_read_transaction(self): + return TransactionProxy(self, access_mode="READ", parallel_runtime=True) + def impersonate(self, user: str) -> "ImpersonationHandler": """All queries executed within this context manager will be executed as impersonated user @@ -452,7 +457,6 @@ def cypher_query( :return: A tuple containing a list of results and a tuple of headers. """ - if self._active_transaction: # Use current session is a transaction is currently active results, meta = self._run_cypher_query( @@ -491,6 +495,8 @@ def _run_cypher_query( try: # Retrieve the data start = time.time() + if self._parallel_runtime: + query = "CYPHER runtime=parallel " + query response: Result = session.run(query, params) results, meta = [list(r.values()) for r in response], response.keys() end = time.time() @@ -1171,17 +1177,30 @@ def install_all_labels(stdout=None): class TransactionProxy: bookmarks: Optional[Bookmarks] = None - def __init__(self, db: Database, access_mode=None): + def __init__( + self, db: Database, access_mode: str = None, parallel_runtime: bool = False + ): self.db = db self.access_mode = access_mode + self.parallel_runtime = parallel_runtime @ensure_connection def __enter__(self): + if self.parallel_runtime: + if not self.db.parallel_runtime_available(): + warnings.warn( + "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " + "Reverting to default runtime.", + UserWarning, + ) + self.parallel_runtime = False + self.db._parallel_runtime = self.parallel_runtime self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) self.bookmarks = None return self def __exit__(self, exc_type, exc_value, traceback): + self.db._parallel_runtime = False if exc_value: self.db.rollback() diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index c2c9539d..15a49cfb 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -397,7 +397,6 @@ def __init__( lookup: TOptional[str] = None, additional_return: TOptional[List[str]] = None, is_count: TOptional[bool] = False, - use_parallel_runtime: TOptional[bool] = False, ) -> None: self.match = match if match else [] self.optional_match = optional_match if optional_match else [] @@ -411,7 +410,6 @@ def __init__( self.lookup = lookup self.additional_return = additional_return if additional_return else [] self.is_count = is_count - self.use_parallel_runtime = use_parallel_runtime self.subgraph: Dict = {} @@ -435,19 +433,6 @@ def build_ast(self) -> "QueryBuilder": self._ast.skip = self.node_set.skip if hasattr(self.node_set, "limit"): self._ast.limit = self.node_set.limit - if hasattr(self.node_set, "use_parallel_runtime"): - if ( - self.node_set.use_parallel_runtime - and not db.parallel_runtime_available() - ): - warnings.warn( - "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " - "Reverting to default runtime.", - UserWarning, - ) - self.node_set.use_parallel_runtime = False - else: - self._ast.use_parallel_runtime = self.node_set.use_parallel_runtime return self @@ -830,8 +815,6 @@ def lookup_query_variable( def build_query(self) -> str: query: str = "" - if self._ast.use_parallel_runtime: - query += "CYPHER runtime=parallel " if self._ast.lookup: query += self._ast.lookup @@ -1256,8 +1239,6 @@ def __init__(self, source) -> None: self._subqueries: list[Tuple[str, list[str]]] = [] self._intermediate_transforms: list = [] - self.use_parallel_runtime = False - def __await__(self): return self.all().__await__() @@ -1584,10 +1565,6 @@ def intermediate_transform( ) return self - def parallel_runtime(self) -> "NodeSet": - self.use_parallel_runtime = True - return self - class Traversal(BaseSet): """ diff --git a/pyproject.toml b/pyproject.toml index d72c546b..99335e40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dev = [ "pytest>=7.1", "pytest-asyncio", "pytest-cov>=4.0", + "pytest-mock", "pre-commit", "black", "isort", diff --git a/requirements-dev.txt b/requirements-dev.txt index 446dd8c1..ad82ba50 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,6 +5,7 @@ unasync>=0.5.0 pytest>=7.1 pytest-asyncio>=0.19.0 pytest-cov>=4.0 +pytest-mock pre-commit black isort diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 7df6f7d7..77b4b2ab 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -1,9 +1,8 @@ import re -import warnings from datetime import datetime from test._async_compat import mark_async_test -from pytest import raises, warns +from pytest import raises, skip, warns from neomodel import ( INCOMING, @@ -32,7 +31,11 @@ RawCypher, RelationNameResolver, ) -from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined +from neomodel.exceptions import ( + FeatureNotSupported, + MultipleNodesReturned, + RelationshipClassNotDefined, +) class SupplierRel(AsyncStructuredRel): @@ -1116,31 +1119,56 @@ async def test_async_iterator(): assert counter == n -@mark_async_test -async def test_parallel_runtime(): - await Coffee(name="Java", price=99).save() - - node_set = AsyncNodeSet(Coffee).parallel_runtime() +def assert_last_query_startswith(mock_func, query) -> bool: + return mock_func.call_args_list[-1].args[0].startswith(query) - assert node_set.use_parallel_runtime +@mark_async_test +async def test_parallel_runtime(mocker): if ( not await adb.version_is_higher_than("5.13") or not await adb.edition_is_enterprise() ): - assert not await adb.parallel_runtime_available() - with warns( - UserWarning, - match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", - ): - qb = await AsyncQueryBuilder(node_set).build_ast() - assert not qb._ast.use_parallel_runtime - assert not qb.build_query().startswith("CYPHER runtime=parallel") - else: - assert await adb.parallel_runtime_available() - qb = await AsyncQueryBuilder(node_set).build_ast() - assert qb._ast.use_parallel_runtime - assert qb.build_query().startswith("CYPHER runtime=parallel") + skip("Only supported for Enterprise 5.13 and above.") + + assert await adb.parallel_runtime_available() + mock_transaction_run = mocker.patch("neo4j.AsyncTransaction.run") + + # Parallel should be applied to custom Cypher query + async with adb.parallel_read_transaction: + # Mock transaction.run to access executed query + # Assert query starts with CYPHER runtime=parallel + assert adb._parallel_runtime == True + await adb.cypher_query("MATCH (n:Coffee) RETURN n") + assert assert_last_query_startswith( + mock_transaction_run, "CYPHER runtime=parallel" + ) + # Test exiting the context sets the parallel_runtime to False + assert adb._parallel_runtime == False + + # Parallel should be applied to neomodel queries + async with adb.parallel_read_transaction: + await Coffee.nodes + assert len(mock_transaction_run.call_args_list) > 1 + assert assert_last_query_startswith( + mock_transaction_run, "CYPHER runtime=parallel" + ) - results = [node async for node in qb._execute()] - assert len(results) == 1 + +@mark_async_test +async def test_parallel_runtime_conflict(mocker): + if await adb.version_is_higher_than("5.13") and await adb.edition_is_enterprise(): + skip("Test for unavailable parallel runtime.") + + assert not await adb.parallel_runtime_available() + mock_transaction_run = mocker.patch("neo4j.AsyncTransaction.run") + with warns( + UserWarning, + match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", + ): + async with adb.parallel_read_transaction: + await Coffee.nodes + assert not adb._parallel_runtime + assert not assert_last_query_startswith( + mock_transaction_run, "CYPHER runtime=parallel" + ) diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 2b148601..16ffb532 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -1,9 +1,8 @@ import re -import warnings from datetime import datetime from test._async_compat import mark_sync_test -from pytest import raises, warns +from pytest import raises, skip, warns from neomodel import ( INCOMING, @@ -21,7 +20,11 @@ db, ) from neomodel._async_compat.util import Util -from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined +from neomodel.exceptions import ( + FeatureNotSupported, + MultipleNodesReturned, + RelationshipClassNotDefined, +) from neomodel.sync_.match import ( Collect, Last, @@ -1100,28 +1103,53 @@ def test_async_iterator(): assert counter == n -@mark_sync_test -def test_parallel_runtime(): - Coffee(name="Java", price=99).save() - - node_set = NodeSet(Coffee).parallel_runtime() +def assert_last_query_startswith(mock_func, query) -> bool: + return mock_func.call_args_list[-1].args[0].startswith(query) - assert node_set.use_parallel_runtime +@mark_sync_test +def test_parallel_runtime(mocker): if not db.version_is_higher_than("5.13") or not db.edition_is_enterprise(): - assert not db.parallel_runtime_available() - with warns( - UserWarning, - match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", - ): - qb = QueryBuilder(node_set).build_ast() - assert not qb._ast.use_parallel_runtime - assert not qb.build_query().startswith("CYPHER runtime=parallel") - else: - assert db.parallel_runtime_available() - qb = QueryBuilder(node_set).build_ast() - assert qb._ast.use_parallel_runtime - assert qb.build_query().startswith("CYPHER runtime=parallel") + skip("Only supported for Enterprise 5.13 and above.") + + assert db.parallel_runtime_available() + mock_transaction_run = mocker.patch("neo4j.Transaction.run") + + # Parallel should be applied to custom Cypher query + with db.parallel_read_transaction: + # Mock transaction.run to access executed query + # Assert query starts with CYPHER runtime=parallel + assert db._parallel_runtime == True + db.cypher_query("MATCH (n:Coffee) RETURN n") + assert assert_last_query_startswith( + mock_transaction_run, "CYPHER runtime=parallel" + ) + # Test exiting the context sets the parallel_runtime to False + assert db._parallel_runtime == False + + # Parallel should be applied to neomodel queries + with db.parallel_read_transaction: + Coffee.nodes + assert len(mock_transaction_run.call_args_list) > 1 + assert assert_last_query_startswith( + mock_transaction_run, "CYPHER runtime=parallel" + ) - results = [node for node in qb._execute()] - assert len(results) == 1 + +@mark_sync_test +def test_parallel_runtime_conflict(mocker): + if db.version_is_higher_than("5.13") and db.edition_is_enterprise(): + skip("Test for unavailable parallel runtime.") + + assert not db.parallel_runtime_available() + mock_transaction_run = mocker.patch("neo4j.Transaction.run") + with warns( + UserWarning, + match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", + ): + with db.parallel_read_transaction: + Coffee.nodes + assert not db._parallel_runtime + assert not assert_last_query_startswith( + mock_transaction_run, "CYPHER runtime=parallel" + ) From 53291f4c86f07518cbffec9970d25a017fc851fd Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 25 Nov 2024 16:05:41 +0100 Subject: [PATCH 3/4] Fix tests --- test/async_/test_match_api.py | 10 +++++----- test/sync_/test_match_api.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 77b4b2ab..2dff91c0 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -1132,13 +1132,13 @@ async def test_parallel_runtime(mocker): skip("Only supported for Enterprise 5.13 and above.") assert await adb.parallel_runtime_available() - mock_transaction_run = mocker.patch("neo4j.AsyncTransaction.run") # Parallel should be applied to custom Cypher query async with adb.parallel_read_transaction: # Mock transaction.run to access executed query # Assert query starts with CYPHER runtime=parallel assert adb._parallel_runtime == True + mock_transaction_run = mocker.patch("neo4j.AsyncTransaction.run") await adb.cypher_query("MATCH (n:Coffee) RETURN n") assert assert_last_query_startswith( mock_transaction_run, "CYPHER runtime=parallel" @@ -1148,10 +1148,10 @@ async def test_parallel_runtime(mocker): # Parallel should be applied to neomodel queries async with adb.parallel_read_transaction: - await Coffee.nodes - assert len(mock_transaction_run.call_args_list) > 1 + mock_transaction_run_2 = mocker.patch("neo4j.AsyncTransaction.run") + await Coffee.nodes.all() assert assert_last_query_startswith( - mock_transaction_run, "CYPHER runtime=parallel" + mock_transaction_run_2, "CYPHER runtime=parallel" ) @@ -1167,7 +1167,7 @@ async def test_parallel_runtime_conflict(mocker): match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", ): async with adb.parallel_read_transaction: - await Coffee.nodes + await Coffee.nodes.all() assert not adb._parallel_runtime assert not assert_last_query_startswith( mock_transaction_run, "CYPHER runtime=parallel" diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 16ffb532..4df51866 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -1113,13 +1113,13 @@ def test_parallel_runtime(mocker): skip("Only supported for Enterprise 5.13 and above.") assert db.parallel_runtime_available() - mock_transaction_run = mocker.patch("neo4j.Transaction.run") # Parallel should be applied to custom Cypher query with db.parallel_read_transaction: # Mock transaction.run to access executed query # Assert query starts with CYPHER runtime=parallel assert db._parallel_runtime == True + mock_transaction_run = mocker.patch("neo4j.Transaction.run") db.cypher_query("MATCH (n:Coffee) RETURN n") assert assert_last_query_startswith( mock_transaction_run, "CYPHER runtime=parallel" @@ -1129,10 +1129,10 @@ def test_parallel_runtime(mocker): # Parallel should be applied to neomodel queries with db.parallel_read_transaction: - Coffee.nodes - assert len(mock_transaction_run.call_args_list) > 1 + mock_transaction_run_2 = mocker.patch("neo4j.Transaction.run") + Coffee.nodes.all() assert assert_last_query_startswith( - mock_transaction_run, "CYPHER runtime=parallel" + mock_transaction_run_2, "CYPHER runtime=parallel" ) @@ -1148,7 +1148,7 @@ def test_parallel_runtime_conflict(mocker): match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13", ): with db.parallel_read_transaction: - Coffee.nodes + Coffee.nodes.all() assert not db._parallel_runtime assert not assert_last_query_startswith( mock_transaction_run, "CYPHER runtime=parallel" From 01bfb6e405bea125def5b7575da8f826e057c191 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 25 Nov 2024 16:10:24 +0100 Subject: [PATCH 4/4] Fixed leftover code smell --- neomodel/async_/core.py | 15 +++++++-------- neomodel/sync_/core.py | 15 +++++++-------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index e28895ba..bfa5b8b9 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -1195,14 +1195,13 @@ def __init__( @ensure_connection async def __aenter__(self): - if self.parallel_runtime: - if not await self.db.parallel_runtime_available(): - warnings.warn( - "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " - "Reverting to default runtime.", - UserWarning, - ) - self.parallel_runtime = False + if self.parallel_runtime and not await self.db.parallel_runtime_available(): + warnings.warn( + "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " + "Reverting to default runtime.", + UserWarning, + ) + self.parallel_runtime = False self.db._parallel_runtime = self.parallel_runtime await self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) self.bookmarks = None diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 2b693908..75b7a10e 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -1186,14 +1186,13 @@ def __init__( @ensure_connection def __enter__(self): - if self.parallel_runtime: - if not self.db.parallel_runtime_available(): - warnings.warn( - "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " - "Reverting to default runtime.", - UserWarning, - ) - self.parallel_runtime = False + if self.parallel_runtime and not self.db.parallel_runtime_available(): + warnings.warn( + "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " + "Reverting to default runtime.", + UserWarning, + ) + self.parallel_runtime = False self.db._parallel_runtime = self.parallel_runtime self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) self.bookmarks = None