Skip to content

Commit

Permalink
Implement power transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 18, 2022
1 parent 90b6bec commit 98ccc68
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 20 deletions.
149 changes: 130 additions & 19 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from copy import copy
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import pytensor.tensor as at

from pytensor.gradient import DisconnectedType, jacobian
Expand All @@ -48,10 +49,22 @@
from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.scalar import Add, Exp, Log, Mul, Reciprocal
from pytensor.scalar import Add, Exp, Log, Mul, Pow, Sqr, Sqrt
from pytensor.scan.op import Scan
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import add, exp, log, mul, neg, reciprocal, sub, true_div
from pytensor.tensor.math import (
add,
exp,
log,
mul,
neg,
pow,
reciprocal,
sqr,
sqrt,
sub,
true_div,
)
from pytensor.tensor.rewriting.basic import (
register_specialize,
register_stabilize,
Expand Down Expand Up @@ -110,8 +123,11 @@ def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
"""Apply the transformation."""

@abc.abstractmethod
def backward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
"""Invert the transformation."""
def backward(
self, value: TensorVariable, *inputs: Variable
) -> Union[TensorVariable, Tuple[TensorVariable, ...]]:
"""Invert the transformation. Multiple values may be returned when the
transformation is not 1-to-1"""

def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable:
"""Construct the log of the absolute value of the Jacobian determinant."""
Expand Down Expand Up @@ -320,7 +336,7 @@ def apply(self, fgraph: FunctionGraph):
class MeasurableTransform(MeasurableElemwise):
"""A placeholder used to specify a log-likelihood for a transformed measurable variable"""

valid_scalar_types = (Exp, Log, Add, Mul, Reciprocal)
valid_scalar_types = (Exp, Log, Add, Mul, Pow)

# Cannot use `transform` as name because it would clash with the property added by
# the `TransformValuesRewrite`
Expand Down Expand Up @@ -349,16 +365,64 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
# The value variable must still be back-transformed to be on the natural support of
# the respective measurable input.
backward_value = op.transform_elemwise.backward(value, *other_inputs)
input_logprob = logprob(measurable_input, backward_value, **kwargs)

# Some transformations, like squaring may produce multiple backward values
if isinstance(backward_value, tuple):
input_logprob = at.logaddexp(
*(logprob(measurable_input, backward_val, **kwargs) for backward_val in backward_value)
)
else:
input_logprob = logprob(measurable_input, backward_value)

jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs)

return input_logprob + jacobian


@node_rewriter([reciprocal])
def measurable_reciprocal_to_power(fgraph, node):
"""Convert reciprocal of `MeasurableVariable`s to power."""
inp = node.inputs[0]
if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)):
return None

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

# Only apply this rewrite if the variable is unvalued
if inp in rv_map_feature.rv_values:
return None # pragma: no cover

return [at.pow(inp, -1.0)]


@node_rewriter([sqr, sqrt])
def measurable_sqrt_sqr_to_power(fgraph, node):
"""Convert square root or square of `MeasurableVariable`s to power form."""

inp = node.inputs[0]
if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)):
return None

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

# Only apply this rewrite if the variable is unvalued
if inp in rv_map_feature.rv_values:
return None # pragma: no cover

if isinstance(node.op.scalar_op, Sqr):
return [at.pow(inp, 2)]

if isinstance(node.op.scalar_op, Sqrt):
return [at.pow(inp, 1 / 2)]


@node_rewriter([true_div])
def measurable_div_to_reciprocal_product(fgraph, node):
"""Convert divisions involving `MeasurableVariable`s to product with reciprocal."""
def measurable_div_to_product(fgraph, node):
"""Convert divisions involving `MeasurableVariable`s to products."""

measurable_vars = [
var for var in node.inputs if (var.owner and isinstance(var.owner.op, MeasurableVariable))
Expand All @@ -379,9 +443,13 @@ def measurable_div_to_reciprocal_product(fgraph, node):
# Check if numerator is 1
try:
if at.get_scalar_constant_value(numerator) == 1:
return [at.reciprocal(denominator)]
# We convert the denominator directly to a power transform as this
# must be the measurable input
return [at.pow(denominator, -1)]
except NotScalarConstantError:
pass
# We don't convert the denominator directly to a power transform as
# it might not be measurable (and therefore not needed)
return [at.mul(numerator, at.reciprocal(denominator))]


Expand Down Expand Up @@ -425,7 +493,7 @@ def measurable_sub_to_neg(fgraph, node):
return [at.add(minuend, at.neg(subtrahend))]


@node_rewriter([exp, log, add, mul, reciprocal])
@node_rewriter([exp, log, add, mul, pow])
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
"""Find measurable transformations from Elemwise operators."""

Expand Down Expand Up @@ -485,8 +553,18 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
transform = ExpTransform()
elif isinstance(scalar_op, Log):
transform = LogTransform()
elif isinstance(scalar_op, Reciprocal):
transform = ReciprocalTransform()
elif isinstance(scalar_op, Pow):
# We only allow for the base to be measurable
if measurable_input_idx != 0:
return None
try:
(power,) = other_inputs
power = at.get_scalar_constant_value(power).item()
# Power needs to be a constant
except NotScalarConstantError:
return None
transform_inputs = (measurable_input, power)
transform = PowerTransform(power=power)
elif isinstance(scalar_op, Add):
transform_inputs = (measurable_input, at.add(*other_inputs))
transform = LocTransform(
Expand All @@ -510,12 +588,29 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li


measurable_ir_rewrites_db.register(
"measurable_div_to_reciprocal_product",
measurable_div_to_reciprocal_product,
"measurable_reciprocal_to_power",
measurable_reciprocal_to_power,
"basic",
"transform",
)


measurable_ir_rewrites_db.register(
"measurable_sqrt_sqr_to_power",
measurable_sqrt_sqr_to_power,
"basic",
"transform",
)


measurable_ir_rewrites_db.register(
"measurable_div_to_product",
measurable_div_to_product,
"basic",
"transform",
)


measurable_ir_rewrites_db.register(
"measurable_neg_to_product",
measurable_neg_to_product,
Expand Down Expand Up @@ -601,17 +696,33 @@ def log_jac_det(self, value, *inputs):
return -at.log(value)


class ReciprocalTransform(RVTransform):
name = "reciprocal"
class PowerTransform(RVTransform):
name = "power"

def __init__(self, power=None):
if not isinstance(power, (int, float)):
raise TypeError(f"Power must be integer or float, got {type(power)}")
if power == 0:
raise ValueError("Power cannot be 0")
self.power = power
super().__init__()

def forward(self, value, *inputs):
return at.reciprocal(value)
at.power(value, self.power)

def backward(self, value, *inputs):
return at.reciprocal(value)
backward_value = at.power(value, (1 / self.power))

# In this case the transform is not 1-to-1
if (self.power > 1) and (self.power % 2 == 0):
return -backward_value, backward_value
else:
return backward_value

def log_jac_det(self, value, *inputs):
return -2 * at.log(value)
inv_power = 1 / self.power
# Note: This fails for value==0
return np.log(np.abs(inv_power)) + (inv_power - 1) * at.log(value)


class IntervalTransform(RVTransform):
Expand Down
32 changes: 31 additions & 1 deletion pymc/tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def test_transformed_logprob(at_dist, dist_params, sp_dist, size):

a = at_dist(*dist_params, size=size)
a.name = "a"
a_value_var = at.tensor(a.dtype, shape=(None,) * a.ndim)
a_value_var = at.tensor(dtype=a.dtype, shape=(None,) * a.ndim)
a_value_var.name = "a_value"

b = at.random.normal(a, 1.0)
Expand Down Expand Up @@ -807,6 +807,36 @@ def test_reciprocal_rv_transform(numerator):
)


def test_sqr_transform():
# The square of a unit normal is a chi-square with 1 df
x_rv = at.random.normal(0, 1, size=(3,)) ** 2
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))

x_test_val = np.r_[0.5, 1, 2.5]
assert np.allclose(
x_logp_fn(x_test_val),
sp.stats.chi2(df=1).logpdf(x_test_val),
)


def test_sqrt_transform():
# The sqrt of a chisquare with n df is a chi distribution with n df
x_rv = at.sqrt(at.random.chisquare(df=3, size=(3,)))
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))

x_test_val = np.r_[0.5, 1, 2.5]
assert np.allclose(
x_logp_fn(x_test_val),
sp.stats.chi(df=3).logpdf(x_test_val),
)


def test_negated_rv_transform():
x_rv = -at.random.halfnormal()
x_rv.name = "x"
Expand Down

0 comments on commit 98ccc68

Please sign in to comment.