From 98ccc6844d60abd7e8bbc9445a5111d658aaa48f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 16 Dec 2022 17:04:20 +0100 Subject: [PATCH] Implement power transforms --- pymc/logprob/transforms.py | 149 ++++++++++++++++++++++---- pymc/tests/logprob/test_transforms.py | 32 +++++- 2 files changed, 161 insertions(+), 20 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index fa5bcd5cef..cd8b041bcb 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -39,6 +39,7 @@ from copy import copy from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +import numpy as np import pytensor.tensor as at from pytensor.gradient import DisconnectedType, jacobian @@ -48,10 +49,22 @@ from pytensor.graph.op import Op from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter -from pytensor.scalar import Add, Exp, Log, Mul, Reciprocal +from pytensor.scalar import Add, Exp, Log, Mul, Pow, Sqr, Sqrt from pytensor.scan.op import Scan from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.math import add, exp, log, mul, neg, reciprocal, sub, true_div +from pytensor.tensor.math import ( + add, + exp, + log, + mul, + neg, + pow, + reciprocal, + sqr, + sqrt, + sub, + true_div, +) from pytensor.tensor.rewriting.basic import ( register_specialize, register_stabilize, @@ -110,8 +123,11 @@ def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: """Apply the transformation.""" @abc.abstractmethod - def backward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: - """Invert the transformation.""" + def backward( + self, value: TensorVariable, *inputs: Variable + ) -> Union[TensorVariable, Tuple[TensorVariable, ...]]: + """Invert the transformation. Multiple values may be returned when the + transformation is not 1-to-1""" def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable: """Construct the log of the absolute value of the Jacobian determinant.""" @@ -320,7 +336,7 @@ def apply(self, fgraph: FunctionGraph): class MeasurableTransform(MeasurableElemwise): """A placeholder used to specify a log-likelihood for a transformed measurable variable""" - valid_scalar_types = (Exp, Log, Add, Mul, Reciprocal) + valid_scalar_types = (Exp, Log, Add, Mul, Pow) # Cannot use `transform` as name because it would clash with the property added by # the `TransformValuesRewrite` @@ -349,16 +365,64 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa # The value variable must still be back-transformed to be on the natural support of # the respective measurable input. backward_value = op.transform_elemwise.backward(value, *other_inputs) - input_logprob = logprob(measurable_input, backward_value, **kwargs) + + # Some transformations, like squaring may produce multiple backward values + if isinstance(backward_value, tuple): + input_logprob = at.logaddexp( + *(logprob(measurable_input, backward_val, **kwargs) for backward_val in backward_value) + ) + else: + input_logprob = logprob(measurable_input, backward_value) jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs) return input_logprob + jacobian +@node_rewriter([reciprocal]) +def measurable_reciprocal_to_power(fgraph, node): + """Convert reciprocal of `MeasurableVariable`s to power.""" + inp = node.inputs[0] + if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)): + return None + + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + if rv_map_feature is None: + return None # pragma: no cover + + # Only apply this rewrite if the variable is unvalued + if inp in rv_map_feature.rv_values: + return None # pragma: no cover + + return [at.pow(inp, -1.0)] + + +@node_rewriter([sqr, sqrt]) +def measurable_sqrt_sqr_to_power(fgraph, node): + """Convert square root or square of `MeasurableVariable`s to power form.""" + + inp = node.inputs[0] + if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)): + return None + + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + if rv_map_feature is None: + return None # pragma: no cover + + # Only apply this rewrite if the variable is unvalued + if inp in rv_map_feature.rv_values: + return None # pragma: no cover + + if isinstance(node.op.scalar_op, Sqr): + return [at.pow(inp, 2)] + + if isinstance(node.op.scalar_op, Sqrt): + return [at.pow(inp, 1 / 2)] + + @node_rewriter([true_div]) -def measurable_div_to_reciprocal_product(fgraph, node): - """Convert divisions involving `MeasurableVariable`s to product with reciprocal.""" +def measurable_div_to_product(fgraph, node): + """Convert divisions involving `MeasurableVariable`s to products.""" measurable_vars = [ var for var in node.inputs if (var.owner and isinstance(var.owner.op, MeasurableVariable)) @@ -379,9 +443,13 @@ def measurable_div_to_reciprocal_product(fgraph, node): # Check if numerator is 1 try: if at.get_scalar_constant_value(numerator) == 1: - return [at.reciprocal(denominator)] + # We convert the denominator directly to a power transform as this + # must be the measurable input + return [at.pow(denominator, -1)] except NotScalarConstantError: pass + # We don't convert the denominator directly to a power transform as + # it might not be measurable (and therefore not needed) return [at.mul(numerator, at.reciprocal(denominator))] @@ -425,7 +493,7 @@ def measurable_sub_to_neg(fgraph, node): return [at.add(minuend, at.neg(subtrahend))] -@node_rewriter([exp, log, add, mul, reciprocal]) +@node_rewriter([exp, log, add, mul, pow]) def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: """Find measurable transformations from Elemwise operators.""" @@ -485,8 +553,18 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li transform = ExpTransform() elif isinstance(scalar_op, Log): transform = LogTransform() - elif isinstance(scalar_op, Reciprocal): - transform = ReciprocalTransform() + elif isinstance(scalar_op, Pow): + # We only allow for the base to be measurable + if measurable_input_idx != 0: + return None + try: + (power,) = other_inputs + power = at.get_scalar_constant_value(power).item() + # Power needs to be a constant + except NotScalarConstantError: + return None + transform_inputs = (measurable_input, power) + transform = PowerTransform(power=power) elif isinstance(scalar_op, Add): transform_inputs = (measurable_input, at.add(*other_inputs)) transform = LocTransform( @@ -510,12 +588,29 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li measurable_ir_rewrites_db.register( - "measurable_div_to_reciprocal_product", - measurable_div_to_reciprocal_product, + "measurable_reciprocal_to_power", + measurable_reciprocal_to_power, "basic", "transform", ) + +measurable_ir_rewrites_db.register( + "measurable_sqrt_sqr_to_power", + measurable_sqrt_sqr_to_power, + "basic", + "transform", +) + + +measurable_ir_rewrites_db.register( + "measurable_div_to_product", + measurable_div_to_product, + "basic", + "transform", +) + + measurable_ir_rewrites_db.register( "measurable_neg_to_product", measurable_neg_to_product, @@ -601,17 +696,33 @@ def log_jac_det(self, value, *inputs): return -at.log(value) -class ReciprocalTransform(RVTransform): - name = "reciprocal" +class PowerTransform(RVTransform): + name = "power" + + def __init__(self, power=None): + if not isinstance(power, (int, float)): + raise TypeError(f"Power must be integer or float, got {type(power)}") + if power == 0: + raise ValueError("Power cannot be 0") + self.power = power + super().__init__() def forward(self, value, *inputs): - return at.reciprocal(value) + at.power(value, self.power) def backward(self, value, *inputs): - return at.reciprocal(value) + backward_value = at.power(value, (1 / self.power)) + + # In this case the transform is not 1-to-1 + if (self.power > 1) and (self.power % 2 == 0): + return -backward_value, backward_value + else: + return backward_value def log_jac_det(self, value, *inputs): - return -2 * at.log(value) + inv_power = 1 / self.power + # Note: This fails for value==0 + return np.log(np.abs(inv_power)) + (inv_power - 1) * at.log(value) class IntervalTransform(RVTransform): diff --git a/pymc/tests/logprob/test_transforms.py b/pymc/tests/logprob/test_transforms.py index 9d41929624..bf35c7f103 100644 --- a/pymc/tests/logprob/test_transforms.py +++ b/pymc/tests/logprob/test_transforms.py @@ -224,7 +224,7 @@ def test_transformed_logprob(at_dist, dist_params, sp_dist, size): a = at_dist(*dist_params, size=size) a.name = "a" - a_value_var = at.tensor(a.dtype, shape=(None,) * a.ndim) + a_value_var = at.tensor(dtype=a.dtype, shape=(None,) * a.ndim) a_value_var.name = "a_value" b = at.random.normal(a, 1.0) @@ -807,6 +807,36 @@ def test_reciprocal_rv_transform(numerator): ) +def test_sqr_transform(): + # The square of a unit normal is a chi-square with 1 df + x_rv = at.random.normal(0, 1, size=(3,)) ** 2 + x_rv.name = "x" + + x_vv = x_rv.clone() + x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False)) + + x_test_val = np.r_[0.5, 1, 2.5] + assert np.allclose( + x_logp_fn(x_test_val), + sp.stats.chi2(df=1).logpdf(x_test_val), + ) + + +def test_sqrt_transform(): + # The sqrt of a chisquare with n df is a chi distribution with n df + x_rv = at.sqrt(at.random.chisquare(df=3, size=(3,))) + x_rv.name = "x" + + x_vv = x_rv.clone() + x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False)) + + x_test_val = np.r_[0.5, 1, 2.5] + assert np.allclose( + x_logp_fn(x_test_val), + sp.stats.chi(df=3).logpdf(x_test_val), + ) + + def test_negated_rv_transform(): x_rv = -at.random.halfnormal() x_rv.name = "x"