Skip to content

Commit

Permalink
Merge pull request #61 from danielward27/multiple_ndim
Browse files Browse the repository at this point in the history
Multiple ndim
  • Loading branch information
danielward27 authored Jan 12, 2023
2 parents 8d7d023 + 808d272 commit 91fede8
Show file tree
Hide file tree
Showing 39 changed files with 1,312 additions and 1,052 deletions.
11 changes: 4 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ flow.log_prob(x)

The package currently supports the following:

- `CouplingFlow` ([Dinh et al., 2017](https://arxiv.org/abs/1605.08803)) and `MaskedAutoregressiveFlow` ([Papamakarios et al., 2017](https://arxiv.org/abs/1705.07057v4)) conditioner architectures
- Common "transformers", such as `AffineTransformer` and `RationalQuadraticSplineTransformer` (the latter used in neural spline flows; [Durkan et al., 2019](https://arxiv.org/abs/1906.04032))
- `CouplingFlow` ([Dinh et al., 2017](https://arxiv.org/abs/1605.08803))
- `MaskedAutoregressiveFlow` ([Papamakarios et al., 2017](https://arxiv.org/abs/1705.07057v4)) conditioner architectures.
- Common "transformers", such as `Affine` and `RationalQuadraticSpline` (the latter used in neural spline flows; [Durkan et al., 2019](https://arxiv.org/abs/1906.04032))
- `BlockNeuralAutoregressiveFlow`, as introduced by [De Cao et al., 2019](https://arxiv.org/abs/1904.04676)
- `TriangularSplineFlow`, introduced here.

Expand All @@ -47,12 +48,8 @@ This package is new and may have substantial breaking changes between major rele
## TODO
A few limitations / things that could be worth including in the future:
- Add documentation
- Support varied "event" dimensions:
- i.e. allow `x` and `condition` instances to have `ndim==0` (scalar), or `ndim > 1`.
- Chaining of bijections with varied event `ndim` could follow numpy-like broadcasting rules.
- Allow vmap-like transform to define bijections with expanded event dimensions.
- Add ability to "reshape" bijections.
- Training script for variational inference
- Define transformers by wrapping a bijection?

## Related
We make use of the [Equinox](https://arxiv.org/abs/2111.00254) package, which facilitates object-oriented programming with Jax.
Expand Down
8 changes: 0 additions & 8 deletions docs/api/transformers.rst

This file was deleted.

36 changes: 19 additions & 17 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,49 @@
# Configuration file for the Sphinx documentation builder.
import sys
import os
sys.path.insert(0, os.path.abspath('..'))

sys.path.insert(0, os.path.abspath(".."))

# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html

# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

project = 'FlowJax'
copyright = '2022, Daniel Ward'
author = 'Daniel Ward'
release = 'v6.1.0'
project = "FlowJax"
copyright = "2022, Daniel Ward"
author = "Daniel Ward"
release = "v7.0.0"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = [
"sphinx.ext.todo", "sphinx.ext.viewcode", "sphinx.ext.autodoc",
"sphinx.ext.napoleon", "sphinx.ext.doctest", "nbsphinx", "sphinx_copybutton"
"sphinx.ext.viewcode",
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
"sphinx.ext.doctest",
"nbsphinx",
"sphinx_copybutton",
]

templates_path = ['_templates']
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

add_module_names = False
napoleon_include_init_with_doc = True

# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output

html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']
html_theme = "sphinx_rtd_theme"
html_static_path = ["_static"]

html_css_files = [
'style.css',
"style.css",
]

napoleon_include_init_with_doc = True

html_theme_options = {
'navigation_depth': 2,
#'logo_only': True,
"navigation_depth": 2,
}

# html_logo = "../images/flowjax_logo.png"
18 changes: 9 additions & 9 deletions docs/examples/conditional.ipynb

Large diffs are not rendered by default.

42 changes: 25 additions & 17 deletions docs/examples/unconditional.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ Installation
.. toctree::
:caption: Examples
:maxdepth: 1
:glob:

examples/unconditional
examples/conditional


.. toctree::
:caption: API
:maxdepth: 1
Expand All @@ -29,4 +29,3 @@ Installation
:caption: Miscellaneous

faq

2 changes: 1 addition & 1 deletion flowjax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""flowjax - Basic flowjax implementation in jax."""

__version__ = "6.1.0"
__version__ = "7.0.0"
__author__ = "Daniel Ward <danielward27@outlook.com>"
__all__ = []
15 changes: 8 additions & 7 deletions flowjax/bijections/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""Bijections from ``flowjax.bijections``"""

from .abc import Bijection, Transformer
from .bijection import Bijection
from .affine import AdditiveLinearCondition, Affine, TriangularAffine
from .block_autoregressive_network import BlockAutoregressiveNetwork
from .chain import Chain, ScannableChain
from .chain import Chain
from .coupling import Coupling
from .masked_autoregressive import MaskedAutoregressive
from .tanh import Tanh, TanhLinearTails
from .utils import (EmbedCondition, Flip, Invert, Partial, Permute,
TransformerToBijection)
from .utils import EmbedCondition, Flip, Invert, Partial, Permute
from .rational_quadratic_spline import RationalQuadraticSpline
from .jax_transforms import Scan, Vmap

__all__ = [
"Bijection",
"Transformer",
"Affine",
"TriangularAffine",
"BlockAutoregressiveNetwork",
Expand All @@ -21,12 +21,13 @@
"Tanh",
"TanhLinearTails",
"Chain",
"ScannableChain",
"Scan",
"Vmap",
"Invert",
"Flip",
"Permute",
"TransformerToBijection",
"AdditiveLinearCondition",
"Partial",
"EmbedCondition",
"RationalQuadraticSpline",
]
54 changes: 0 additions & 54 deletions flowjax/bijections/abc.py

This file was deleted.

Loading

0 comments on commit 91fede8

Please sign in to comment.