From 537dbacdc5071a9c3c41c881058e540e5aabc662 Mon Sep 17 00:00:00 2001 From: Steven Kearnes Date: Wed, 19 Jun 2024 21:30:20 -0400 Subject: [PATCH] Simplify ORM dataset ingestion (#725) --- ord_schema/logging.py | 2 +- ord_schema/orm/README.md | 4 +--- ord_schema/orm/conftest.py | 9 ++------- ord_schema/orm/database.py | 7 ++++++- ord_schema/orm/mappers.py | 5 +++-- ord_schema/orm/scripts/add_datasets.py | 13 +------------ setup.py | 2 +- 7 files changed, 15 insertions(+), 27 deletions(-) diff --git a/ord_schema/logging.py b/ord_schema/logging.py index ac531d647..fa19e09d7 100644 --- a/ord_schema/logging.py +++ b/ord_schema/logging.py @@ -15,7 +15,7 @@ import logging -def get_logger(name: str, level: int = logging.DEBUG) -> logging.Logger: +def get_logger(name: str, level: int = logging.INFO) -> logging.Logger: """Creates a Logger.""" if not get_logger.initialized: logging.basicConfig(format="%(levelname)s %(asctime)s %(filename)s:%(lineno)d: %(message)s") diff --git a/ord_schema/orm/README.md b/ord_schema/orm/README.md index 1371ecb39..94b4dcc22 100644 --- a/ord_schema/orm/README.md +++ b/ord_schema/orm/README.md @@ -94,7 +94,7 @@ methods (this list is not an endorsement of any particular provider): # Create a new conda environment. conda create -n ord python=3.10 conda activate ord -conda install -c rdkit rdkit-postgresql==2020.03.3.0 +conda install -c rdkit rdkit-postgresql # Install ord-schema in this environment. cd ord-schema pip install . @@ -143,8 +143,6 @@ connection_string = f"postgresql://{username}:{password}@{host}:{port}/{database engine = create_engine(connection_string, future=True) with Session(engine) as session: add_dataset(dataset, session) - session.flush() - add_rdkit(session) session.commit() ``` diff --git a/ord_schema/orm/conftest.py b/ord_schema/orm/conftest.py index 24416210b..7486e7ec9 100644 --- a/ord_schema/orm/conftest.py +++ b/ord_schema/orm/conftest.py @@ -22,7 +22,7 @@ from testing.postgresql import Postgresql from ord_schema.message_helpers import load_message -from ord_schema.orm.database import add_dataset, prepare_database, update_rdkit_ids, update_rdkit_tables +from ord_schema.orm.database import add_dataset, prepare_database from ord_schema.proto import dataset_pb2 @@ -38,12 +38,7 @@ def test_session() -> Iterator[Session]: rdkit_cartridge = prepare_database(engine) with Session(engine) as session: for dataset in datasets: - add_dataset(dataset, session) - if rdkit_cartridge: - session.flush() - update_rdkit_tables(dataset.dataset_id, session) - session.flush() - update_rdkit_ids(dataset.dataset_id, session) + add_dataset(dataset, session, rdkit_cartridge=rdkit_cartridge) session.commit() with Session(engine) as session: yield session diff --git a/ord_schema/orm/database.py b/ord_schema/orm/database.py index 78f799b6e..9b9f424c7 100644 --- a/ord_schema/orm/database.py +++ b/ord_schema/orm/database.py @@ -70,7 +70,7 @@ def prepare_database(engine: Engine) -> bool: return rdkit_cartridge -def add_dataset(dataset: dataset_pb2.Dataset, session: Session) -> None: +def add_dataset(dataset: dataset_pb2.Dataset, session: Session, rdkit_cartridge: bool = True) -> None: """Adds a dataset to the database.""" logger.info(f"Adding dataset {dataset.dataset_id}") start = time.time() @@ -79,6 +79,11 @@ def add_dataset(dataset: dataset_pb2.Dataset, session: Session) -> None: start = time.time() session.add(mapped_dataset) logger.info(f"session.add() took {time.time() - start:g}s") + if rdkit_cartridge: + session.flush() + update_rdkit_tables(dataset.dataset_id, session) + session.flush() + update_rdkit_ids(dataset.dataset_id, session) def get_dataset_md5(dataset_id: str, session: Session) -> str | None: diff --git a/ord_schema/orm/mappers.py b/ord_schema/orm/mappers.py index 7233890f5..af11160c7 100644 --- a/ord_schema/orm/mappers.py +++ b/ord_schema/orm/mappers.py @@ -80,7 +80,7 @@ def _get_message_contexts( if field.type == FieldDescriptor.TYPE_MESSAGE: if set(field.message_type.fields_by_name.keys()) == {"key", "value"}: # Check for maps. - logger.info(f"Found map: ({descriptor.full_name}, {field.name})") + logger.debug(f"Found map: ({descriptor.full_name}, {field.name})") field_message_type = field.message_type.fields_by_name["value"].message_type else: field_message_type = field.message_type @@ -99,10 +99,11 @@ def build_mappers() -> dict[Type[Message], Type]: Returns: Dict mapping protocol buffer message types to mapper classes. """ + logger.info("Building ORM mappers") mappers = {} parents = get_parents(dataset_pb2.Dataset) for message_type in sorted(parents, key=lambda x: x.DESCRIPTOR.name): - logger.info(f"Building mapper for {message_type}") + logger.debug(f"Building mapper for {message_type}") mappers[message_type] = build_mapper(message_type, parents=parents) return mappers diff --git a/ord_schema/orm/scripts/add_datasets.py b/ord_schema/orm/scripts/add_datasets.py index 04b8161ed..bd41b5260 100644 --- a/ord_schema/orm/scripts/add_datasets.py +++ b/ord_schema/orm/scripts/add_datasets.py @@ -43,14 +43,7 @@ from ord_schema.logging import get_logger from ord_schema.message_helpers import load_message -from ord_schema.orm.database import ( - add_dataset, - delete_dataset, - get_connection_string, - get_dataset_md5, - update_rdkit_ids, - update_rdkit_tables, -) +from ord_schema.orm.database import add_dataset, delete_dataset, get_connection_string, get_dataset_md5 from ord_schema.proto import dataset_pb2 logger = get_logger(__name__) @@ -85,10 +78,6 @@ def _add_dataset(filename: str, url: str, overwrite: bool) -> None: logger.info(f"existing dataset {dataset.dataset_id} unchanged; skipping") return add_dataset(dataset, session) - session.flush() - update_rdkit_tables(dataset.dataset_id, session=session) - session.flush() - update_rdkit_ids(dataset.dataset_id, session=session) start = time.time() session.commit() logger.info(f"session.commit() took {time.time() - start:g}s") diff --git a/setup.py b/setup.py index 48a2f5158..a6bdd5001 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ "inflection>=0.5.1", "jinja2>=2.0.0", "joblib>=1.0.0", - "numpy>=1.18.1", + "numpy<2", "openpyxl>=3.0.5", "pandas>=1.0.4", "protobuf==4.22.3",