diff --git a/CHANGELOG.md b/CHANGELOG.md index 4aeae0b11..94950a101 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,13 +15,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `hypothesis` strategies and roundtrip test for kernels, constraints and objectives ### Changed -- `torch` numeric types are now loaded lazily - Reorganized acquisition.py into `acquisition` subpackage -- `torch` is imported lazily in `surrogates` - Acquisition functions are now their own objects - `acquisition_function_cls` constructor parameter renamed to `acquisition_function` - User guide now explains the new objective classes - Telemetry deactivation warning is only shown to developers +- `torch`, `gpytorch` and `botorch` are lazy-loaded for improved startup time ### Removed - `model_params` attribute from `Surrogate` base class, `GaussianProcessSurrogate` and diff --git a/baybe/acquisition/__init__.py b/baybe/acquisition/__init__.py index 3eedf15ad..42acd9533 100644 --- a/baybe/acquisition/__init__.py +++ b/baybe/acquisition/__init__.py @@ -8,8 +8,6 @@ qProbabilityOfImprovement, qUpperConfidenceBound, ) -from baybe.acquisition.adapter import AdapterModel, debotorchize -from baybe.acquisition.partial import PartialAcquisitionFunction EI = ExpectedImprovement PI = ProbabilityOfImprovement @@ -35,9 +33,4 @@ "qEI", "qPI", "qUCB", - # --------------------------- - # Helpers - "debotorchize", - "AdapterModel", - "PartialAcquisitionFunction", ] diff --git a/baybe/acquisition/base.py b/baybe/acquisition/base.py index cca166b80..bfaa35192 100644 --- a/baybe/acquisition/base.py +++ b/baybe/acquisition/base.py @@ -8,7 +8,6 @@ from attrs import define -from baybe.acquisition.adapter import debotorchize from baybe.serialization.core import ( converter, get_base_structure_hook, @@ -29,6 +28,8 @@ def to_botorch(self, surrogate: Surrogate, best_f: float): """Create the botorch-ready representation of the function.""" import botorch.acquisition as botorch_acquisition + from baybe.acquisition.adapter import debotorchize + acqf_cls = getattr(botorch_acquisition, self.__class__.__name__) return debotorchize(acqf_cls)(surrogate, best_f) diff --git a/baybe/recommenders/naive.py b/baybe/recommenders/naive.py index 1f79bc35c..b09b7b7c0 100644 --- a/baybe/recommenders/naive.py +++ b/baybe/recommenders/naive.py @@ -1,13 +1,11 @@ """Naive recommender for hybrid spaces.""" import warnings -from typing import ClassVar, Optional, cast +from typing import ClassVar, Optional import pandas as pd from attrs import define, evolve, field, fields -from torch import Tensor -from baybe.acquisition import PartialAcquisitionFunction from baybe.recommenders.pure.base import PureRecommender from baybe.recommenders.pure.bayesian.base import BayesianRecommender from baybe.recommenders.pure.bayesian.sequential_greedy import ( @@ -86,6 +84,8 @@ def recommend( # noqa: D102 ) -> pd.DataFrame: # See base class. + from baybe.acquisition.partial import PartialAcquisitionFunction + if (not isinstance(self.disc_recommender, BayesianRecommender)) and ( not isinstance(self.disc_recommender, NonPredictiveRecommender) ): @@ -116,7 +116,7 @@ def recommend( # noqa: D102 # will then be attached to every discrete point when the acquisition function # is evaluated. cont_part = searchspace.continuous.samples_random(1) - cont_part_tensor = cast(Tensor, to_tensor(cont_part)).unsqueeze(-2) + cont_part_tensor = to_tensor(cont_part).unsqueeze(-2) # Get discrete candidates. The metadata flags are ignored since the search space # is hybrid @@ -151,7 +151,7 @@ def recommend( # noqa: D102 # Get one random discrete point that will be attached when evaluating the # acquisition function in the discrete space. disc_part = searchspace.discrete.comp_rep.loc[disc_rec_idx].sample(1) - disc_part_tensor = cast(Tensor, to_tensor(disc_part)).unsqueeze(-2) + disc_part_tensor = to_tensor(disc_part).unsqueeze(-2) # Setup a fresh acquisition function for the continuous recommender self.cont_recommender._setup_botorch_acqf(searchspace, train_x, train_y) diff --git a/baybe/recommenders/pure/bayesian/sequential_greedy.py b/baybe/recommenders/pure/bayesian/sequential_greedy.py index 8c5df51e1..486806531 100644 --- a/baybe/recommenders/pure/bayesian/sequential_greedy.py +++ b/baybe/recommenders/pure/bayesian/sequential_greedy.py @@ -4,7 +4,6 @@ import pandas as pd from attrs import define, field, validators -from botorch.optim import optimize_acqf, optimize_acqf_discrete, optimize_acqf_mixed from baybe.exceptions import NoMCAcquisitionFunctionError from baybe.recommenders.pure.bayesian.base import BayesianRecommender @@ -69,6 +68,8 @@ def _recommend_discrete( ) -> pd.Index: # See base class. + from botorch.optim import optimize_acqf_discrete + # determine the next set of points to be tested candidates_tensor = to_tensor(candidates_comp) try: @@ -102,7 +103,9 @@ def _recommend_continuous( batch_size: int, ) -> pd.DataFrame: # See base class. + import torch + from botorch.optim import optimize_acqf try: points, _ = optimize_acqf( @@ -161,6 +164,7 @@ def _recommend_hybrid( is chosen. """ import torch + from botorch.optim import optimize_acqf_mixed if len(candidates_comp) > 0: # Calculate the number of samples from the given percentage diff --git a/baybe/surrogates/custom.py b/baybe/surrogates/custom.py index 554d4a40a..03a4f4194 100644 --- a/baybe/surrogates/custom.py +++ b/baybe/surrogates/custom.py @@ -27,7 +27,6 @@ from baybe.surrogates.utils import batchify, catch_constant_targets from baybe.surrogates.validation import validate_custom_architecture_cls from baybe.utils.numerical import DTypeFloatONNX -from baybe.utils.torch import DTypeFloatTorch try: import onnxruntime as ort @@ -156,6 +155,8 @@ def default_model(self) -> ort.InferenceSession: def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]: import torch + from baybe.utils.torch import DTypeFloatTorch + model_inputs = { self.onnx_input_name: candidates.numpy().astype(DTypeFloatONNX) } diff --git a/baybe/surrogates/gaussian_process.py b/baybe/surrogates/gaussian_process.py index db486cf55..c154aa8d7 100644 --- a/baybe/surrogates/gaussian_process.py +++ b/baybe/surrogates/gaussian_process.py @@ -2,17 +2,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar, Optional +from typing import TYPE_CHECKING, ClassVar from attr import define, field -from botorch import fit_gpytorch_mll -from botorch.models import SingleTaskGP -from botorch.models.transforms import Normalize, Standardize -from gpytorch import ExactMarginalLogLikelihood -from gpytorch.kernels import IndexKernel, ScaleKernel -from gpytorch.likelihoods import GaussianLikelihood -from gpytorch.means import ConstantMean -from gpytorch.priors import GammaPrior from baybe.kernels import MaternKernel from baybe.kernels.base import Kernel @@ -38,7 +30,9 @@ class GaussianProcessSurrogate(Surrogate): kernel: Kernel = field(factory=MaternKernel) """The kernel used by the Gaussian Process.""" - _model: Optional[SingleTaskGP] = field(init=False, default=None) + # TODO: type should be Optional[botorch.models.SingleTaskGP] but is currently + # omitted due to: https://github.com/python-attrs/cattrs/issues/531 + _model = field(init=False, default=None) """The actual model.""" def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]: @@ -49,7 +43,10 @@ def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]: def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> None: # See base class. + import botorch + import gpytorch import torch + from gpytorch.priors import GammaPrior # identify the indexes of the task and numeric dimensions # TODO: generalize to multiple task parameters @@ -63,10 +60,10 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No # define the input and outcome transforms # TODO [Scaling]: scaling should be handled by search space object - input_transform = Normalize( + input_transform = botorch.models.transforms.Normalize( train_x.shape[1], bounds=bounds, indices=numeric_idxs ) - outcome_transform = Standardize(train_y.shape[1]) + outcome_transform = botorch.models.transforms.Standardize(train_y.shape[1]) # ---------- GP prior selection ---------- # # TODO: temporary prior choices adapted from edbo, replace later on @@ -105,7 +102,7 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No batch_shape = train_x.shape[:-2] # create GP mean - mean_module = ConstantMean(batch_shape=batch_shape) + mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape) # define the covariance module for the numeric dimensions gpytorch_kernel = self.kernel.to_gpytorch( @@ -114,7 +111,7 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No batch_shape=batch_shape, lengthscale_prior=lengthscale_prior[0], ) - base_covar_module = ScaleKernel( + base_covar_module = gpytorch.kernels.ScaleKernel( gpytorch_kernel, batch_shape=batch_shape, outputscale_prior=outputscale_prior[0], @@ -130,7 +127,7 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No if task_idx is None: covar_module = base_covar_module else: - task_covar_module = IndexKernel( + task_covar_module = gpytorch.kernels.IndexKernel( num_tasks=searchspace.n_tasks, active_dims=task_idx, rank=searchspace.n_tasks, # TODO: make controllable @@ -138,14 +135,14 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No covar_module = base_covar_module * task_covar_module # create GP likelihood - likelihood = GaussianLikelihood( + likelihood = gpytorch.likelihoods.GaussianLikelihood( noise_prior=noise_prior[0], batch_shape=batch_shape ) if noise_prior[1] is not None: likelihood.noise = torch.tensor([noise_prior[1]]) # construct and fit the Gaussian process - self._model = SingleTaskGP( + self._model = botorch.models.SingleTaskGP( train_x, train_y, input_transform=input_transform, @@ -154,5 +151,5 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No covar_module=covar_module, likelihood=likelihood, ) - mll = ExactMarginalLogLikelihood(self._model.likelihood, self._model) - fit_gpytorch_mll(mll) + mll = gpytorch.ExactMarginalLogLikelihood(self._model.likelihood, self._model) + botorch.fit_gpytorch_mll(mll) diff --git a/baybe/surrogates/utils.py b/baybe/surrogates/utils.py index 15cef0bbc..d5fb41765 100644 --- a/baybe/surrogates/utils.py +++ b/baybe/surrogates/utils.py @@ -5,13 +5,12 @@ from functools import wraps from typing import TYPE_CHECKING, Callable, ClassVar -import torch -from torch import Tensor - from baybe.scaler import DefaultScaler from baybe.searchspace import SearchSpace if TYPE_CHECKING: + from torch import Tensor + from baybe.surrogates.base import Surrogate _MIN_TARGET_STD = 1e-6 @@ -90,6 +89,8 @@ def __init__(self, *args, **kwargs): def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]: """Call the posterior function of the internal model instance.""" + import torch + mean, var = self.model._posterior(candidates) # If a joint posterior is expected but the model has been overridden by one @@ -105,6 +106,8 @@ def _fit( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> None: """Select a model based on the variance of the targets and fits it.""" + import torch + from baybe.surrogates.naive import MeanPredictionSurrogate # https://github.com/pytorch/pytorch/issues/29372 @@ -232,6 +235,8 @@ def sequential_posterior(model: Surrogate, candidates: Tensor) -> [Tensor, Tenso Returns: The mean and the covariance. """ + import torch + # If no batch dimensions are given, call the model directly if candidates.ndim == 2: return posterior(model, candidates) diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index 749ab3ae2..b5d32060e 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -3,12 +3,13 @@ from __future__ import annotations import logging -from collections.abc import Iterable, Sequence +from collections.abc import Iterable, Iterator, Sequence from typing import ( TYPE_CHECKING, Literal, Optional, Union, + overload, ) import numpy as np @@ -28,7 +29,17 @@ _logger = logging.getLogger(__name__) -def to_tensor(*dfs: pd.DataFrame) -> Union[Tensor, Iterable[Tensor]]: +@overload +def to_tensor(df: pd.DataFrame) -> Tensor: + ... + + +@overload +def to_tensor(*dfs: pd.DataFrame) -> Iterator[Tensor]: + ... + + +def to_tensor(*dfs: pd.DataFrame) -> Union[Tensor, Iterator[Tensor]]: """Convert a given set of dataframes into tensors (dropping all indices). Args: diff --git a/streamlit/surrogate_models.py b/streamlit/surrogate_models.py index 767c4937a..ca17e0d3a 100644 --- a/streamlit/surrogate_models.py +++ b/streamlit/surrogate_models.py @@ -15,7 +15,7 @@ from funcy import rpartial import streamlit as st -from baybe.acquisition import debotorchize +from baybe.acquisition.adapter import debotorchize from baybe.parameters import NumericalDiscreteParameter from baybe.searchspace import SearchSpace from baybe.surrogates import get_available_surrogates diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 000000000..286b2a715 --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,93 @@ +"""Tests for module imports.""" + +import importlib +import os +import pkgutil +import subprocess +import sys +from collections.abc import Sequence + +import pytest +from pytest import param + +pytestmark = pytest.mark.skipif( + os.environ.get("BAYBE_TEST_ENV") != "FULLTEST", + reason="Only possible in FULLTEST environment.", +) + +_EAGER_LOADING_EXIT_CODE = 42 + + +def find_modules() -> list[str]: + """Return all BayBE module names.""" + package = importlib.import_module("baybe") + return [ + name + for _, name, _ in pkgutil.walk_packages( + path=package.__path__, prefix=package.__name__ + "." + ) + ] + + +def make_import_check(modules: Sequence[str], target: str) -> str: + """Create code that tests if importing the given modules also imports the target. + + Args: + modules: The modules to be imported by the created code. + target: The target module whose presence is to be checked after the import. + + Returns: + Code that signals the presence of the target via a non-zero exit code. + """ + imports = "\n".join([f"import {module}" for module in modules]) + return "\n".join( + [ + "import sys", + f"{imports}", + f"hit = '{target}' in sys.modules.keys()", + f"exit({_EAGER_LOADING_EXIT_CODE} if hit else 0)", + ] + ) + + +@pytest.mark.parametrize("module", find_modules()) +def test_imports(module: str): + """All modules can be imported without throwing errors.""" + importlib.import_module(module) + + +WHITELISTS = { + "torch": [ + "baybe.acquisition.adapter", + "baybe.acquisition.partial", + "baybe.utils.botorch_wrapper", + "baybe.utils.torch", + ], +} +"""Modules (dict values) for which certain imports (dict keys) are permitted.""" + + +@pytest.mark.parametrize( + ("target", "whitelist"), [param(k, v, id=k) for k, v in WHITELISTS.items()] +) +def test_lazy_loading(target: str, whitelist: Sequence[str]): + """The target does not appear in the module list after loading BayBE modules.""" + all_modules = find_modules() + assert (w in all_modules for w in whitelist) + modules = [i for i in all_modules if i not in whitelist] + code = make_import_check(modules, target) + python_interpreter = sys.executable + result = subprocess.call([python_interpreter, "-c", code]) + assert result == 0 + + +@pytest.mark.parametrize( + ("target", "module"), + [param(k, m, id=f"{k}-{m}") for k, v in WHITELISTS.items() for m in v], +) +def test_whitelist_modules_are_true_positives(target, module): + """The whitelisted modules actually import the target.""" + code = make_import_check([module], target) + python_interpreter = sys.executable + result = subprocess.call([python_interpreter, "-c", code]) + assert result == _EAGER_LOADING_EXIT_CODE diff --git a/tox.ini b/tox.ini index 7dd318bf3..c2487a299 100644 --- a/tox.ini +++ b/tox.ini @@ -9,6 +9,7 @@ extras = test,chem,examples,simulation,onnx passenv = CI setenv = SMOKE_TEST = true + BAYBE_TEST_ENV = FULLTEST commands = python --version pytest -p no:warnings --cov=baybe --durations=5 {posargs} @@ -19,6 +20,7 @@ extras = test passenv = CI setenv = SMOKE_TEST = true + BAYBE_TEST_ENV = CORETEST commands = python --version pytest -p no:warnings --cov=baybe --durations=5 {posargs}