Skip to content

Commit

Permalink
Test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tingiskhan committed Nov 15, 2024
1 parent a0398eb commit 4669473
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tests/test_sktime.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from numpyro.contrib.control_flow import scan
from numpyro.distributions import Normal, TransformedDistribution, HalfNormal
from numpyro.distributions.transforms import SigmoidTransform
from xarray import DataArray

from skyro import BaseNumpyroForecaster

Expand Down Expand Up @@ -43,18 +44,17 @@ def build_model(self, y, length: int, X=None, future=0, **kwargs):

return

def select_output(self, x):
return x["y"]
def format_output(self, x, index):
return DataArray(x["y"], dims=["draw", "time"], coords={"time": index})


def test_autoregressive():
model = AutoRegressive(num_warmup=1_000, num_samples=500, seed=123)

with model.prior_predictive(output="np.ndarray"):
to_predict = np.arange(0, 100)
train = model.predict(to_predict)
samples = model.sample_prior_predictive(100)
train = samples["y"][0]

assert train.shape == to_predict.shape
assert train.shape == (100,)

train = pd.Series(train, index=pd.date_range("2024-01-01", periods=train.shape[0], freq="W"))

Expand All @@ -73,7 +73,7 @@ def test_autoregressive():
new_predictions = new_model.predict(fh)
assert new_predictions.index.equals(predictions.index)

fh = np.arange(1, 12)
fh = np.arange(-5, 12)
proba = new_model.predict_proba(fh)

assert proba.shape == (fh.shape[0], 1)

0 comments on commit 4669473

Please sign in to comment.