Skip to content

Commit

Permalink
Feature/pass model specific arguments (#3)
Browse files Browse the repository at this point in the history
* Adds dependency and bumps

* Adds support for passing model specific parameters instead of having to serialize model etc

* Can now build model specific kernel

* Import

* Default to true

* Minor fix

* Docs and minor helper

* Minor fix, might be bad behaviour

* Not time independent

* Set name as well

* Fixes clonability

* Handles case with semi-deterministic elements

* Test fix

* Helper stuff

* Make hidden

* Fixx?

* Missed

* Bug fix

* Bug fix

* Pass cutoff

* Moves method

* WIP but simplifed API for PPC

* Ruff

* RuffRename

* Test fix

* Seems to be working

* Adds xarray

* Now uses xarray

* Uses xarray

* Revert verison and set to tag

* Sets commit as well

* bump version 0.0.1 -> 0.1.0

* Sets dynamic args

* Docs

* Arviz

* Test 3.12 as well

* Removes 3.12 as not supported

* Fix version

* Skip python version since dependencies cause it

* Use samples instead

* Skip improt

* Set default tags specific for numpyro only

* Set default tags specific for numpyro only

* Fix
  • Loading branch information
tingiskhan authored Nov 16, 2024
1 parent 80b3d0d commit 2d60bd2
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 133 deletions.
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"sktime",
"skbase",
"skpro",
"xarray",
]

[project.optional-dependencies]
Expand All @@ -35,6 +36,7 @@ dev = [
"pre-commit",
"pytest",
"coverage",
"bumpver",
]

viz = [
Expand All @@ -58,11 +60,11 @@ exclude = '''
line-length = 120

[tool.bumpver]
current_version = "0.0.1"
current_version = "0.1.0"
version_pattern = "MAJOR.MINOR.PATCH"
commit_message = "bump version {old_version} -> {new_version}"
commit = false
tag = false
commit = true
tag = true
push = false

[tool.bumpver.file_patterns]
Expand Down
2 changes: 1 addition & 1 deletion skyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .sklearn import BaseNumpyroEstimator
from .sktime import BaseNumpyroForecaster

__version__ = "0.0.1"
__version__ = "0.1.0"


__all__ = [
Expand Down
55 changes: 27 additions & 28 deletions skyro/_mixin.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import sys
from contextlib import contextmanager
from functools import cached_property
from operator import attrgetter
from random import randint
from typing import Any, Dict

import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey
from numpyro.diagnostics import summary
from numpyro.infer import MCMC, NUTS

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
from numpyro.infer.mcmc import MCMCKernel

from ._result import NumpyroResultSet
from .exc import ConvergenceError
Expand All @@ -41,8 +34,9 @@ def __init__(
num_chains: int = 1,
chain_method: str = "parallel",
seed: int = None,
progress_bar: bool = False,
progress_bar: bool = True,
kernel_kwargs: Dict[str, Any] = None,
model_kwargs: Dict[str, Any] = None,
):
self.num_samples = num_samples
self.num_warmup = num_warmup
Expand All @@ -52,10 +46,11 @@ def __init__(
self.seed = seed
self.progress_bar = progress_bar

self.model_kwargs = model_kwargs

self.result_set_: NumpyroResultSet = None

self._is_vectorized = False
self._prior_predictive = False

def reduce(self, posterior: np.ndarray) -> np.ndarray:
"""
Expand Down Expand Up @@ -83,9 +78,22 @@ def build_model(self, *args, **kwargs):
def _get_key(self) -> PRNGKey:
return PRNGKey(self.seed or randint(0, 1_000))

def build_kernel(self, **kwargs) -> MCMCKernel:
"""
Utility for building model specific kernel. Otherwise defaults to NUTS.
Args:
**kwargs: Kwargs passed in class' __init__.
Returns:
Returns a :class:`MCMCKernel`.
"""

return NUTS(self.build_model, **kwargs)

@cached_property
def mcmc(self) -> MCMC:
kernel = NUTS(self.build_model, **(self.kernel_kwargs or {}))
kernel = self.build_kernel(**(self.kernel_kwargs or {}))

mcmc = MCMC(
kernel,
Expand Down Expand Up @@ -142,23 +150,9 @@ def __setstate__(self, state):

return

@contextmanager
def prior_predictive(self, **kwargs) -> Self:
def sample_prior_predictive(self, **kwargs) -> Dict[str, np.ndarray]:
"""
Does posterior/prior predictive checking.
Returns:
Returns samples.
"""

raise NotImplementedError("abstract method")

def select_output(self, x: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""
Abstract method overridden by derived classes to format output given returned predictions.
Args:
x: Samples.
Samples from the prior predictive density.
Returns:
Returns samples.
Expand All @@ -175,8 +169,13 @@ def _process_results(self, mcmc: MCMC) -> NumpyroResultSet:
sub_samples = {k: v for k, v in samples.items() if k in sites}
s = summary(sub_samples, group_by_chain=self.group_by_chain)

# TODO: need to handle case when some of the dimensions of the variables are nan
for name, summary_ in s.items():
if (summary_["r_hat"] <= self.max_rhat).all():
# NB: some variables have deterministic elements (s.a. samples from LKJCov).
mask = np.isnan(summary_["n_eff"])
r_hat = summary_["r_hat"][~mask]

if (r_hat <= self.max_rhat).all():
continue

raise ConvergenceError(f"Parameter '{name}' did not converge!")
Expand Down
2 changes: 1 addition & 1 deletion skyro/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ def map_to_output(x: np.ndarray, y, fh: ForecastingHorizon = None, full_posterio
return pd.Series(x.reshape(-1), index=index, name=y.name)

if isinstance(y, pd.DataFrame):
return pd.DataFrame(x.reshape(-1, y.shape[-1]), columns=y.columns, index=y.index)
return pd.DataFrame(x.reshape(-1, x.shape[-1]), columns=y.columns, index=index)

return x
45 changes: 22 additions & 23 deletions skyro/sklearn.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
import sys
from contextlib import contextmanager
from typing import Any, Dict

import numpy as np
from numpyro.infer import Predictive
from skbase.base import BaseEstimator

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from skyro._mixin import BaseNumpyroMixin
from ._mixin import BaseNumpyroMixin


class BaseNumpyroEstimator(BaseNumpyroMixin, BaseEstimator):
Expand All @@ -29,6 +22,7 @@ def __init__(
seed: int = None,
progress_bar: bool = False,
kernel_kwargs: Dict[str, Any] = None,
model_kwargs: Dict[str, Any] = None,
):
super().__init__(
num_samples=num_samples,
Expand All @@ -38,6 +32,7 @@ def __init__(
seed=seed,
progress_bar=progress_bar,
kernel_kwargs=kernel_kwargs,
model_kwargs=model_kwargs,
)
BaseEstimator.__init__(self)

Expand All @@ -47,23 +42,36 @@ def build_model(self, X, y=None, **kwargs):
def fit(self, X, y=None):
key = self._get_key()

self.mcmc.run(key, X=X, y=y)
self.mcmc.run(key, X=X, y=y, **(self.model_kwargs or {}))
self.result_set_ = self._process_results(self.mcmc)

self._is_fitted = True

return

def _do_sample(self, X, **kwargs) -> Dict[str, np.ndarray]:
samples = None if self._prior_predictive else self.result_set_.get_samples(group_by_chain=False)
def _do_sample(self, X, prior_predictive: bool = False, **kwargs) -> Dict[str, np.ndarray]:
samples = None if prior_predictive else self.result_set_.get_samples(group_by_chain=False)
predictive = Predictive(
self.build_model, posterior_samples=samples, num_samples=self.num_samples if samples is None else None
)

output = predictive(self._get_key(), X=X, **kwargs)
output = predictive(self._get_key(), X=X, **(self.model_kwargs or {}), **kwargs)

return {k: np.array(v) for k, v in output.items()}

def select_output(self, x: Dict[str, np.ndarray]) -> np.ndarray:
"""
Abstract method overridden by derived classes to format output given returned predictions.
Args:
x: Samples.
Returns:
Returns samples.
"""

raise NotImplementedError("abstract method")

def predict(self, X, full_posterior: bool = False, **kwargs):
ppc = self._do_sample(X, **kwargs)

Expand All @@ -74,14 +82,5 @@ def predict(self, X, full_posterior: bool = False, **kwargs):

return self.reduce(output)

@contextmanager
def prior_predictive(self, **kwargs) -> Self:
try:
self._prior_predictive = True
yield self
except Exception:
raise
finally:
self._prior_predictive = False

return
def sample_prior_predictive(self, X, **kwargs):
return self._do_sample(X=X)
Loading

0 comments on commit 2d60bd2

Please sign in to comment.