From 217d0c5d7e4a25f14ce064ae134cde73637a328c Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Tue, 8 Oct 2024 11:44:54 +0200 Subject: [PATCH] Added traversal support to filtering and ordering features --- neomodel/async_/match.py | 358 ++++++++++++++++++---------------- neomodel/properties.py | 13 +- neomodel/sync_/match.py | 358 ++++++++++++++++++---------------- test/async_/test_match_api.py | 50 ++++- test/sync_/test_match_api.py | 48 ++++- 5 files changed, 495 insertions(+), 332 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index c0919aca..fa7e0bf6 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -6,11 +6,12 @@ from typing import Optional as TOptional from typing import Tuple, Union +from neomodel.async_ import relationship_manager from neomodel.async_.core import AsyncStructuredNode, adb from neomodel.async_.relationship import AsyncStructuredRel from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase -from neomodel.properties import AliasProperty, ArrayProperty +from neomodel.properties import AliasProperty, ArrayProperty, Property from neomodel.util import INCOMING, OUTGOING @@ -197,6 +198,8 @@ def _rel_merge_helper( # add all regex operators OPERATOR_TABLE.update(_REGEX_OPERATOR_TABLE) +path_split_regex = re.compile(r"__(?!_)") + def install_traversals(cls, node_set): """ @@ -216,136 +219,111 @@ def install_traversals(cls, node_set): setattr(node_set, key, traversal) -def process_filter_args(cls, kwargs) -> Dict: - """ - loop through properties in filter parameters check they match class definition - deflate them and convert into something easy to generate cypher from - """ - - output = {} - - for key, value in kwargs.items(): - if "__" in key: - prop, operator = key.rsplit("__") - operator = OPERATOR_TABLE[operator] - else: - prop = key - operator = "=" - - if prop not in cls.defined_properties(rels=False): +def _handle_special_operators( + property_obj: Property, key: str, value: str, operator: str, prop: str +) -> Tuple[str, str, str]: + if operator == _SPECIAL_OPERATOR_IN: + if not isinstance(value, (list, tuple)): raise ValueError( - f"No such property {prop} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." + f"Value must be a tuple or list for IN operation {key}={value}" ) - - property_obj = getattr(cls, prop) - if isinstance(property_obj, AliasProperty): - prop = property_obj.aliased_to() - deflated_value = getattr(cls, prop).deflate(value) + if isinstance(property_obj, ArrayProperty): + deflated_value = property_obj.deflate(value) + operator = _SPECIAL_OPERATOR_ARRAY_IN else: - operator, deflated_value = transform_operator_to_filter( - operator=operator, - filter_key=key, - filter_value=value, - property_obj=property_obj, - ) + deflated_value = [property_obj.deflate(v) for v in value] + elif operator == _SPECIAL_OPERATOR_ISNULL: + if not isinstance(value, bool): + raise ValueError(f"Value must be a bool for isnull operation on {key}") + operator = "IS NULL" if value else "IS NOT NULL" + deflated_value = None + elif operator in _REGEX_OPERATOR_TABLE.values(): + deflated_value = property_obj.deflate(value) + if not isinstance(deflated_value, str): + raise ValueError(f"Must be a string value for {key}") + if operator in _STRING_REGEX_OPERATOR_TABLE.values(): + deflated_value = re.escape(deflated_value) + deflated_value = operator.format(deflated_value) + operator = _SPECIAL_OPERATOR_REGEX + else: + deflated_value = property_obj.deflate(value) - # map property to correct property name in the database - db_property = cls.defined_properties(rels=False)[prop].get_db_property_name( - prop + return deflated_value, operator, prop + + +def _deflate_value( + cls, property_obj: Property, key: str, value: str, operator: str, prop: str +) -> Tuple[str, str, str]: + if isinstance(property_obj, AliasProperty): + prop = property_obj.aliased_to() + deflated_value = getattr(cls, prop).deflate(value) + else: + # handle special operators + deflated_value, operator, prop = _handle_special_operators( + property_obj, key, value, operator, prop ) - output[db_property] = (operator, deflated_value) + return deflated_value, operator, prop - return output +def _initialize_filter_args_variables(cls, key: str): + current_class = cls + leaf_prop = None + operator = "=" + prop = key -def transform_in_operator_to_filter(operator, filter_key, filter_value, property_obj): - """ - Transform in operator to a cypher filter - Args: - operator (str): operator to transform - filter_key (str): filter key - filter_value (str): filter value - property_obj (object): property object - Returns: - tuple: operator, deflated_value - """ - if not isinstance(filter_value, tuple) and not isinstance(filter_value, list): - raise ValueError( - f"Value must be a tuple or list for IN operation {filter_key}={filter_value}" - ) - if isinstance(property_obj, ArrayProperty): - deflated_value = property_obj.deflate(filter_value) - operator = _SPECIAL_OPERATOR_ARRAY_IN - else: - deflated_value = [property_obj.deflate(v) for v in filter_value] + return current_class, leaf_prop, operator, prop - return operator, deflated_value +def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: + ( + current_class, + leaf_prop, + operator, + prop, + ) = _initialize_filter_args_variables(cls, key) -def transform_null_operator_to_filter(filter_key, filter_value): - """ - Transform null operator to a cypher filter - Args: - filter_key (str): filter key - filter_value (str): filter value - Returns: - tuple: operator, deflated_value - """ - if not isinstance(filter_value, bool): - raise ValueError(f"Value must be a bool for isnull operation on {filter_key}") - operator = "IS NULL" if filter_value else "IS NOT NULL" - deflated_value = None - return operator, deflated_value + for part in re.split(path_split_regex, key): + defined_props = current_class.defined_properties(rels=True) + if part in defined_props: + if isinstance( + defined_props[part], relationship_manager.AsyncRelationshipDefinition + ): + defined_props[part].lookup_node_class() + current_class = defined_props[part].definition["node_class"] + elif part in OPERATOR_TABLE: + operator = OPERATOR_TABLE[part] + prop, _ = prop.rsplit("__", 1) + continue + else: + raise ValueError( + f"No such property {part} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." + ) + leaf_prop = part + property_obj = getattr(current_class, leaf_prop) -def transform_regex_operator_to_filter( - operator, filter_key, filter_value, property_obj -): - """ - Transform regex operator to a cypher filter - Args: - operator (str): operator to transform - filter_key (str): filter key - filter_value (str): filter value - property_obj (object): property object - Returns: - tuple: operator, deflated_value - """ + return property_obj, operator, prop - deflated_value = property_obj.deflate(filter_value) - if not isinstance(deflated_value, str): - raise ValueError(f"Must be a string value for {filter_key}") - if operator in _STRING_REGEX_OPERATOR_TABLE.values(): - deflated_value = re.escape(deflated_value) - deflated_value = operator.format(deflated_value) - operator = _SPECIAL_OPERATOR_REGEX - return operator, deflated_value +def process_filter_args(cls, kwargs) -> Dict: + """ + loop through properties in filter parameters check they match class definition + deflate them and convert into something easy to generate cypher from + """ + output = {} -def transform_operator_to_filter(operator, filter_key, filter_value, property_obj): - if operator == _SPECIAL_OPERATOR_IN: - operator, deflated_value = transform_in_operator_to_filter( - operator=operator, - filter_key=filter_key, - filter_value=filter_value, - property_obj=property_obj, - ) - elif operator == _SPECIAL_OPERATOR_ISNULL: - operator, deflated_value = transform_null_operator_to_filter( - filter_key=filter_key, filter_value=filter_value - ) - elif operator in _REGEX_OPERATOR_TABLE.values(): - operator, deflated_value = transform_regex_operator_to_filter( - operator=operator, - filter_key=filter_key, - filter_value=filter_value, - property_obj=property_obj, + for key, value in kwargs.items(): + property_obj, operator, prop = _process_filter_key(cls, key) + deflated_value, operator, prop = _deflate_value( + cls, property_obj, key, value, operator, prop ) - else: - deflated_value = property_obj.deflate(filter_value) - return operator, deflated_value + # map property to correct property name in the database + db_property = prop + + output[db_property] = (operator, deflated_value) + return output def process_has_args(cls, kwargs): @@ -382,7 +360,7 @@ class QueryAST: where: TOptional[list] with_clause: TOptional[str] return_clause: TOptional[str] - order_by: TOptional[str] + order_by: TOptional[List[str]] skip: TOptional[int] limit: TOptional[int] result_class: TOptional[type] @@ -397,7 +375,7 @@ def __init__( where: TOptional[list] = None, with_clause: TOptional[str] = None, return_clause: TOptional[str] = None, - order_by: TOptional[str] = None, + order_by: TOptional[List[str]] = None, skip: TOptional[int] = None, limit: TOptional[int] = None, result_class: TOptional[type] = None, @@ -421,18 +399,14 @@ def __init__( class AsyncQueryBuilder: - def __init__( - self, node_set, with_subgraph: bool = False, subquery_context: bool = False - ) -> None: + def __init__(self, node_set, subquery_context: bool = False) -> None: self.node_set = node_set self._ast = QueryAST() - self._query_params = {} + self._query_params: Dict = {} self._place_holder_registry = {} - self._ident_count = 0 + self._ident_count: int = 0 self._node_counters = defaultdict(int) - self._with_subgraph: bool = with_subgraph self._subquery_context: bool = subquery_context - self._relation_identifiers: Dict[str, str] = {} async def build_ast(self) -> "AsyncQueryBuilder": if hasattr(self.node_set, "relations_to_fetch"): @@ -477,19 +451,31 @@ async def build_source(self, source) -> str: return await self.build_node(source) raise ValueError("Unknown source type " + repr(source)) - def create_ident(self, relation_name: TOptional[str] = None) -> str: + def create_ident(self) -> str: self._ident_count += 1 - result = f"r{self._ident_count}" - if relation_name: - self._relation_identifiers[relation_name] = result - return result + return f"r{self._ident_count}" - def build_order_by(self, ident, source): + def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None: if "?" in source.order_by_elements: self._ast.with_clause = f"{ident}, rand() as r" - self._ast.order_by = "r" + self._ast.order_by = ["r"] else: - self._ast.order_by = [f"{ident}.{p}" for p in source.order_by_elements] + order_by = [] + for elm in source.order_by_elements: + if "__" not in elm: + prop = elm.split(" ")[0] if " " in elm else elm + if prop not in source.source_class.defined_properties(rels=False): + raise ValueError( + f"No such property {prop} on {source.source_class.__name__}. " + f"Note that Neo4j internals like id or element_id are not allowed " + f"for use in this operation." + ) + order_by.append(f"{ident}.{elm}") + else: + path, prop = elm.rsplit("__", 1) + order_by_clause = self.lookup_query_variable(path) + order_by.append(f"{order_by_clause}.{prop}") + self._ast.order_by = order_by async def build_traversal(self, traversal): """ @@ -527,9 +513,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: path: str = relation["path"] stmt: str = "" source_class_iterator = source_class - parts = path.split("__") - if self._with_subgraph: - subgraph = self._ast.subgraph + parts = re.split(path_split_regex, path) + subgraph = self._ast.subgraph rel_iterator: str = "" for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) @@ -569,8 +554,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: else: lhs_ident = stmt - rel_ident = self.create_ident(rel_iterator) - if self._with_subgraph and part not in self._ast.subgraph: + rel_ident = self.create_ident() + if part not in self._ast.subgraph: subgraph[part] = { "target": relationship.definition["node_class"], "children": {}, @@ -587,8 +572,7 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: relation_type=relationship.definition["relation_type"], ) source_class_iterator = relationship.definition["node_class"] - if self._with_subgraph: - subgraph = subgraph[part]["children"] + subgraph = subgraph[part]["children"] if relation.get("optional"): self._ast.optional_match.append(stmt) @@ -653,6 +637,41 @@ def _register_place_holder(self, key): self._place_holder_registry[key] = 1 return key + "_" + str(self._place_holder_registry[key]) + def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str]: + path, prop = prop.rsplit("__", 1) + ident = self.build_traversal_from_path( + {"path": path, "include_in_return": True}, + source_class, + ) + return ident, path, prop + + def _finalize_filter_statement(self, operator, ident, prop, val) -> str: + if operator in _UNARY_OPERATORS: + # unary operators do not have a parameter + statement = f"{ident}.{prop} {operator}" + else: + place_holder = self._register_place_holder(ident + "_" + prop) + if operator == _SPECIAL_OPERATOR_ARRAY_IN: + statement = operator.format( + ident=ident, + prop=prop, + val=f"${place_holder}", + ) + else: + statement = f"{ident}.{prop} {operator} ${place_holder}" + self._query_params[place_holder] = val + + return statement + + def _build_filter_statements(self, ident, filters, target, source_class): + for prop, op_and_val in filters.items(): + path = None + if "__" in prop: + ident, path, prop = self._parse_path(source_class, prop) + operator, val = op_and_val + statement = self._finalize_filter_statement(operator, ident, prop, val) + target.append(statement) + def _parse_q_filters(self, ident, q, source_class): target = [] for child in q.children: @@ -664,23 +683,7 @@ def _parse_q_filters(self, ident, q, source_class): else: kwargs = {child[0]: child[1]} filters = process_filter_args(source_class, kwargs) - for prop, op_and_val in filters.items(): - operator, val = op_and_val - if operator in _UNARY_OPERATORS: - # unary operators do not have a parameter - statement = f"{ident}.{prop} {operator}" - else: - place_holder = self._register_place_holder(ident + "_" + prop) - if operator == _SPECIAL_OPERATOR_ARRAY_IN: - statement = operator.format( - ident=ident, - prop=prop, - val=f"${place_holder}", - ) - else: - statement = f"{ident}.{prop} {operator} ${place_holder}" - self._query_params[place_holder] = val - target.append(statement) + self._build_filter_statements(ident, filters, target, source_class) ret = f" {q.connector} ".join(target) if q.negated: ret = f"NOT ({ret})" @@ -719,6 +722,35 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): self._ast.where.append(" AND ".join(stmts)) + def lookup_query_variable( + self, path: str, return_relation: bool = False + ) -> TOptional[str]: + """Retrieve the variable name generated internally for the given traversal path.""" + subgraph = self._ast.subgraph + if not subgraph: + return None + traversals = re.split(path_split_regex, path) + if len(traversals) == 0: + raise ValueError("Can only lookup traversal variables") + if traversals[0] not in subgraph: + return None + subgraph = subgraph[traversals[0]] + variable_to_return = None + last_property = traversals[-1] + for part in traversals: + if part in subgraph["children"]: + subgraph = subgraph["children"][part] + elif part == last_property: + # if last part of prop is the last traversal + # we are safe to lookup the variable from the query + if return_relation: + variable_to_return = f"{subgraph['rel_variable_name']}" + else: + variable_to_return = f"{subgraph['variable_name']}" + else: + break + return variable_to_return + def build_query(self) -> str: query: str = "" @@ -754,7 +786,9 @@ def build_query(self) -> str: if type(source) is str: injected_vars.append(f"{source} AS {name}") elif isinstance(source, RelationNameResolver): - internal_name = self._relation_identifiers.get(source.relation) + internal_name = self.lookup_query_variable( + source.relation, return_relation=True + ) if not internal_name: raise ValueError( f"Unable to resolve variable name for relation {source.relation}." @@ -880,6 +914,7 @@ class AsyncBaseSet: """ query_cls = AsyncQueryBuilder + source_class: AsyncStructuredNode async def all(self, lazy=False): """ @@ -1029,6 +1064,7 @@ def __init__(self, source) -> None: self.filters: List = [] self.q_filters = Q() + self.order_by_elements: List = [] # used by has() self.must_match: Dict = {} @@ -1179,14 +1215,10 @@ def order_by(self, *props): else: desc = False - if prop not in self.source_class.defined_properties(rels=False): - raise ValueError( - f"No such property {prop} on {self.source_class.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." - ) - - property_obj = getattr(self.source_class, prop) - if isinstance(property_obj, AliasProperty): - prop = property_obj.aliased_to() + if prop in self.source_class.defined_properties(rels=False): + property_obj = getattr(self.source_class, prop) + if isinstance(property_obj, AliasProperty): + prop = property_obj.aliased_to() self.order_by_elements.append(prop + (" DESC" if desc else "")) @@ -1300,7 +1332,7 @@ async def resolve_subgraph(self) -> list: "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." ) results: list = [] - qbuilder = self.query_cls(self, with_subgraph=True) + qbuilder = self.query_cls(self) await qbuilder.build_ast() all_nodes = qbuilder._execute(dict_output=True) other_nodes = {} diff --git a/neomodel/properties.py b/neomodel/properties.py index d4a91885..8c848ea8 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -1,9 +1,10 @@ import functools import json import re -import sys import uuid +from abc import ABCMeta, abstractmethod from datetime import date, datetime +from typing import Any import neo4j.time import pytz @@ -37,7 +38,7 @@ def _validator(self, value, obj=None, rethrow=True): return _validator -class FulltextIndex(object): +class FulltextIndex: """ Fulltext index definition """ @@ -57,7 +58,7 @@ def __init__( self.eventually_consistent = eventually_consistent -class VectorIndex(object): +class VectorIndex: """ Vector index definition """ @@ -73,7 +74,7 @@ def __init__(self, dimensions=1536, similarity_function="cosine"): self.similarity_function = similarity_function -class Property: +class Property(metaclass=ABCMeta): """ Base class for object properties. @@ -158,6 +159,10 @@ def get_db_property_name(self, attribute_name): def is_indexed(self): return self.unique_index or self.index + @abstractmethod + def deflate(self, value: Any) -> Any: + pass + class NormalizedProperty(Property): """ diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 16ad5350..a1d311fd 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -8,7 +8,8 @@ from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase -from neomodel.properties import AliasProperty, ArrayProperty +from neomodel.properties import AliasProperty, ArrayProperty, Property +from neomodel.sync_ import relationship_manager from neomodel.sync_.core import StructuredNode, db from neomodel.sync_.relationship import StructuredRel from neomodel.util import INCOMING, OUTGOING @@ -197,6 +198,8 @@ def _rel_merge_helper( # add all regex operators OPERATOR_TABLE.update(_REGEX_OPERATOR_TABLE) +path_split_regex = re.compile(r"__(?!_)") + def install_traversals(cls, node_set): """ @@ -216,136 +219,111 @@ def install_traversals(cls, node_set): setattr(node_set, key, traversal) -def process_filter_args(cls, kwargs) -> Dict: - """ - loop through properties in filter parameters check they match class definition - deflate them and convert into something easy to generate cypher from - """ - - output = {} - - for key, value in kwargs.items(): - if "__" in key: - prop, operator = key.rsplit("__") - operator = OPERATOR_TABLE[operator] - else: - prop = key - operator = "=" - - if prop not in cls.defined_properties(rels=False): +def _handle_special_operators( + property_obj: Property, key: str, value: str, operator: str, prop: str +) -> Tuple[str, str, str]: + if operator == _SPECIAL_OPERATOR_IN: + if not isinstance(value, (list, tuple)): raise ValueError( - f"No such property {prop} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." + f"Value must be a tuple or list for IN operation {key}={value}" ) - - property_obj = getattr(cls, prop) - if isinstance(property_obj, AliasProperty): - prop = property_obj.aliased_to() - deflated_value = getattr(cls, prop).deflate(value) + if isinstance(property_obj, ArrayProperty): + deflated_value = property_obj.deflate(value) + operator = _SPECIAL_OPERATOR_ARRAY_IN else: - operator, deflated_value = transform_operator_to_filter( - operator=operator, - filter_key=key, - filter_value=value, - property_obj=property_obj, - ) + deflated_value = [property_obj.deflate(v) for v in value] + elif operator == _SPECIAL_OPERATOR_ISNULL: + if not isinstance(value, bool): + raise ValueError(f"Value must be a bool for isnull operation on {key}") + operator = "IS NULL" if value else "IS NOT NULL" + deflated_value = None + elif operator in _REGEX_OPERATOR_TABLE.values(): + deflated_value = property_obj.deflate(value) + if not isinstance(deflated_value, str): + raise ValueError(f"Must be a string value for {key}") + if operator in _STRING_REGEX_OPERATOR_TABLE.values(): + deflated_value = re.escape(deflated_value) + deflated_value = operator.format(deflated_value) + operator = _SPECIAL_OPERATOR_REGEX + else: + deflated_value = property_obj.deflate(value) - # map property to correct property name in the database - db_property = cls.defined_properties(rels=False)[prop].get_db_property_name( - prop + return deflated_value, operator, prop + + +def _deflate_value( + cls, property_obj: Property, key: str, value: str, operator: str, prop: str +) -> Tuple[str, str, str]: + if isinstance(property_obj, AliasProperty): + prop = property_obj.aliased_to() + deflated_value = getattr(cls, prop).deflate(value) + else: + # handle special operators + deflated_value, operator, prop = _handle_special_operators( + property_obj, key, value, operator, prop ) - output[db_property] = (operator, deflated_value) + return deflated_value, operator, prop - return output +def _initialize_filter_args_variables(cls, key: str): + current_class = cls + leaf_prop = None + operator = "=" + prop = key -def transform_in_operator_to_filter(operator, filter_key, filter_value, property_obj): - """ - Transform in operator to a cypher filter - Args: - operator (str): operator to transform - filter_key (str): filter key - filter_value (str): filter value - property_obj (object): property object - Returns: - tuple: operator, deflated_value - """ - if not isinstance(filter_value, tuple) and not isinstance(filter_value, list): - raise ValueError( - f"Value must be a tuple or list for IN operation {filter_key}={filter_value}" - ) - if isinstance(property_obj, ArrayProperty): - deflated_value = property_obj.deflate(filter_value) - operator = _SPECIAL_OPERATOR_ARRAY_IN - else: - deflated_value = [property_obj.deflate(v) for v in filter_value] + return current_class, leaf_prop, operator, prop - return operator, deflated_value +def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: + ( + current_class, + leaf_prop, + operator, + prop, + ) = _initialize_filter_args_variables(cls, key) -def transform_null_operator_to_filter(filter_key, filter_value): - """ - Transform null operator to a cypher filter - Args: - filter_key (str): filter key - filter_value (str): filter value - Returns: - tuple: operator, deflated_value - """ - if not isinstance(filter_value, bool): - raise ValueError(f"Value must be a bool for isnull operation on {filter_key}") - operator = "IS NULL" if filter_value else "IS NOT NULL" - deflated_value = None - return operator, deflated_value + for part in re.split(path_split_regex, key): + defined_props = current_class.defined_properties(rels=True) + if part in defined_props: + if isinstance( + defined_props[part], relationship_manager.RelationshipDefinition + ): + defined_props[part].lookup_node_class() + current_class = defined_props[part].definition["node_class"] + elif part in OPERATOR_TABLE: + operator = OPERATOR_TABLE[part] + prop, _ = prop.rsplit("__", 1) + continue + else: + raise ValueError( + f"No such property {part} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." + ) + leaf_prop = part + property_obj = getattr(current_class, leaf_prop) -def transform_regex_operator_to_filter( - operator, filter_key, filter_value, property_obj -): - """ - Transform regex operator to a cypher filter - Args: - operator (str): operator to transform - filter_key (str): filter key - filter_value (str): filter value - property_obj (object): property object - Returns: - tuple: operator, deflated_value - """ + return property_obj, operator, prop - deflated_value = property_obj.deflate(filter_value) - if not isinstance(deflated_value, str): - raise ValueError(f"Must be a string value for {filter_key}") - if operator in _STRING_REGEX_OPERATOR_TABLE.values(): - deflated_value = re.escape(deflated_value) - deflated_value = operator.format(deflated_value) - operator = _SPECIAL_OPERATOR_REGEX - return operator, deflated_value +def process_filter_args(cls, kwargs) -> Dict: + """ + loop through properties in filter parameters check they match class definition + deflate them and convert into something easy to generate cypher from + """ + output = {} -def transform_operator_to_filter(operator, filter_key, filter_value, property_obj): - if operator == _SPECIAL_OPERATOR_IN: - operator, deflated_value = transform_in_operator_to_filter( - operator=operator, - filter_key=filter_key, - filter_value=filter_value, - property_obj=property_obj, - ) - elif operator == _SPECIAL_OPERATOR_ISNULL: - operator, deflated_value = transform_null_operator_to_filter( - filter_key=filter_key, filter_value=filter_value - ) - elif operator in _REGEX_OPERATOR_TABLE.values(): - operator, deflated_value = transform_regex_operator_to_filter( - operator=operator, - filter_key=filter_key, - filter_value=filter_value, - property_obj=property_obj, + for key, value in kwargs.items(): + property_obj, operator, prop = _process_filter_key(cls, key) + deflated_value, operator, prop = _deflate_value( + cls, property_obj, key, value, operator, prop ) - else: - deflated_value = property_obj.deflate(filter_value) - return operator, deflated_value + # map property to correct property name in the database + db_property = prop + + output[db_property] = (operator, deflated_value) + return output def process_has_args(cls, kwargs): @@ -382,7 +360,7 @@ class QueryAST: where: TOptional[list] with_clause: TOptional[str] return_clause: TOptional[str] - order_by: TOptional[str] + order_by: TOptional[List[str]] skip: TOptional[int] limit: TOptional[int] result_class: TOptional[type] @@ -397,7 +375,7 @@ def __init__( where: TOptional[list] = None, with_clause: TOptional[str] = None, return_clause: TOptional[str] = None, - order_by: TOptional[str] = None, + order_by: TOptional[List[str]] = None, skip: TOptional[int] = None, limit: TOptional[int] = None, result_class: TOptional[type] = None, @@ -421,18 +399,14 @@ def __init__( class QueryBuilder: - def __init__( - self, node_set, with_subgraph: bool = False, subquery_context: bool = False - ) -> None: + def __init__(self, node_set, subquery_context: bool = False) -> None: self.node_set = node_set self._ast = QueryAST() - self._query_params = {} + self._query_params: Dict = {} self._place_holder_registry = {} - self._ident_count = 0 + self._ident_count: int = 0 self._node_counters = defaultdict(int) - self._with_subgraph: bool = with_subgraph self._subquery_context: bool = subquery_context - self._relation_identifiers: Dict[str, str] = {} def build_ast(self) -> "QueryBuilder": if hasattr(self.node_set, "relations_to_fetch"): @@ -477,19 +451,31 @@ def build_source(self, source) -> str: return self.build_node(source) raise ValueError("Unknown source type " + repr(source)) - def create_ident(self, relation_name: TOptional[str] = None) -> str: + def create_ident(self) -> str: self._ident_count += 1 - result = f"r{self._ident_count}" - if relation_name: - self._relation_identifiers[relation_name] = result - return result + return f"r{self._ident_count}" - def build_order_by(self, ident, source): + def build_order_by(self, ident: str, source: "NodeSet") -> None: if "?" in source.order_by_elements: self._ast.with_clause = f"{ident}, rand() as r" - self._ast.order_by = "r" + self._ast.order_by = ["r"] else: - self._ast.order_by = [f"{ident}.{p}" for p in source.order_by_elements] + order_by = [] + for elm in source.order_by_elements: + if "__" not in elm: + prop = elm.split(" ")[0] if " " in elm else elm + if prop not in source.source_class.defined_properties(rels=False): + raise ValueError( + f"No such property {prop} on {source.source_class.__name__}. " + f"Note that Neo4j internals like id or element_id are not allowed " + f"for use in this operation." + ) + order_by.append(f"{ident}.{elm}") + else: + path, prop = elm.rsplit("__", 1) + order_by_clause = self.lookup_query_variable(path) + order_by.append(f"{order_by_clause}.{prop}") + self._ast.order_by = order_by def build_traversal(self, traversal): """ @@ -527,9 +513,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: path: str = relation["path"] stmt: str = "" source_class_iterator = source_class - parts = path.split("__") - if self._with_subgraph: - subgraph = self._ast.subgraph + parts = re.split(path_split_regex, path) + subgraph = self._ast.subgraph rel_iterator: str = "" for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) @@ -569,8 +554,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: else: lhs_ident = stmt - rel_ident = self.create_ident(rel_iterator) - if self._with_subgraph and part not in self._ast.subgraph: + rel_ident = self.create_ident() + if part not in self._ast.subgraph: subgraph[part] = { "target": relationship.definition["node_class"], "children": {}, @@ -587,8 +572,7 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: relation_type=relationship.definition["relation_type"], ) source_class_iterator = relationship.definition["node_class"] - if self._with_subgraph: - subgraph = subgraph[part]["children"] + subgraph = subgraph[part]["children"] if relation.get("optional"): self._ast.optional_match.append(stmt) @@ -653,6 +637,41 @@ def _register_place_holder(self, key): self._place_holder_registry[key] = 1 return key + "_" + str(self._place_holder_registry[key]) + def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str]: + path, prop = prop.rsplit("__", 1) + ident = self.build_traversal_from_path( + {"path": path, "include_in_return": True}, + source_class, + ) + return ident, path, prop + + def _finalize_filter_statement(self, operator, ident, prop, val) -> str: + if operator in _UNARY_OPERATORS: + # unary operators do not have a parameter + statement = f"{ident}.{prop} {operator}" + else: + place_holder = self._register_place_holder(ident + "_" + prop) + if operator == _SPECIAL_OPERATOR_ARRAY_IN: + statement = operator.format( + ident=ident, + prop=prop, + val=f"${place_holder}", + ) + else: + statement = f"{ident}.{prop} {operator} ${place_holder}" + self._query_params[place_holder] = val + + return statement + + def _build_filter_statements(self, ident, filters, target, source_class): + for prop, op_and_val in filters.items(): + path = None + if "__" in prop: + ident, path, prop = self._parse_path(source_class, prop) + operator, val = op_and_val + statement = self._finalize_filter_statement(operator, ident, prop, val) + target.append(statement) + def _parse_q_filters(self, ident, q, source_class): target = [] for child in q.children: @@ -664,23 +683,7 @@ def _parse_q_filters(self, ident, q, source_class): else: kwargs = {child[0]: child[1]} filters = process_filter_args(source_class, kwargs) - for prop, op_and_val in filters.items(): - operator, val = op_and_val - if operator in _UNARY_OPERATORS: - # unary operators do not have a parameter - statement = f"{ident}.{prop} {operator}" - else: - place_holder = self._register_place_holder(ident + "_" + prop) - if operator == _SPECIAL_OPERATOR_ARRAY_IN: - statement = operator.format( - ident=ident, - prop=prop, - val=f"${place_holder}", - ) - else: - statement = f"{ident}.{prop} {operator} ${place_holder}" - self._query_params[place_holder] = val - target.append(statement) + self._build_filter_statements(ident, filters, target, source_class) ret = f" {q.connector} ".join(target) if q.negated: ret = f"NOT ({ret})" @@ -719,6 +722,35 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): self._ast.where.append(" AND ".join(stmts)) + def lookup_query_variable( + self, path: str, return_relation: bool = False + ) -> TOptional[str]: + """Retrieve the variable name generated internally for the given traversal path.""" + subgraph = self._ast.subgraph + if not subgraph: + return None + traversals = re.split(path_split_regex, path) + if len(traversals) == 0: + raise ValueError("Can only lookup traversal variables") + if traversals[0] not in subgraph: + return None + subgraph = subgraph[traversals[0]] + variable_to_return = None + last_property = traversals[-1] + for part in traversals: + if part in subgraph["children"]: + subgraph = subgraph["children"][part] + elif part == last_property: + # if last part of prop is the last traversal + # we are safe to lookup the variable from the query + if return_relation: + variable_to_return = f"{subgraph['rel_variable_name']}" + else: + variable_to_return = f"{subgraph['variable_name']}" + else: + break + return variable_to_return + def build_query(self) -> str: query: str = "" @@ -754,7 +786,9 @@ def build_query(self) -> str: if type(source) is str: injected_vars.append(f"{source} AS {name}") elif isinstance(source, RelationNameResolver): - internal_name = self._relation_identifiers.get(source.relation) + internal_name = self.lookup_query_variable( + source.relation, return_relation=True + ) if not internal_name: raise ValueError( f"Unable to resolve variable name for relation {source.relation}." @@ -878,6 +912,7 @@ class BaseSet: """ query_cls = QueryBuilder + source_class: StructuredNode def all(self, lazy=False): """ @@ -1027,6 +1062,7 @@ def __init__(self, source) -> None: self.filters: List = [] self.q_filters = Q() + self.order_by_elements: List = [] # used by has() self.must_match: Dict = {} @@ -1177,14 +1213,10 @@ def order_by(self, *props): else: desc = False - if prop not in self.source_class.defined_properties(rels=False): - raise ValueError( - f"No such property {prop} on {self.source_class.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." - ) - - property_obj = getattr(self.source_class, prop) - if isinstance(property_obj, AliasProperty): - prop = property_obj.aliased_to() + if prop in self.source_class.defined_properties(rels=False): + property_obj = getattr(self.source_class, prop) + if isinstance(property_obj, AliasProperty): + prop = property_obj.aliased_to() self.order_by_elements.append(prop + (" DESC" if desc else "")) @@ -1298,7 +1330,7 @@ def resolve_subgraph(self) -> list: "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." ) results: list = [] - qbuilder = self.query_cls(self, with_subgraph=True) + qbuilder = self.query_cls(self) qbuilder.build_ast() all_nodes = qbuilder._execute(dict_output=True) other_nodes = {} diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 0659f468..1d90b625 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -310,7 +310,7 @@ async def test_order_by(): ns = ns.order_by("?") qb = await AsyncQueryBuilder(ns).build_ast() assert qb._ast.with_clause == "coffee, rand() as r" - assert qb._ast.order_by == "r" + assert qb._ast.order_by == ["r"] with raises( ValueError, @@ -544,8 +544,32 @@ async def test_traversal_filter_left_hand_statement(): assert lidl in lidl_supplier +@mark_async_test +async def test_filter_with_traversal(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + robusta = await Species(name="Robusta").save() + nescafe = await Coffee(name="Nescafe", price=11).save() + nescafe_gold = await Coffee(name="Nescafe Gold", price=99).save() + tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + await nescafe.suppliers.connect(tesco) + await nescafe_gold.suppliers.connect(tesco) + await nescafe.species.connect(arabica) + await nescafe_gold.species.connect(robusta) + + results = await Coffee.nodes.filter(species__name="Arabica").all() + assert len(results) == 1 + assert len(results[0]) == 3 + assert results[0][0] == nescafe + + @mark_async_test async def test_fetch_relations(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + arabica = await Species(name="Arabica").save() robusta = await Species(name="Robusta").save() nescafe = await Coffee(name="Nescafe 1000", price=99).save() @@ -601,6 +625,30 @@ async def test_fetch_relations(): ) +@mark_async_test +async def test_traverse_and_order_by(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + robusta = await Species(name="Robusta").save() + nescafe = await Coffee(name="Nescafe", price=99).save() + nescafe_gold = await Coffee(name="Nescafe Gold", price=110).save() + tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + await nescafe.suppliers.connect(tesco) + await nescafe_gold.suppliers.connect(tesco) + await nescafe.species.connect(arabica) + await nescafe_gold.species.connect(robusta) + + results = ( + await Species.nodes.fetch_relations("coffees").order_by("-coffees__price").all() + ) + assert len(results) == 2 + assert len(results[0]) == 3 # 2 nodes and 1 relation + assert results[0][0] == robusta + assert results[1][0] == arabica + + @mark_async_test async def test_annotate_and_collect(): # Clean DB before we start anything... diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 3976c92a..f9beafa6 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -306,7 +306,7 @@ def test_order_by(): ns = ns.order_by("?") qb = QueryBuilder(ns).build_ast() assert qb._ast.with_clause == "coffee, rand() as r" - assert qb._ast.order_by == "r" + assert qb._ast.order_by == ["r"] with raises( ValueError, @@ -538,8 +538,32 @@ def test_traversal_filter_left_hand_statement(): assert lidl in lidl_supplier +@mark_sync_test +def test_filter_with_traversal(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + robusta = Species(name="Robusta").save() + nescafe = Coffee(name="Nescafe", price=11).save() + nescafe_gold = Coffee(name="Nescafe Gold", price=99).save() + tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + nescafe.suppliers.connect(tesco) + nescafe_gold.suppliers.connect(tesco) + nescafe.species.connect(arabica) + nescafe_gold.species.connect(robusta) + + results = Coffee.nodes.filter(species__name="Arabica").all() + assert len(results) == 1 + assert len(results[0]) == 3 + assert results[0][0] == nescafe + + @mark_sync_test def test_fetch_relations(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + arabica = Species(name="Arabica").save() robusta = Species(name="Robusta").save() nescafe = Coffee(name="Nescafe 1000", price=99).save() @@ -595,6 +619,28 @@ def test_fetch_relations(): ) +@mark_sync_test +def test_traverse_and_order_by(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + robusta = Species(name="Robusta").save() + nescafe = Coffee(name="Nescafe", price=99).save() + nescafe_gold = Coffee(name="Nescafe Gold", price=110).save() + tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + nescafe.suppliers.connect(tesco) + nescafe_gold.suppliers.connect(tesco) + nescafe.species.connect(arabica) + nescafe_gold.species.connect(robusta) + + results = Species.nodes.fetch_relations("coffees").order_by("-coffees__price").all() + assert len(results) == 2 + assert len(results[0]) == 3 # 2 nodes and 1 relation + assert results[0][0] == robusta + assert results[1][0] == arabica + + @mark_sync_test def test_annotate_and_collect(): # Clean DB before we start anything...