Skip to content

Commit

Permalink
Added missing db property resolution call
Browse files Browse the repository at this point in the history
  • Loading branch information
tonioo committed Oct 8, 2024
1 parent 217d0c5 commit 272393f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 26 deletions.
34 changes: 21 additions & 13 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ def process_filter_args(cls, kwargs) -> Dict:
deflated_value, operator, prop = _deflate_value(
cls, property_obj, key, value, operator, prop
)

# map property to correct property name in the database
db_property = prop

Expand Down Expand Up @@ -357,7 +356,7 @@ def process_has_args(cls, kwargs):
class QueryAST:
match: List[str]
optional_match: List[str]
where: TOptional[list]
where: List[str]
with_clause: TOptional[str]
return_clause: TOptional[str]
order_by: TOptional[List[str]]
Expand All @@ -372,7 +371,7 @@ def __init__(
self,
match: TOptional[List[str]] = None,
optional_match: TOptional[List[str]] = None,
where: TOptional[list] = None,
where: TOptional[List[str]] = None,
with_clause: TOptional[str] = None,
return_clause: TOptional[str] = None,
order_by: TOptional[List[str]] = None,
Expand Down Expand Up @@ -403,7 +402,7 @@ def __init__(self, node_set, subquery_context: bool = False) -> None:
self.node_set = node_set
self._ast = QueryAST()
self._query_params: Dict = {}
self._place_holder_registry = {}
self._place_holder_registry: Dict = {}
self._ident_count: int = 0
self._node_counters = defaultdict(int)
self._subquery_context: bool = subquery_context
Expand Down Expand Up @@ -477,7 +476,7 @@ def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None:
order_by.append(f"{order_by_clause}.{prop}")
self._ast.order_by = order_by

async def build_traversal(self, traversal):
async def build_traversal(self, traversal) -> str:
"""
traverse a relationship from a node to a set of nodes
"""
Expand Down Expand Up @@ -630,7 +629,7 @@ def build_additional_match(self, ident, node_set):
else:
raise ValueError("Expecting dict got: " + repr(val))

def _register_place_holder(self, key):
def _register_place_holder(self, key: str) -> str:
if key in self._place_holder_registry:
self._place_holder_registry[key] += 1
else:
Expand All @@ -645,7 +644,9 @@ def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str]:
)
return ident, path, prop

def _finalize_filter_statement(self, operator, ident, prop, val) -> str:
def _finalize_filter_statement(
self, operator: str, ident: str, prop: str, val: Any
) -> str:
if operator in _UNARY_OPERATORS:
# unary operators do not have a parameter
statement = f"{ident}.{prop} {operator}"
Expand All @@ -663,16 +664,21 @@ def _finalize_filter_statement(self, operator, ident, prop, val) -> str:

return statement

def _build_filter_statements(self, ident, filters, target, source_class):
def _build_filter_statements(
self, ident: str, filters, target: List[str], source_class
) -> None:
for prop, op_and_val in filters.items():
path = None
if "__" in prop:
ident, path, prop = self._parse_path(source_class, prop)
operator, val = op_and_val
prop = source_class.defined_properties(rels=False)[
prop
].get_db_property_name(prop)
statement = self._finalize_filter_statement(operator, ident, prop, val)
target.append(statement)

def _parse_q_filters(self, ident, q, source_class):
def _parse_q_filters(self, ident, q, source_class) -> str:
target = []
for child in q.children:
if isinstance(child, QBase):
Expand All @@ -689,14 +695,16 @@ def _parse_q_filters(self, ident, q, source_class):
ret = f"NOT ({ret})"
return ret

def build_where_stmt(self, ident, filters, q_filters=None, source_class=None):
def build_where_stmt(
self, ident: str, filters, q_filters=None, source_class=None
) -> None:
"""
construct a where statement from some filters
"""
if q_filters is not None:
stmts = self._parse_q_filters(ident, q_filters, source_class)
if stmts:
self._ast.where.append(stmts)
stmt = self._parse_q_filters(ident, q_filters, source_class)
if stmt:
self._ast.where.append(stmt)
else:
stmts = []
for row in filters:
Expand Down
34 changes: 21 additions & 13 deletions neomodel/sync_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ def process_filter_args(cls, kwargs) -> Dict:
deflated_value, operator, prop = _deflate_value(
cls, property_obj, key, value, operator, prop
)

# map property to correct property name in the database
db_property = prop

Expand Down Expand Up @@ -357,7 +356,7 @@ def process_has_args(cls, kwargs):
class QueryAST:
match: List[str]
optional_match: List[str]
where: TOptional[list]
where: List[str]
with_clause: TOptional[str]
return_clause: TOptional[str]
order_by: TOptional[List[str]]
Expand All @@ -372,7 +371,7 @@ def __init__(
self,
match: TOptional[List[str]] = None,
optional_match: TOptional[List[str]] = None,
where: TOptional[list] = None,
where: TOptional[List[str]] = None,
with_clause: TOptional[str] = None,
return_clause: TOptional[str] = None,
order_by: TOptional[List[str]] = None,
Expand Down Expand Up @@ -403,7 +402,7 @@ def __init__(self, node_set, subquery_context: bool = False) -> None:
self.node_set = node_set
self._ast = QueryAST()
self._query_params: Dict = {}
self._place_holder_registry = {}
self._place_holder_registry: Dict = {}
self._ident_count: int = 0
self._node_counters = defaultdict(int)
self._subquery_context: bool = subquery_context
Expand Down Expand Up @@ -477,7 +476,7 @@ def build_order_by(self, ident: str, source: "NodeSet") -> None:
order_by.append(f"{order_by_clause}.{prop}")
self._ast.order_by = order_by

def build_traversal(self, traversal):
def build_traversal(self, traversal) -> str:
"""
traverse a relationship from a node to a set of nodes
"""
Expand Down Expand Up @@ -630,7 +629,7 @@ def build_additional_match(self, ident, node_set):
else:
raise ValueError("Expecting dict got: " + repr(val))

def _register_place_holder(self, key):
def _register_place_holder(self, key: str) -> str:
if key in self._place_holder_registry:
self._place_holder_registry[key] += 1
else:
Expand All @@ -645,7 +644,9 @@ def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str]:
)
return ident, path, prop

def _finalize_filter_statement(self, operator, ident, prop, val) -> str:
def _finalize_filter_statement(
self, operator: str, ident: str, prop: str, val: Any
) -> str:
if operator in _UNARY_OPERATORS:
# unary operators do not have a parameter
statement = f"{ident}.{prop} {operator}"
Expand All @@ -663,16 +664,21 @@ def _finalize_filter_statement(self, operator, ident, prop, val) -> str:

return statement

def _build_filter_statements(self, ident, filters, target, source_class):
def _build_filter_statements(
self, ident: str, filters, target: List[str], source_class
) -> None:
for prop, op_and_val in filters.items():
path = None
if "__" in prop:
ident, path, prop = self._parse_path(source_class, prop)
operator, val = op_and_val
prop = source_class.defined_properties(rels=False)[
prop
].get_db_property_name(prop)
statement = self._finalize_filter_statement(operator, ident, prop, val)
target.append(statement)

def _parse_q_filters(self, ident, q, source_class):
def _parse_q_filters(self, ident, q, source_class) -> str:
target = []
for child in q.children:
if isinstance(child, QBase):
Expand All @@ -689,14 +695,16 @@ def _parse_q_filters(self, ident, q, source_class):
ret = f"NOT ({ret})"
return ret

def build_where_stmt(self, ident, filters, q_filters=None, source_class=None):
def build_where_stmt(
self, ident: str, filters, q_filters=None, source_class=None
) -> None:
"""
construct a where statement from some filters
"""
if q_filters is not None:
stmts = self._parse_q_filters(ident, q_filters, source_class)
if stmts:
self._ast.where.append(stmts)
stmt = self._parse_q_filters(ident, q_filters, source_class)
if stmt:
self._ast.where.append(stmt)
else:
stmts = []
for row in filters:
Expand Down

0 comments on commit 272393f

Please sign in to comment.