diff --git a/alchemiscale/models.py b/alchemiscale/models.py index 69fc1288..ed7a6cfb 100644 --- a/alchemiscale/models.py +++ b/alchemiscale/models.py @@ -8,6 +8,8 @@ from pydantic import BaseModel, Field, validator, root_validator from gufe.tokenization import GufeKey from re import fullmatch +import unicodedata +import string class Scope(BaseModel): @@ -114,6 +116,9 @@ def specific(self) -> bool: return all(self.to_tuple()) +class InvalidGufeKeyError(ValueError): ... + + class ScopedKey(BaseModel): """Unique identifier for GufeTokenizables in state store. @@ -131,8 +136,26 @@ class Config: frozen = True @validator("gufe_key") - def cast_gufe_key(cls, v): - return GufeKey(v) + def gufe_key_validator(cls, v): + v = str(v) + + # GufeKey is of form - + try: + _prefix, _token = v.split("-") + except ValueError: + raise InvalidGufeKeyError("gufe_key must be of the form '-'") + + # Normalize the input to NFC form + v_normalized = unicodedata.normalize("NFC", v) + + # Allowed characters: letters, numbers, underscores, hyphens + allowed_chars = set(string.ascii_letters + string.digits + "_-") + + if not set(v_normalized).issubset(allowed_chars): + raise InvalidGufeKeyError("gufe_key contains invalid characters") + + # Cast to GufeKey + return GufeKey(v_normalized) def __repr__(self): # pragma: no cover return f"" diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 4e6b9c3c..801b6600 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -9,6 +9,7 @@ from contextlib import contextmanager import json from functools import lru_cache +from operator import ne from typing import Dict, List, Optional, Union, Tuple import weakref import numpy as np @@ -99,43 +100,22 @@ def _select_tasks_from_taskpool(taskpool: List[Tuple[str, float]], count) -> Lis return list(np.random.choice(tasks, count, replace=False, p=prob)) -def _generate_claim_query( - task_sks: List[ScopedKey], compute_service_id: ComputeServiceID -) -> str: - """Generate a query to claim a list of Tasks. - - Parameters - ---------- - task_sks - A list of ScopedKeys of Tasks to claim. - compute_service_id - ComputeServiceID of the claiming service. - - Returns - ------- - query: str - The Cypher query to claim the Task. - """ - - task_data = cypher_list_from_scoped_keys(task_sks) - - query = f""" +CLAIM_QUERY = f""" // only match the task if it doesn't have an existing CLAIMS relationship - UNWIND {task_data} AS task_sk + UNWIND $tasks_list AS task_sk MATCH (t:Task {{_scoped_key: task_sk}}) WHERE NOT (t)<-[:CLAIMS]-(:ComputeServiceRegistration) WITH t // create CLAIMS relationship with given compute service - MATCH (csreg:ComputeServiceRegistration {{identifier: '{compute_service_id}'}}) - CREATE (t)<-[cl:CLAIMS {{claimed: localdatetime('{datetime.utcnow().isoformat()}')}}]-(csreg) + MATCH (csreg:ComputeServiceRegistration {{identifier: $compute_service_id}}) + CREATE (t)<-[cl:CLAIMS {{claimed: localdatetime($datetimestr)}}]-(csreg) SET t.status = '{TaskStatusEnum.running.value}' RETURN t - """ - return query +""" class Neo4jStore(AlchemiscaleStateStore): @@ -468,26 +448,21 @@ def _get_node( ) -> Union[Node, Tuple[Node, Subgraph]]: """ If `return_subgraph = True`, also return subgraph for gufe object. - """ - qualname = scoped_key.qualname - - properties = {"_scoped_key": str(scoped_key)} - prop_string = ", ".join( - "{}: '{}'".format(key, value) for key, value in properties.items() - ) - prop_string = f" {{{prop_string}}}" + # Safety: qualname comes from GufeKey which is validated + qualname = scoped_key.qualname + parameters = {"scoped_key": str(scoped_key)} q = f""" - MATCH (n:{qualname}{prop_string}) + MATCH (n:{qualname} {{ _scoped_key: $scoped_key }}) """ if return_subgraph: q += """ OPTIONAL MATCH p = (n)-[r:DEPENDS_ON*]->(m) WHERE NOT (m)-[:DEPENDS_ON]->() - RETURN n,p + RETURN n, p """ else: q += """ @@ -497,10 +472,12 @@ def _get_node( nodes = set() subgraph = Subgraph() - for record in self.execute_query(q).records: + result = self.execute_query(q, parameters_=parameters) + + for record in result.records: node = record_data_to_node(record["n"]) nodes.add(node) - if return_subgraph and record["p"] is not None: + if return_subgraph and record.get("p") is not None: subgraph = subgraph | subgraph_from_path_record(record["p"]) else: subgraph = node @@ -521,8 +498,8 @@ def _query( self, *, qualname: str, - additional: Dict = None, - key: GufeKey = None, + additional: Optional[Dict] = None, + key: Optional[GufeKey] = None, scope: Scope = Scope(), return_gufe=False, ): @@ -532,9 +509,8 @@ def _query( "_project": scope.project, } - for k, v in list(properties.items()): - if v is None: - properties.pop(k) + # Remove None values from properties + properties = {k: v for k, v in properties.items() if v is not None} if key is not None: properties["_gufe_key"] = str(key) @@ -547,7 +523,7 @@ def _query( prop_string = "" else: prop_string = ", ".join( - "{}: '{}'".format(key, value) for key, value in properties.items() + "{}: ${}".format(key, key) for key in properties.keys() ) prop_string = f" {{{prop_string}}}" @@ -568,7 +544,7 @@ def _query( """ with self.transaction() as tx: - res = tx.run(q).to_eager_result() + res = tx.run(q, **properties).to_eager_result() nodes = list() subgraph = Subgraph() @@ -707,8 +683,8 @@ def delete_network( self.delete_taskhub(network) # then delete the network - q = f""" - MATCH (an:AlchemicalNetwork {{_scoped_key: "{network}"}}) + q = """ + MATCH (an:AlchemicalNetwork {_scoped_key: $network}) DETACH DELETE an """ raise NotImplementedError @@ -848,11 +824,14 @@ def query_networks( *, name=None, key=None, - scope: Optional[Scope] = Scope(), + scope: Optional[Scope] = None, state: Optional[str] = None, ) -> List[ScopedKey]: """Query for `AlchemicalNetwork`\s matching given attributes.""" + if scope is None: + scope = Scope() + query_params = dict( name_pattern=name, org_pattern=scope.org, @@ -916,14 +895,14 @@ def query_chemicalsystems(self, *, name=None, key=None, scope: Scope = Scope()): def get_network_transformations(self, network: ScopedKey) -> List[ScopedKey]: """List ScopedKeys for Transformations associated with the given AlchemicalNetwork.""" - q = f""" - MATCH (:AlchemicalNetwork {{_scoped_key: '{network}'}})-[:DEPENDS_ON]->(t:Transformation|NonTransformation) + q = """ + MATCH (:AlchemicalNetwork {_scoped_key: $network})-[:DEPENDS_ON]->(t:Transformation|NonTransformation) WITH t._scoped_key as sk RETURN sk """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, network=str(network)) for rec in res: sks.append(rec["sk"]) @@ -931,14 +910,14 @@ def get_network_transformations(self, network: ScopedKey) -> List[ScopedKey]: def get_transformation_networks(self, transformation: ScopedKey) -> List[ScopedKey]: """List ScopedKeys for AlchemicalNetworks associated with the given Transformation.""" - q = f""" - MATCH (:Transformation|NonTransformation {{_scoped_key: '{transformation}'}})<-[:DEPENDS_ON]-(an:AlchemicalNetwork) + q = """ + MATCH (:Transformation|NonTransformation {_scoped_key: $transformation})<-[:DEPENDS_ON]-(an:AlchemicalNetwork) WITH an._scoped_key as sk RETURN sk """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, transformation=str(transformation)) for rec in res: sks.append(rec["sk"]) @@ -946,14 +925,14 @@ def get_transformation_networks(self, transformation: ScopedKey) -> List[ScopedK def get_network_chemicalsystems(self, network: ScopedKey) -> List[ScopedKey]: """List ScopedKeys for ChemicalSystems associated with the given AlchemicalNetwork.""" - q = f""" - MATCH (:AlchemicalNetwork {{_scoped_key: '{network}'}})-[:DEPENDS_ON]->(cs:ChemicalSystem) + q = """ + MATCH (:AlchemicalNetwork {_scoped_key: $network})-[:DEPENDS_ON]->(cs:ChemicalSystem) WITH cs._scoped_key as sk RETURN sk """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, network=str(network)) for rec in res: sks.append(rec["sk"]) @@ -961,14 +940,14 @@ def get_network_chemicalsystems(self, network: ScopedKey) -> List[ScopedKey]: def get_chemicalsystem_networks(self, chemicalsystem: ScopedKey) -> List[ScopedKey]: """List ScopedKeys for AlchemicalNetworks associated with the given ChemicalSystem.""" - q = f""" - MATCH (:ChemicalSystem {{_scoped_key: '{chemicalsystem}'}})<-[:DEPENDS_ON]-(an:AlchemicalNetwork) + q = """ + MATCH (:ChemicalSystem {_scoped_key: $chemicalsystem})<-[:DEPENDS_ON]-(an:AlchemicalNetwork) WITH an._scoped_key as sk RETURN sk """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, chemicalsystem=str(chemicalsystem)) for rec in res: sks.append(rec["sk"]) @@ -978,14 +957,14 @@ def get_transformation_chemicalsystems( self, transformation: ScopedKey ) -> List[ScopedKey]: """List ScopedKeys for the ChemicalSystems associated with the given Transformation.""" - q = f""" - MATCH (:Transformation|NonTransformation {{_scoped_key: '{transformation}'}})-[:DEPENDS_ON]->(cs:ChemicalSystem) + q = """ + MATCH (:Transformation|NonTransformation {_scoped_key: $transformation})-[:DEPENDS_ON]->(cs:ChemicalSystem) WITH cs._scoped_key as sk RETURN sk """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, transformation=str(transformation)) for rec in res: sks.append(rec["sk"]) @@ -995,14 +974,14 @@ def get_chemicalsystem_transformations( self, chemicalsystem: ScopedKey ) -> List[ScopedKey]: """List ScopedKeys for the Transformations associated with the given ChemicalSystem.""" - q = f""" - MATCH (:ChemicalSystem {{_scoped_key: '{chemicalsystem}'}})<-[:DEPENDS_ON]-(t:Transformation|NonTransformation) + q = """ + MATCH (:ChemicalSystem {_scoped_key: $chemicalsystem})<-[:DEPENDS_ON]-(t:Transformation|NonTransformation) WITH t._scoped_key as sk RETURN sk """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, chemicalsystem=str(chemicalsystem)) for rec in res: sks.append(rec["sk"]) @@ -1093,10 +1072,10 @@ def deregister_computeservice(self, compute_service_id: ComputeServiceID): """ q = f""" - MATCH (n:ComputeServiceRegistration {{identifier: '{compute_service_id}'}}) + MATCH (n:ComputeServiceRegistration {{identifier: $compute_service_id}}) - OPTIONAL MATCH (n)-[cl:CLAIMS]->(t:Task {{status: 'running'}}) - SET t.status = 'waiting' + OPTIONAL MATCH (n)-[cl:CLAIMS]->(t:Task {{status: '{TaskStatusEnum.running.value}'}}) + SET t.status = '{TaskStatusEnum.waiting.value}' WITH n, n.identifier as identifier @@ -1106,7 +1085,7 @@ def deregister_computeservice(self, compute_service_id: ComputeServiceID): """ with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, compute_service_id=str(compute_service_id)) identifier = next(res)["identifier"] return ComputeServiceID(identifier) @@ -1117,12 +1096,12 @@ def heartbeat_computeservice( """Update the heartbeat for the given ComputeServiceID.""" q = f""" - MATCH (n:ComputeServiceRegistration {{identifier: '{compute_service_id}'}}) + MATCH (n:ComputeServiceRegistration {{identifier: $compute_service_id}}) SET n.heartbeat = localdatetime('{heartbeat.isoformat()}') """ with self.transaction() as tx: - tx.run(q) + tx.run(q, compute_service_id=str(compute_service_id)) return compute_service_id @@ -1134,8 +1113,8 @@ def expire_registrations(self, expire_time: datetime): WITH n - OPTIONAL MATCH (n)-[cl:CLAIMS]->(t:Task {{status: 'running'}}) - SET t.status = 'waiting' + OPTIONAL MATCH (n)-[cl:CLAIMS]->(t:Task {{status: '{TaskStatusEnum.running.value}'}}) + SET t.status = '{TaskStatusEnum.waiting.value}' WITH n, n.identifier as ident @@ -1221,13 +1200,15 @@ def get_taskhub( "`network` ScopedKey does not correspond to an `AlchemicalNetwork`" ) - q = f""" - match (th:TaskHub {{network: "{network}"}})-[:PERFORMS]->(an:AlchemicalNetwork) - return th - """ + q = """ + MATCH (th:TaskHub {network: $network})-[:PERFORMS]->(an:AlchemicalNetwork) + RETURN th + """ try: - node = record_data_to_node(self.execute_query(q).records[0]["th"]) + node = record_data_to_node( + self.execute_query(q, network=str(network)).records[0]["th"] + ) except IndexError: raise KeyError("No such object in database") @@ -1249,11 +1230,11 @@ def delete_taskhub( taskhub = self.get_taskhub(network) - q = f""" - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}}), + q = """ + MATCH (th:TaskHub {_scoped_key: $taskhub}) DETACH DELETE th """ - self.execute_query(q) + self.execute_query(q, taskhub=str(taskhub)) return taskhub @@ -1314,14 +1295,14 @@ def get_taskhub_actioned_tasks( A list of dicts, one per TaskHub, which contains the Task ScopedKeys that are actioned on the given TaskHub as keys, with their weights as values. """ - - q = f""" - UNWIND {cypher_list_from_scoped_keys(taskhubs)} as th_sk - MATCH (th: TaskHub {{_scoped_key: th_sk}})-[a:ACTIONS]->(t:Task) + th_scoped_keys = [str(taskhub) for taskhub in taskhubs if taskhub is not None] + q = """ + UNWIND $taskhubs as th_sk + MATCH (th: TaskHub {_scoped_key: th_sk})-[a:ACTIONS]->(t:Task) RETURN t._scoped_key, a.weight, th._scoped_key """ - results = self.execute_query(q) + results = self.execute_query(q, taskhubs=th_scoped_keys) data = {taskhub: {} for taskhub in taskhubs} for record in results.records: @@ -1374,13 +1355,17 @@ def get_taskhub_weight(self, networks: List[ScopedKey]) -> List[float]: "`network` ScopedKey does not correspond to an `AlchemicalNetwork`" ) - q = f""" - UNWIND {cypher_list_from_scoped_keys(networks)} as network - MATCH (th:TaskHub {{network: network}}) + networks_scoped_keys = [ + str(network) for network in networks if network is not None + ] + + q = """ + UNWIND $networks as network + MATCH (th:TaskHub {network: network}) RETURN network, th.weight """ - results = self.execute_query(q) + results = self.execute_query(q, networks=networks_scoped_keys) network_weights = {str(network): None for network in networks} for record in results.records: @@ -1411,10 +1396,12 @@ def action_tasks( # so we can properly return `None` if needed task_map = {str(task): None for task in tasks} + tasks_scoped_keys = [str(task) for task in tasks if task is not None] + q = f""" // get our TaskHub - UNWIND {cypher_list_from_scoped_keys(tasks)} AS task_sk - MATCH (th:TaskHub {{_scoped_key: "{taskhub}"}})-[:PERFORMS]->(an:AlchemicalNetwork) + UNWIND $tasks as task_sk + MATCH (th:TaskHub {{_scoped_key: $taskhub}})-[:PERFORMS]->(an:AlchemicalNetwork) // get the task we want to add to the hub; check that it connects to same network MATCH (task:Task {{_scoped_key: task_sk}})-[:PERFORMS]->(tf:Transformation|NonTransformation)<-[:DEPENDS_ON]-(an) @@ -1423,7 +1410,7 @@ def action_tasks( // and where the task is either in 'waiting', 'running', or 'error' status WITH th, an, task WHERE NOT (th)-[:ACTIONS]->(task) - AND task.status IN ['{TaskStatusEnum.waiting.value}', '{TaskStatusEnum.running.value}', '{TaskStatusEnum.error.value}'] + AND task.status IN ['{TaskStatusEnum.waiting.value}', '{TaskStatusEnum.running.value}', '{TaskStatusEnum.error.value}'] // create the connection CREATE (th)-[ar:ACTIONS {{weight: 0.5}}]->(task) @@ -1434,7 +1421,7 @@ def action_tasks( RETURN task """ - results = self.execute_query(q) + results = self.execute_query(q, tasks=tasks_scoped_keys, taskhub=str(taskhub)) # update our map with the results, leaving None for tasks that aren't found for task_record in results.records: @@ -1496,13 +1483,19 @@ def set_task_weights( if not all([0 <= weight <= 1 for weight in tasks.values()]): raise ValueError("weights must be between 0 and 1 (inclusive)") - for t, w in tasks.items(): - q = f""" - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}}) - SET ar.weight = {w} - RETURN task, ar - """ - results.append(tx.run(q).to_eager_result()) + tasks_list = [{"task": str(t), "weight": w} for t, w in tasks.items()] + + q = """ + UNWIND $tasks_list AS item + MATCH (th:TaskHub {_scoped_key: $taskhub})-[ar:ACTIONS]->(task:Task {_scoped_key: item.task}) + SET ar.weight = item.weight + RETURN task, ar + """ + results.append( + tx.run( + q, taskhub=str(taskhub), tasks_list=tasks_list + ).to_eager_result() + ) elif isinstance(tasks, list): if weight is None: @@ -1513,14 +1506,19 @@ def set_task_weights( if not 0 <= weight <= 1: raise ValueError("weight must be between 0 and 1 (inclusive)") - # TODO: remove for loop with an unwind clause - for t in tasks: - q = f""" - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}}) - SET ar.weight = {weight} - RETURN task, ar - """ - results.append(tx.run(q).to_eager_result()) + tasks_list = [str(t) for t in tasks] + + q = """ + UNWIND $tasks_list AS task_sk + MATCH (th:TaskHub {_scoped_key: $taskhub})-[ar:ACTIONS]->(task:Task {_scoped_key: task_sk}) + SET ar.weight = $weight + RETURN task, ar + """ + results.append( + tx.run( + q, taskhub=str(taskhub), tasks_list=tasks_list, weight=weight + ).to_eager_result() + ) # return ScopedKeys for Tasks we changed; `None` for tasks we didn't for res in results: @@ -1553,22 +1551,18 @@ def get_task_weights( weights Weights for the list of Tasks, in the same order. """ - weights = [] + with self.transaction() as tx: - for t in tasks: - q = f""" - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}}) - RETURN ar.weight - """ - result = tx.run(q) + q = """ + UNWIND $tasks_list AS task_scoped_key + OPTIONAL MATCH (th:TaskHub {_scoped_key: $taskhub})-[ar:ACTIONS]->(task:Task {_scoped_key: task_scoped_key}) + RETURN task_scoped_key, ar.weight AS weight + """ - weight = [record.get("ar.weight") for record in result] + result = tx.run(q, taskhub=str(taskhub), tasks_list=list(map(str, tasks))) + results = result.data() - # if no match for the given Task, we put a `None` as result - if len(weight) == 0: - weights.append(None) - else: - weights.extend(weight) + weights = [record["weight"] for record in results] return weights @@ -1609,13 +1603,13 @@ def get_taskhub_tasks( ) -> Union[List[ScopedKey], Dict[ScopedKey, Task]]: """Get a list of Tasks on the TaskHub.""" - q = f""" + q = """ // get list of all tasks associated with the taskhub - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[:ACTIONS]->(task:Task) + MATCH (th:TaskHub {_scoped_key: $taskhub})-[:ACTIONS]->(task:Task) RETURN task """ with self.transaction() as tx: - res = tx.run(q).to_eager_result() + res = tx.run(q, taskhub=str(taskhub)).to_eager_result() tasks = [] subgraph = Subgraph() @@ -1636,14 +1630,14 @@ def get_taskhub_unclaimed_tasks( ) -> Union[List[ScopedKey], Dict[ScopedKey, Task]]: """Get a list of unclaimed Tasks in the TaskHub.""" - q = f""" + q = """ // get list of all unclaimed tasks in the hub - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[:ACTIONS]->(task:Task) + MATCH (th:TaskHub {_scoped_key: $taskhub})-[:ACTIONS]->(task:Task) WHERE NOT (task)<-[:CLAIMS]-(:ComputeServiceRegistration) RETURN task """ with self.transaction() as tx: - res = tx.run(q).to_eager_result() + res = tx.run(q, taskhub=str(taskhub)).to_eager_result() tasks = [] subgraph = Subgraph() @@ -1695,7 +1689,7 @@ def claim_taskhub_tasks( raise ValueError("`protocols` must be either `None` or not empty") q = f""" - MATCH (th:TaskHub {{`_scoped_key`: '{taskhub}'}})-[actions:ACTIONS]-(task:Task) + MATCH (th:TaskHub {{_scoped_key: $taskhub}})-[actions:ACTIONS]-(task:Task) WHERE task.status = '{TaskStatusEnum.waiting.value}' AND actions.weight > 0 OPTIONAL MATCH (task)-[:EXTENDS]->(other_task:Task) @@ -1725,14 +1719,15 @@ def claim_taskhub_tasks( _tasks = {} with self.transaction() as tx: tx.run( - f""" - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}}) + """ + MATCH (th:TaskHub {_scoped_key: $taskhub}) - // lock the TaskHub to avoid other queries from changing its state while we claim - SET th._lock = True - """ + // lock the TaskHub to avoid other queries from changing its state while we claim + SET th._lock = True + """, + taskhub=str(taskhub), ) - _taskpool = tx.run(q) + _taskpool = tx.run(q, taskhub=str(taskhub)) def task_count(task_dict: dict): return sum(map(len, task_dict.values())) @@ -1797,16 +1792,21 @@ def task_count(task_dict: dict): # if tasks is not empty, proceed with claiming if tasks: - q = _generate_claim_query(tasks, compute_service_id) - tx.run(q) + tx.run( + CLAIM_QUERY, + tasks_list=[str(task) for task in tasks if task is not None], + datetimestr=str(datetime.utcnow().isoformat()), + compute_service_id=str(compute_service_id), + ) tx.run( - f""" - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}}) + """ + MATCH (th:TaskHub {_scoped_key: $taskhub}) - // remove lock on the TaskHub now that we're done with it - SET th._lock = null - """ + // remove lock on the TaskHub now that we're done with it + SET th._lock = null + """, + taskhub=str(taskhub), ) return tasks + [None] * (count - len(tasks)) @@ -1818,13 +1818,13 @@ def _validate_extends_tasks(self, task_list) -> Dict[str, Tuple[Node, str]]: if not task_list: return {} - q = f""" - UNWIND {cypher_list_from_scoped_keys(task_list)} as task - MATCH (t:Task {{`_scoped_key`: task}})-[PERFORMS]->(tf:Transformation|NonTransformation) + q = """ + UNWIND $task_list AS task + MATCH (t:Task {_scoped_key: task})-[PERFORMS]->(tf:Transformation|NonTransformation) return t, tf._scoped_key as tf_sk """ - results = self.execute_query(q) + results = self.execute_query(q, task_list=list(map(str, task_list))) nodes = {} @@ -1919,12 +1919,14 @@ def create_tasks( continue q = f""" - UNWIND {cypher_list_from_scoped_keys(transformation_subset)} as sk + UNWIND $transformation_subset AS sk MATCH (n:{node_type} {{`_scoped_key`: sk}}) RETURN n """ - results = self.execute_query(q) + results = self.execute_query( + q, transformation_subset=list(map(str, transformation_subset)) + ) transformation_nodes = {} for record in results.records: @@ -2007,14 +2009,14 @@ def get_network_tasks( self, network: ScopedKey, status: Optional[TaskStatusEnum] = None ) -> List[ScopedKey]: """List ScopedKeys for all Tasks associated with the given AlchemicalNetwork.""" - q = f""" - MATCH (an:AlchemicalNetwork {{_scoped_key: "{network}"}})-[:DEPENDS_ON]->(tf:Transformation|NonTransformation), + q = """ + MATCH (an:AlchemicalNetwork {_scoped_key: $network})-[:DEPENDS_ON]->(tf:Transformation|NonTransformation), (tf)<-[:PERFORMS]-(t:Task) """ if status is not None: - q += f""" - WHERE t.status = '{status.value}' + q += """ + WHERE t.status = $status """ q += """ @@ -2023,7 +2025,9 @@ def get_network_tasks( """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run( + q, network=str(network), status=status.value if status else None + ) for rec in res: sks.append(rec["sk"]) @@ -2031,15 +2035,15 @@ def get_network_tasks( def get_task_networks(self, task: ScopedKey) -> List[ScopedKey]: """List ScopedKeys for AlchemicalNetworks associated with the given Task.""" - q = f""" - MATCH (t:Task {{_scoped_key: '{task}'}})-[:PERFORMS]->(tf:Transformation|NonTransformation), + q = """ + MATCH (t:Task {_scoped_key: $task})-[:PERFORMS]->(tf:Transformation|NonTransformation), (tf)<-[:DEPENDS_ON]-(an:AlchemicalNetwork) WITH an._scoped_key as sk RETURN sk """ sks = [] with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, task=str(task)) for rec in res: sks.append(rec["sk"]) @@ -2068,18 +2072,18 @@ def get_transformation_tasks( extends """ - q = f""" - MATCH (trans:Transformation|NonTransformation {{_scoped_key: '{transformation}'}})<-[:PERFORMS]-(task:Task) + q = """ + MATCH (trans:Transformation|NonTransformation {_scoped_key: $transformation})<-[:PERFORMS]-(task:Task) """ if status is not None: - q += f""" - WHERE task.status = '{status.value}' + q += """ + WHERE task.status = $status """ if extends: - q += f""" - MATCH (trans)<-[:PERFORMS]-(extends:Task {{_scoped_key: '{extends}'}}) + q += """ + MATCH (trans)<-[:PERFORMS]-(extends:Task {_scoped_key: $extends}) WHERE (task)-[:EXTENDS*]->(extends) RETURN task """ @@ -2089,7 +2093,12 @@ def get_transformation_tasks( """ with self.transaction() as tx: - res = tx.run(q).to_eager_result() + res = tx.run( + q, + transformation=str(transformation), + status=status.value if status else None, + extends=str(extends) if extends else None, + ).to_eager_result() tasks = [] for record in res.records: @@ -2123,14 +2132,14 @@ def get_task_transformation( `ScopedKey`\s for these instead. """ - q = f""" - MATCH (task:Task {{_scoped_key: "{task}"}})-[:PERFORMS]->(trans:Transformation|NonTransformation) + q = """ + MATCH (task:Task {_scoped_key: $task})-[:PERFORMS]->(trans:Transformation|NonTransformation) OPTIONAL MATCH (task)-[:EXTENDS]->(prev:Task)-[:RESULTS_IN]->(result:ProtocolDAGResultRef) RETURN trans, result """ with self.transaction() as tx: - res = tx.run(q).to_eager_result() + res = tx.run(q, task=str(task)).to_eager_result() transformations = [] results = [] @@ -2229,7 +2238,9 @@ def set_task_priority( RETURN scoped_key, t """ res = tx.run( - q, scoped_keys=[str(t) for t in tasks], priority=priority + q, + scoped_keys=list(map(str, tasks)), + priority=priority, ).to_eager_result() task_results = [] @@ -2266,7 +2277,7 @@ def get_task_priority(self, tasks: List[ScopedKey]) -> List[Optional[int]]: WHERE t._scoped_key = scoped_key RETURN t.priority as priority """ - res = tx.run(q, scoped_keys=[str(t) for t in tasks]) + res = tx.run(q, scoped_keys=list(map(str, tasks))) priorities = [rec["priority"] for rec in res] return priorities @@ -2308,7 +2319,7 @@ def get_scope_status( } prop_string = ", ".join( - "{}: '{}'".format(key, value) + "{}: ${}".format(key, key) for key, value in properties.items() if value is not None ) @@ -2325,22 +2336,22 @@ def get_scope_status( RETURN n.status AS status, count(DISTINCT n) as counts """ with self.transaction() as tx: - res = tx.run(q, state_pattern=network_state) + res = tx.run(q, state_pattern=network_state, **properties) counts = {rec["status"]: rec["counts"] for rec in res} return counts def get_network_status(self, networks: List[ScopedKey]) -> List[Dict[str, int]]: """Return status counts for all Tasks associated with the given AlchemicalNetworks.""" - q = f""" - UNWIND {cypher_list_from_scoped_keys(networks)} as network - MATCH (an:AlchemicalNetwork {{_scoped_key: network}})-[:DEPENDS_ON]->(tf:Transformation|NonTransformation), + q = """ + UNWIND $networks AS network + MATCH (an:AlchemicalNetwork {_scoped_key: network})-[:DEPENDS_ON]->(tf:Transformation|NonTransformation), (tf)<-[:PERFORMS]-(t:Task) RETURN an._scoped_key AS sk, t.status AS status, count(t) as counts """ network_data = {str(network_sk): {} for network_sk in networks} - for rec in self.execute_query(q).records: + for rec in self.execute_query(q, networks=list(map(str, networks))).records: sk = rec["sk"] status = rec["status"] counts = rec["counts"] @@ -2350,12 +2361,12 @@ def get_network_status(self, networks: List[ScopedKey]) -> List[Dict[str, int]]: def get_transformation_status(self, transformation: ScopedKey) -> Dict[str, int]: """Return status counts for all Tasks associated with the given Transformation.""" - q = f""" - MATCH (:Transformation|NonTransformation {{_scoped_key: "{transformation}"}})<-[:PERFORMS]-(t:Task) + q = """ + MATCH (:Transformation|NonTransformation {_scoped_key: $transformation})<-[:PERFORMS]-(t:Task) RETURN t.status AS status, count(t) as counts """ with self.transaction() as tx: - res = tx.run(q) + res = tx.run(q, transformation=str(transformation)) counts = {rec["status"]: rec["counts"] for rec in res} return counts @@ -2507,15 +2518,15 @@ def set_task_waiting( """ - q = """ + q = f""" WITH $scoped_keys AS batch UNWIND batch AS scoped_key - OPTIONAL MATCH (t:Task {_scoped_key: scoped_key}) + OPTIONAL MATCH (t:Task {{_scoped_key: scoped_key}}) - OPTIONAL MATCH (t_:Task {_scoped_key: scoped_key}) - WHERE t_.status IN ['waiting', 'running', 'error'] - SET t_.status = 'waiting' + OPTIONAL MATCH (t_:Task {{_scoped_key: scoped_key}}) + WHERE t_.status IN ['{TaskStatusEnum.waiting.value}', '{TaskStatusEnum.running.value}', '{TaskStatusEnum.error.value}'] + SET t_.status = '{TaskStatusEnum.waiting.value}' WITH scoped_key, t, t_ @@ -2541,15 +2552,15 @@ def set_task_running( """ - q = """ + q = f""" WITH $scoped_keys AS batch UNWIND batch AS scoped_key - OPTIONAL MATCH (t:Task {_scoped_key: scoped_key}) + OPTIONAL MATCH (t:Task {{_scoped_key: scoped_key}}) - OPTIONAL MATCH (t_:Task {_scoped_key: scoped_key}) - WHERE t_.status IN ['running', 'waiting'] - SET t_.status = 'running' + OPTIONAL MATCH (t_:Task {{_scoped_key: scoped_key}}) + WHERE t_.status IN ['{TaskStatusEnum.running.value}', '{TaskStatusEnum.waiting.value}'] + SET t_.status = '{TaskStatusEnum.running.value}' RETURN scoped_key, t, t_ """ @@ -2568,15 +2579,15 @@ def set_task_complete( """ - q = """ + q = f""" WITH $scoped_keys AS batch UNWIND batch AS scoped_key - OPTIONAL MATCH (t:Task {_scoped_key: scoped_key}) + OPTIONAL MATCH (t:Task {{_scoped_key: scoped_key}}) - OPTIONAL MATCH (t_:Task {_scoped_key: scoped_key}) - WHERE t_.status IN ['complete', 'running'] - SET t_.status = 'complete' + OPTIONAL MATCH (t_:Task {{_scoped_key: scoped_key}}) + WHERE t_.status IN ['{TaskStatusEnum.complete.value}', '{TaskStatusEnum.running.value}'] + SET t_.status = '{TaskStatusEnum.complete.value}' WITH scoped_key, t, t_ @@ -2609,15 +2620,15 @@ def set_task_error( """ - q = """ + q = f""" WITH $scoped_keys AS batch UNWIND batch AS scoped_key - OPTIONAL MATCH (t:Task {_scoped_key: scoped_key}) + OPTIONAL MATCH (t:Task {{_scoped_key: scoped_key}}) - OPTIONAL MATCH (t_:Task {_scoped_key: scoped_key}) - WHERE t_.status IN ['error', 'running'] - SET t_.status = 'error' + OPTIONAL MATCH (t_:Task {{_scoped_key: scoped_key}}) + WHERE t_.status IN ['{TaskStatusEnum.error.value}', '{TaskStatusEnum.running.value}'] + SET t_.status = '{TaskStatusEnum.error.value}' WITH scoped_key, t, t_ @@ -2647,20 +2658,20 @@ def set_task_invalid( # set the status and delete the ACTIONS relationship # make sure we follow the extends chain and set all tasks to invalid # and remove actions relationships - q = """ + q = f""" WITH $scoped_keys AS batch UNWIND batch AS scoped_key - OPTIONAL MATCH (t:Task {_scoped_key: scoped_key}) + OPTIONAL MATCH (t:Task {{_scoped_key: scoped_key}}) - OPTIONAL MATCH (t_:Task {_scoped_key: scoped_key}) - WHERE NOT t_.status IN ['deleted'] - SET t_.status = 'invalid' + OPTIONAL MATCH (t_:Task {{_scoped_key: scoped_key}}) + WHERE NOT t_.status IN ['{TaskStatusEnum.deleted.value}'] + SET t_.status = '{TaskStatusEnum.invalid.value}' WITH scoped_key, t, t_ OPTIONAL MATCH (t_)<-[er:EXTENDS*]-(extends_task:Task) - SET extends_task.status = 'invalid' + SET extends_task.status = '{TaskStatusEnum.invalid.value}' WITH scoped_key, t, t_, extends_task @@ -2697,20 +2708,20 @@ def set_task_deleted( # set the status and delete the ACTIONS relationship # make sure we follow the extends chain and set all tasks to deleted # and remove actions relationships - q = """ + q = f""" WITH $scoped_keys AS batch UNWIND batch AS scoped_key - OPTIONAL MATCH (t:Task {_scoped_key: scoped_key}) + OPTIONAL MATCH (t:Task {{_scoped_key: scoped_key}}) - OPTIONAL MATCH (t_:Task {_scoped_key: scoped_key}) - WHERE NOT t_.status IN ['invalid'] - SET t_.status = 'deleted' + OPTIONAL MATCH (t_:Task {{_scoped_key: scoped_key}}) + WHERE NOT t_.status IN ['{TaskStatusEnum.invalid.value}'] + SET t_.status = '{TaskStatusEnum.deleted.value}' WITH scoped_key, t, t_ OPTIONAL MATCH (t_)<-[er:EXTENDS*]-(extends_task:Task) - SET extends_task.status = 'deleted' + SET extends_task.status = '{TaskStatusEnum.deleted.value}' WITH scoped_key, t, t_, extends_task diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index 3ec702a6..f2f25ef5 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -271,6 +271,59 @@ def test_query_transformations(self, n4js, network_tyk2, multiple_scopes): == 1 ) + def test_query_transformations_exploit(self, n4js, multiple_scopes, network_tyk2): + # This test is to show that common cypher exploits are mitigated by using parameters + + an = network_tyk2 + + n4js.assemble_network(an, multiple_scopes[0]) + n4js.assemble_network(an, multiple_scopes[1]) + + malicious_name = """'}) + WITH {_org: '', _campaign: '', _project: '', _gufe_key: ''} AS n + RETURN n + UNION + MATCH (m) DETACH DELETE m + WITH {_org: '', _campaign: '', _project: '', _gufe_key: ''} AS n + RETURN n + UNION + CREATE (mark:InjectionMark {_scoped_key: 'InjectionMark-12345-test-testcamp-testproj'}) + WITH {_org: '', _campaign: '', _project: '', _gufe_key: ''} AS n // """ + try: + n4js.query_transformations(name=malicious_name) + except AttributeError as e: + # With old _query, AttributeError would be thrown AFTER the transaction has finished, and the database is already corrupted + assert "'dict' object has no attribute 'labels'" in str(e) + assert len(n4js.query_transformations(scope=multiple_scopes[0])) == 0 + + mark_from__query = n4js._query(qualname="InjectionMark") + # Just to be double sure, check explicitly + q = """ + match (m:InjectionMark) + return m + """ + mark_explicit = n4js.execute_query(q).records + + assert len(mark_from__query) == len(mark_explicit) == 0 + + assert len(n4js.query_transformations()) == len(network_tyk2.edges) * 2 + assert len(n4js.query_transformations(scope=multiple_scopes[0])) == len( + network_tyk2.edges + ) + + assert ( + len(n4js.query_transformations(name="lig_ejm_31_to_lig_ejm_50_complex")) + == 2 + ) + assert ( + len( + n4js.query_transformations( + scope=multiple_scopes[0], name="lig_ejm_31_to_lig_ejm_50_complex" + ) + ) + == 1 + ) + def test_query_chemicalsystems(self, n4js, network_tyk2, multiple_scopes): an = network_tyk2 diff --git a/alchemiscale/tests/unit/test_models.py b/alchemiscale/tests/unit/test_models.py index c8285fbf..ba7fc389 100644 --- a/alchemiscale/tests/unit/test_models.py +++ b/alchemiscale/tests/unit/test_models.py @@ -2,7 +2,7 @@ from pydantic import ValidationError -from alchemiscale.models import Scope +from alchemiscale.models import Scope, ScopedKey @pytest.mark.parametrize( @@ -101,3 +101,35 @@ def test_scope_non_alphanumeric_invalid(scope_string): ) def test_underscore_scopes_valid(scope_string): scope = Scope.from_str(scope_string) + + +@pytest.mark.parametrize( + "gufe_key", + [ + "White Space-token", + "WhiteSpace-tok en", + "NoToken", + "Unicode-\u0027MATCH", + "CredentialedEntity) DETACH DELETE n //", + "BadPrefix-token`backtick", + ], +) +def test_gufe_key_invalid(gufe_key): + with pytest.raises(ValidationError): + ScopedKey( + gufe_key=gufe_key, org="org1", campaign="campaignA", project="projectI" + ) + + +@pytest.mark.parametrize( + "gufe_key", + [ + "ClassName-uuid4hex", + "DummyProtocol-1234567890abcdef", + "DummyProtocol-1234567890abcdef41234567890abcdef", + ], +) +def test_gufe_key_valid(gufe_key): + scoped_key = ScopedKey( + gufe_key=gufe_key, org="org1", campaign="campaignA", project="projectI" + )