Skip to content

Commit

Permalink
Merge pull request #189 from danielward27/reduce_coupling
Browse files Browse the repository at this point in the history
Reduce coupling
  • Loading branch information
danielward27 authored Oct 17, 2024
2 parents c006585 + 0b870b6 commit d711b18
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 17 deletions.
7 changes: 6 additions & 1 deletion flowjax/bijections/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import jax.numpy as jnp
from jaxtyping import PRNGKeyArray

from flowjax import wrappers
from flowjax.bijections.bijection import AbstractBijection
from flowjax.bijections.jax_transforms import Vmap
from flowjax.utils import Array, get_ravelled_pytree_constructor
Expand Down Expand Up @@ -55,7 +56,11 @@ def __init__(
"Only unconditional transformers with shape () are supported.",
)

constructor, num_params = get_ravelled_pytree_constructor(transformer)
constructor, num_params = get_ravelled_pytree_constructor(
transformer,
filter_spec=eqx.is_inexact_array,
is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable),
)

self.transformer_constructor = constructor
self.untransformed_dim = untransformed_dim
Expand Down
10 changes: 7 additions & 3 deletions flowjax/bijections/masked_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import jax.numpy as jnp
from jaxtyping import Array, Int, PRNGKeyArray

from flowjax import wrappers
from flowjax.bijections.bijection import AbstractBijection
from flowjax.bijections.jax_transforms import Vmap
from flowjax.masks import rank_based_mask
from flowjax.utils import get_ravelled_pytree_constructor
from flowjax.wrappers import Parameterize


class MaskedAutoregressive(AbstractBijection):
Expand Down Expand Up @@ -58,7 +58,11 @@ def __init__(
"Only unconditional transformers with shape () are supported.",
)

constructor, num_params = get_ravelled_pytree_constructor(transformer)
constructor, num_params = get_ravelled_pytree_constructor(
transformer,
filter_spec=eqx.is_inexact_array,
is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable),
)

if cond_dim is None:
self.cond_shape = None
Expand Down Expand Up @@ -162,7 +166,7 @@ def masked_autoregressive_mlp(
masked_linear = eqx.tree_at(
lambda linear: linear.weight,
linear,
Parameterize(jnp.where, mask, linear.weight, 0),
wrappers.Parameterize(jnp.where, mask, linear.weight, 0),
)
masked_layers.append(masked_linear)

Expand Down
22 changes: 10 additions & 12 deletions flowjax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from jax.flatten_util import ravel_pytree
from jaxtyping import Array, ArrayLike

import flowjax


def inv_softplus(x: ArrayLike) -> Array:
"""The inverse of the softplus function, checking for positive inputs."""
Expand Down Expand Up @@ -70,28 +68,28 @@ def _shapes_to_str(shapes):
return f"{in_shapes_str}->{out_shapes_str}"


def get_ravelled_pytree_constructor(tree, filter_spec=eqx.is_inexact_array) -> tuple:
def get_ravelled_pytree_constructor(
tree,
*args,
**kwargs,
) -> tuple:
"""Get a pytree constructor taking ravelled parameters, and the number of params.
The constructor takes a single argument as input, which is all the bijection
parameters flattened into a single contiguous vector. This is useful when we wish to
parameterize a pytree with a neural neural network. Calling the constructor
at the zero vector will return the initial pytree. Parameters warpped in
``NonTrainable`` are treated as leaves during partitioning.
at the zero vector will return the initial pytree. When using, you may wish to
specify ``NonTrainable`` nodes as leaves, using the ``is_leaf`` argument.
Args:
tree: Pytree to form constructor for.
filter_spec: Filter function to specify parameters. Defaults to
eqx.is_inexact_array.
*args: Arguments passed to ``eqx.partition``.
**kwargs: Key word arguments passed to ``eqx.partition``.
Returns:
tuple: Tuple containing the constructor, and the number of parameters.
"""
params, static = eqx.partition(
tree,
filter_spec,
is_leaf=lambda leaf: isinstance(leaf, flowjax.wrappers.NonTrainable),
)
params, static = eqx.partition(tree, *args, **kwargs)
init, unravel = ravel_pytree(params)

def constructor(ravelled_params: Array):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ license = { file = "LICENSE" }
name = "flowjax"
readme = "README.md"
requires-python = ">=3.10"
version = "15.1.0"
version = "16.0.0"

[project.urls]
repository = "https://github.com/danielward27/flowjax"
Expand Down

0 comments on commit d711b18

Please sign in to comment.