Skip to content

Commit

Permalink
Update ReactionRole use in get_reaction_smiles (#722)
Browse files Browse the repository at this point in the history
* Substructure operator example

* function

* smarts

* reaction smarts

* query

* query

* carbon

* func

* cleanup

* rename methods

* rename methods

* reaction smarts

* lint

* lint

* BYPRODUCT

* silence rdkit

* hash
  • Loading branch information
skearnes authored Jun 16, 2024
1 parent bbe1ff5 commit 1a7bc51
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 68 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
run: |
# NOTE(skearnes): conda is only used for postgres (not python).
# NOTE(skearnes): rdkit-postgresql may not be available for ARM.
conda install -c rdkit rdkit-postgresql==2020.03.3.0 || conda install -c conda-forge postgresql
conda install -c rdkit rdkit-postgresql || conda install -c conda-forge postgresql
initdb
- uses: actions/setup-python@v4
with:
Expand All @@ -48,7 +48,7 @@ jobs:
shell: bash -l {0}
run: |
coverage erase
pytest -n auto -vv --cov=ord_schema --durations=0 --durations-min=1
pytest -vv --cov=ord_schema --durations=0 --durations-min=1
coverage xml
- uses: codecov/codecov-action@v1

Expand Down
2 changes: 1 addition & 1 deletion ord_schema/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging


def get_logger(name: str, level: int = logging.INFO) -> logging.Logger:
def get_logger(name: str, level: int = logging.DEBUG) -> logging.Logger:
"""Creates a Logger."""
if not get_logger.initialized:
logging.basicConfig(format="%(levelname)s %(asctime)s %(filename)s:%(lineno)d: %(message)s")
Expand Down
22 changes: 11 additions & 11 deletions ord_schema/message_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def get_reaction_smiles(
message: reaction_pb2.Reaction,
generate_if_missing: bool = False,
allow_incomplete: bool = True,
allow_unspecified_roles: bool = True,
validate: bool = False,
canonical: bool = True,
) -> Optional[str]:
Expand All @@ -364,6 +365,8 @@ def get_reaction_smiles(
allow_incomplete: Boolean whether to allow "incomplete" reaction SMILES
that do not include all components (e.g. if a component does not
have a structural identifier).
allow_unspecified_roles: If True, reactants and products with the UNSPECIFIED reaction role will be included
when generating a reaction SMILES.
validate: Boolean whether to validate the reaction SMILES with rdkit.
Only used if allow_incomplete is False.
canonical: Boolean whether to return a canonicalized reaction SMILES.
Expand All @@ -386,6 +389,11 @@ def get_reaction_smiles(

reactants, agents, products = set(), set(), set()
roles = reaction_pb2.ReactionRole
reactant_roles = [roles.REACTANT]
product_roles = [roles.PRODUCT]
if allow_unspecified_roles:
reactant_roles.append(roles.UNSPECIFIED)
product_roles.append(roles.UNSPECIFIED)
for key in sorted(message.inputs):
for compound in message.inputs[key].components:
try:
Expand All @@ -396,10 +404,8 @@ def get_reaction_smiles(
raise error
if compound.reaction_role in [roles.REAGENT, roles.SOLVENT, roles.CATALYST]:
agents.add(smiles)
elif compound.reaction_role == roles.REACTANT:
elif compound.reaction_role in reactant_roles:
reactants.add(smiles)
else:
continue

for outcome in message.outcomes:
for product in outcome.products:
Expand All @@ -409,15 +415,9 @@ def get_reaction_smiles(
if allow_incomplete:
continue
raise error
if product.reaction_role == roles.PRODUCT:
if product.reaction_role in product_roles:
products.add(smiles)
elif product.reaction_role in [
roles.REAGENT,
roles.SOLVENT,
roles.CATALYST,
roles.INTERNAL_STANDARD,
]:
continue

if not allow_incomplete and (not reactants or not products):
raise ValueError("reaction must contain at least one reactant and one product")
if not reactants and not products:
Expand Down
7 changes: 6 additions & 1 deletion ord_schema/message_helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,13 @@ def test_get_reaction_smiles(self):
reactant2 = reaction.inputs["reactant2"]
reactant2.components.add(reaction_role="REACTANT").identifiers.add(value="Cc1ccccc1", type="SMILES")
reactant2.components.add(reaction_role="SOLVENT").identifiers.add(value="N", type="SMILES")
reaction.outcomes.add().products.add(reaction_role="PRODUCT").identifiers.add(value="O=C=O", type="SMILES")
reaction.outcomes.add().products.add().identifiers.add(value="O=C=O", type="SMILES")
assert message_helpers.get_reaction_smiles(reaction, generate_if_missing=True) == "Cc1ccccc1.c1ccccc1>N>O=C=O"
reaction.outcomes.add().products.add(reaction_role="PRODUCT").identifiers.add(value="O=CC=O", type="SMILES")
assert (
message_helpers.get_reaction_smiles(reaction, generate_if_missing=True, allow_unspecified_roles=False)
== "Cc1ccccc1.c1ccccc1>N>O=CC=O"
)

def test_get_reaction_smiles_failure(self):
reaction = reaction_pb2.Reaction()
Expand Down
2 changes: 2 additions & 0 deletions ord_schema/orm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

"""Base ORM objects."""
from rdkit import RDLogger
from sqlalchemy.orm import declarative_base

RDLogger.DisableLog("rdApp.*")
Base = declarative_base()
2 changes: 1 addition & 1 deletion ord_schema/orm/database_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ def test_delete_dataset(test_session):


def test_get_dataset_md5(test_session):
assert get_dataset_md5("test_dataset", test_session) == "42c687cafd247fd72d2a78c550b0b054"
assert get_dataset_md5("test_dataset", test_session) == "0343d39a98d38eb39abd69d899af2bdf"
assert get_dataset_md5("other_dataset", test_session) is None
4 changes: 3 additions & 1 deletion ord_schema/orm/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ def from_proto( # pylint: disable=too-many-branches
reaction_smiles = message_helpers.get_reaction_smiles(
message, generate_if_missing=True, allow_incomplete=False, validate=True
)
except ValueError:
except ValueError as error:
assert hasattr(message, "reaction_id") # Type hint.
logger.debug(f"Error generating reaction SMILES for {message.reaction_id}: {error}")
reaction_smiles = None
if reaction_smiles is not None:
kwargs["reaction_smiles"] = reaction_smiles.split()[0] # Handle CXSMILES.
Expand Down
44 changes: 20 additions & 24 deletions ord_schema/orm/rdkit_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from distutils.util import strtobool # pylint: disable=deprecated-module
from enum import Enum

from sqlalchemy import Column, Index, Integer, Text, func
from sqlalchemy import Column, Index, Integer, Text, cast, func
from sqlalchemy.sql.expression import ColumnElement
from sqlalchemy.types import UserDefinedType

from ord_schema.orm import Base
Expand Down Expand Up @@ -116,33 +117,13 @@ def get_col_spec(self, **kwargs):
class FingerprintType(Enum):
"""RDKit PostgreSQL fingerprint types."""

# NOTE(skearnes): Add Column and Index entries for each member to RDKitMol below.
MORGAN_BFP = func.morganbv_fp
MORGAN_SFP = func.morgan_fp

def __call__(self, *args, **kwargs):
return self.value(*args, **kwargs)

@classmethod
def get_table_args(cls) -> list:
"""Returns a list of __table_args__ for _Structure.
Each fingerprint type is given a column (name.lower()) and a corresponding index.
Returns:
List of Column and Index objects.
"""
table_args = []
for fp_type in cls:
name = fp_type.name.lower()
if name.endswith("_bfp"):
dtype = _RDKitBfp
elif name.endswith("_sfp"):
dtype = _RDKitSfp
else:
raise ValueError(f"unable to determine dtype for {name}")
table_args.extend([Column(name, dtype), Index(f"{name}_index", name, postgresql_using="gist")])
return table_args


class RDKitMol(Base):
"""Table for storing compound structures and associated RDKit cartridge data."""
Expand All @@ -151,17 +132,28 @@ class RDKitMol(Base):
id = Column(Integer, primary_key=True)
smiles = Column(Text, index=True, unique=True)
mol = Column(_RDKitMol)
morgan_bfp = Column(_RDKitBfp)
morgan_sfp = Column(_RDKitSfp)

__table_args__ = (
Index("mol_index", "mol", postgresql_using="gist"),
*FingerprintType.get_table_args(),
Index("morgan_bfp_index", "morgan_bfp", postgresql_using="gist"),
Index("morgan_sfp_index", "morgan_sfp", postgresql_using="gist"),
{"schema": "rdkit"},
)

@classmethod
def tanimoto(cls, other: str, fp_type: FingerprintType = FingerprintType.MORGAN_BFP):
def tanimoto(cls, other: str, fp_type: FingerprintType = FingerprintType.MORGAN_BFP) -> ColumnElement[float]:
return func.tanimoto_sml(getattr(cls, fp_type.name.lower()), fp_type(other))

@classmethod
def contains_substructure(cls, pattern: str) -> ColumnElement[bool]:
return func.substruct(cls.mol, pattern)

@classmethod
def matches_smarts(cls, pattern: str) -> ColumnElement[bool]:
return func.substruct(cls.mol, func.qmol_from_smarts(cast(pattern, CString)))


class RDKitReaction(Base):
"""Table for storing reaction objects and associated RDKit cartridge data."""
Expand All @@ -175,3 +167,7 @@ class RDKitReaction(Base):
Index("reaction_index", "reaction", postgresql_using="gist"),
{"schema": "rdkit"},
)

@classmethod
def matches_smarts(cls, pattern: str) -> ColumnElement[bool]:
return func.substruct(cls.reaction, func.reaction_from_smarts(cast(pattern, CString)))
117 changes: 90 additions & 27 deletions ord_schema/orm/rdkit_mappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,103 @@
# limitations under the License.

"""Tests for ord_schema.orm.rdkit_mappers."""
import platform

import pytest
from sqlalchemy import select
from sqlalchemy.exc import ProgrammingError
from sqlalchemy import func, select

from ord_schema.orm.mappers import Mappers
from ord_schema.orm.rdkit_mappers import FingerprintType, RDKitMol
from ord_schema.orm.rdkit_mappers import FingerprintType, RDKitMol, RDKitReaction

pytestmark = pytest.mark.skipif(platform.machine() != "x86_64", reason="RDKit cartridge is required")


def test_tanimoto_operator(test_session):
try:
query = (
select(Mappers.Reaction)
.join(Mappers.ReactionInput)
.join(Mappers.Compound)
.join(RDKitMol)
.where(RDKitMol.morgan_bfp % FingerprintType.MORGAN_BFP("c1ccccc1CCC(O)C"))
)
results = test_session.execute(query)
assert len(results.fetchall()) == 20
except ProgrammingError as error:
pytest.skip(f"RDKit cartridge is required: {error}")
query = (
select(Mappers.Reaction)
.join(Mappers.ReactionInput)
.join(Mappers.Compound)
.join(RDKitMol)
.where(RDKitMol.morgan_bfp % FingerprintType.MORGAN_BFP("c1ccccc1CCC(O)C"))
)
results = test_session.execute(query)
assert len(results.fetchall()) == 20


@pytest.mark.parametrize("fp_type", list(FingerprintType))
def test_tanimoto(test_session, fp_type):
try:
query = (
select(Mappers.Reaction)
.join(Mappers.ReactionInput)
.join(Mappers.Compound)
.join(RDKitMol)
.where(RDKitMol.tanimoto("c1ccccc1CCC(O)C", fp_type=fp_type) > 0.5)
)
results = test_session.execute(query)
assert len(results.fetchall()) == 20
except ProgrammingError as error:
pytest.skip(f"RDKit cartridge is required: {error}")
query = (
select(Mappers.Reaction)
.join(Mappers.ReactionInput)
.join(Mappers.Compound)
.join(RDKitMol)
.where(RDKitMol.tanimoto("c1ccccc1CCC(O)C", fp_type=fp_type) > 0.5)
)
results = test_session.execute(query)
assert len(results.fetchall()) == 20


def test_substructure_operator(test_session):
query = (
select(Mappers.Reaction)
.join(Mappers.ReactionInput)
.join(Mappers.Compound)
.join(RDKitMol)
.where(RDKitMol.mol.op("@>")("c1ccccc1CCC(O)C"))
)
results = test_session.execute(query)
assert len(results.fetchall()) == 20


def test_contains_substructure(test_session):
query = (
select(Mappers.Reaction)
.join(Mappers.ReactionInput)
.join(Mappers.Compound)
.join(RDKitMol)
.where(RDKitMol.contains_substructure("c1ccccc1CCC(O)C"))
)
results = test_session.execute(query)
assert len(results.fetchall()) == 20


def test_smarts_operator(test_session):
query = (
select(Mappers.Reaction)
.join(Mappers.ReactionInput)
.join(Mappers.Compound)
.join(RDKitMol)
.where(RDKitMol.mol.op("@>")(func.qmol_from_smarts("c1ccccc1CCC(O)[#6]")))
)
results = test_session.execute(query)
assert len(results.fetchall()) == 20


def test_matches_smarts(test_session):
query = (
select(Mappers.Reaction)
.join(Mappers.ReactionInput)
.join(Mappers.Compound)
.join(RDKitMol)
.where(RDKitMol.matches_smarts("c1ccccc1CCC(O)[#6]"))
)
results = test_session.execute(query)
assert len(results.fetchall()) == 20


def test_reaction_smarts_operator(test_session):
query = (
select(Mappers.Reaction)
.join(RDKitReaction)
.where(RDKitReaction.reaction.op("@>")(func.reaction_from_smarts("[#6:1].[#9:2]>>[#6:1][#9:2]")))
)
results = test_session.execute(query)
assert len(results.fetchall()) == 79 # One reaction has a BYPRODUCT outcome, so no reaction SMILES.


def test_reaction_matches_smarts(test_session):
query = (
select(Mappers.Reaction).join(RDKitReaction).where(RDKitReaction.matches_smarts("[#6:1].[#9:2]>>[#6:1][#9:2]"))
)
results = test_session.execute(query)
assert len(results.fetchall()) == 79 # One reaction has a BYPRODUCT outcome, so no reaction SMILES.
1 change: 1 addition & 0 deletions ord_schema/orm/testdata/ord-nielsen-example.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ reactions {
analysis_key: "19f nmr of crude"
type: IDENTITY
}
reaction_role: BYPRODUCT
}
analyses {
key: "19f nmr of crude"
Expand Down

0 comments on commit 1a7bc51

Please sign in to comment.