Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow variable mcmc length #365

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions py/dynesty/dynamicsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,8 @@ def results(self):
d = {}
for k in [
'nc', 'v', 'id', 'batch', 'it', 'u', 'n', 'logwt', 'logl',
'logvol', 'logz', 'logzvar', 'h', 'batch_nlive', 'batch_bounds'
'logvol', 'logz', 'logzvar', 'h', 'batch_nlive', 'batch_bounds',
'scale', 'walks',
]:
d[k] = np.array(self.saved_run.D[k])

Expand All @@ -633,6 +634,7 @@ def results(self):
results.append(
('samples_bound', np.array(self.saved_run.D['boundidx'])))
results.append(('scale', np.array(self.saved_run.D['scale'])))
results.append(('walks', np.array(self.saved_run.D['walks'])))

return Results(results)

Expand Down Expand Up @@ -865,7 +867,9 @@ def sample_initial(self,
n=self.nlive_init,
boundidx=results.boundidx,
bounditer=results.bounditer,
scale=self.sampler.scale)
scale=self.sampler.scale,
walks=self.sampler.walks,
)

self.base_run.append(add_info)
self.saved_run.append(add_info)
Expand Down Expand Up @@ -907,7 +911,9 @@ def sample_initial(self,
n=self.nlive_init - it,
boundidx=results.boundidx,
bounditer=results.bounditer,
scale=self.sampler.scale)
scale=self.sampler.scale,
walks=self.sampler.walks,
)

self.base_run.append(add_info)
self.saved_run.append(add_info)
Expand Down Expand Up @@ -1048,6 +1054,7 @@ def sample_batch(self,
saved_logl = np.array(self.saved_run.D['logl'])
saved_logvol = np.array(self.saved_run.D['logvol'])
saved_scale = np.array(self.saved_run.D['scale'])
saved_walks = np.array(self.saved_run.D['walks'])
nblive = self.nlive_init

update_interval = self.__get_update_interval(update_interval,
Expand Down Expand Up @@ -1141,6 +1148,7 @@ def sample_batch(self,
self.new_logl_min = logl_min

live_scale = saved_scale[subset0[0]]
live_walks = saved_walks[subset0[0]]
# set the scale based on the lowest point

# we are weighting each point by X_i to ensure
Expand Down Expand Up @@ -1188,6 +1196,7 @@ def sample_batch(self,
batch_sampler.live_v = live_v
batch_sampler.live_logl = live_logl
batch_sampler.scale = live_scale
batch_sampler.walks = live_walks

# Trigger an update of the internal bounding distribution based
# on the "new" set of live points.
Expand Down Expand Up @@ -1292,7 +1301,9 @@ def sample_batch(self,
n=nlive_new,
boundidx=results.boundidx,
bounditer=results.bounditer,
scale=batch_sampler.scale)
scale=batch_sampler.scale,
walks=batch_sampler.walks,
)
self.new_run.append(D)

# Increment relevant counters.
Expand Down Expand Up @@ -1328,7 +1339,9 @@ def sample_batch(self,
n=nlive_new - it,
boundidx=results.boundidx,
bounditer=results.bounditer,
scale=batch_sampler.scale)
scale=batch_sampler.scale,
walks=batch_sampler.walks,
)
self.new_run.append(D)

# Increment relevant counters.
Expand Down Expand Up @@ -1358,7 +1371,7 @@ def combine_runs(self):

for k in [
'id', 'u', 'v', 'logl', 'nc', 'boundidx', 'it', 'bounditer',
'n', 'scale'
'n', 'scale', 'walks',
]:
saved_d[k] = np.array(self.saved_run.D[k])
new_d[k] = np.array(self.new_run.D[k])
Expand Down Expand Up @@ -1411,7 +1424,7 @@ def combine_runs(self):

for k in [
'id', 'u', 'v', 'logl', 'nc', 'boundidx', 'it',
'bounditer', 'scale'
'bounditer', 'scale', 'walks',
]:
add_info[k] = add_source[k][add_idx]
self.saved_run.append(add_info)
Expand Down
36 changes: 34 additions & 2 deletions py/dynesty/dynesty.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,13 @@ def NestedSampler(loglikelihood,
update_func=None,
ncdim=None,
save_history=False,
history_filename=None):
history_filename=None,
adapt_scale=True,
adapt_walks=False,
adapt_time=None,
max_walks=1000,
target_accept=None,
):
"""
Initializes and returns a sampler object for Static Nested Sampling.

Expand Down Expand Up @@ -506,6 +512,10 @@ def prior_transform(u):
if update_func is not None and not callable(update_func):
raise ValueError("Unknown update function: '{0}'".format(update_func))
kwargs['update_func'] = update_func
kwargs['adapt_scale'] = adapt_scale
kwargs['adapt_walks'] = adapt_walks
kwargs['adapt_time'] = adapt_time
kwargs['max_walks'] = max_walks

# Citation generator.
kwargs['cite'] = _get_citations('static', bound, sample)
Expand Down Expand Up @@ -561,6 +571,14 @@ def prior_transform(u):
kwargs['fmove'] = fmove
if max_move is not None:
kwargs['max_move'] = max_move
if adapt_walks is not None:
kwargs['adapt_time'] = adapt_time
if max_walks is not None:
kwargs['max_walks'] = max_walks
if target_accept is not None:
kwargs['target_accept'] = target_accept
kwargs['adapt_scale'] = adapt_scale
kwargs['adapt_walks'] = adapt_walks

update_interval_ratio = _get_update_interval_ratio(update_interval, sample,
bound, ndim, nlive,
Expand Down Expand Up @@ -662,7 +680,13 @@ def DynamicNestedSampler(loglikelihood,
update_func=None,
ncdim=None,
save_history=False,
history_filename=None):
history_filename=None,
adapt_scale=True,
adapt_walks=False,
adapt_time=None,
max_walks=None,
target_accept=None,
):
"""
Initializes and returns a sampler object for Dynamic Nested Sampling.

Expand Down Expand Up @@ -959,6 +983,14 @@ def prior_transform(u):
kwargs['fmove'] = fmove
if max_move is not None:
kwargs['max_move'] = max_move
if adapt_walks is not None:
kwargs['adapt_time'] = adapt_time
if max_walks is not None:
kwargs['max_walks'] = max_walks
if target_accept is not None:
kwargs['target_accept'] = target_accept
kwargs['adapt_scale'] = adapt_scale
kwargs['adapt_walks'] = adapt_walks

# Set up parallel (or serial) evaluation.
queue_size = _parse_pool_queue(pool, queue_size)[1]
Expand Down
29 changes: 29 additions & 0 deletions py/dynesty/nestedsamplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,18 @@ def __init__(self,
self.compute_jac = self.kwargs.get('compute_jac', False)

# Initialize random walk parameters.
self.adapt_walks = self.kwargs.get("adapt_walks", True)
self.walks = max(2, self.kwargs.get('walks', 25))
self.max_walks = self.kwargs.get("max_walks", 1000)
self.adapt_scale = self.kwargs.get("adapt_scale", True)
self.facc = self.kwargs.get('facc', 0.5)
self.facc = min(1., max(1. / self.walks, self.facc))
self.adapt_time = self.kwargs.get("adapt_time", None)
if self.adapt_time is None:
self.adapt_time = self.nlive / 5
self.target_accept = self.kwargs.get("target_accept", None)
if self.target_accept is None:
self.target_accept = self.walks * self.facc

# Initialize slice parameters.
self.slices = self.kwargs.get('slices', 5)
Expand All @@ -148,6 +157,12 @@ def update_unif(self, blob):
pass

def update_rwalk(self, blob):
if self.adapt_scale:
self.update_rwalk_scale(blob)
if self.adapt_walks:
self.update_rwalk_walks(blob)

def update_rwalk_scale(self, blob):
"""Update the random walk proposal scale based on the current
number of accepted/rejected steps.
For rwalk the scale is important because it
Expand All @@ -173,6 +188,20 @@ def update_rwalk(self, blob):
# here because our coefficients a_k do not obey \sum a_k^2 = \infty
self.scale *= math.exp((facc - self.facc) / self.ncdim / self.facc)

def update_rwalk_walks(self, blob):
"""Update the number of MCMC steps taken with the rwalk method.
This tries to keep the number of accepted steps at each iteration
approximately constant.
"""
accept = blob["accept"]
if accept == 0:
factor = 1.25
else:
factor = (self.target_accept / accept) ** (1 / self.adapt_time)
estimated_steps = self.walks * factor
self.walks = max(min([self.max_walks, estimated_steps]), self.target_accept)
self.kwargs["walks"] = int(self.walks)

def update_slice(self, blob):
"""Update the slice proposal scale based on the relative
size of the slices compared to our initial guess.
Expand Down
3 changes: 2 additions & 1 deletion py/dynesty/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ def print_fn_fallback(results,
('batch_nlive', 'array[int]',
"The number of live points added in a given batch ???"
"How is it different from samples_n", 'nbatch???'),
('scale', 'array[float]', "Scalar scale applied for proposals", 'niter')
('scale', 'array[float]', "Scalar scale applied for proposals", 'niter'),
('walks', 'array[float]', "MCMC chain length for rwalk", 'niter'),
]


Expand Down
9 changes: 7 additions & 2 deletions py/dynesty/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(self, loglikelihood, prior_transform, npdim, live_points,

# set to none just for qa
self.scale = None
self.walks = None
self.method = None
self.kwargs = {}

Expand Down Expand Up @@ -466,7 +467,9 @@ def add_live_points(self):
boundidx=boundidx,
it=point_it,
bounditer=bounditer,
scale=self.scale))
scale=self.scale,
walks=self.walks,
))
self.eff = 100. * (self.it + i) / self.ncall # efficiency

# Return our new "dead" point and ancillary quantities.
Expand Down Expand Up @@ -768,7 +771,9 @@ def sample(self,
nc=nc,
it=worst_it,
bounditer=bounditer,
scale=self.scale))
scale=self.scale,
walks=self.walks,
))

# Update the live point (previously our "worst" point).
self.live_u[worst] = u
Expand Down
3 changes: 2 additions & 1 deletion py/dynesty/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def __init__(self, dynamic=False):
'it', # iteration the live (now dead) point was proposed
'n', # number of live points interior to dead point
'bounditer', # active bound at a specific iteration
'scale' # scale factor at each iteration
'scale', # scale factor at each iteration
'walks', # number of steps taken at each iteration
]
if dynamic:
keys.extend([
Expand Down
47 changes: 47 additions & 0 deletions tests/test_adapt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
import dynesty
import pytest
import itertools
from utils import get_rstate, get_printing
"""
Run a series of basic tests of the 2d eggbox
"""

nlive = 1000
printing = get_printing()

# EGGBOX


# see 1306.2144
def loglike_egg(x):
logl = ((2 + np.cos(x[0] / 2) * np.cos(x[1] / 2))**5)
return logl


def prior_transform_egg(x):
return x * 10 * np.pi


@pytest.mark.parametrize(
"scale,walks",
itertools.product([True, False], [True, False])
)
def test_adapt(scale, walks):
# stress test various boundaries
ndim = 2
rstate = get_rstate()
sampler = dynesty.NestedSampler(loglike_egg,
prior_transform_egg,
ndim,
nlive=nlive,
bound="single",
sample="rwalk",
rstate=rstate,
adapt_scale=scale,
adapt_walks=walks,
)
sampler.run_nested(dlogz=0.01, print_progress=printing)
logz_truth = 235.856
assert (abs(logz_truth - sampler.results.logz[-1]) <
5. * sampler.results.logzerr[-1])