Skip to content

Commit

Permalink
added postprocessing_chunks option to sample_blackjax_nuts and sample… (
Browse files Browse the repository at this point in the history
#6388)

* added postprocessing_chunks option to sample_blackjax_nuts and sample_numpyro_nuts

* make chunking optional, add chunking argument to _get_loglikelihood

* update tests for jax postprocessing chunking

* update docs

* Run pre-commit

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
wnorcbrown and github-actions[bot] authored Dec 19, 2022
1 parent 98ccc68 commit f231d13
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 9 deletions.
56 changes: 48 additions & 8 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytensor.tensor as at

from arviz.data.base import make_attrs
from jax.experimental.maps import SerialLoop, xmap
from pytensor.compile import SharedVariable, Supervisor, mode
from pytensor.graph.basic import graph_inputs
from pytensor.graph.fg import FunctionGraph
Expand Down Expand Up @@ -143,6 +144,27 @@ def _sample_stats_to_xarray(posterior):
return data


def _postprocess_samples(
jax_fn: List[TensorVariable],
raw_mcmc_samples: List[TensorVariable],
postprocessing_backend: str,
num_chunks: Optional[int] = None,
) -> List[TensorVariable]:
if num_chunks is not None:
loop = xmap(
jax_fn,
in_axes=["chain", "samples", ...],
out_axes=["chain", "samples", ...],
axis_resources={"samples": SerialLoop(num_chunks)},
)
f = xmap(loop, in_axes=[...], out_axes=[...])
return f(*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0]))
else:
return jax.vmap(jax.vmap(jax_fn))(
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
)


def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
"""Extract compatible stats from blackjax NUTS sampler
with PyMC/Arviz naming conventions.
Expand Down Expand Up @@ -177,11 +199,13 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
return converted_stats


def _get_log_likelihood(model: Model, samples, backend=None) -> Dict:
def _get_log_likelihood(
model: Model, samples, backend=None, num_chunks: Optional[int] = None
) -> Dict:
"""Compute log-likelihood for all observations"""
elemwise_logp = model.logp(model.observed_RVs, sum=False)
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp)
result = jax.vmap(jax.vmap(jax_fn))(*jax.device_put(samples, jax.devices(backend)[0]))
result = _postprocess_samples(jax_fn, samples, backend, num_chunks=num_chunks)
return {v.name: r for v, r in zip(model.observed_RVs, result)}


Expand Down Expand Up @@ -275,6 +299,7 @@ def sample_blackjax_nuts(
keep_untransformed: bool = False,
chain_method: str = "parallel",
postprocessing_backend: Optional[str] = None,
postprocessing_chunks: Optional[int] = None,
idata_kwargs: Optional[Dict[str, Any]] = None,
) -> az.InferenceData:
"""
Expand Down Expand Up @@ -314,6 +339,10 @@ def sample_blackjax_nuts(
"vectorized".
postprocessing_backend : str, optional
Specify how postprocessing should be computed. gpu or cpu
postprocessing_chunks: Optional[int], default None
Specify the number of chunks the postprocessing should be computed in. More
chunks reduces memory usage at the cost of losing some vectorization, None
uses jax.vmap
idata_kwargs : dict, optional
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
value for the ``log_likelihood`` key to indicate that the pointwise log
Expand Down Expand Up @@ -400,8 +429,8 @@ def sample_blackjax_nuts(

print("Transforming variables...", file=sys.stdout)
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = jax.vmap(jax.vmap(jax_fn))(
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
result = _postprocess_samples(
jax_fn, raw_mcmc_samples, postprocessing_backend, num_chunks=postprocessing_chunks
)
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy)
Expand All @@ -417,7 +446,10 @@ def sample_blackjax_nuts(
tic5 = datetime.now()
print("Computing Log Likelihood...", file=sys.stdout)
log_likelihood = _get_log_likelihood(
model, raw_mcmc_samples, backend=postprocessing_backend
model,
raw_mcmc_samples,
backend=postprocessing_backend,
num_chunks=postprocessing_chunks,
)
tic6 = datetime.now()
print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout)
Expand Down Expand Up @@ -478,6 +510,7 @@ def sample_numpyro_nuts(
keep_untransformed: bool = False,
chain_method: str = "parallel",
postprocessing_backend: Optional[str] = None,
postprocessing_chunks: Optional[int] = None,
idata_kwargs: Optional[Dict] = None,
nuts_kwargs: Optional[Dict] = None,
) -> az.InferenceData:
Expand Down Expand Up @@ -522,6 +555,10 @@ def sample_numpyro_nuts(
"parallel", and "vectorized".
postprocessing_backend : Optional[str]
Specify how postprocessing should be computed. gpu or cpu
postprocessing_chunks: Optional[int], default None
Specify the number of chunks the postprocessing should be computed in. More
chunks reduces memory usage at the cost of losing some vectorization, None
uses jax.vmap
idata_kwargs : dict, optional
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
value for the ``log_likelihood`` key to indicate that the pointwise log
Expand Down Expand Up @@ -622,8 +659,8 @@ def sample_numpyro_nuts(

print("Transforming variables...", file=sys.stdout)
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = jax.vmap(jax.vmap(jax_fn))(
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
result = _postprocess_samples(
jax_fn, raw_mcmc_samples, postprocessing_backend, num_chunks=postprocessing_chunks
)
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

Expand All @@ -639,7 +676,10 @@ def sample_numpyro_nuts(
tic5 = datetime.now()
print("Computing Log Likelihood...", file=sys.stdout)
log_likelihood = _get_log_likelihood(
model, raw_mcmc_samples, backend=postprocessing_backend
model,
raw_mcmc_samples,
backend=postprocessing_backend,
num_chunks=postprocessing_chunks,
)
tic6 = datetime.now()
print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout)
Expand Down
4 changes: 3 additions & 1 deletion pymc/tests/sampling/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def test_old_import_route():
),
],
)
def test_transform_samples(sampler, postprocessing_backend, chains):
@pytest.mark.parametrize("postprocessing_chunks", [None, 10])
def test_transform_samples(sampler, postprocessing_backend, chains, postprocessing_chunks):
pytensor.config.on_opt_error = "raise"
np.random.seed(13244)

Expand All @@ -71,6 +72,7 @@ def test_transform_samples(sampler, postprocessing_backend, chains):
random_seed=1322,
keep_untransformed=True,
postprocessing_backend=postprocessing_backend,
postprocessing_chunks=postprocessing_chunks,
)

log_vals = trace.posterior["sigma_log__"].values
Expand Down

0 comments on commit f231d13

Please sign in to comment.