Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cover more priors with empirical distribution extension function #184

Merged
merged 9 commits into from
Dec 21, 2022
Merged
1 change: 1 addition & 0 deletions enterprise_extensions/empirical_distr.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def make_empirical_distributions(pta, paramlist, params, chain,

if len(pl) == 1:
idx = pta.param_names.index(pl[0])

prior_min = pta.params[idx].prior._defaults['pmin']
prior_max = pta.params[idx].prior._defaults['pmax']

Expand Down
69 changes: 51 additions & 18 deletions enterprise_extensions/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,35 @@
EmpiricalDistribution2DKDE)


def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outdir='chains'):
def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outdir='./chains'):
new_emp_dists = []
modified = False # check if anything was changed
for emp_dist in emp_dists:
if isinstance(emp_dist, EmpiricalDistribution2D) or isinstance(emp_dist, EmpiricalDistribution2DKDE):
# check if we need to extend the distribution
prior_ok=True
for ii, (param, nbins) in enumerate(zip(emp_dist.param_names, emp_dist._Nbins)):
if param not in pta.param_names: # skip if one of the parameters isn't in our PTA object
continue
param_names = [par.name for par in pta.params]
if param not in param_names: # skip if one of the parameters isn't in our PTA object
short_par = '_'.join(param.split('_')[:-1]) # make sure we aren't skipping priors with size!=None
if short_par in param_names:
param = short_par
else:
continue
# check 2 conditions on both params to make sure that they cover their priors
# skip if emp dist already covers the prior
param_idx = pta.param_names.index(param)
prior_min = pta.params[param_idx].prior._defaults['pmin']
prior_max = pta.params[param_idx].prior._defaults['pmax']
param_idx = param_names.index(param)
if pta.params[param_idx].type not in ['uniform', 'normal']:
msg = '{} cannot be covered automatically by the empirical distribution\n'.format(pta.params[param_idx].prior)
msg += 'Please check that your prior is covered by the empirical distribution.\n'
print(msg)
continue
elif pta.params[param_idx].type == 'uniform':
prior_min = pta.params[param_idx].prior._defaults['pmin']
prior_max = pta.params[param_idx].prior._defaults['pmax']
elif pta.params[param_idx].type == 'normal':
prior_min = pta.params[param_idx].prior._defaults['mu'] - 10 * pta.params[param_idx].prior._defaults['sigma']
prior_max = pta.params[param_idx].prior._defaults['mu'] + 10 * pta.params[param_idx].prior._defaults['sigma']

# no need to extend if histogram edges are already prior min/max
if isinstance(emp_dist, EmpiricalDistribution2D):
Expand All @@ -53,9 +67,13 @@ def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outd
maxvals = []
idxs_to_remove = []
for ii, (param, nbins) in enumerate(zip(emp_dist.param_names, emp_dist._Nbins)):
param_idx = pta.param_names.index(param)
prior_min = pta.params[param_idx].prior._defaults['pmin']
prior_max = pta.params[param_idx].prior._defaults['pmax']
param_idx = param_names.index(param)
if pta.params[param_idx].type == 'uniform':
prior_min = pta.params[param_idx].prior._defaults['pmin']
prior_max = pta.params[param_idx].prior._defaults['pmax']
elif pta.params[param_idx].type == 'normal':
prior_min = pta.params[param_idx].prior._defaults['mu'] - 10 * pta.params[param_idx].prior._defaults['sigma']
prior_max = pta.params[param_idx].prior._defaults['mu'] + 10 * pta.params[param_idx].prior._defaults['sigma']
# drop samples that are outside the prior range (in case prior is smaller than samples)
if isinstance(emp_dist, EmpiricalDistribution2D):
samples[(samples[:, ii] < prior_min) | (samples[:, ii] > prior_max), ii] = -np.inf
Expand All @@ -74,11 +92,27 @@ def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outd
new_emp_dists.append(new_emp)

elif isinstance(emp_dist, EmpiricalDistribution1D) or isinstance(emp_dist, EmpiricalDistribution1DKDE):
if emp_dist.param_name not in pta.param_names:
param_names = [par.name for par in pta.params]
if emp_dist.param_name not in param_names: # skip if one of the parameters isn't in our PTA object
short_par = '_'.join(emp_dist.param_name.split('_')[:-1]) # make sure we aren't skipping priors with size!=None
if short_par in param_names:
param = short_par
else:
continue
else:
param = emp_dist.param_name
param_idx = param_names.index(param)
if pta.params[param_idx].type not in ['uniform', 'normal']:
msg = 'This prior cannot be covered automatically by the empirical distribution\n'
msg += 'Please check that your prior is covered by the empirical distribution.\n'
print(msg)
continue
param_idx = pta.param_names.index(emp_dist.param_name)
prior_min = pta.params[param_idx].prior._defaults['pmin']
prior_max = pta.params[param_idx].prior._defaults['pmax']
if pta.params[param_idx].type == 'uniform':
prior_min = pta.params[param_idx].prior._defaults['pmin']
prior_max = pta.params[param_idx].prior._defaults['pmax']
elif pta.params[param_idx].type == 'uniform':
prior_min = pta.params[param_idx].prior._defaults['mu'] - 10 * pta.params[param_idx].prior._defaults['sigma']
prior_max = pta.params[param_idx].prior._defaults['mu'] + 10 * pta.params[param_idx].prior._defaults['sigma']
# check 2 conditions on param to make sure that it covers the prior
# skip if emp dist already covers the prior
if isinstance(emp_dist, EmpiricalDistribution1D):
Expand All @@ -96,7 +130,6 @@ def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outd
new_bins = []
idxs_to_remove = []
# drop samples that are outside the prior range (in case prior is smaller than samples)

if isinstance(emp_dist, EmpiricalDistribution1D):
samples[(samples < prior_min) | (samples > prior_max)] = -np.inf
elif isinstance(emp_dist, EmpiricalDistribution1DKDE):
Expand All @@ -111,20 +144,20 @@ def extend_emp_dists(pta, emp_dists, npoints=100_000, save_ext_dists=False, outd
minval=prior_min, maxval=prior_max,
bandwidth=emp_dist.bandwidth)
new_emp_dists.append(new_emp)

else:
print('Unable to extend class of unknown type to the edges of the priors.')
new_emp_dists.append(emp_dist)
continue

if save_ext_dists and modified: # if user wants to save them, and they have been modified...
pickle.dump(new_emp_dists, outdir + 'new_emp_dists.pkl')
if save_ext_dists and modified: # if user wants to save them, and they have been modified...
with open(outdir + '/new_emp_dists.pkl', 'wb') as f:
pickle.dump(new_emp_dists, f)
return new_emp_dists


class JumpProposal(object):

def __init__(self, pta, snames=None, empirical_distr=None, f_stat_file=None, save_ext_dists=False, outdir='chains'):
def __init__(self, pta, snames=None, empirical_distr=None, f_stat_file=None, save_ext_dists=False, outdir='./chains'):
"""Set up some custom jump proposals"""
self.params = pta.params
self.pnames = pta.param_names
Expand Down