diff --git a/appletree/context.py b/appletree/context.py index c3602fc5..4c021746 100644 --- a/appletree/context.py +++ b/appletree/context.py @@ -13,7 +13,7 @@ import appletree as apt from appletree import randgen from appletree import Parameter -from appletree.utils import load_json +from appletree.utils import load_json, get_file_path from appletree.share import _cached_configs, set_global_config os.environ["OMP_NUM_THREADS"] = "1" @@ -55,6 +55,17 @@ def __init__(self, instruct, par_config=None): self.register_all_likelihood(instruct) + @classmethod + def from_backend(cls, backend_h5_file_name): + """Initialize context from a backend_h5 file.""" + with h5py.File(get_file_path(backend_h5_file_name)) as file: + instruct = eval(file["mcmc"].attrs["instruct"]) + nwalkers = file["mcmc"].attrs["nwalkers"] + batch_size = file["mcmc"].attrs["batch_size"] + tree = cls(instruct) + tree.pre_fitting(nwalkers, batch_size=batch_size) + return tree + def __getitem__(self, keys): """Get likelihood in context.""" return self.likelihoods[keys] @@ -174,19 +185,19 @@ def log_posterior(self, parameters, batch_size=1_000_000): def _ndim(self): return len(self.par_manager.parameter_fit_array) - def _set_backend(self, nwalkers=100, read_only=True): + def _set_backend(self, nwalkers=100, read_only=True, reset=False): if self.backend_h5 is None: self._backend = None print("With no backend") else: self._backend = emcee.backends.HDFBackend(self.backend_h5, read_only=read_only) - if not read_only: + if reset: self._backend.reset(nwalkers, self._ndim) print(f"With h5 backend {self.backend_h5}") - def pre_fitting(self, nwalkers=100, read_only=True, batch_size=1_000_000): + def pre_fitting(self, nwalkers=100, read_only=True, reset=False, batch_size=1_000_000): """Prepare for fitting, initialize backend and sampler.""" - self._set_backend(nwalkers, read_only=read_only) + self._set_backend(nwalkers, read_only=read_only, reset=reset) self.sampler = emcee.EnsembleSampler( nwalkers, self._ndim, @@ -212,7 +223,7 @@ def fitting(self, nwalkers=200, iteration=500, batch_size=1_000_000): self.par_manager.sample_init() p0.append(self.par_manager.parameter_fit_array) - self.pre_fitting(nwalkers=nwalkers, read_only=False, batch_size=batch_size) + self.pre_fitting(nwalkers=nwalkers, read_only=False, reset=True, batch_size=batch_size) result = self.sampler.run_mcmc( p0, @@ -221,10 +232,10 @@ def fitting(self, nwalkers=200, iteration=500, batch_size=1_000_000): progress=True, ) - self._dump_meta() + self._dump_meta(batch_size=batch_size) return result - def continue_fitting(self, context, iteration=500, batch_size=1_000_000): + def continue_fitting(self, context=None, iteration=500, batch_size=1_000_000): """Continue a fitting of another context. Args: @@ -232,14 +243,19 @@ def continue_fitting(self, context, iteration=500, batch_size=1_000_000): iteration: int, number of steps to generate. """ - # Final iteration - final_iteration = context.sampler.get_chain()[-1, :, :] - p0 = final_iteration.tolist() + # If context is None, use self, i.e. continue the fitting defined in self + if context is None: + context = self + p0 = None + else: + # Final iteration + final_iteration = context.sampler.get_chain()[-1, :, :] + p0 = final_iteration.tolist() - nwalkers = len(p0) + nwalkers = context.sampler.get_chain().shape[1] # Init sampler for current context - self.pre_fitting(nwalkers=nwalkers, read_only=False, batch_size=batch_size) + self.pre_fitting(nwalkers=nwalkers, read_only=False, reset=False, batch_size=batch_size) result = self.sampler.run_mcmc( p0, @@ -249,7 +265,7 @@ def continue_fitting(self, context, iteration=500, batch_size=1_000_000): skip_initial_state_check=True, ) - self._dump_meta() + self._dump_meta(batch_size=batch_size) return result def get_post_parameters(self): @@ -276,7 +292,7 @@ def dump_post_parameters(self, file_name): with open(file_name, "w") as fp: json.dump(parameters, fp) - def _dump_meta(self, metadata=None): + def _dump_meta(self, batch_size, metadata=None): """Save parameters name as attributes.""" if metadata is None: metadata = { @@ -299,6 +315,8 @@ def _dump_meta(self, metadata=None): opt[name].attrs["config"] = json.dumps(self.config) # configurations, maybe users will manually add some maps opt[name].attrs["_cached_configs"] = json.dumps(_cached_configs) + # batch size + opt[name].attrs["batch_size"] = batch_size def get_template( self, diff --git a/appletree/utils.py b/appletree/utils.py index c5340ef6..9240864e 100644 --- a/appletree/utils.py +++ b/appletree/utils.py @@ -149,9 +149,9 @@ def get_file_path(fname): * can be downloaded from MongoDB, download and return cached path """ - # 1. From absolute path + # 1. From absolute path if file exists # Usually Config.default is a absolute path - if fname.startswith("/"): + if os.path.isfile(fname): return fname # 2. From local folder diff --git a/tests/test_context.py b/tests/test_context.py index 6e462433..63ffdad4 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -84,3 +84,19 @@ def test_literature_context(): parameters = context.get_post_parameters() context.get_num_events_accepted(parameters, batch_size=batch_size) check_unused_configs() + + +def test_backend(): + """Test backend, initialize from backend and continue fitting.""" + _cached_functions.clear() + _cached_configs.clear() + instruct = apt.utils.load_json("rn220.json") + instruct["backend_h5"] = "test_backend.h5" + context = apt.Context(instruct) + context.fitting(nwalkers=100, iteration=2, batch_size=int(1e4)) + + _cached_functions.clear() + _cached_configs.clear() + context = apt.Context.from_backend("test_backend.h5") + context.continue_fitting(iteration=2, batch_size=int(1e4)) + assert context.sampler.get_chain().shape[0] == 4