Skip to content

Commit

Permalink
Merge branch 'main' into seqprobit
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed Jun 21, 2023
2 parents 1c23a9e + 4dea50f commit 915a1ea
Show file tree
Hide file tree
Showing 18 changed files with 313 additions and 162 deletions.
37 changes: 25 additions & 12 deletions fortuna/calib_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
Optional,
)

from flax.core.frozen_dict import freeze
from flax.traverse_util import path_aware_map
from flax.core import FrozenDict
import jax.numpy as jnp
import optax

from fortuna.calib_model.calib_mixin import WithCalibCheckpointingMixin
from fortuna.calib_model.calib_model_calibrator import (
Expand Down Expand Up @@ -36,6 +34,11 @@
get_inputs_from_shape,
)
from fortuna.utils.device import select_trainer_given_devices
from fortuna.utils.freeze import get_trainable_paths
from fortuna.utils.nested_dicts import (
nested_get,
nested_set,
)
from fortuna.utils.random import RandomNumberGenerator


Expand Down Expand Up @@ -94,22 +97,32 @@ def _calibrate(
early_stopping_monitor=config.monitor.early_stopping_monitor,
early_stopping_min_delta=config.monitor.early_stopping_min_delta,
early_stopping_patience=config.monitor.early_stopping_patience,
freeze_fun=config.optimizer.freeze_fun,
)

state = self._init_state(calib_data_loader, config)

if config.optimizer.freeze_fun is not None:
partition_optimizers = {
"trainable": config.optimizer.method,
"frozen": optax.set_to_zero(),
}
partition_params = freeze(
path_aware_map(config.optimizer.freeze_fun, state.params)
trainable_paths = get_trainable_paths(
state.params, config.optimizer.freeze_fun
)
config.optimizer.method = optax.multi_transform(
partition_optimizers, partition_params
state = state.replace(
opt_state=config.optimizer.method.init(
FrozenDict(
nested_set(
d={},
key_paths=trainable_paths,
objs=tuple(
[
nested_get(state.params.unfreeze(), path)
for path in trainable_paths
]
),
allow_nonexistent=True,
)
)
)
)
state = self._init_state(calib_data_loader, config)

loss = Loss(self.likelihood, loss_fn=loss_fn)
loss.rng = self.rng
Expand Down
2 changes: 1 addition & 1 deletion fortuna/calib_model/calib_model_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def validation_step(
kwargs: FrozenDict[str, Any] = FrozenDict(),
) -> Dict[str, jnp.ndarray]:
loss, aux = loss_fun(
state.params,
self._get_all_params(state),
batch,
n_data=n_data,
mutable=state.mutable,
Expand Down
29 changes: 23 additions & 6 deletions fortuna/prob_model/posterior/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Type,
)

from flax.core import FrozenDict
from jax._src.prng import PRNGKeyArray

from fortuna.data.loader import DataLoader
Expand All @@ -23,7 +24,11 @@
Path,
Status,
)
from fortuna.utils.freeze import freeze_optimizer
from fortuna.utils.freeze import get_trainable_paths
from fortuna.utils.nested_dicts import (
nested_get,
nested_set,
)
from fortuna.utils.random import WithRNG


Expand Down Expand Up @@ -90,13 +95,25 @@ def _freeze_optimizer_in_state(
state: PosteriorState, fit_config: FitConfig
) -> PosteriorState:
if fit_config.optimizer.freeze_fun is not None:
frozen_optimizer = freeze_optimizer(
params=state.params,
optimizer=fit_config.optimizer.method,
freeze_fun=fit_config.optimizer.freeze_fun,
trainable_paths = get_trainable_paths(
state.params, fit_config.optimizer.freeze_fun
)
state = state.replace(
tx=frozen_optimizer, opt_state=frozen_optimizer.init(state.params)
opt_state=fit_config.optimizer.method.init(
FrozenDict(
nested_set(
d={},
key_paths=trainable_paths,
objs=tuple(
[
nested_get(state.params.unfreeze(), path)
for path in trainable_paths
]
),
allow_nonexistent=True,
)
)
)
)
return state

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def _fit(i):
early_stopping_monitor=fit_config.monitor.early_stopping_monitor,
early_stopping_min_delta=fit_config.monitor.early_stopping_min_delta,
early_stopping_patience=fit_config.monitor.early_stopping_patience,
freeze_fun=fit_config.optimizer.freeze_fun,
)

return trainer.train(
Expand Down
1 change: 1 addition & 0 deletions fortuna/prob_model/posterior/map/map_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def fit(
early_stopping_monitor=fit_config.monitor.early_stopping_monitor,
early_stopping_min_delta=fit_config.monitor.early_stopping_min_delta,
early_stopping_patience=fit_config.monitor.early_stopping_patience,
freeze_fun=fit_config.optimizer.freeze_fun,
)

if super()._is_state_available_somewhere(fit_config):
Expand Down
2 changes: 1 addition & 1 deletion fortuna/prob_model/posterior/map/map_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def validation_step(
kwargs: FrozenDict[str, Any] = FrozenDict(),
) -> Dict[str, jnp.ndarray]:
loss, aux = loss_fun(
state.params,
self._get_all_params(state),
batch,
n_data=n_data,
mutable=state.mutable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def fit(
else:
which_params = None

rav, self._unravel, self._indices, rav_log_stds = self._get_unravel(
rav, self._unravel, sub_unravel, rav_log_stds = self._get_unravel(
params=state.params, log_stds=log_stds, which_params=which_params
)

Expand All @@ -151,12 +151,12 @@ def fit(
early_stopping_monitor=fit_config.monitor.early_stopping_monitor,
early_stopping_min_delta=fit_config.monitor.early_stopping_min_delta,
early_stopping_patience=fit_config.monitor.early_stopping_patience,
freeze_fun=fit_config.optimizer.freeze_fun,
base=self._base,
architecture=self._architecture,
which_params=which_params,
all_params=state.params if which_params else None,
indices=self._indices,
unravel=self._unravel,
sub_unravel=sub_unravel,
)

state = self._init_advi_from_map_state(
Expand Down Expand Up @@ -187,7 +187,6 @@ def fit(
max_grad_norm=fit_config.hyperparameters.max_grad_norm,
gradient_accumulation_steps=fit_config.hyperparameters.gradient_accumulation_steps,
)
trainer._all_params = None

self.state = PosteriorStateRepository(
fit_config.checkpointer.save_checkpoint_dir
Expand Down Expand Up @@ -239,7 +238,7 @@ def sample(
rng = self.rng.get()
state = self.state.get()

if self._base is None or not self._unravel is None:
if self._base is None or self._unravel is None:
if state._encoded_which_params is None:
n_params = len(ravel_pytree(state.params)[0]) // 2
which_params = None
Expand All @@ -257,28 +256,29 @@ def sample(
)[0]
)
_base, _architecture = self._get_base_and_architecture(n_params)
_unravel, _indices = self._get_unravel(
params=nested_unpair(
d=state.params.unfreeze(),
key_paths=which_params,
labels=("mean", "log_std"),
)[0]
if which_params
else {
k: dict(params=v["params"]["mean"]) for k, v in state.params.items()
},
_unravel = self._get_unravel(
FrozenDict(
nested_unpair(
d=state.params.unfreeze(),
key_paths=which_params,
labels=("mean", "log_std"),
)[0]
if which_params
else {
k: dict(params=v["params"]["mean"])
for k, v in state.params.items()
}
),
which_params=which_params,
)[1:3]
)[1]

self._base = _base
self._architecture = _architecture
self._unravel = _unravel
self._indices = _indices
else:
_base = self._base
_architecture = self._architecture
_unravel = self._unravel
_indices = self._indices

if state._encoded_which_params is None:
means = _unravel(
Expand Down Expand Up @@ -311,18 +311,7 @@ def sample(
0
][0]

means = FrozenDict(
nested_set(
d=means,
key_paths=which_params,
objs=tuple(
[
_unravel(rav_params[_indices[i] : _indices[i + 1]])
for i, _unravel in enumerate(_unravel)
]
),
)
)
means = _unravel(rav_params)

return JointState(
params=FrozenDict(means),
Expand Down Expand Up @@ -413,32 +402,33 @@ def _get_unravel(
if which_params is None:
rav, unravel = ravel_pytree(params)
rav_log_stds = ravel_pytree(log_stds)[0] if log_stds is not None else None
indices = None
sub_unravel = None
else:
rav, sub_unravel = ravel_pytree(
tuple([nested_get(params, path) for path in which_params])
)

def unravel_fn(_params, _path):
return ravel_pytree(nested_get(_params, _path))

rav, unravel, indices, rav_log_stds = [], [], [], []

for path in which_params:
_rav, _unravel = unravel_fn(params, path)
unravel.append(_unravel)
rav.append(_rav)

if log_stds is not None:
rav_log_stds.append(unravel_fn(log_stds, path)[0])

indices.append(len(_rav))

rav = jnp.concatenate(rav)
def unravel(_rav):
return FrozenDict(
nested_set(
d=params.unfreeze(),
key_paths=which_params,
objs=sub_unravel(_rav),
)
)

if log_stds is not None:
rav_log_stds = jnp.concatenate(rav_log_stds)
rav_log_stds = ravel_pytree(
nested_set(
d={},
key_paths=which_params,
objs=tuple(
[nested_get(log_stds, path) for path in which_params]
),
allow_nonexistent=True,
)
)[0]
else:
rav_log_stds = None

unravel = tuple(unravel)
indices = np.concatenate((np.array([0]), np.cumsum(indices)))

return rav, unravel, indices, rav_log_stds
return rav, unravel, sub_unravel, rav_log_stds
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

import jax.numpy as jnp

from fortuna.prob_model.posterior.state import PosteriorState
from fortuna.prob_model.posterior.normalizing_flow.normalizing_flow_state import (
NormalizingFlowState,
)
from fortuna.typing import Array
from fortuna.utils.strings import convert_string_to_jnp_array


class ADVIState(PosteriorState):
class ADVIState(NormalizingFlowState):
"""
Attributes
----------
Expand Down
Loading

0 comments on commit 915a1ea

Please sign in to comment.