Releases: danielward27/flowjax
v16.0.0
What's Changed
- Return dict of list of floats or list of floats in fitting functions by @danielward27 in #185
- No need for static by @danielward27 in #186
- Support old style keys in numpyro wrappers by @danielward27 in #187
- key_based_loss update by @danielward27 in #188
- Reduce coupling by @danielward27 in #189
Breaking changes:
- Calls to
get_ravelled_pytree_constructor
will now need to explicitly pass the *args and **kwargs for partitioning parameters (usually settingis_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable)
. fit_to_data
now returns list of floats, rather than than a list of scalar arrays.
Note, fit_to_variational_target
will be deprecated in the next version. This version adds its replacement fit_to_key_based_loss
. This was primarily because of some defaults which were "bad", e.g. steps: int = 100
, and return_best=True
(see #188 for details). It also generalizes the name, as it can be used to fit any pytree, and doesn't have to be used with a variational inference loss function.
Full Changelog: v15.1.0...v16.0.0
v15.1.0
What's Changed
- Add gamma and power transform by @danielward27 in #182
- Update docs by @danielward27 in #183
- Add Beta distribution and Sigmoid transform by @danielward27 in #184
Full Changelog: v15.0.0...v15.1.0
v15.0.0
Breaking changes:
- Users must switch from old style PRNGKey arrays to new style ones (replacing
jax.random.PRNGKey
withjax.random.key
. The old keys will be deprecated in JAX. LogNormal
is now aExp
transformedNormal
, which is an implementation detail, unless you previously relied onlog_normal.base_dist
orlog_normal.bijection
.recursive_unwrap
was removed, as alone it didn't provide additional functionality.
What's Changed
- Avoid array in AutoregressiveBisectionInverter by @danielward27 in #174
- Simplify wrappers by @danielward27 in #175
- Docs and switch to jr.key by @danielward27 in #177
- Documentation example fixes by @danielward27 in #178
- Update README.md by @danielward27 in #179
Full Changelog: v14.0.0...v15.0.0
v14.0.0
What's Changed
I was not happy with the original implementation of some wrappers in flowjax.wrappers
, e.g. BijectionReparam
introduced some (avoidable, but untidy) circular dependency issues. We now primarily use Parameterize
a new name for what was Lambda
. This introduces some breaking changes, primarily the aformentioned renaming, in addition to the removal of BijectionReparam
and Where
. See the pull request for more information.
- Update wrappers by @danielward27 in #172
Full Changelog: v13.1.1...v14.0.0
v13.1.1
What's Changed
- Test for updated equinox errors by @danielward27 in #171
Full Changelog: v13.1.0...v13.1.1
v13.1.0
What's Changed
- Planar inverse with leaky relu by @danielward27 in #170.
Full Changelog: v13.0.1...v13.1.0
Thanks for the help @weiyaw!
v13.0.1
v13.0.0
Some small breaking changes:
- The interval attribute of
RationalQuadraticSpline
is now a two-tuple with typetuple[float | int, float | int]
, representing the upper and lower bound of the spline. Previously, the interval was afloat | int
, forcing the interval to be symmetric about 0. - Any custom loss functions used with
fit_to_data
now should accept a key (whether or not it is used), having signatureloss(params, static, x, condition, key)
, for consistency of API.
What's Changed
- Allow uneven interval in spline by @danielward27 in #166
- Sample contrastive by @danielward27 in #167
Full Changelog: v12.4.0...v13.0.0
v12.4.0
v12.3.0
What's Changed
The main thing of note is a bug fix to ContrastiveLoss
(see #164). Apologies for any inconvenience!
- Fix bug in
ContrastiveLoss
normalization and add improve numpyro transformed wrapper by @danielward27 in #164
Full Changelog: v12.2.0...v12.3.0