Skip to content

Commit

Permalink
Supports WHERE clause for relationships in multigraphs
Browse files Browse the repository at this point in the history
  • Loading branch information
jackboyla committed May 9, 2024
1 parent 3595706 commit da81cfd
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions grandcypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _aggregate_edge_labels(edges: Dict) -> Dict:
aggregated[edge_id] = attrs
return aggregated

def _get_entity_from_host(host: nx.DiGraph, entity_name, entity_attribute=None):
def _get_entity_from_host(host: Union[nx.DiGraph, nx.MultiDiGraph], entity_name, entity_attribute=None):
if entity_name in host.nodes():
# We are looking for a node mapping in the target graph:
if entity_attribute:
Expand All @@ -276,7 +276,10 @@ def _get_entity_from_host(host: nx.DiGraph, entity_name, entity_attribute=None):
return None # print(f"Nothing found for {entity_name} {entity_attribute}")
if entity_attribute:
# looking for edge attribute:
return edge_data.get(entity_attribute, None)
if isinstance(host, nx.MultiDiGraph):
return [r.get(entity_attribute, None) for r in edge_data.values()]
else:
return edge_data.get(entity_attribute, None)
else:
return host.get_edge_data(*entity_name)

Expand Down Expand Up @@ -307,7 +310,7 @@ def inner(match: dict, host: nx.DiGraph, return_endges: list) -> bool:


def cond_(should_be, entity_id, operator, value) -> CONDITION:
def inner(match: dict, host: nx.DiGraph, return_endges: list) -> bool:
def inner(match: dict, host: Union[nx.DiGraph, nx.MultiDiGraph], return_endges: list) -> bool:
host_entity_id = entity_id.split(".")
if host_entity_id[0] in match:
host_entity_id[0] = match[host_entity_id[0]]
Expand All @@ -318,7 +321,13 @@ def inner(match: dict, host: nx.DiGraph, return_endges: list) -> bool:
else:
raise IndexError(f"Entity {host_entity_id} not in graph.")
try:
val = operator(_get_entity_from_host(host, *host_entity_id), value)
if isinstance(host, nx.MultiDiGraph):
# if any of the relations between nodes satisfies condition, return True
r_vals = _get_entity_from_host(host, *host_entity_id)
val = any(operator(r_val, value) for r_val in r_vals)
else:
val = operator(_get_entity_from_host(host, *host_entity_id), value)

except:
val = False
if val != should_be:
Expand Down

0 comments on commit da81cfd

Please sign in to comment.