Skip to content

Commit

Permalink
Enable OSX tests for ORM (#744)
Browse files Browse the repository at this point in the history
* Enable OSX tests for ORM

* fix for mac

* require cartridge

* readme
  • Loading branch information
skearnes authored Jul 28, 2024
1 parent a3822d8 commit 34eadce
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cleanup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install wheel
python -m pip install .[tests]
python -m pip install .[examples,tests]
- name: Run black
run: |
black --check .
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ jobs:
shell: bash -l {0}
run: |
# NOTE(skearnes): conda is only used for postgres (not python).
# NOTE(skearnes): rdkit-postgresql may not be available for ARM.
conda install -c rdkit rdkit-postgresql || conda install -c conda-forge postgresql
conda install -c conda-forge rdkit-postgresql
initdb
- uses: actions/setup-python@v4
with:
Expand Down
3 changes: 1 addition & 2 deletions ord_schema/orm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ methods (this list is not an endorsement of any particular provider):

```shell
# Create a new conda environment.
conda create -n ord python=3.10
conda create -n ord -c conda-forge python=3.10 rdkit rdkit-postgresql
conda activate ord
conda install -c rdkit rdkit-postgresql
# Install ord-schema in this environment.
cd ord-schema
pip install .
Expand Down
4 changes: 2 additions & 2 deletions ord_schema/orm/rdkit_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from __future__ import annotations

import os
from distutils.util import strtobool # pylint: disable=deprecated-module
from enum import Enum

from setuptools import distutils # pytype: disable=import-error
from sqlalchemy import Column, Index, Integer, Text, cast, func
from sqlalchemy.sql.expression import ColumnElement
from sqlalchemy.types import UserDefinedType
Expand All @@ -36,7 +36,7 @@

def rdkit_cartridge() -> bool:
"""Returns whether to use RDKit PostgreSQL cartridge functionality."""
return bool(strtobool(os.environ.get("ORD_POSTGRES_RDKIT", "1")))
return bool(distutils.util.strtobool(os.environ.get("ORD_POSTGRES_RDKIT", "1")))


class RDKitMol(UserDefinedType):
Expand Down
2 changes: 0 additions & 2 deletions ord_schema/orm/rdkit_mappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from ord_schema.orm.mappers import Mappers
from ord_schema.orm.rdkit_mappers import CString, FingerprintType, RDKitMol, RDKitMols, RDKitReactions

pytestmark = pytest.mark.skipif(platform.machine() != "x86_64", reason="RDKit cartridge is required")


def test_tanimoto_operator(test_session):
query = (
Expand Down
48 changes: 22 additions & 26 deletions ord_schema/orm/scripts/add_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Options:
--pattern=<str> Pattern for dataset filenames
--overwrite Update changed datasets
--url=<str> Postgres connection string
--dsn=<str> Postgres connection string
--database=<str> Database [default: orm]
--username=<str> Database username [default: postgres]
--password=<str> Database password
Expand All @@ -48,20 +48,14 @@
from ord_schema.orm import database
from ord_schema.proto import dataset_pb2

engine: Engine = None
logger = get_logger(__name__)


def initializer():
"""Initializer for child processes."""
# See https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork.
engine.dispose(close=False)


def add_dataset(filename: str, overwrite: bool) -> str:
def add_dataset(dsn: str, filename: str, overwrite: bool) -> str:
"""Adds a single dataset to the database.
Args:
dsn: Database connection string.
filename: Dataset filename.
overwrite: If True, update the dataset if the MD5 hash has changed.
Expand All @@ -73,6 +67,9 @@ def add_dataset(filename: str, overwrite: bool) -> str:
"""
logger.debug(f"Loading {filename}")
dataset = load_message(filename, dataset_pb2.Dataset)
# NOTE(skearnes): Multiprocessing is hard to get right for shared connection pools, so we don't even try; see
# https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork.
engine = create_engine(dsn)
with Session(engine) as session:
with session.begin():
dataset_md5 = database.get_dataset_md5(dataset.dataset_id, session)
Expand All @@ -92,7 +89,7 @@ def add_dataset(filename: str, overwrite: bool) -> str:
return dataset.dataset_id


def add_rdkit(dataset_id: str) -> None:
def add_rdkit(engine: Engine, dataset_id: str) -> None:
"""Updates RDKit tables."""
with Session(engine) as session:
with session.begin():
Expand All @@ -105,24 +102,22 @@ def main(**kwargs):
RDLogger.DisableLog("rdApp.*")
if kwargs["--debug"]:
get_logger(database.__name__, level=logging.DEBUG)
if kwargs["--url"]:
url = kwargs["--url"]
if kwargs["--dsn"]:
dsn = kwargs["--dsn"]
else:
url = database.get_connection_string(
dsn = database.get_connection_string(
database=kwargs["--database"],
username=kwargs["--username"],
password=kwargs["--password"] or os.environ["PGPASSWORD"],
host=kwargs["--host"],
port=int(kwargs["--port"]),
)
global engine # pylint: disable=global-statement
engine = create_engine(url)
filenames = sorted(glob(kwargs["--pattern"]))
with ProcessPoolExecutor(initializer=initializer, max_workers=int(kwargs["--n_jobs"])) as executor:
with ProcessPoolExecutor(max_workers=int(kwargs["--n_jobs"])) as executor:
logger.info("Adding datasets")
futures = {}
for filename in filenames:
future = executor.submit(add_dataset, filename=filename, overwrite=kwargs["--overwrite"])
future = executor.submit(add_dataset, dsn=dsn, filename=filename, overwrite=kwargs["--overwrite"])
futures[future] = filename
dataset_ids = []
failures = []
Expand All @@ -133,15 +128,16 @@ def main(**kwargs):
filename = futures[future]
failures.append(filename)
logger.error(f"Adding dataset {filename} failed: {error}")
logger.info("Adding RDKit functionality")
for dataset_id in tqdm(dataset_ids):
try:
add_rdkit(dataset_id) # NOTE(skearnes): Do this serially to avoid deadlocks.
except Exception as error: # pylint: disable=broad-exception-caught
failures.append(dataset_id)
logger.error(f"Adding RDKit functionality for {dataset_id} failed: {error}")
if failures:
raise RuntimeError(failures)
logger.info("Adding RDKit functionality")
engine = create_engine(dsn)
for dataset_id in tqdm(dataset_ids):
try:
add_rdkit(engine, dataset_id) # NOTE(skearnes): Do this serially to avoid deadlocks.
except Exception as error: # pylint: disable=broad-exception-caught
failures.append(dataset_id)
logger.error(f"Adding RDKit functionality for {dataset_id} failed: {error}")
if failures:
raise RuntimeError(failures)


if __name__ == "__main__":
Expand Down
6 changes: 2 additions & 4 deletions ord_schema/orm/scripts/add_datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@
import os

import docopt
import pytest

from ord_schema.orm.database import prepare_database
from ord_schema.orm.scripts import add_datasets


def test_main(test_engine):
if not prepare_database(test_engine):
pytest.skip("RDKit cartridge is required")
assert prepare_database(test_engine)
argv = [
"--url",
"--dsn",
test_engine.url,
"--pattern",
os.path.join(os.path.dirname(__file__), "..", "testdata", "ord-nielsen-example.pbtxt"),
Expand Down

0 comments on commit 34eadce

Please sign in to comment.