Skip to content

Commit

Permalink
Simplify ORM dataset ingestion (#725)
Browse files Browse the repository at this point in the history
  • Loading branch information
skearnes authored Jun 20, 2024
1 parent 2a1a11c commit 537dbac
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 27 deletions.
2 changes: 1 addition & 1 deletion ord_schema/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging


def get_logger(name: str, level: int = logging.DEBUG) -> logging.Logger:
def get_logger(name: str, level: int = logging.INFO) -> logging.Logger:
"""Creates a Logger."""
if not get_logger.initialized:
logging.basicConfig(format="%(levelname)s %(asctime)s %(filename)s:%(lineno)d: %(message)s")
Expand Down
4 changes: 1 addition & 3 deletions ord_schema/orm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ methods (this list is not an endorsement of any particular provider):
# Create a new conda environment.
conda create -n ord python=3.10
conda activate ord
conda install -c rdkit rdkit-postgresql==2020.03.3.0
conda install -c rdkit rdkit-postgresql
# Install ord-schema in this environment.
cd ord-schema
pip install .
Expand Down Expand Up @@ -143,8 +143,6 @@ connection_string = f"postgresql://{username}:{password}@{host}:{port}/{database
engine = create_engine(connection_string, future=True)
with Session(engine) as session:
add_dataset(dataset, session)
session.flush()
add_rdkit(session)
session.commit()
```

Expand Down
9 changes: 2 additions & 7 deletions ord_schema/orm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from testing.postgresql import Postgresql

from ord_schema.message_helpers import load_message
from ord_schema.orm.database import add_dataset, prepare_database, update_rdkit_ids, update_rdkit_tables
from ord_schema.orm.database import add_dataset, prepare_database
from ord_schema.proto import dataset_pb2


Expand All @@ -38,12 +38,7 @@ def test_session() -> Iterator[Session]:
rdkit_cartridge = prepare_database(engine)
with Session(engine) as session:
for dataset in datasets:
add_dataset(dataset, session)
if rdkit_cartridge:
session.flush()
update_rdkit_tables(dataset.dataset_id, session)
session.flush()
update_rdkit_ids(dataset.dataset_id, session)
add_dataset(dataset, session, rdkit_cartridge=rdkit_cartridge)
session.commit()
with Session(engine) as session:
yield session
7 changes: 6 additions & 1 deletion ord_schema/orm/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def prepare_database(engine: Engine) -> bool:
return rdkit_cartridge


def add_dataset(dataset: dataset_pb2.Dataset, session: Session) -> None:
def add_dataset(dataset: dataset_pb2.Dataset, session: Session, rdkit_cartridge: bool = True) -> None:
"""Adds a dataset to the database."""
logger.info(f"Adding dataset {dataset.dataset_id}")
start = time.time()
Expand All @@ -79,6 +79,11 @@ def add_dataset(dataset: dataset_pb2.Dataset, session: Session) -> None:
start = time.time()
session.add(mapped_dataset)
logger.info(f"session.add() took {time.time() - start:g}s")
if rdkit_cartridge:
session.flush()
update_rdkit_tables(dataset.dataset_id, session)
session.flush()
update_rdkit_ids(dataset.dataset_id, session)


def get_dataset_md5(dataset_id: str, session: Session) -> str | None:
Expand Down
5 changes: 3 additions & 2 deletions ord_schema/orm/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _get_message_contexts(
if field.type == FieldDescriptor.TYPE_MESSAGE:
if set(field.message_type.fields_by_name.keys()) == {"key", "value"}:
# Check for maps.
logger.info(f"Found map: ({descriptor.full_name}, {field.name})")
logger.debug(f"Found map: ({descriptor.full_name}, {field.name})")
field_message_type = field.message_type.fields_by_name["value"].message_type
else:
field_message_type = field.message_type
Expand All @@ -99,10 +99,11 @@ def build_mappers() -> dict[Type[Message], Type]:
Returns:
Dict mapping protocol buffer message types to mapper classes.
"""
logger.info("Building ORM mappers")
mappers = {}
parents = get_parents(dataset_pb2.Dataset)
for message_type in sorted(parents, key=lambda x: x.DESCRIPTOR.name):
logger.info(f"Building mapper for {message_type}")
logger.debug(f"Building mapper for {message_type}")
mappers[message_type] = build_mapper(message_type, parents=parents)
return mappers

Expand Down
13 changes: 1 addition & 12 deletions ord_schema/orm/scripts/add_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,7 @@

from ord_schema.logging import get_logger
from ord_schema.message_helpers import load_message
from ord_schema.orm.database import (
add_dataset,
delete_dataset,
get_connection_string,
get_dataset_md5,
update_rdkit_ids,
update_rdkit_tables,
)
from ord_schema.orm.database import add_dataset, delete_dataset, get_connection_string, get_dataset_md5
from ord_schema.proto import dataset_pb2

logger = get_logger(__name__)
Expand Down Expand Up @@ -85,10 +78,6 @@ def _add_dataset(filename: str, url: str, overwrite: bool) -> None:
logger.info(f"existing dataset {dataset.dataset_id} unchanged; skipping")
return
add_dataset(dataset, session)
session.flush()
update_rdkit_tables(dataset.dataset_id, session=session)
session.flush()
update_rdkit_ids(dataset.dataset_id, session=session)
start = time.time()
session.commit()
logger.info(f"session.commit() took {time.time() - start:g}s")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"inflection>=0.5.1",
"jinja2>=2.0.0",
"joblib>=1.0.0",
"numpy>=1.18.1",
"numpy<2",
"openpyxl>=3.0.5",
"pandas>=1.0.4",
"protobuf==4.22.3",
Expand Down

0 comments on commit 537dbac

Please sign in to comment.