diff --git a/ord_schema/orm/conftest.py b/ord_schema/orm/conftest.py index cd40786f..24416210 100644 --- a/ord_schema/orm/conftest.py +++ b/ord_schema/orm/conftest.py @@ -22,23 +22,28 @@ from testing.postgresql import Postgresql from ord_schema.message_helpers import load_message -from ord_schema.orm.database import add_dataset, prepare_database, update_rdkit +from ord_schema.orm.database import add_dataset, prepare_database, update_rdkit_ids, update_rdkit_tables from ord_schema.proto import dataset_pb2 @pytest.fixture def test_session() -> Iterator[Session]: - dataset = load_message( - os.path.join(os.path.dirname(__file__), "testdata", "ord-nielsen-example.pbtxt"), dataset_pb2.Dataset - ) + datasets = [ + load_message( + os.path.join(os.path.dirname(__file__), "testdata", "ord-nielsen-example.pbtxt"), dataset_pb2.Dataset + ) + ] with Postgresql() as postgres: engine = create_engine(postgres.url(), future=True) rdkit_cartridge = prepare_database(engine) with Session(engine) as session: - add_dataset(dataset, session) - session.flush() - if rdkit_cartridge: - update_rdkit(dataset.dataset_id, session) - session.commit() + for dataset in datasets: + add_dataset(dataset, session) + if rdkit_cartridge: + session.flush() + update_rdkit_tables(dataset.dataset_id, session) + session.flush() + update_rdkit_ids(dataset.dataset_id, session) + session.commit() with Session(engine) as session: yield session diff --git a/ord_schema/orm/database.py b/ord_schema/orm/database.py index c35ee575..dec27259 100644 --- a/ord_schema/orm/database.py +++ b/ord_schema/orm/database.py @@ -94,9 +94,14 @@ def delete_dataset(dataset_id: str, session: Session) -> None: logger.info(f"delete took {time.time() - start}s") -def update_rdkit(dataset_id: str, session: Session) -> None: +def update_rdkit_tables(dataset_id: str, session: Session) -> None: """Updates RDKit PostgreSQL cartridge data.""" - # select distinct smiles, count(*) from rdkit.mols where smiles similar to '\[[A-Z][a-z]*[+-]*[0-9]*\]' + _update_rdkit_reactions(dataset_id, session) + _update_rdkit_mols(dataset_id, session) + + +def _update_rdkit_reactions(dataset_id: str, session: Session) -> None: + """Updates the RDKit reactions table.""" logger.info("Updating RDKit reactions") assert hasattr(RDKitReaction, "__table__") # Type hint. table = RDKitReaction.__table__ @@ -118,10 +123,15 @@ def update_rdkit(dataset_id: str, session: Session) -> None: .values(reaction=func.reaction_from_smiles(cast(table.c.reaction_smiles, CString))) ) logger.info(f"Updating reactions took {time.time() - start:g}s") + + +def _update_rdkit_mols(dataset_id: str, session: Session) -> None: + """Updates the RDKit mols table.""" logger.info("Updating RDKit mols") assert hasattr(RDKitMol, "__table__") # Type hint. table = RDKitMol.__table__ start = time.time() + # NOTE(skearnes): This join path will not include non-input compounds like workups, internal standards, etc. session.execute( insert(table) .from_select( @@ -167,3 +177,47 @@ def update_rdkit(dataset_id: str, session: Session) -> None: .values(**{column: fp_type(table.c.mol)}) ) logger.info(f"Updating {fp_type} took {time.time() - start:g}s") + + +def update_rdkit_ids(dataset_id: str, session: Session) -> None: + """Updates RDKit reaction and mol ID associations in the ORD tables.""" + logger.info("Updating RDKit ID associations") + start = time.time() + # Update Reaction. + query = session.execute( + select(Mappers.Reaction.id, RDKitReaction.id) + .join(RDKitReaction, Mappers.Reaction.reaction_smiles == RDKitReaction.reaction_smiles) + .join(Mappers.Dataset) + .where(Mappers.Dataset.dataset_id == dataset_id) + ) + updates = [] + for ord_id, rdkit_id in query.fetchall(): + updates.append({"id": ord_id, "rdkit_reaction_id": rdkit_id}) + session.execute(update(Mappers.Reaction), updates) + # Update Compound. + query = session.execute( + select(Mappers.Compound.id, RDKitMol.id) + .join(RDKitMol, Mappers.Compound.smiles == RDKitMol.smiles) + .join(Mappers.ReactionInput) + .join(Mappers.Reaction) + .join(Mappers.Dataset) + .where(Mappers.Dataset.dataset_id == dataset_id) + ) + updates = [] + for ord_id, rdkit_id in query.fetchall(): + updates.append({"id": ord_id, "rdkit_mol_id": rdkit_id}) + session.execute(update(Mappers.Compound), updates) + # Update ProductCompound. + query = session.execute( + select(Mappers.ProductCompound.id, RDKitMol.id) + .join(RDKitMol, Mappers.ProductCompound.smiles == RDKitMol.smiles) + .join(Mappers.ReactionOutcome) + .join(Mappers.Reaction) + .join(Mappers.Dataset) + .where(Mappers.Dataset.dataset_id == dataset_id) + ) + updates = [] + for ord_id, rdkit_id in query.fetchall(): + updates.append({"id": ord_id, "rdkit_mol_id": rdkit_id}) + session.execute(update(Mappers.ProductCompound), updates) + logger.info(f"Updating RDKit IDs took {time.time() - start:g}s") diff --git a/ord_schema/orm/mappers.py b/ord_schema/orm/mappers.py index 319b38f0..b386fa75 100644 --- a/ord_schema/orm/mappers.py +++ b/ord_schema/orm/mappers.py @@ -176,8 +176,12 @@ def build_mapper( # pylint: disable=too-many-branches # Serialize and store the entire Reaction proto. attrs["proto"] = Column(LargeBinary, nullable=False) attrs["reaction_smiles"] = Column(Text, index=True) + attrs["rdkit_reaction_id"] = Column(Integer, ForeignKey("rdkit.reactions.id")) + attrs["rdkit_reaction"] = relationship("RDKitReaction") elif message_type in {reaction_pb2.Compound, reaction_pb2.ProductCompound}: attrs["smiles"] = Column(Text, index=True) + attrs["rdkit_mol_id"] = Column(Integer, ForeignKey("rdkit.mols.id")) + attrs["rdkit_mol"] = relationship("RDKitMol") elif message_type in {reaction_pb2.CompoundPreparation, reaction_pb2.CrudeComponent}: # Add foreign key to reaction.reaction_id. kwargs = {} diff --git a/ord_schema/orm/rdkit_mappers.py b/ord_schema/orm/rdkit_mappers.py index 5db7ede9..c14f709a 100644 --- a/ord_schema/orm/rdkit_mappers.py +++ b/ord_schema/orm/rdkit_mappers.py @@ -149,7 +149,7 @@ class RDKitMol(Base): __tablename__ = "mols" id = Column(Integer, primary_key=True) - smiles = Column(Text, unique=True) + smiles = Column(Text, index=True, unique=True) mol = Column(_RDKitMol) __table_args__ = ( @@ -168,7 +168,7 @@ class RDKitReaction(Base): __tablename__ = "reactions" id = Column(Integer, primary_key=True) - reaction_smiles = Column(Text, unique=True) + reaction_smiles = Column(Text, index=True, unique=True) reaction = Column(_RDKitReaction) __table_args__ = ( diff --git a/ord_schema/orm/rdkit_mappers_test.py b/ord_schema/orm/rdkit_mappers_test.py index cd4e7faa..6c00f255 100644 --- a/ord_schema/orm/rdkit_mappers_test.py +++ b/ord_schema/orm/rdkit_mappers_test.py @@ -24,11 +24,8 @@ def test_tanimoto_operator(test_session): select(Mappers.Reaction) .join(Mappers.ReactionInput) .join(Mappers.Compound) - .where( - Mappers.Compound.smiles.in_( - select(RDKitMol.smiles).where(RDKitMol.morgan_bfp % FingerprintType.MORGAN_BFP("c1ccccc1CCC(O)C")) - ) - ) + .join(RDKitMol) + .where(RDKitMol.morgan_bfp % FingerprintType.MORGAN_BFP("c1ccccc1CCC(O)C")) ) results = test_session.execute(query) assert len(results.fetchall()) == 20 @@ -40,11 +37,8 @@ def test_tanimoto(test_session, fp_type): select(Mappers.Reaction) .join(Mappers.ReactionInput) .join(Mappers.Compound) - .where( - Mappers.Compound.smiles.in_( - select(RDKitMol.smiles).where(RDKitMol.tanimoto("c1ccccc1CCC(O)C", fp_type=fp_type) > 0.5) - ) - ) + .join(RDKitMol) + .where(RDKitMol.tanimoto("c1ccccc1CCC(O)C", fp_type=fp_type) > 0.5) ) results = test_session.execute(query) assert len(results.fetchall()) == 20 diff --git a/ord_schema/orm/scripts/add_datasets.py b/ord_schema/orm/scripts/add_datasets.py index c9063003..e2832e37 100644 --- a/ord_schema/orm/scripts/add_datasets.py +++ b/ord_schema/orm/scripts/add_datasets.py @@ -43,7 +43,14 @@ from ord_schema.logging import get_logger from ord_schema.message_helpers import load_message -from ord_schema.orm.database import add_dataset, delete_dataset, get_connection_string, get_dataset_md5, update_rdkit +from ord_schema.orm.database import ( + add_dataset, + delete_dataset, + get_connection_string, + get_dataset_md5, + update_rdkit_ids, + update_rdkit_tables, +) from ord_schema.proto import dataset_pb2 logger = get_logger(__name__) @@ -79,7 +86,9 @@ def _add_dataset(filename: str, url: str, overwrite: bool) -> None: return add_dataset(dataset, session) session.flush() - update_rdkit(dataset.dataset_id, session=session) + update_rdkit_tables(dataset.dataset_id, session=session) + session.flush() + update_rdkit_ids(dataset.dataset_id, session=session) start = time.time() session.commit() logger.info(f"session.commit() took {time.time() - start:g}s")