Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add foreign keys to rdkit tables #702

Merged
merged 3 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions ord_schema/orm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,28 @@
from testing.postgresql import Postgresql

from ord_schema.message_helpers import load_message
from ord_schema.orm.database import add_dataset, prepare_database, update_rdkit
from ord_schema.orm.database import add_dataset, prepare_database, update_rdkit_ids, update_rdkit_tables
from ord_schema.proto import dataset_pb2


@pytest.fixture
def test_session() -> Iterator[Session]:
dataset = load_message(
os.path.join(os.path.dirname(__file__), "testdata", "ord-nielsen-example.pbtxt"), dataset_pb2.Dataset
)
datasets = [
load_message(
os.path.join(os.path.dirname(__file__), "testdata", "ord-nielsen-example.pbtxt"), dataset_pb2.Dataset
)
]
with Postgresql() as postgres:
engine = create_engine(postgres.url(), future=True)
rdkit_cartridge = prepare_database(engine)
with Session(engine) as session:
add_dataset(dataset, session)
session.flush()
if rdkit_cartridge:
update_rdkit(dataset.dataset_id, session)
session.commit()
for dataset in datasets:
add_dataset(dataset, session)
if rdkit_cartridge:
session.flush()
update_rdkit_tables(dataset.dataset_id, session)
session.flush()
update_rdkit_ids(dataset.dataset_id, session)
session.commit()
with Session(engine) as session:
yield session
58 changes: 56 additions & 2 deletions ord_schema/orm/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,14 @@ def delete_dataset(dataset_id: str, session: Session) -> None:
logger.info(f"delete took {time.time() - start}s")


def update_rdkit(dataset_id: str, session: Session) -> None:
def update_rdkit_tables(dataset_id: str, session: Session) -> None:
"""Updates RDKit PostgreSQL cartridge data."""
# select distinct smiles, count(*) from rdkit.mols where smiles similar to '\[[A-Z][a-z]*[+-]*[0-9]*\]'
_update_rdkit_reactions(dataset_id, session)
_update_rdkit_mols(dataset_id, session)


def _update_rdkit_reactions(dataset_id: str, session: Session) -> None:
"""Updates the RDKit reactions table."""
logger.info("Updating RDKit reactions")
assert hasattr(RDKitReaction, "__table__") # Type hint.
table = RDKitReaction.__table__
Expand All @@ -118,10 +123,15 @@ def update_rdkit(dataset_id: str, session: Session) -> None:
.values(reaction=func.reaction_from_smiles(cast(table.c.reaction_smiles, CString)))
)
logger.info(f"Updating reactions took {time.time() - start:g}s")


def _update_rdkit_mols(dataset_id: str, session: Session) -> None:
"""Updates the RDKit mols table."""
logger.info("Updating RDKit mols")
assert hasattr(RDKitMol, "__table__") # Type hint.
table = RDKitMol.__table__
start = time.time()
# NOTE(skearnes): This join path will not include non-input compounds like workups, internal standards, etc.
session.execute(
insert(table)
.from_select(
Expand Down Expand Up @@ -167,3 +177,47 @@ def update_rdkit(dataset_id: str, session: Session) -> None:
.values(**{column: fp_type(table.c.mol)})
)
logger.info(f"Updating {fp_type} took {time.time() - start:g}s")


def update_rdkit_ids(dataset_id: str, session: Session) -> None:
"""Updates RDKit reaction and mol ID associations in the ORD tables."""
logger.info("Updating RDKit ID associations")
start = time.time()
# Update Reaction.
query = session.execute(
select(Mappers.Reaction.id, RDKitReaction.id)
.join(RDKitReaction, Mappers.Reaction.reaction_smiles == RDKitReaction.reaction_smiles)
.join(Mappers.Dataset)
.where(Mappers.Dataset.dataset_id == dataset_id)
)
updates = []
for ord_id, rdkit_id in query.fetchall():
updates.append({"id": ord_id, "rdkit_reaction_id": rdkit_id})
session.execute(update(Mappers.Reaction), updates)
# Update Compound.
query = session.execute(
select(Mappers.Compound.id, RDKitMol.id)
.join(RDKitMol, Mappers.Compound.smiles == RDKitMol.smiles)
.join(Mappers.ReactionInput)
.join(Mappers.Reaction)
.join(Mappers.Dataset)
.where(Mappers.Dataset.dataset_id == dataset_id)
)
updates = []
for ord_id, rdkit_id in query.fetchall():
updates.append({"id": ord_id, "rdkit_mol_id": rdkit_id})
session.execute(update(Mappers.Compound), updates)
# Update ProductCompound.
query = session.execute(
select(Mappers.ProductCompound.id, RDKitMol.id)
.join(RDKitMol, Mappers.ProductCompound.smiles == RDKitMol.smiles)
.join(Mappers.ReactionOutcome)
.join(Mappers.Reaction)
.join(Mappers.Dataset)
.where(Mappers.Dataset.dataset_id == dataset_id)
)
updates = []
for ord_id, rdkit_id in query.fetchall():
updates.append({"id": ord_id, "rdkit_mol_id": rdkit_id})
session.execute(update(Mappers.ProductCompound), updates)
logger.info(f"Updating RDKit IDs took {time.time() - start:g}s")
4 changes: 4 additions & 0 deletions ord_schema/orm/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,12 @@ def build_mapper( # pylint: disable=too-many-branches
# Serialize and store the entire Reaction proto.
attrs["proto"] = Column(LargeBinary, nullable=False)
attrs["reaction_smiles"] = Column(Text, index=True)
attrs["rdkit_reaction_id"] = Column(Integer, ForeignKey("rdkit.reactions.id"))
attrs["rdkit_reaction"] = relationship("RDKitReaction")
elif message_type in {reaction_pb2.Compound, reaction_pb2.ProductCompound}:
attrs["smiles"] = Column(Text, index=True)
attrs["rdkit_mol_id"] = Column(Integer, ForeignKey("rdkit.mols.id"))
attrs["rdkit_mol"] = relationship("RDKitMol")
elif message_type in {reaction_pb2.CompoundPreparation, reaction_pb2.CrudeComponent}:
# Add foreign key to reaction.reaction_id.
kwargs = {}
Expand Down
4 changes: 2 additions & 2 deletions ord_schema/orm/rdkit_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class RDKitMol(Base):

__tablename__ = "mols"
id = Column(Integer, primary_key=True)
smiles = Column(Text, unique=True)
smiles = Column(Text, index=True, unique=True)
mol = Column(_RDKitMol)

__table_args__ = (
Expand All @@ -168,7 +168,7 @@ class RDKitReaction(Base):

__tablename__ = "reactions"
id = Column(Integer, primary_key=True)
reaction_smiles = Column(Text, unique=True)
reaction_smiles = Column(Text, index=True, unique=True)
reaction = Column(_RDKitReaction)

__table_args__ = (
Expand Down
14 changes: 4 additions & 10 deletions ord_schema/orm/rdkit_mappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,8 @@ def test_tanimoto_operator(test_session):
select(Mappers.Reaction)
.join(Mappers.ReactionInput)
.join(Mappers.Compound)
.where(
Mappers.Compound.smiles.in_(
select(RDKitMol.smiles).where(RDKitMol.morgan_bfp % FingerprintType.MORGAN_BFP("c1ccccc1CCC(O)C"))
)
)
.join(RDKitMol)
.where(RDKitMol.morgan_bfp % FingerprintType.MORGAN_BFP("c1ccccc1CCC(O)C"))
)
results = test_session.execute(query)
assert len(results.fetchall()) == 20
Expand All @@ -40,11 +37,8 @@ def test_tanimoto(test_session, fp_type):
select(Mappers.Reaction)
.join(Mappers.ReactionInput)
.join(Mappers.Compound)
.where(
Mappers.Compound.smiles.in_(
select(RDKitMol.smiles).where(RDKitMol.tanimoto("c1ccccc1CCC(O)C", fp_type=fp_type) > 0.5)
)
)
.join(RDKitMol)
.where(RDKitMol.tanimoto("c1ccccc1CCC(O)C", fp_type=fp_type) > 0.5)
)
results = test_session.execute(query)
assert len(results.fetchall()) == 20
13 changes: 11 additions & 2 deletions ord_schema/orm/scripts/add_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@

from ord_schema.logging import get_logger
from ord_schema.message_helpers import load_message
from ord_schema.orm.database import add_dataset, delete_dataset, get_connection_string, get_dataset_md5, update_rdkit
from ord_schema.orm.database import (
add_dataset,
delete_dataset,
get_connection_string,
get_dataset_md5,
update_rdkit_ids,
update_rdkit_tables,
)
from ord_schema.proto import dataset_pb2

logger = get_logger(__name__)
Expand Down Expand Up @@ -79,7 +86,9 @@ def _add_dataset(filename: str, url: str, overwrite: bool) -> None:
return
add_dataset(dataset, session)
session.flush()
update_rdkit(dataset.dataset_id, session=session)
update_rdkit_tables(dataset.dataset_id, session=session)
session.flush()
update_rdkit_ids(dataset.dataset_id, session=session)
start = time.time()
session.commit()
logger.info(f"session.commit() took {time.time() - start:g}s")
Expand Down
Loading