Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize context from backend #156

Merged
merged 5 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions appletree/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import appletree as apt
from appletree import randgen
from appletree import Parameter
from appletree.utils import load_json
from appletree.utils import load_json, get_file_path
from appletree.share import _cached_configs, set_global_config

os.environ["OMP_NUM_THREADS"] = "1"
Expand Down Expand Up @@ -55,6 +55,17 @@ def __init__(self, instruct, par_config=None):

self.register_all_likelihood(instruct)

@classmethod
def from_backend(cls, backend_h5_file_name):
"""Initialize context from a backend_h5 file."""
with h5py.File(get_file_path(backend_h5_file_name)) as file:
instruct = eval(file["mcmc"].attrs["instruct"])
nwalkers = file["mcmc"].attrs["nwalkers"]
batch_size = file["mcmc"].attrs["batch_size"]
tree = cls(instruct)
tree.pre_fitting(nwalkers, batch_size=batch_size)
return tree

def __getitem__(self, keys):
"""Get likelihood in context."""
return self.likelihoods[keys]
Expand Down Expand Up @@ -174,19 +185,19 @@ def log_posterior(self, parameters, batch_size=1_000_000):
def _ndim(self):
return len(self.par_manager.parameter_fit_array)

def _set_backend(self, nwalkers=100, read_only=True):
def _set_backend(self, nwalkers=100, read_only=True, reset=False):
if self.backend_h5 is None:
self._backend = None
print("With no backend")
else:
self._backend = emcee.backends.HDFBackend(self.backend_h5, read_only=read_only)
if not read_only:
if reset:
self._backend.reset(nwalkers, self._ndim)
print(f"With h5 backend {self.backend_h5}")

def pre_fitting(self, nwalkers=100, read_only=True, batch_size=1_000_000):
def pre_fitting(self, nwalkers=100, read_only=True, reset=False, batch_size=1_000_000):
"""Prepare for fitting, initialize backend and sampler."""
self._set_backend(nwalkers, read_only=read_only)
self._set_backend(nwalkers, read_only=read_only, reset=reset)
self.sampler = emcee.EnsembleSampler(
nwalkers,
self._ndim,
Expand All @@ -212,7 +223,7 @@ def fitting(self, nwalkers=200, iteration=500, batch_size=1_000_000):
self.par_manager.sample_init()
p0.append(self.par_manager.parameter_fit_array)

self.pre_fitting(nwalkers=nwalkers, read_only=False, batch_size=batch_size)
self.pre_fitting(nwalkers=nwalkers, read_only=False, reset=True, batch_size=batch_size)

result = self.sampler.run_mcmc(
p0,
Expand All @@ -221,25 +232,30 @@ def fitting(self, nwalkers=200, iteration=500, batch_size=1_000_000):
progress=True,
)

self._dump_meta()
self._dump_meta(batch_size=batch_size)
return result

def continue_fitting(self, context, iteration=500, batch_size=1_000_000):
def continue_fitting(self, context=None, iteration=500, batch_size=1_000_000):
"""Continue a fitting of another context.

Args:
context: appletree context.
iteration: int, number of steps to generate.

"""
# Final iteration
final_iteration = context.sampler.get_chain()[-1, :, :]
p0 = final_iteration.tolist()
# If context is None, use self, i.e. continue the fitting defined in self
if context is None:
context = self
p0 = None
else:
# Final iteration
final_iteration = context.sampler.get_chain()[-1, :, :]
p0 = final_iteration.tolist()

nwalkers = len(p0)
nwalkers = context.sampler.get_chain().shape[1]

# Init sampler for current context
self.pre_fitting(nwalkers=nwalkers, read_only=False, batch_size=batch_size)
self.pre_fitting(nwalkers=nwalkers, read_only=False, reset=False, batch_size=batch_size)

result = self.sampler.run_mcmc(
p0,
Expand All @@ -249,7 +265,7 @@ def continue_fitting(self, context, iteration=500, batch_size=1_000_000):
skip_initial_state_check=True,
)

self._dump_meta()
self._dump_meta(batch_size=batch_size)
return result

def get_post_parameters(self):
Expand All @@ -276,7 +292,7 @@ def dump_post_parameters(self, file_name):
with open(file_name, "w") as fp:
json.dump(parameters, fp)

def _dump_meta(self, metadata=None):
def _dump_meta(self, batch_size, metadata=None):
"""Save parameters name as attributes."""
if metadata is None:
metadata = {
Expand All @@ -299,6 +315,8 @@ def _dump_meta(self, metadata=None):
opt[name].attrs["config"] = json.dumps(self.config)
# configurations, maybe users will manually add some maps
opt[name].attrs["_cached_configs"] = json.dumps(_cached_configs)
# batch size
opt[name].attrs["batch_size"] = batch_size

def get_template(
self,
Expand Down
4 changes: 2 additions & 2 deletions appletree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def get_file_path(fname):
* can be downloaded from MongoDB, download and return cached path

"""
# 1. From absolute path
# 1. From absolute path if file exists
# Usually Config.default is a absolute path
if fname.startswith("/"):
if os.path.isfile(fname):
return fname

# 2. From local folder
Expand Down
16 changes: 16 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,19 @@ def test_literature_context():
parameters = context.get_post_parameters()
context.get_num_events_accepted(parameters, batch_size=batch_size)
check_unused_configs()


def test_backend():
"""Test backend, initialize from backend and continue fitting."""
_cached_functions.clear()
_cached_configs.clear()
instruct = apt.utils.load_json("rn220.json")
instruct["backend_h5"] = "test_backend.h5"
context = apt.Context(instruct)
context.fitting(nwalkers=100, iteration=2, batch_size=int(1e4))

_cached_functions.clear()
_cached_configs.clear()
context = apt.Context.from_backend("test_backend.h5")
context.continue_fitting(iteration=2, batch_size=int(1e4))
assert context.sampler.get_chain().shape[0] == 4
Loading