Skip to content

Commit

Permalink
More options for get_post_parameters (#193)
Browse files Browse the repository at this point in the history
* add function to allow user sample from walker file

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add more doc

---------

Co-authored-by: Dacheng Xu <dx2227@columbia.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 26, 2024
1 parent 6c0cb68 commit 5ed0744
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions appletree/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,17 +272,36 @@ def continue_fitting(self, context=None, iteration=500, batch_size=1_000_000):
self._dump_meta(batch_size=batch_size)
return result

def get_post_parameters(self):
"""Get parameters correspondes to max posterior."""
logp = self.sampler.get_log_prob(flat=True)
chain = self.sampler.get_chain(flat=True)
mpe_parameters = chain[np.argmax(logp)]
mpe_parameters = emcee.ensemble.ndarray_to_list_of_dicts(
[mpe_parameters],
def get_post_parameters(self, which="mpe"):
"""Get parameters from the backend.
Args:
which: str, 'mpe', 'random' or 'median'. 'mpe' is the maximum posterior estimate,
i.e. the parameter set with the highest posterior value. 'random' returns a
random parameter set from the posterior distribution. 'median' is the marginal medians.
"""
# Assign attributes for the first time
# This speeds up if the user wanna call this function many times
if not hasattr(self, "_logp"):
self._logp = self.sampler.get_log_prob(flat=True)
if not hasattr(self, "_chain"):
self._chain = self.sampler.get_chain(flat=True)
if which == "mpe":
_parameters = self._chain[np.argmax(self._logp)]
elif which == "random":
_parameters = self._chain[np.random.randint(len(self._logp))]
elif which == "median":
_parameters = np.median(self._chain, axis=0)
else:
raise ValueError(f"which should be 'mpe', 'random' or 'median', got {which}!")

_parameters = emcee.ensemble.ndarray_to_list_of_dicts(
[_parameters],
self.sampler.parameter_names,
)[0]
parameters = copy.deepcopy(self.par_manager.get_all_parameter())
parameters.update(mpe_parameters)
parameters.update(_parameters)
return parameters

def get_all_post_parameters(self, **kwargs):
Expand Down

0 comments on commit 5ed0744

Please sign in to comment.