Skip to content

Commit

Permalink
Fixes merging excess properties. (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 9, 2020
1 parent f063eaa commit c1c4098
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 3 deletions.
9 changes: 8 additions & 1 deletion docs/releasehistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@ Releases follow the ``major.minor.micro`` scheme recommended by
0.3.1
-----

This release ...
This release fixes a bug introduced in version 0.3.0 of this framework, whereby the default workflows for computing
excess properties could in rare cases be incorrectly merged leading to downstream protocols taking their inputs from
the wrong upstream protocol outputs.

While this bug should not affect most calculations, it is recommended that any production calculations performed
using version 0.3.0 of this framework be repeated using version 0.3.1.

Bugfixes
""""""""

* PR `#331 <https://github.com/openforcefield/openff-evaluator/pull/331>`_: Fixes merging excess properties.

0.3.0
-----

Expand Down
4 changes: 3 additions & 1 deletion openff/evaluator/properties/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def default_simulation_schema(
component_substance = ReplicatorValue(component_replicator.id)

component_protocols, _, component_stored_data = generate_simulation_protocols(
analysis.AverageObservable("extract_observable_component"),
analysis.AverageObservable(
f"extract_observable_component_{component_replicator.placeholder_id}"
),
use_target_uncertainty,
id_suffix=f"_component_{component_replicator.placeholder_id}",
n_molecules=n_molecules,
Expand Down
51 changes: 50 additions & 1 deletion openff/evaluator/tests/test_workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from openff.evaluator.protocols.groups import ConditionalGroup
from openff.evaluator.protocols.miscellaneous import DummyProtocol
from openff.evaluator.thermodynamics import ThermodynamicState
from openff.evaluator.workflow import Workflow, WorkflowResult, WorkflowSchema
from openff.evaluator.workflow import (
ProtocolGroup,
Workflow,
WorkflowResult,
WorkflowSchema,
)
from openff.evaluator.workflow.schemas import ProtocolReplicator
from openff.evaluator.workflow.utils import ProtocolPath, ReplicatorValue

Expand Down Expand Up @@ -210,3 +215,47 @@ def test_from_schema():
rebuilt_schema.outputs_to_store = UNDEFINED

assert rebuilt_schema.json(format=True) == schema.json(format=True)


def test_unique_ids():

protocol_a = DummyProtocol("protocol-a")
protocol_a.input_value = 1

group_a = ProtocolGroup("group-a")
group_a.add_protocols(protocol_a)

group_b = ProtocolGroup("group-b")
group_b.add_protocols(protocol_a)

schema = WorkflowSchema()
schema.protocol_schemas = [group_a.schema, group_b.schema]

with pytest.raises(ValueError) as error_info:
schema.validate()

assert "Several protocols in the schema have the same id" in str(error_info.value)
assert "protocol-a" in str(error_info.value)


def test_replicated_ids():

replicator = ProtocolReplicator("replicator-a")

protocol_a = DummyProtocol("protocol-a")
protocol_a.input_value = 1

group_a = ProtocolGroup(f"group-a-{replicator.placeholder_id}")
group_a.add_protocols(protocol_a)

schema = WorkflowSchema()
schema.protocol_schemas = [group_a.schema]
schema.protocol_replicators = [replicator]

with pytest.raises(ValueError) as error_info:
schema.validate()

assert (
f"The children of replicated protocol {group_a.id} must also contain the "
"replicators placeholder" in str(error_info.value)
)
93 changes: 93 additions & 0 deletions openff/evaluator/workflow/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
A collection of schemas which represent elements of a workflow.
"""
import re
from typing import Dict, Iterable, List

from openff.evaluator.attributes import UNDEFINED, Attribute, AttributeClass
from openff.evaluator.attributes.typing import is_type_subclass_of_type
Expand Down Expand Up @@ -886,6 +887,94 @@ def _validate_interfaces(self, schemas_by_id):
f"input type ({expected_input_type}) of {source_path}."
)

@classmethod
def _find_child_ids(cls, schemas_by_id: Dict[str, ProtocolSchema]) -> List[str]:
"""A function which will recursive find the ids of all protocols in a
workflow.
Parameters
----------
schemas_by_id
The protocols to find the child ids of.
"""

protocol_ids = []

for protocol_id, protocol_schema in schemas_by_id.items():

protocol_ids.append(protocol_id)

if not isinstance(protocol_schema, ProtocolGroupSchema):
continue

protocol_ids.extend(cls._find_child_ids(protocol_schema.protocol_schemas))

return protocol_ids

@classmethod
def _find_duplicates(cls, iterable: Iterable[str]) -> List[str]:
"""Returns the duplicate items in a list.
Notes
-----
* Based on the answer by moooeeeep (accessed 09/11/2020 14:56) on stack overflow
here: https://stackoverflow.com/a/9836685
"""

seen = set()
seen_add = seen.add

seen_twice = set(x for x in iterable if x in seen or seen_add(x))

return list(seen_twice)

def _validate_unique_children(self, schemas_by_id: Dict[str, ProtocolSchema]):
"""Validates that every protocol in a workflow has a unique id."""

all_protocol_ids = self._find_child_ids(schemas_by_id)
duplicate_ids = self._find_duplicates(all_protocol_ids)

if len(duplicate_ids) > 0:

raise ValueError(
f"Several protocols in the schema have the same id: "
f"{duplicate_ids}. This is currently unsupported due to issues "
f"with merging two graphs which contain duplicate ids."
)

def _validate_replicated_child_ids(self, schemas_by_id: Dict[str, ProtocolSchema]):
"""Validates that the children of replicated protocols also unique ids to
avoid issues when merging workflows."""

if self.protocol_replicators == UNDEFINED:
return

replicator_ids = [x.placeholder_id for x in self.protocol_replicators]

for protocol_id, protocol_schema in schemas_by_id.items():

if not isinstance(protocol_schema, ProtocolGroupSchema):
continue

for replicator_id in replicator_ids:

if replicator_id not in protocol_id:
continue

if any(
replicator_id not in child_id
for child_id in protocol_schema.protocol_schemas
):

raise ValueError(
f"The children of replicated protocol {protocol_id} must also "
f"contain the replicators placeholder id in their id to ensure "
f"all replicated protocols have a unique id. This is to avoid "
f"issues when mering multiple workflows."
)

self._validate_replicated_child_ids(protocol_schema.protocol_schemas)

def validate(self, attribute_type=None):

super(WorkflowSchema, self).validate(attribute_type)
Expand All @@ -896,6 +985,10 @@ def validate(self, attribute_type=None):

schemas_by_id = {x.id: x for x in self.protocol_schemas}

# Validate unique ids. This is critical to ensure correct merging.
self._validate_unique_children(schemas_by_id)
self._validate_replicated_child_ids(schemas_by_id)

# Validate the different pieces of data to populate / draw from.
self._validate_final_value(schemas_by_id)
self._validate_replicators(schemas_by_id)
Expand Down

0 comments on commit c1c4098

Please sign in to comment.