Skip to content

Commit

Permalink
Bump PyTensor dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 11, 2024
1 parent 0c6d0df commit 330bbcc
Show file tree
Hide file tree
Showing 13 changed files with 33 additions and 34 deletions.
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.23,<2.24
- pytensor>=2.25.1,<2.26
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.23,<2.24
- pytensor>=2.25.1,<2.26
- python-graphviz
- rich>=13.7.1
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies:
- numpyro>=0.8.0
- pandas>=0.24.0
- pip
- pytensor>=2.23,<2.24
- pytensor>=2.25.1,<2.26
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.23,<2.24
- pytensor>=2.25.1,<2.26
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.23,<2.24
- pytensor>=2.25.1,<2.26
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.23,<2.24
- pytensor>=2.25.1,<2.26
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
4 changes: 2 additions & 2 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,11 +836,11 @@ def create_partial_observed_rv(
if can_rewrite:
masked_rv = rv[mask]
fgraph = FunctionGraph(outputs=[masked_rv], clone=False, features=[ShapeFeature()])
[unobserved_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
unobserved_rv = local_subtensor_rv_lift.transform(fgraph, masked_rv.owner)[masked_rv]

antimasked_rv = rv[antimask]
fgraph = FunctionGraph(outputs=[antimasked_rv], clone=False, features=[ShapeFeature()])
[observed_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
observed_rv = local_subtensor_rv_lift.transform(fgraph, antimasked_rv.owner)[antimasked_rv]

# Make a clone of the observedRV, with a distinct rng so that observed and
# unobserved are never treated as equivalent (and mergeable) nodes by pytensor.
Expand Down
32 changes: 18 additions & 14 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,20 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
if not all(params.type.broadcastable):
return None

# Check whether axis covers all dimensions
axis = set(node.op.axis)
base_var_dims = set(range(base_var.ndim))
if axis != base_var_dims:
return None
if node.op.axis is None:
axis = tuple(range(base_var.ndim))
else:
# Check whether axis covers all dimensions
axis = tuple(sorted(node.op.axis))
if axis != tuple(range(base_var.ndim)):
return None

# distinguish measurable discrete and continuous (because logprob is different)
measurable_max: Max
if base_var.type.dtype.startswith("int"):
measurable_max = MeasurableMaxDiscrete(list(axis))
measurable_max = MeasurableMaxDiscrete(axis)
else:
measurable_max = MeasurableMax(list(axis))
measurable_max = MeasurableMax(axis)

max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs
Expand Down Expand Up @@ -206,21 +208,23 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVa
if not all(params.type.broadcastable):
return None

# Check whether axis is supported or not
axis = set(node.op.axis)
base_var_dims = set(range(base_var.ndim))
if axis != base_var_dims:
return None
if node.op.axis is None:
axis = tuple(range(base_var.ndim))
else:
# Check whether axis is supported or not
axis = tuple(sorted(node.op.axis))
if axis != tuple(range(base_var.ndim)):
return None

if not rv_map_feature.request_measurable([base_rv]):
return None

# distinguish measurable discrete and continuous (because logprob is different)
measurable_min: Max
if base_rv.type.dtype.startswith("int"):
measurable_min = MeasurableDiscreteMaxNeg(list(axis))
measurable_min = MeasurableDiscreteMaxNeg(axis)
else:
measurable_min = MeasurableMaxNeg(list(axis))
measurable_min = MeasurableMaxNeg(axis)

return measurable_min.make_node(base_rv).outputs

Expand Down
7 changes: 0 additions & 7 deletions pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
from pytensor.tensor.rewriting.basic import register_canonicalize
from pytensor.tensor.rewriting.math import local_exp_over_1_plus_exp
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.rewriting.uncanonicalize import local_max_and_argmax
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand Down Expand Up @@ -374,12 +373,6 @@ def incsubtensor_rv_replace(fgraph, node):

logprob_rewrites_db.register("measurable_ir_rewrites", measurable_ir_rewrites_db, "basic")

# Split max_and_argmax
# We only register this in the measurable IR db because max does not have a grad implemented
# And running this on any MaxAndArgmax would lead to issues: https://github.com/pymc-devs/pymc/issues/7251
# This special registering can be removed after https://github.com/pymc-devs/pytensor/issues/334 is fixed
measurable_ir_rewrites_db.register("local_max_and_argmax", local_max_and_argmax, "basic")

# These rewrites push random/measurable variables "down", making them closer to
# (or eventually) the graph outputs. Often this is done by lifting other `Op`s
# "up" through the random/measurable variables and into their inputs.
Expand Down
4 changes: 3 additions & 1 deletion pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
# SOFTWARE.


from pathlib import Path

import pytensor

from pytensor import tensor as pt
Expand Down Expand Up @@ -237,7 +239,7 @@ class MeasurableDimShuffle(DimShuffle):

# Need to get the absolute path of `c_func_file`, otherwise it tries to
# find it locally and fails when a new `Op` is initialized
c_func_file = DimShuffle.get_path(DimShuffle.c_func_file)
c_func_file = str(DimShuffle.get_path(Path(DimShuffle.c_func_file)))


MeasurableVariable.register(MeasurableDimShuffle)
Expand Down
4 changes: 2 additions & 2 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
graph_inputs,
walk,
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.scalar.basic import Cast
from pytensor.scan.op import Scan
Expand Down Expand Up @@ -897,7 +897,7 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
[client, _] = rng_clients[0]

# RNG is an output of the function, this is not a problem
if client == "output":
if isinstance(client.op, Output):
return rng

# RNG is used by another operator, which should output an update for the RNG
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ numpydoc
pandas>=0.24.0
polyagamma
pre-commit>=2.8.0
pytensor>=2.23,<2.24
pytensor>=2.25.1,<2.26
pytest-cov>=2.5
pytest>=3.0
rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ cachetools>=4.2.1
cloudpickle
numpy>=1.15.0
pandas>=0.24.0
pytensor>=2.23,<2.24
pytensor>=2.25.1,<2.26
rich>=13.7.1
scipy>=1.4.1
threadpoolctl>=3.1.0,<4.0.0
Expand Down

0 comments on commit 330bbcc

Please sign in to comment.