diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 3f017c5a8f..07028f520a 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -21,6 +21,7 @@ import pytensor.tensor as at from arviz.data.base import make_attrs +from jax.experimental.maps import SerialLoop, xmap from pytensor.compile import SharedVariable, Supervisor, mode from pytensor.graph.basic import graph_inputs from pytensor.graph.fg import FunctionGraph @@ -143,6 +144,27 @@ def _sample_stats_to_xarray(posterior): return data +def _postprocess_samples( + jax_fn: List[TensorVariable], + raw_mcmc_samples: List[TensorVariable], + postprocessing_backend: str, + num_chunks: Optional[int] = None, +) -> List[TensorVariable]: + if num_chunks is not None: + loop = xmap( + jax_fn, + in_axes=["chain", "samples", ...], + out_axes=["chain", "samples", ...], + axis_resources={"samples": SerialLoop(num_chunks)}, + ) + f = xmap(loop, in_axes=[...], out_axes=[...]) + return f(*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])) + else: + return jax.vmap(jax.vmap(jax_fn))( + *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0]) + ) + + def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict: """Extract compatible stats from blackjax NUTS sampler with PyMC/Arviz naming conventions. @@ -177,11 +199,13 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict: return converted_stats -def _get_log_likelihood(model: Model, samples, backend=None) -> Dict: +def _get_log_likelihood( + model: Model, samples, backend=None, num_chunks: Optional[int] = None +) -> Dict: """Compute log-likelihood for all observations""" elemwise_logp = model.logp(model.observed_RVs, sum=False) jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp) - result = jax.vmap(jax.vmap(jax_fn))(*jax.device_put(samples, jax.devices(backend)[0])) + result = _postprocess_samples(jax_fn, samples, backend, num_chunks=num_chunks) return {v.name: r for v, r in zip(model.observed_RVs, result)} @@ -275,6 +299,7 @@ def sample_blackjax_nuts( keep_untransformed: bool = False, chain_method: str = "parallel", postprocessing_backend: Optional[str] = None, + postprocessing_chunks: Optional[int] = None, idata_kwargs: Optional[Dict[str, Any]] = None, ) -> az.InferenceData: """ @@ -314,6 +339,10 @@ def sample_blackjax_nuts( "vectorized". postprocessing_backend : str, optional Specify how postprocessing should be computed. gpu or cpu + postprocessing_chunks: Optional[int], default None + Specify the number of chunks the postprocessing should be computed in. More + chunks reduces memory usage at the cost of losing some vectorization, None + uses jax.vmap idata_kwargs : dict, optional Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value for the ``log_likelihood`` key to indicate that the pointwise log @@ -400,8 +429,8 @@ def sample_blackjax_nuts( print("Transforming variables...", file=sys.stdout) jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) - result = jax.vmap(jax.vmap(jax_fn))( - *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0]) + result = _postprocess_samples( + jax_fn, raw_mcmc_samples, postprocessing_backend, num_chunks=postprocessing_chunks ) mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy) @@ -417,7 +446,10 @@ def sample_blackjax_nuts( tic5 = datetime.now() print("Computing Log Likelihood...", file=sys.stdout) log_likelihood = _get_log_likelihood( - model, raw_mcmc_samples, backend=postprocessing_backend + model, + raw_mcmc_samples, + backend=postprocessing_backend, + num_chunks=postprocessing_chunks, ) tic6 = datetime.now() print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout) @@ -478,6 +510,7 @@ def sample_numpyro_nuts( keep_untransformed: bool = False, chain_method: str = "parallel", postprocessing_backend: Optional[str] = None, + postprocessing_chunks: Optional[int] = None, idata_kwargs: Optional[Dict] = None, nuts_kwargs: Optional[Dict] = None, ) -> az.InferenceData: @@ -522,6 +555,10 @@ def sample_numpyro_nuts( "parallel", and "vectorized". postprocessing_backend : Optional[str] Specify how postprocessing should be computed. gpu or cpu + postprocessing_chunks: Optional[int], default None + Specify the number of chunks the postprocessing should be computed in. More + chunks reduces memory usage at the cost of losing some vectorization, None + uses jax.vmap idata_kwargs : dict, optional Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value for the ``log_likelihood`` key to indicate that the pointwise log @@ -622,8 +659,8 @@ def sample_numpyro_nuts( print("Transforming variables...", file=sys.stdout) jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) - result = jax.vmap(jax.vmap(jax_fn))( - *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0]) + result = _postprocess_samples( + jax_fn, raw_mcmc_samples, postprocessing_backend, num_chunks=postprocessing_chunks ) mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} @@ -639,7 +676,10 @@ def sample_numpyro_nuts( tic5 = datetime.now() print("Computing Log Likelihood...", file=sys.stdout) log_likelihood = _get_log_likelihood( - model, raw_mcmc_samples, backend=postprocessing_backend + model, + raw_mcmc_samples, + backend=postprocessing_backend, + num_chunks=postprocessing_chunks, ) tic6 = datetime.now() print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout) diff --git a/pymc/tests/sampling/test_jax.py b/pymc/tests/sampling/test_jax.py index ab20aca667..cbae94c021 100644 --- a/pymc/tests/sampling/test_jax.py +++ b/pymc/tests/sampling/test_jax.py @@ -55,7 +55,8 @@ def test_old_import_route(): ), ], ) -def test_transform_samples(sampler, postprocessing_backend, chains): +@pytest.mark.parametrize("postprocessing_chunks", [None, 10]) +def test_transform_samples(sampler, postprocessing_backend, chains, postprocessing_chunks): pytensor.config.on_opt_error = "raise" np.random.seed(13244) @@ -71,6 +72,7 @@ def test_transform_samples(sampler, postprocessing_backend, chains): random_seed=1322, keep_untransformed=True, postprocessing_backend=postprocessing_backend, + postprocessing_chunks=postprocessing_chunks, ) log_vals = trace.posterior["sigma_log__"].values