diff --git a/prospect/fitting/fitting.py b/prospect/fitting/fitting.py index 70e8cf34..aa4c0a22 100755 --- a/prospect/fitting/fitting.py +++ b/prospect/fitting/fitting.py @@ -430,7 +430,7 @@ def run_nested(observations, model, sps, lnprobfn=lnprobfn, nested_sampler="dynesty", nested_nlive=1000, - nested_neff=1000, + nested_target_n_effective=1000, verbose=False, **kwargs): """Thin wrapper on :py:class:`prospect.fitting.nested.run_nested_sampler` @@ -454,6 +454,13 @@ def run_nested(observations, model, sps, ``model``, and ``sps`` as keywords. By default use the :py:func:`lnprobfn` defined above. + nested_target_n_effective : int + Target number of effective samples + + nested_nlive : int + Number of live points for the nested sampler. Meaning somewhat + dependent on the chosen sampler + Returns -------- result: Dictionary @@ -476,7 +483,7 @@ def run_nested(observations, model, sps, nested_sampler=nested_sampler, verbose=verbose, nested_nlive=nested_nlive, - nested_neff=nested_neff, + nested_neff=nested_target_n_effective, nested_sampler_kwargs=ns_kwargs, nested_run_kwargs=nr_kwargs) info, result_obj = output diff --git a/prospect/fitting/nested.py b/prospect/fitting/nested.py index 3b467bf8..f0f25185 100644 --- a/prospect/fitting/nested.py +++ b/prospect/fitting/nested.py @@ -53,6 +53,8 @@ def run_nested_sampler(model, obj : Object The sampling object. This will depend on the nested sampler being used. """ + if verbose: + print(f"running {nested_sampler} for {nested_neff} effective samples") go = time.time() diff --git a/tests/tests_samplers.py b/tests/tests_samplers.py index 6bf59582..8c47d85f 100644 --- a/tests/tests_samplers.py +++ b/tests/tests_samplers.py @@ -13,15 +13,15 @@ from prospect.fitting import fit_model from prospect.fitting.nested import parse_nested_kwargs from prospect.io.write_results import write_hdf5 - +from prospect.io.read_results import results_from #@pytest.fixture -def get_sps(): +def get_sps(**kwargs): sps = CSPSpecBasis(zcontinuous=1) return sps -def build_model(add_neb=False, add_outlier=False): +def build_model(add_neb=False, add_outlier=False, **kwargs): model_params = templates.TemplateLibrary["parametric_sfh"] model_params["logzsol"]["isfree"] = False # built for speed if add_neb: # skip for speed @@ -63,10 +63,11 @@ def build_obs(**kwargs): parser = get_parser() parser.set_defaults(nested_target_n_effective=256, + nested_nlive=512, verbose=0) args = parser.parse_args() run_params = vars(args) - run_params["parameter_file"] = __file__ + run_params["param_file"] = __file__ # test the parsing run_params["nested_sampler"] = "dynesty" @@ -99,11 +100,14 @@ def build_obs(**kwargs): results[sampler], None, sps=sps) + ires, iobs, im = results_from(hfile) + assert (im is not None) + # compare runtime for sampler in samplers: print(sampler, results[sampler]["duration"]) - + # compare posteriors colors = ["royalblue", "darkorange", "firebrick"] import matplotlib.pyplot as pl from prospect.plotting import corner