diff --git a/CombineHarvester-flow/__init__.py b/CombineHarvester-flow/__init__.py deleted file mode 100644 index d08574c..0000000 --- a/CombineHarvester-flow/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -__author__ = """Peter Taylor""" -__email__ = 'taylor.4264@osu.edu' -__version__ = '1.0.0' - -from CombineHarvesterFlow.harvest import Harvest -from CombineHarvesterFlow.combine import Combine diff --git a/CombineHarvester-flow/combine.py b/CombineHarvester-flow/combine.py deleted file mode 100644 index a453992..0000000 --- a/CombineHarvester-flow/combine.py +++ /dev/null @@ -1,82 +0,0 @@ -import numpy as np - - -class Combine(): - """Combine two Harvest objects. - """ - def __init__(self, harvest_1, harvest_2): - """Initializes the Combine object - - Parameters - ---------- - harvest_1 : Harvest - First Harvest object. - harvest_2 : Harvest - Second Harvest object. - """ - self.harvest_1 = harvest_1 - self.harvest_2 = harvest_2 - - def __getattr__(self, name): - if hasattr(self.harvest_1, name) and callable(getattr(self.harvest_1, name)): - return getattr(self.harvest_1, name) - elif hasattr(self.harvest_2, name) and callable(getattr(self.harvest_2, name)): - return getattr(self.harvest_2, name) - else: - raise AttributeError(f"'Combine' object has no attribute '{name}'") - - def combine(self): - """Combine the two chains - - Returns - ------- - (array, array) - The combined weights for the two chains. - """ - # normalize the chains - norm_chain_1 = (self.harvest_1.chain - self.harvest_2.mean) / self.harvest_2.std - norm_chain_2 = (self.harvest_2.chain - self.harvest_1.mean) / self.harvest_1.std - - # get the weights - flow_weight_list_2, flow_weight_list_1 = [], [] - for i in range(self.harvest_1.n_flows): - flow_weight_list_2 += [np.asarray(self.harvest_1.flow_list[i].log_prob(norm_chain_2))] - for i in range(self.harvest_2.n_flows): - flow_weight_list_1 += [np.asarray(self.harvest_2.flow_list[i].log_prob(norm_chain_1))] - - ln_weights_2 = np.sum(np.vstack(flow_weight_list_2), axis=0) / self.harvest_1.n_flows - ln_weights_1 = np.sum(np.vstack(flow_weight_list_1), axis=0) / self.harvest_2.n_flows - - # convert from log-likelihood to likelihood and update weights. Normalize mx(ln(weights)) to 0 to avoid overflow. - ln_weights_1 -= np.max(ln_weights_1) - ln_weights_2 -= np.max(ln_weights_2) - - chain_1_weights = self.harvest_1.weights * np.exp(ln_weights_1) - chain_2_weights = self.harvest_2.weights * np.exp(ln_weights_2) - - return chain_1_weights, chain_2_weights - - def combine_subset(self, n_flows_1, n_flows_2): - """Combine the two chains with a subset of the flows - - Parameters - ---------- - n_flows_1 : int - Number of flows to use for the first chain. - n_flows_2 : int - Number of flows to use for the second chain. - - Returns - ------- - (array, array) - The combined weights for the two chains. - """ - old_n_flows_1, old_n_flows_2 = self.harvest_1.n_flows, self.harvest_2.n_flows - self.harvest_1.n_flows = n_flows_1 - self.harvest_2.n_flows = n_flows_2 - chain_1_weights, chain_2_weights = self.combine() - - # reset - self.harvest_1.n_flows = old_n_flows_1 - self.harvest_2.n_flows = old_n_flows_2 - return chain_1_weights, chain_2_weights diff --git a/CombineHarvester-flow/harvest.py b/CombineHarvester-flow/harvest.py deleted file mode 100644 index 9d81424..0000000 --- a/CombineHarvester-flow/harvest.py +++ /dev/null @@ -1,98 +0,0 @@ -import equinox as eqx -import jax -import jax.numpy as jnp -import jax.random as jr -import numpy as np -from flowjax.bijections import RationalQuadraticSpline -from flowjax.distributions import Normal -from flowjax.flows import masked_autoregressive_flow - -from CombineHarvesterFlow.utils import (WeightedMaximumLikelihoodLoss, - fit_to_data_weight) - - -class Harvest(): - """Class to harvest the chains and train the flows.""" - def __init__(self, harvest_path, chain, n_flows, weights=None, random_seed=42): - """Initializes the Harvest object. - - Parameters - ---------- - harvest_path : string - Path to save the models. - chain : array - The chain to train the flows on. - n_flows : int - Number of flows to train. - weights : array, optional - Weights for the chain, by default None - random_seed : int, optional - Random seed for the training of the flows, by default 42 - """ - self.harvest_path = harvest_path - self.chain = chain - self.n_flows = n_flows - self.random_seed = random_seed - - self.weights = weights - if self.weights is None: - self.weights = np.ones_like(self.chain[:, 0]) - - def _normalize_data(self): - """Normalize the chain""" - self.mean = np.average(self.chain, weights=self.weights, axis=0) - self.std = (np.average((self.chain - self.mean)**2, weights=self.weights, axis=0)) ** 0.5 - self.norm_chain = (self.chain - self.mean) / self.std - - def _train_models(self): - """Train the flows""" - self.flow_list = [] - x = self.norm_chain - for i in range(self.n_flows): - key = jax.random.PRNGKey(self.random_seed + i) - key, subkey = jax.random.split(key) - flow = masked_autoregressive_flow( - subkey, - base_dist=Normal(jnp.zeros(x.shape[1])), - transformer=RationalQuadraticSpline(knots=8, interval=4), - ) - - key, subkey = jax.random.split(key) - flow, losses = fit_to_data_weight( - weights=self.weights, key=subkey, dist=flow, x=x, - learning_rate=1e-3, loss_fn=WeightedMaximumLikelihoodLoss() - ) - self.flow_list += [flow] - - def harvest(self): - """Harvest the chains and train the flows.""" - self._normalize_data() - print('Training the flows') - self._train_models() - - def save_models(self): - """Save the models""" - np.save(self.harvest_path + '_mean.npy', self.mean) - np.save(self.harvest_path + '_std.npy', self.std) - np.save(self.harvest_path + '_weights.npy', self.weights) - np.save(self.harvest_path + '_norm_chain.npy', self.norm_chain) - np.save(self.harvest_path + '_chain.npy', self.chain) - for _ in range(len(self.flow_list)): - eqx.tree_serialise_leaves(self.harvest_path + f'_flow_{_}.eqx', self.flow_list[_]) - - def load_models(self): - """Load the models""" - self.mean = np.load(self.harvest_path + '_mean.npy') - self.std = np.load(self.harvest_path + '_std.npy') - self.weights = np.load(self.harvest_path + '_weights.npy') - self.norm_chain = np.load(self.harvest_path + '_norm_chain.npy') - self.chain = np.load(self.harvest_path + '_chain.npy') - self.flow_list = [] - for i in range(self.n_flows): - key, subkey = jr.split(jr.PRNGKey(i)) - model = masked_autoregressive_flow( - subkey, base_dist=Normal(jnp.zeros_like(self.chain[0,:])), - transformer=RationalQuadraticSpline(knots=8, interval=4) - ) - self.flow_list += [ - eqx.tree_deserialise_leaves(self.harvest_path + "_flow_%s.eqx" % i, model)] diff --git a/CombineHarvester-flow/utils.py b/CombineHarvester-flow/utils.py deleted file mode 100644 index fd79565..0000000 --- a/CombineHarvester-flow/utils.py +++ /dev/null @@ -1,140 +0,0 @@ -from collections.abc import Callable - -import equinox as eqx -import jax.numpy as jnp -import jax.random as jr -import numpy as np -import optax -from flowjax import wrappers -from flowjax.distributions import AbstractDistribution -from flowjax.train.train_utils import (count_fruitless, get_batches, step, - train_val_split) -from flowjax.wrappers import unwrap -from jaxtyping import Array, ArrayLike, Float, PRNGKeyArray, PyTree -from tqdm import tqdm - - -def fit_to_data_weight( - key: PRNGKeyArray, - dist: PyTree, # Custom losses may support broader types than AbstractDistribution - x: ArrayLike, - *, - condition: ArrayLike | None = None, - loss_fn: Callable | None = None, - max_epochs: int = 100, - max_patience: int = 5, - batch_size: int = 100, - val_prop: float = 0.1, - learning_rate: float = 5e-4, - optimizer: optax.GradientTransformation | None = None, - return_best: bool = True, - show_progress: bool = True, - weights: ArrayLike | None = None -): - r"""Train a distribution (e.g. a flow) to samples from the target distribution. - - The distribution can be unconditional :math:`p(x)` or conditional - :math:`p(x|\text{condition})`. Note that the last batch in each epoch is dropped - if truncated (to avoid recompilation). This function can also be used to fit - non-distribution pytrees as long as a compatible loss function is provided. - - Args: - key: Jax random seed. - dist: The distribution to train. - x: Samples from target distribution. - condition: Conditioning variables. Defaults to None. - loss_fn: Loss function. Defaults to MaximumLikelihoodLoss. - max_epochs: Maximum number of epochs. Defaults to 100. - max_patience: Number of consecutive epochs with no validation loss improvement - after which training is terminated. Defaults to 5. - batch_size: Batch size. Defaults to 100. - val_prop: Proportion of data to use in validation set. Defaults to 0.1. - learning_rate: Adam learning rate. Defaults to 5e-4. - optimizer: Optax optimizer. If provided, this overrides the default Adam - optimizer, and the learning_rate is ignored. Defaults to None. - return_best: Whether the result should use the parameters where the minimum loss - was reached (when True), or the parameters after the last update (when - False). Defaults to True. - show_progress: Whether to show progress bar. Defaults to True. - - Returns: - A tuple containing the trained distribution and the losses. - """ - # data = (x,) if condition is None else (x, condition) - data = tuple(jnp.asarray(a) for a in (np.c_[x, weights],)) - - if optimizer is None: - optimizer = optax.adam(learning_rate) - - if loss_fn is None: - loss_fn = WeightedMaximumLikelihoodLoss() - - params, static = eqx.partition( - dist, - eqx.is_inexact_array, - is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable), - ) - best_params = params - opt_state = optimizer.init(params) - - # train val split - key, subkey = jr.split(key) - train_data, val_data = train_val_split(subkey, data, val_prop=val_prop) - losses = {"train": [], "val": []} - - loop = tqdm(range(max_epochs), disable=not show_progress) - - for _ in loop: - # Shuffle data - key, *subkeys = jr.split(key, 3) - train_data = [jr.permutation(subkeys[0], a) for a in train_data] - val_data = [jr.permutation(subkeys[1], a) for a in val_data] - - # Train epoch - batch_losses = [] - for batch in zip(*get_batches(train_data, batch_size), strict=True): - params, opt_state, loss_i = step( - params, - static, - batch[0][:, :-1], batch[0][:, -1], - optimizer=optimizer, - opt_state=opt_state, - loss_fn=loss_fn, - ) - batch_losses.append(loss_i) - losses["train"].append(sum(batch_losses) / len(batch_losses)) - - # Val epoch - batch_losses = [] - for batch in zip(*get_batches(val_data, batch_size), strict=True): - loss_i = loss_fn(params, static, batch[0][:, :-1], batch[0][:, -1]) - batch_losses.append(loss_i) - losses["val"].append(sum(batch_losses) / len(batch_losses)) - - loop.set_postfix({k: v[-1] for k, v in losses.items()}) - if losses["val"][-1] == min(losses["val"]): - best_params = params - - elif count_fruitless(losses["val"]) > max_patience: - loop.set_postfix_str(f"{loop.postfix} (Max patience reached)") - break - - params = best_params if return_best else params - dist = eqx.combine(params, static) - return dist, losses - - -class WeightedMaximumLikelihoodLoss: - @eqx.filter_jit - def __call__( - self, - params: AbstractDistribution, - static: AbstractDistribution, - x: Array, weights: Array, - condition: Array | None = None, - ) -> Float[Array, ""]: - """Compute the loss.""" - - dist = unwrap(eqx.combine(params, static)) - evl = -dist.log_prob(x, condition) - return (evl*weights).sum()/len(evl)