Skip to content

Commit

Permalink
Implement check_icdf helper to test icdf implementations
Browse files Browse the repository at this point in the history
Note that adding a nan switch to the icdf expression of discrete variables, prevents the returned dtype to be the same as the original distribution. There is no integer nan!
  • Loading branch information
ricardoV94 committed Mar 13, 2023
1 parent f043ad9 commit 4da5edf
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 47 deletions.
10 changes: 9 additions & 1 deletion pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def polyagamma_cdf(*args, **kwargs):
from pymc.distributions import transforms
from pymc.distributions.dist_math import (
SplineWrapper,
check_icdf_parameters,
check_icdf_value,
check_parameters,
clipped_beta_rvs,
i0e,
Expand Down Expand Up @@ -532,7 +534,13 @@ def logcdf(value, mu, sigma):
)

def icdf(value, mu, sigma):
return mu + sigma * -np.sqrt(2.0) * at.erfcinv(2 * value)
res = mu + sigma * -np.sqrt(2.0) * at.erfcinv(2 * value)
res = check_icdf_value(res, value)
return check_icdf_parameters(
res,
sigma > 0,
msg="sigma > 0",
)


class TruncatedNormalRV(RandomVariable):
Expand Down
11 changes: 10 additions & 1 deletion pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from pymc.distributions.dist_math import (
betaln,
binomln,
check_icdf_parameters,
check_icdf_value,
check_parameters,
factln,
log_diff_normal_cdf,
Expand Down Expand Up @@ -820,7 +822,14 @@ def logcdf(value, p):
)

def icdf(value, p):
return at.ceil(at.log1p(-value) / at.log1p(-p)).astype("int64")
res = at.ceil(at.log1p(-value) / at.log1p(-p)).astype("int64")
res = check_icdf_value(res, value)
return check_icdf_parameters(
res,
0 <= p,
p <= 1,
msg="0 <= p <= 1",
)


class HyperGeometric(Discrete):
Expand Down
16 changes: 16 additions & 0 deletions pymc/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""
import warnings

from functools import partial
from typing import Iterable

import numpy as np
Expand Down Expand Up @@ -77,6 +78,21 @@ def check_parameters(
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)


check_icdf_parameters = partial(check_parameters, can_be_replaced_by_ninf=False)


def check_icdf_value(expr: Variable, value: Variable) -> Variable:
"""Wrap icdf expression in nan switch for value."""
value = at.as_tensor_variable(value)
expr = at.switch(
at.and_(value >= 0, value <= 1),
expr,
np.nan,
)
expr.name = "0 <= value <= 1"
return expr


def logpow(x, m):
"""
Calculates log(x**m) since m*log(x) will fail when m, x = 0.
Expand Down
92 changes: 92 additions & 0 deletions pymc/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from pymc.distributions.shape_utils import change_dist_size
from pymc.initial_point import make_initial_point_fn
from pymc.logprob import joint_logp
from pymc.logprob.abstract import icdf
from pymc.logprob.utils import ParameterValueError
from pymc.pytensorf import (
compile_pymc,
Expand Down Expand Up @@ -520,6 +521,97 @@ def check_logcdf(
)


def check_icdf(
pymc_dist: Distribution,
paramdomains: Dict[str, Domain],
scipy_icdf: Callable,
decimal: Optional[int] = None,
n_samples: int = 100,
) -> None:
"""
Generic test for PyMC icdf methods
The following tests are performed by default:
1. Test PyMC icdf and equivalent scipy icdf (ppf) methods give similar
results for parameters inside the supported edges.
Edges are excluded by default, but can be artificially included by
creating a domain with repeated values (e.g., `Domain([0, 0, .5, 1, 1]`)
2. Test PyMC icdf method raises for invalid parameter values
outside the supported edges.
3. Test PyMC icdf method returns np.nan for values below 0 or above 1,
when using valid parameters.
Parameters
----------
pymc_dist: PyMC distribution
paramdomains : Dictionary of Parameter : Domain pairs
Supported domains of distribution parameters
scipy_icdf : Scipy icdf method
Scipy icdf (ppp) method of equivalent pymc_dist distribution
decimal : int, optional
Level of precision with which pymc_dist and scipy_icdf are compared.
Defaults to 6 for float64 and 3 for float32
n_samples : int
Upper limit on the number of valid domain and value combinations that
are compared between pymc and scipy methods. If n_samples is below the
total number of combinations, a random subset is evaluated. Setting
n_samples = -1, will return all possible combinations. Defaults to 100
"""
if decimal is None:
decimal = select_by_precision(float64=6, float32=3)

dist = create_dist_from_paramdomains(pymc_dist, paramdomains)
q = pt.scalar(dtype="float64", name="q")
dist_icdf = icdf(dist, q)
pymc_icdf = pytensor.function(list(inputvars(dist_icdf)), dist_icdf)

# Test pymc and scipy distributions match for values and parameters
# within the supported domain edges (excluding edges)
domains = paramdomains.copy()
domain = Domain([0, 0.1, 0.5, 0.75, 0.95, 0.99, 1]) # Values we test the icdf at
domains["q"] = domain

for point in product(domains, n_samples=n_samples):
point = dict(point)
npt.assert_almost_equal(
pymc_icdf(**point),
scipy_icdf(**point),
decimal=decimal,
err_msg=str(point),
)

valid_value = domain.vals[0]
valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
valid_params["q"] = valid_value

# Test pymc distribution raises ParameterValueError for parameters outside the
# supported domain edges (excluding edges)
invalid_params = find_invalid_scalar_params(paramdomains)
for invalid_param, invalid_edges in invalid_params.items():
for invalid_edge in invalid_edges:
if invalid_edge is None:
continue

point = valid_params.copy()
point[invalid_param] = invalid_edge
with pytest.raises(ParameterValueError):
pymc_icdf(**point)
pytest.fail(f"test_params={point}")

# Test that values below 0 or above 1 evaluate to nan
invalid_values = find_invalid_scalar_params({"q": domain})["q"]
for invalid_value in invalid_values:
if invalid_value is not None:
point = valid_params.copy()
point["q"] = invalid_value
npt.assert_equal(
pymc_icdf(**point),
np.nan,
err_msg=str(point),
)


def check_selfconsistency_discrete_logcdf(
distribution: Distribution,
domain: Domain,
Expand Down
24 changes: 6 additions & 18 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Runif,
Unit,
assert_moment_is_expected,
check_icdf,
check_logcdf,
check_logp,
continuous_random_tester,
Expand Down Expand Up @@ -270,6 +271,11 @@ def test_normal(self):
lambda value, mu, sigma: st.norm.logcdf(value, mu, sigma),
decimal=select_by_precision(float64=6, float32=1),
)
check_icdf(
pm.Normal,
{"mu": R, "sigma": Rplus},
lambda q, mu, sigma: st.norm.ppf(q, mu, sigma),
)

def test_half_normal(self):
check_logp(
Expand Down Expand Up @@ -2269,21 +2275,3 @@ def dist(cls, **kwargs):
extra_args={"rng": pytensor.shared(rng)},
ref_rand=ref_rand,
)


class TestICDF:
@pytest.mark.parametrize(
"dist_params, obs, size",
[
((0, 1), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), ()),
((-1, 20), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), ()),
((-1, 20), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), (2, 3)),
],
)
def test_normal_icdf(self, dist_params, obs, size):
dist_params_at, obs_at, size_at = create_pytensor_params(dist_params, obs, size)
dist_params = dict(zip(dist_params_at, dist_params))

x = Normal.dist(*dist_params_at, size=size_at)

scipy_logprob_tester(x, obs, dist_params, test_fn=st.norm.ppf, test="icdf")
32 changes: 6 additions & 26 deletions tests/distributions/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
UnitSortedVector,
Vector,
assert_moment_is_expected,
check_icdf,
check_logcdf,
check_logp,
check_selfconsistency_discrete_logcdf,
Expand Down Expand Up @@ -143,6 +144,11 @@ def test_geometric(self):
Nat,
{"p": Unit},
)
check_icdf(
pm.Geometric,
{"p": Unit},
st.geom.ppf,
)

def test_hypergeometric(self):
def modified_scipy_hypergeom_logcdf(value, N, k, n):
Expand Down Expand Up @@ -1148,29 +1154,3 @@ def test_shape_inputs(self, eta, cutpoints, sigma, expected):
)
p = categorical.owner.inputs[3].eval()
assert p.shape == expected


class TestICDF:
@pytest.mark.parametrize(
"dist_params, obs, size",
[
((0.1,), np.array([-0.5, 0, 0.1, 0.5, 0.9, 1.0, 1.5], dtype=np.int64), ()),
((0.5,), np.array([-0.5, 0, 0.1, 0.5, 0.9, 1.0, 1.5], dtype=np.int64), (3, 2)),
(
(np.array([0.0, 0.2, 0.5, 1.0]),),
np.array([0.7, 0.7, 0.7, 0.7], dtype=np.int64),
(),
),
],
)
def test_geometric_icdf(self, dist_params, obs, size):
dist_params_at, obs_at, size_at = create_pytensor_params(dist_params, obs, size)
dist_params = dict(zip(dist_params_at, dist_params))

x = Geometric.dist(*dist_params_at, size=size_at)

def scipy_geom_icdf(value, p):
# Scipy ppf returns floats
return st.geom.ppf(value, p).astype(value.dtype)

scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_geom_icdf, test="icdf")
1 change: 0 additions & 1 deletion tests/distributions/test_truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def test_truncation_discrete_random(op_type, lower, upper):
x = geometric_op(p, name="x", size=500)
xt = Truncated.dist(x, lower=lower, upper=upper)
assert isinstance(xt.owner.op, TruncatedRV)
assert xt.type.dtype == x.type.dtype

xt_draws = draw(xt)
assert np.all(xt_draws >= lower)
Expand Down

0 comments on commit 4da5edf

Please sign in to comment.