Skip to content

Commit

Permalink
refactor: made utility models and types private
Browse files Browse the repository at this point in the history
  • Loading branch information
MothNik committed May 20, 2024
1 parent 6ef864a commit 907baa4
Showing 16 changed files with 88 additions and 84 deletions.
2 changes: 1 addition & 1 deletion chemotools/smooth/__init__.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@

### Imports ###

from chemotools.utils.models import ( # noqa: F401
from chemotools.utils._models import ( # noqa: F401
WhittakerSmoothLambda,
WhittakerSmoothMethods,
)
2 changes: 1 addition & 1 deletion chemotools/smooth/_whittaker_smooth.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@
from sklearn.utils.validation import check_is_fitted

from chemotools.utils.check_inputs import check_input, check_weights
from chemotools.utils.types import RealNumeric
from chemotools.utils._types import RealNumeric
from chemotools.utils._whittaker_base import (
WhittakerLikeSolver,
WhittakerSmoothLambda,
2 changes: 1 addition & 1 deletion chemotools/utils/_banded_linalg.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@
from numpy.typing import ArrayLike
from scipy.linalg import lapack

from chemotools.utils.models import BandedLUFactorization
from chemotools.utils._models import BandedLUFactorization

### Type Aliases ###

File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion chemotools/utils/_whittaker_base/__init__.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@

### Imports ###

from chemotools.utils.models import ( # noqa: F401
from chemotools.utils._models import ( # noqa: F401
WhittakerSmoothLambda,
WhittakerSmoothMethods,
)
4 changes: 2 additions & 2 deletions chemotools/utils/_whittaker_base/auto_lambda/logml.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
import numpy as np

from chemotools.utils import _banded_linalg as bla
from chemotools.utils import models
from chemotools.utils import _models
from chemotools.utils._whittaker_base.auto_lambda.shared import get_smooth_wrss

### Constants ###
@@ -22,7 +22,7 @@
### Type Aliases ###

# TODO: add QR factorization
_FactorizationForLogMarginalLikelihood = models.BandedLUFactorization
_FactorizationForLogMarginalLikelihood = _models.BandedLUFactorization

### Functions ###

Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@

from scipy.optimize import minimize_scalar

from chemotools.utils.models import WhittakerSmoothLambda
from chemotools.utils._models import WhittakerSmoothLambda

### Constants ###

6 changes: 4 additions & 2 deletions chemotools/utils/_whittaker_base/auto_lambda/shared.py
Original file line number Diff line number Diff line change
@@ -11,11 +11,13 @@

import numpy as np

from chemotools.utils import models
from chemotools.utils import _models

### Type Aliases ###

_Factorization = Union[models.BandedLUFactorization, models.BandedPentapyFactorization]
_Factorization = Union[
_models.BandedLUFactorization, _models.BandedPentapyFactorization
]

### Functions ###

18 changes: 9 additions & 9 deletions chemotools/utils/_whittaker_base/initialisation.py
Original file line number Diff line number Diff line change
@@ -12,14 +12,14 @@

from chemotools.utils import _banded_linalg as bla
from chemotools.utils import finite_differences as fdiff
from chemotools.utils import models
from chemotools.utils.types import RealNumeric
from chemotools.utils import _models
from chemotools.utils._types import RealNumeric

### Type Aliases ###

_StrWhittakerSmoothMethods = Literal["fixed", "logml"]
_AllWhittakerSmoothMethods = Union[
models.WhittakerSmoothMethods, _StrWhittakerSmoothMethods
_models.WhittakerSmoothMethods, _StrWhittakerSmoothMethods
]
_WhittakerSmoothLambdaPlain = Tuple[
RealNumeric,
@@ -29,7 +29,7 @@
_LambdaSpecs = Union[
RealNumeric,
_WhittakerSmoothLambdaPlain,
models.WhittakerSmoothLambda,
_models.WhittakerSmoothLambda,
]

### Constants ###
@@ -39,7 +39,7 @@
### Functions ###


def get_checked_lambda(lam: Any) -> models.WhittakerSmoothLambda:
def get_checked_lambda(lam: Any) -> _models.WhittakerSmoothLambda:
"""
Checks the penalty weights lambda and casts it to the respective dataclass used
inside the ``WhittakerLikeSolver`` class.
@@ -48,14 +48,14 @@ def get_checked_lambda(lam: Any) -> models.WhittakerSmoothLambda:

# if lambda is already the correct dataclass, it can be returned directly since all
# the checks have already been performed
if isinstance(lam, models.WhittakerSmoothLambda):
if isinstance(lam, _models.WhittakerSmoothLambda):
return lam

# now, there are other cases to check
# Case 1: lambda is a single number
if isinstance(lam, RealNumericTypes):
return models.WhittakerSmoothLambda(
bounds=lam, method=models.WhittakerSmoothMethods.FIXED
return _models.WhittakerSmoothLambda(
bounds=lam, method=_models.WhittakerSmoothMethods.FIXED
)

# Case 2: lambda is a tuple
@@ -69,7 +69,7 @@ def get_checked_lambda(lam: Any) -> models.WhittakerSmoothLambda:
)

# otherwise, the tuple is unpacked and the dataclass is created
return models.WhittakerSmoothLambda(
return _models.WhittakerSmoothLambda(
bounds=(lam[0], lam[1]),
method=lam[2],
)
12 changes: 6 additions & 6 deletions chemotools/utils/_whittaker_base/main.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
import numpy as np

from chemotools._runtime import PENTAPY_AVAILABLE
from chemotools.utils import models
from chemotools.utils import _models
from chemotools.utils._banded_linalg import LAndUBandCounts
from chemotools.utils._whittaker_base import auto_lambda as auto
from chemotools.utils._whittaker_base import initialisation as init
@@ -113,7 +113,7 @@ def _setup_for_fit(
# the input arguments are stored and validated
self.n_data_: int = n_data
self.differences_: int = differences
self._lam_inter_: models.WhittakerSmoothLambda = init.get_checked_lambda(
self._lam_inter_: _models.WhittakerSmoothLambda = init.get_checked_lambda(
lam=lam
)
self.__child_class_name: str = child_class_name
@@ -147,7 +147,7 @@ def _setup_for_fit(
self._diff_kernel_flipped_: np.ndarray = np.ndarray([], dtype=self.__dtype)
self._penalty_mat_log_pseudo_det_: float = float("nan")
if self._lam_inter_.fit_auto and self._lam_inter_.method_used in {
models.WhittakerSmoothMethods.LOGML,
_models.WhittakerSmoothMethods.LOGML,
}:
# NOTE: the kernel is also returned with integer entries because integer
# computations can be carried out at maximum precision
@@ -178,7 +178,7 @@ def _solve(
lam: float,
b_weighted: np.ndarray,
w: Union[float, np.ndarray],
) -> tuple[np.ndarray, models.BandedSolvers, auto._Factorization]:
) -> tuple[np.ndarray, _models.BandedSolvers, auto._Factorization]:
"""
Internal wrapper for the solver methods to solve the linear system of equations
for the Whittaker-like smoother.
@@ -422,8 +422,8 @@ def _whittaker_solve(
# first, the smoothing method is specified depending on whether the penalty
# weight lambda is fitted automatically or not
smooth_method_assignment = {
models.WhittakerSmoothMethods.FIXED: self._solve_single_b_fixed_lam,
models.WhittakerSmoothMethods.LOGML: self._solve_single_b_auto_lam_logml,
_models.WhittakerSmoothMethods.FIXED: self._solve_single_b_fixed_lam,
_models.WhittakerSmoothMethods.LOGML: self._solve_single_b_auto_lam_logml,
}
smooth_method = smooth_method_assignment[self._lam_inter_.method_used]

20 changes: 11 additions & 9 deletions chemotools/utils/_whittaker_base/solvers.py
Original file line number Diff line number Diff line change
@@ -13,14 +13,16 @@

from chemotools._runtime import PENTAPY_AVAILABLE
from chemotools.utils import _banded_linalg as bla
from chemotools.utils import models
from chemotools.utils import _models

if PENTAPY_AVAILABLE:
import pentapy as pp

### Type Aliases ###

_Factorization = Union[models.BandedLUFactorization, models.BandedPentapyFactorization]
_Factorization = Union[
_models.BandedLUFactorization, _models.BandedPentapyFactorization
]

### Functions ###

@@ -71,7 +73,7 @@ def solve_ppivoted_lu(
l_and_u: bla.LAndUBandCounts,
a_banded: np.ndarray,
b_weighted: np.ndarray,
) -> tuple[np.ndarray, models.BandedLUFactorization]:
) -> tuple[np.ndarray, _models.BandedLUFactorization]:
"""
Solves the linear system of equations ``(W + lam * D.T @ D) @ x = W @ b`` with a
partially pivoted LU decomposition. This is the same as solving the linear system
@@ -107,7 +109,7 @@ def solve_normal_equations(
b_weighted: np.ndarray,
w: Union[float, np.ndarray],
pentapy_enabled: bool,
) -> tuple[np.ndarray, models.BandedSolvers, _Factorization]:
) -> tuple[np.ndarray, _models.BandedSolvers, _Factorization]:
"""
Solves the linear system of equations ``(W + lam * D.T @ D) @ x = W @ b`` where
``W`` is a diagonal matrix with the weights ``w`` on the main diagonal and ``D`` is
@@ -185,8 +187,8 @@ def solve_normal_equations(
if np.isfinite(x).all():
return (
x,
models.BandedSolvers.PENTAPY,
models.BandedPentapyFactorization(),
_models.BandedSolvers.PENTAPY,
_models.BandedPentapyFactorization(),
)

# Case 2: LU decomposition (final fallback for pentapy)
@@ -198,14 +200,14 @@ def solve_normal_equations(
)
return (
x,
models.BandedSolvers.PIVOTED_LU,
_models.BandedSolvers.PIVOTED_LU,
lub_factorization,
)

except np.linalg.LinAlgError:
available_solvers = f"{models.BandedSolvers.PIVOTED_LU}"
available_solvers = f"{_models.BandedSolvers.PIVOTED_LU}"
if pentapy_enabled:
available_solvers = f"{models.BandedSolvers.PENTAPY}, {available_solvers}"
available_solvers = f"{_models.BandedSolvers.PENTAPY}, {available_solvers}"

raise RuntimeError(
f"\nAll available solvers ({available_solvers}) failed to solve the "
Loading

0 comments on commit 907baa4

Please sign in to comment.