From 5becc4977602f7c996c4d235bd7a0dce5f5ded1e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 13 Mar 2023 12:16:27 +0100 Subject: [PATCH] Simplify `check_logp` and related testing helpers --- pymc/testing.py | 351 +++++++++++++++++++++++------------------------- 1 file changed, 165 insertions(+), 186 deletions(-) diff --git a/pymc/testing.py b/pymc/testing.py index e56f3fe88d5..999fea1eeee 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -14,7 +14,7 @@ import functools as ft import itertools as it -from typing import Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import pytensor @@ -26,13 +26,14 @@ from pytensor.compile.mode import Mode from pytensor.graph.basic import ancestors from pytensor.graph.rewriting.basic import in2out +from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable from scipy import special as sp from scipy import stats as st import pymc as pm -from pymc import logcdf, logp +from pymc import Distribution, logcdf, logp from pymc.distributions.shape_utils import change_dist_size from pymc.initial_point import make_initial_point_fn from pymc.logprob import joint_logp @@ -40,6 +41,7 @@ from pymc.pytensorf import ( compile_pymc, floatX, + inputvars, intX, local_check_parameter_to_ninf_switch, ) @@ -246,17 +248,69 @@ def build_model(distfam, valuedomain, vardomains, extra_args=None): return m, param_vars +def create_dist_from_paramdomains( + pymc_dist: Distribution, + paramdomains: Dict[str, Domain], + extra_args: Optional[Dict[str, Any]] = None, +) -> TensorVariable: + """Create a PyMC distribution from a dictionary of parameter domains. + + Returns + ------- + PyMC distribution variable: TensorVariable + Value variable: TensorVariable + """ + if extra_args is None: + extra_args = {} + + param_vars = {} + for param, domain in paramdomains.items(): + param_type = pt.constant(np.asarray(domain.vals[0])).type() + param_type.name = param + param_vars[param] = param_type + + return pymc_dist.dist(**param_vars, **extra_args) + + +def find_invalid_scalar_params( + paramdomains: Dict["str", Domain] +) -> Dict["str", Tuple[Union[None, float], Union[None, float]]]: + """Find invalid parameter values from bounded scalar parameter domains. + + For use in `check_logp`-like testing helpers. + + Returns + ------- + Invalid paramemeter values: + Dictionary mapping each parameter, to a lower and upper invalid values (out of domain). + If no lower or upper invalid values exist, None is returned for that entry. + """ + invalid_params = {} + for param, paramdomain in paramdomains.items(): + lower_edge, upper_edge = None, None + + if np.ndim(paramdomain.lower) == 0: + if np.isfinite(paramdomain.lower): + lower_edge = paramdomain.lower - 1 + + if np.isfinite(paramdomain.upper): + upper_edge = paramdomain.upper + 1 + + invalid_params[param] = (lower_edge, upper_edge) + return invalid_params + + def check_logp( - pymc_dist, - domain, - paramdomains, - scipy_logp, - decimal=None, - n_samples=100, - extra_args=None, - scipy_args=None, - skip_paramdomain_outside_edge_test=False, -): + pymc_dist: Distribution, + domain: Domain, + paramdomains: Dict[str, Domain], + scipy_logp: Callable, + decimal: Optional[int] = None, + n_samples: int = 100, + extra_args: Optional[Dict[str, Any]] = None, + scipy_args: Optional[Dict[str, Any]] = None, + skip_paramdomain_outside_edge_test: bool = False, +) -> None: """ Generic test for PyMC logp methods @@ -291,122 +345,77 @@ def check_logp( if decimal is None: decimal = select_by_precision(float64=6, float32=3) - if extra_args is None: - extra_args = {} - if scipy_args is None: scipy_args = {} - def logp_reference(args): + def scipy_logp_with_scipy_args(**args): args.update(scipy_args) return scipy_logp(**args) - def _model_input_dict(model, param_vars, point): - """Create a dict with only the necessary, transformed logp inputs.""" - pt_d = {} - for k, v in point.items(): - rv_var = model.named_vars.get(k) - nv = param_vars.get(k, rv_var) - nv = model.rvs_to_values.get(nv, nv) - - transform = model.rvs_to_transforms.get(rv_var, None) - if transform: - # todo: the compiled graph behind this should be cached and - # reused (if it isn't already). - v = transform.forward(rv_var, v).eval() - - if nv.name in param_vars: - # update the shared parameter variables in `param_vars` - param_vars[nv.name].set_value(v) - else: - # create an argument entry for the (potentially - # transformed) "value" variable - pt_d[nv.name] = v - - return pt_d - - model, param_vars = build_model(pymc_dist, domain, paramdomains, extra_args) - logp_pymc = model.compile_logp(jacobian=False) + dist = create_dist_from_paramdomains(pymc_dist, paramdomains, extra_args) + value = dist.type() + value.name = "value" + pymc_dist_logp = logp(dist, value).sum() + pymc_logp = pytensor.function(list(inputvars(pymc_dist_logp)), pymc_dist_logp) - # Test supported value and parameters domain matches scipy + # Test supported value and parameters domain matches Scipy domains = paramdomains.copy() domains["value"] = domain for point in product(domains, n_samples=n_samples): point = dict(point) - pt_d = _model_input_dict(model, param_vars, point) - pt_logp = pm.Point(pt_d, model=model) - pt_ref = pm.Point(point, filter_model_vars=False, model=model) npt.assert_almost_equal( - logp_pymc(pt_logp), - logp_reference(pt_ref), + pymc_logp(**point), + scipy_logp_with_scipy_args(**point), decimal=decimal, err_msg=str(point), ) valid_value = domain.vals[0] valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()} - valid_dist = pymc_dist.dist(**valid_params, **extra_args) + valid_params["value"] = valid_value # Test pymc distribution raises ParameterValueError for scalar parameters outside # the supported domain edges (excluding edges) if not skip_paramdomain_outside_edge_test: - # Step1: collect potential invalid parameters - invalid_params = {param: [None, None] for param in paramdomains} - for param, paramdomain in paramdomains.items(): - if np.ndim(paramdomain.lower) != 0: - continue - if np.isfinite(paramdomain.lower): - invalid_params[param][0] = paramdomain.lower - 1 - if np.isfinite(paramdomain.upper): - invalid_params[param][1] = paramdomain.upper + 1 + invalid_params = find_invalid_scalar_params(paramdomains) - # Step2: test invalid parameters, one a time for invalid_param, invalid_edges in invalid_params.items(): for invalid_edge in invalid_edges: if invalid_edge is None: continue - test_params = valid_params.copy() # Shallow copy should be okay - test_params[invalid_param] = pt.as_tensor_variable(invalid_edge) - # We need to remove `Assert`s introduced by checks like - # `assert_negative_support` and disable test values; - # otherwise, we won't be able to create the `RandomVariable` - with pytensor.config.change_flags(compute_test_value="off"): - invalid_dist = pymc_dist.dist(**test_params, **extra_args) - with pytensor.config.change_flags(mode=Mode("py")): - with pytest.raises(ParameterValueError): - logp(invalid_dist, valid_value).eval() - pytest.fail(f"test_params={test_params}, valid_value={valid_value}") + + point = valid_params.copy() # Shallow copy should be okay + point[invalid_param] = invalid_edge + with pytest.raises(ParameterValueError): + pymc_logp(**point) + pytest.fail(f"test_params={point}") # Test that values outside of scalar domain support evaluate to -np.inf - if np.ndim(domain.lower) != 0: - return - invalid_values = [None, None] - if np.isfinite(domain.lower): - invalid_values[0] = domain.lower - 1 - if np.isfinite(domain.upper): - invalid_values[1] = domain.upper + 1 + invalid_values = find_invalid_scalar_params({"value": domain})["value"] for invalid_value in invalid_values: if invalid_value is None: continue - with pytensor.config.change_flags(mode=Mode("py")): - npt.assert_equal( - logp(valid_dist, invalid_value).eval(), - -np.inf, - err_msg=str(invalid_value), - ) + + point = valid_params.copy() + point["value"] = invalid_value + npt.assert_equal( + pymc_logp(**point), + -np.inf, + err_msg=str(point), + ) def check_logcdf( - pymc_dist, - domain, - paramdomains, - scipy_logcdf, - decimal=None, - n_samples=100, - skip_paramdomain_inside_edge_test=False, - skip_paramdomain_outside_edge_test=False, -): + pymc_dist: Distribution, + domain: Domain, + paramdomains: Dict[str, Domain], + scipy_logcdf: Callable, + decimal: Optional[int] = None, + n_samples: int = 100, + skip_paramdomain_inside_edge_test: bool = False, + skip_paramdomain_outside_edge_test: bool = False, +) -> None: """ Generic test for PyMC logcdf methods @@ -448,133 +457,103 @@ def check_logcdf( returns -inf for invalid parameter values outside the supported domain edge """ + if decimal is None: + decimal = select_by_precision(float64=6, float32=3) + + dist = create_dist_from_paramdomains(pymc_dist, paramdomains) + value = dist.type() + value.name = "value" + dist_logcdf = logcdf(dist, value) + pymc_logcdf = pytensor.function(list(inputvars(dist_logcdf)), dist_logcdf) + # Test pymc and scipy distributions match for values and parameters # within the supported domain edges (excluding edges) if not skip_paramdomain_inside_edge_test: domains = paramdomains.copy() domains["value"] = domain - - model, param_vars = build_model(pymc_dist, domain, paramdomains) - rv = model["value"] - value = model.rvs_to_values[rv] - pymc_logcdf = model.compile_fn(logcdf(rv, value)) - - if decimal is None: - decimal = select_by_precision(float64=6, float32=3) - for point in product(domains, n_samples=n_samples): - params = dict(point) - scipy_eval = scipy_logcdf(**params) - - value = params.pop("value") - # Update shared parameter variables in pymc_logcdf function - for param_name, param_value in params.items(): - param_vars[param_name].set_value(param_value) - pymc_eval = pymc_logcdf({"value": value}) - - params["value"] = value # for displaying in err_msg + point = dict(point) npt.assert_almost_equal( - pymc_eval, - scipy_eval, + pymc_logcdf(**point), + scipy_logcdf(**point), decimal=decimal, - err_msg=str(params), + err_msg=str(point), ) valid_value = domain.vals[0] valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()} - valid_dist = pymc_dist.dist(**valid_params) + valid_params["value"] = valid_value # Test pymc distribution raises ParameterValueError for parameters outside the # supported domain edges (excluding edges) if not skip_paramdomain_outside_edge_test: - # Step1: collect potential invalid parameters - invalid_params = {param: [None, None] for param in paramdomains} - for param, paramdomain in paramdomains.items(): - if np.isfinite(paramdomain.lower): - invalid_params[param][0] = paramdomain.lower - 1 - if np.isfinite(paramdomain.upper): - invalid_params[param][1] = paramdomain.upper + 1 - # Step2: test invalid parameters, one a time + invalid_params = find_invalid_scalar_params(paramdomains) + for invalid_param, invalid_edges in invalid_params.items(): for invalid_edge in invalid_edges: - if invalid_edge is not None: - test_params = valid_params.copy() # Shallow copy should be okay - test_params[invalid_param] = pt.as_tensor_variable(invalid_edge) - # We need to remove `Assert`s introduced by checks like - # `assert_negative_support` and disable test values; - # otherwise, we won't be able to create the - # `RandomVariable` - with pytensor.config.change_flags(compute_test_value="off"): - invalid_dist = pymc_dist.dist(**test_params) - with pytensor.config.change_flags(mode=Mode("py")): - with pytest.raises(ParameterValueError): - logcdf(invalid_dist, valid_value).eval() - - # Test that values below domain edge evaluate to -np.inf - if np.isfinite(domain.lower): - below_domain = domain.lower - 1 - with pytensor.config.change_flags(mode=Mode("py")): - npt.assert_equal( - logcdf(valid_dist, below_domain).eval(), - -np.inf, - err_msg=str(below_domain), - ) - - # Test that values above domain edge evaluate to 0 - if np.isfinite(domain.upper): - above_domain = domain.upper + 1 - with pytensor.config.change_flags(mode=Mode("py")): - npt.assert_equal( - logcdf(valid_dist, above_domain).eval(), - 0, - err_msg=str(above_domain), - ) + if invalid_edge is None: + continue - # Test that method works with multiple values or raises informative TypeError - valid_dist = pymc_dist.dist(**valid_params, size=2) - with pytensor.config.change_flags(mode=Mode("py")): - try: - logcdf(valid_dist, np.array([valid_value, valid_value])).eval() - except TypeError as err: - assert str(err).endswith( - "logcdf expects a scalar value but received a 1-dimensional object." - ) + point = valid_params.copy() + point[invalid_param] = invalid_edge + with pytest.raises(ParameterValueError): + pymc_logcdf(**point) + pytest.fail(f"test_params={point}") + + # Test that values below domain edge evaluate to -np.inf, and above evaluates to 0 + invalid_lower, invalid_upper = find_invalid_scalar_params({"value": domain})["value"] + if invalid_lower is not None: + point = valid_params.copy() + point["value"] = invalid_lower + npt.assert_equal( + pymc_logcdf(**point), + -np.inf, + err_msg=str(point), + ) + if invalid_upper is not None: + point = valid_params.copy() + point["value"] = invalid_upper + npt.assert_equal( + pymc_logcdf(**point), + 0, + err_msg=str(point), + ) def check_selfconsistency_discrete_logcdf( - distribution, - domain, - paramdomains, - decimal=None, - n_samples=100, -): + distribution: Distribution, + domain: Domain, + paramdomains: Dict[str, Domain], + decimal: Optional[int] = None, + n_samples: int = 100, +) -> None: """ - Check that logcdf of discrete distributions matches sum of logps up to value + Check that logcdf of discrete distributions matches sum of logps up to value. """ - domains = paramdomains.copy() - domains["value"] = domain if decimal is None: decimal = select_by_precision(float64=6, float32=3) - model, param_vars = build_model(distribution, domain, paramdomains) - rv = model["value"] - value = model.rvs_to_values[rv] - dist_logcdf = model.compile_fn(logcdf(rv, value)) - dist_logp = model.compile_fn(logp(rv, value)) + dist = create_dist_from_paramdomains(distribution, paramdomains) + value = dist.type() + value.name = "value" + dist_logp = logp(dist, value) + dist_logp_fn = pytensor.function(list(inputvars(dist_logp)), dist_logp) + + dist_logcdf = logcdf(dist, value) + dist_logcdf_fn = compile_pymc(list(inputvars(dist_logcdf)), dist_logcdf) + + domains = paramdomains.copy() + domains["value"] = domain for point in product(domains, n_samples=n_samples): - params = dict(point) - value = params.pop("value") + point = dict(point) + value = point.pop("value") values = np.arange(domain.lower, value + 1) - # Update shared parameter variables in logp/logcdf function - for param_name, param_value in params.items(): - param_vars[param_name].set_value(param_value) - with pytensor.config.change_flags(mode=Mode("py")): npt.assert_almost_equal( - dist_logcdf({"value": value}), - sp.logsumexp([dist_logp({"value": value}) for value in values]), + dist_logcdf_fn(**point, value=value), + sp.logsumexp([dist_logp_fn(value=value, **point) for value in values]), decimal=decimal, err_msg=str(point), )