Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add acceleration option to JointPrimaryMarginalizedModel likelihood #4688

Merged
merged 56 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
55e5541
Update hierarchical.py
WuShichao Apr 7, 2024
a6ac76d
Update hierarchical.py
WuShichao Apr 8, 2024
f4ed98d
Update hierarchical.py
WuShichao Apr 8, 2024
41224cc
Update hierarchical.py
WuShichao Apr 10, 2024
7109db9
Update hierarchical.py
WuShichao Apr 10, 2024
0bbe7a4
fix cc issues
WuShichao Apr 10, 2024
6920a8c
Update hierarchical.py
WuShichao Apr 15, 2024
55fba36
Update relbin.py
WuShichao Apr 15, 2024
b993833
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao May 16, 2024
7099515
add complex phase correction for sh_others
WuShichao May 16, 2024
94ba798
Update hierarchical.py
WuShichao May 16, 2024
6afbc4e
Update relbin.py
WuShichao May 16, 2024
e20b6f5
fix cc issues
WuShichao May 16, 2024
25a4562
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Jun 4, 2024
8fb82a1
make code more general
WuShichao Jun 7, 2024
757a30b
update
WuShichao Jun 13, 2024
a2d64d1
fix
WuShichao Jun 13, 2024
8a9287e
rename
WuShichao Jun 13, 2024
5423d8c
update
WuShichao Jun 15, 2024
0b9d44e
WIP
WuShichao Jun 17, 2024
eeb8890
fix a bug in frame transform
WuShichao Jun 18, 2024
50e3599
fix overwritten issues
WuShichao Jun 19, 2024
dba5292
update
WuShichao Jun 19, 2024
537256e
update
WuShichao Jun 19, 2024
0af3fed
fix reconstruct
WuShichao Jun 19, 2024
0a09b12
make this PR general
WuShichao Jun 28, 2024
eb57268
update
WuShichao Jun 28, 2024
6d856b3
update
WuShichao Jun 28, 2024
075c39a
fix cc issues
WuShichao Jun 28, 2024
9ffbb70
rename
WuShichao Jun 28, 2024
e0f1ec4
rename
WuShichao Jun 28, 2024
6ce67c3
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Jul 1, 2024
bf105a4
add multiband description
WuShichao Jul 1, 2024
273264f
fix
WuShichao Jul 4, 2024
5226b7d
add comments
WuShichao Jul 4, 2024
21fd035
fix hdf's config
WuShichao Jul 5, 2024
ba3816d
fix
WuShichao Jul 5, 2024
28fc1b2
fix
WuShichao Jul 6, 2024
b06d32e
fix
WuShichao Jul 28, 2024
b4a47af
fix
WuShichao Jul 29, 2024
a5b6d8c
remove print
WuShichao Jul 29, 2024
ca096ec
update for Alex's comments
WuShichao Jul 29, 2024
e8825be
wip
WuShichao Jul 30, 2024
be2b066
update
WuShichao Jul 30, 2024
c03652e
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Jul 31, 2024
02b6937
fix
WuShichao Jul 31, 2024
87f10cb
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Jul 31, 2024
cbcd5a2
update
WuShichao Aug 1, 2024
36af111
seems work
WuShichao Aug 12, 2024
c084f4f
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Aug 12, 2024
709b524
fix CC issue
WuShichao Aug 12, 2024
3a17bf5
fix
WuShichao Aug 12, 2024
3865d5e
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Aug 12, 2024
0df23f2
fix demargin
WuShichao Aug 25, 2024
da34461
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Aug 25, 2024
f2b0798
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion bin/inference/pycbc_inference
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,12 @@ with ctx:
if pool.is_main_process():
for fn in [sampler.checkpoint_file, sampler.backup_file]:
with loadfile(fn, 'a') as fp:
fp.write_config_file(cp)
# some models will interally modify original cp for sampling,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be an option as one should always save the original configuration. Why isn't your internal sampler modifying a copy? The version saved here then doesn't have to know (and really shouldn't) that you may have modified internally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahnitz My understanding is cp is saved when it runs logging.info("Loading joint_primary_marginalized model") return super(HierarchicalModel, cls).from_config( cp, submodels=submodels, **kwargs) https://github.com/WuShichao/pycbc/blob/accelerate_multiband/pycbc/inference/models/hierarchical.py#L1002 the cp here is the modified config for sampling, can't be used as the initial config. So how to let PyCBC save the original one here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WuShichao I don't see anything related to saving the configfile where you've pointed. The place is in pycbc_inference where I'm comment on. Take a look at my first comment, I say how to do it. Don't modify the configparser you are passed in-place. Make a copy so you aren't editting the original. Then you don't need to save a separate copy and this particular line will just work to begin with.

# such as joint_primary_marginalized, we need to save original
if hasattr(model, 'original_config'):
fp.write_config_file(model.original_config)
else:
fp.write_config_file(cp)

# Run the sampler
sampler.run()
Expand Down
5 changes: 4 additions & 1 deletion bin/inference/pycbc_inference_model_stats
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ model.sampling_transforms = None
def callmodel(arg):
iteration, paramvals = arg
# calculate the logposterior to get all stats populated
model.update(**{p: paramvals[p] for p in model.variable_params})
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this is not correct. Your top level model should just have an update method. That method should do whatever is needed to prepare for the log* methods to actually work. You shouldn't be requiring anyone know about a new method 'update_all_models'. It's fine if you want to have an new method for internal use, but not for the top-level api.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So do I need to rename my update_all_models to be update (so overwrite the base one), or just use the original update in callmodel?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WuShichao A sampler is just using 'update' so why are you doing so differently here? Think through how to do this best, but it should be clear that making this change in this program seems very inconsistent. If it is required, that indicates something is wrong with your model, and if so fix that. Otherwise maybe you made this change in error.

model.update_all_models(**{p: paramvals[p] for p in model.variable_params})
except:
model.update(**{p: paramvals[p] for p in model.variable_params})
_ = model.logposterior
stats = model.get_current_stats()

Expand Down
158 changes: 107 additions & 51 deletions pycbc/inference/models/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,39 +607,33 @@ def _loglikelihood(self):


class JointPrimaryMarginalizedModel(HierarchicalModel):
""" Hierarchical heterodyne likelihood for coherent multiband
parameter estimation which combines data from space-borne and
ground-based GW detectors coherently. Currently, this only
supports LISA as the space-borne GW detector.

Sub models are treated as if the same GW source (such as a GW
from stellar-mass BBH) is observed in different frequency bands by
space-borne and ground-based GW detectors, then transform all
the parameters into the same frame in the sub model level, use
`HierarchicalModel` to get the joint likelihood, and marginalize
over all the extrinsic parameters supported by `RelativeTimeDom`
or its variants. Note that LISA submodel only supports the `Relative`
for now, for ground-based detectors, please use `RelativeTimeDom`
or its variants.

Although this likelihood model is used for multiband parameter
estimation, users can still use it for other purposes, such as
GW + EM parameter estimation, in this case, please use `RelativeTimeDom`
or its variants for the GW data, for the likelihood of EM data,
there is no restrictions.
"""This likelihood model can be used for cases when one of the submodels
can be marginalized to accelerate the total likelihood. This likelihood
model also allows for further acceleration of other models during
marginalization, if some extrinsic parameters can be tightly constrained
by the primary model. More specifically, such as the EM + GW parameter
estimation, the sky localization can be well measured. For LISA + 3G
multiband observation, SOBHB signals' (tc, ra, dec) can be tightly
constrained by 3G network, so this model is also useful for this case.
"""
name = 'joint_primary_marginalized'

def __init__(self, variable_params, submodels, **kwargs):
super().__init__(variable_params, submodels, **kwargs)

# store the original config to self
self.original_config = kwargs['original_config']
# assume the ground-based submodel as the primary model
self.primary_model = self.submodels[kwargs['primary_lbl'][0]]
self.primary_lbl = kwargs['primary_lbl'][0]
self.other_models = self.submodels.copy()
self.other_models.pop(kwargs['primary_lbl'][0])
self.other_models = list(self.other_models.values())

# determine whether to accelerate total_loglr
self.static_margin_params_in_other_models = \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still has the old name in the config file. Also, why not just do

self.static_margin_params = 'static_margin_params' in kwargs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahnitz OK, I have updated.

'static_margin_params_in_other_models' in kwargs

def write_metadata(self, fp, group=None):
"""Adds metadata to the output files

Expand All @@ -652,6 +646,9 @@ def write_metadata(self, fp, group=None):
by group, i.e., to ``fp[group].attrs``. Otherwise, metadata is
written to the top-level attrs (``fp.attrs``).
"""
# replace the internal config for top-level model with
# the original config
fp.write_config_file(self.original_config)
super().write_metadata(fp, group=group)
sampattrs = fp.getattrs(group=fp.samples_group)
# if a group is specified, prepend the lognl names with it
Expand Down Expand Up @@ -686,25 +683,48 @@ def total_loglr(self):
"""
# calculate <d-h|d-h> = <h|h> - 2<h|d> + <d|d> up to a constant

# note that for SOBHB signals, ground-based detectors dominant SNR
# and accuracy of (tc, ra, dec)
self.primary_model.return_sh_hh = True
sh_primary, hh_primary = self.primary_model.loglr
self.primary_model.return_sh_hh = False
# set logr, otherwise it will store (sh, hh)
setattr(self.primary_model._current_stats, 'loglr',
self.primary_model.marginalize_loglr(sh_primary, hh_primary))

margin_names_vector = list(
self.primary_model.marginalize_vector_params.keys())
if 'logw_partial' in margin_names_vector:
margin_names_vector.remove('logw_partial')

margin_params = {}
nums = 1
for key, value in self.primary_model.current_params.items():
# add marginalize_vector_params
if key in margin_names_vector:
margin_params[key] = value
if isinstance(value, numpy.ndarray):
nums = len(value)

if self.static_margin_params_in_other_models:
# Due to the high precision of extrinsic parameters constrined
# by the primary model, the mismatch of wavefroms in others by
# varing those parameters is pretty small, so we can keep them
# static to accelerate total_loglr. Here, we use matched-filering
# SNR instead of lilkelihood, because luminosity distance and
# inclination has a very strong degeneracy, change of inclination
# will change best match distance, so change the amplitude of
# waveform. Using SNR will cancel out the effect of amplitude.err
i_max_extrinsic = numpy.argmax(
numpy.abs(sh_primary) / hh_primary**0.5)
for p in margin_names_vector:
if isinstance(self.primary_model.current_params[p],
numpy.ndarray):
margin_params[p] = \
self.primary_model.current_params[p][i_max_extrinsic]
nums = len(self.primary_model.current_params[p])
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WuShichao This logic should already take care of the distance case, except that later on you assume that if any parameter is a scalar they all are. That's the part you should stop assuming. Don't assume they are any particular mix of scalar or vector.

margin_params[p] = self.primary_model.current_params[p]
nums = 1
else:
for key, value in self.primary_model.current_params.items():
# add marginalize_vector_params
if key in margin_names_vector:
margin_params[key] = value
if isinstance(value, numpy.ndarray):
nums = len(value)
else:
nums = 1
# add distance if it has been marginalized,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, avoid needed to know explicitly about distance here. Think instead about the format that you require. If the format differs how to generically convert.

# use numpy array for it is just let it has the same
# shape as marginalize_vector_params, here we assume
Expand All @@ -713,7 +733,7 @@ def total_loglr(self):
margin_params['distance'] = numpy.full(
nums, self.primary_model.current_params['distance'])

# add likelihood contribution from space-borne detectors, we
# add likelihood contribution from other_models, we
# calculate sh/hh for each marginalized parameter point
sh_others = numpy.full(nums, 0 + 0.0j)
hh_others = numpy.zeros(nums)
Expand All @@ -723,24 +743,47 @@ def total_loglr(self):
# not using self.primary_model.current_params, because others_model
# may have its own static parameters
current_params_other = other_model.current_params.copy()
for i in range(nums):
if not self.static_margin_params_in_other_models:
for i in range(nums):
current_params_other.update(
{key: value[i] if isinstance(value, numpy.ndarray)
else value for key, value in margin_params.items()})
other_model.update(**current_params_other)
other_model.return_sh_hh = True
sh_other, hh_other = other_model.loglr
sh_others[i] += sh_other
hh_others[i] += hh_other
other_model.return_sh_hh = False
# set logr, otherwise it will store (sh, hh)
setattr(other_model._current_stats, 'loglr',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to store this? It's not necessarily a problem, but it will slow down the code slightly (maybe not important at the moment). Why not think about why it was being set to a vector (and from where), do you even want this stored in the case of a submodel? Maybe the solution was simply not to store this when it's not actually a marginalized loglr anyway, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that when pycbc_inference_model_stats check for pi, p in enumerate(model.default_stats): it will try to access submodel's loglr, no? If so, I need to store it. When I not rest it, I found other_model._current_stats also contain (sh, hh) for each point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WuShichao OK, good. So now the question is what is the right thing to do in this case?

other_model.marginalize_loglr(sh_other, hh_other))
else:
# use one margin point set to approximate all the others
current_params_other.update(
{key: value[i] if isinstance(value, numpy.ndarray) else
value for key, value in margin_params.items()})
{key: value[0] if isinstance(value, numpy.ndarray)
else value for key, value in margin_params.items()})
other_model.update(**current_params_other)
other_model.return_sh_hh = True
sh_others[i], hh_others[i] = other_model.loglr
sh_other, hh_other = other_model.loglr
other_model.return_sh_hh = False
# set logr, otherwise it will store (sh, hh)
setattr(other_model._current_stats, 'loglr',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. It's not necessarily a problem as it might be useful to have the separate loglrs, but it's not clear that it will always make sense.

other_model.marginalize_loglr(sh_other, hh_other))
sh_others += sh_other
hh_others += hh_other

if nums == 1:
# the type of the original sh/hh_others are numpy.array,
# might not the same as sh/hh_primary during reconstruct,
# during reconstruct of distance, sh/hh_others need to be scalar
sh_others = sh_others[0]
hh_others = hh_others[0]
sh_total = sh_primary + sh_others
hh_total = hh_primary + hh_others

# calculate marginalize_vector_weights
self.primary_model.marginalize_vector_weights = \
- numpy.log(self.primary_model.vsamples)
loglr = self.primary_model.marginalize_loglr(sh_total, hh_total)
setattr(self._current_stats, 'total_loglr', loglr)

return loglr

def others_lognl(self):
Expand Down Expand Up @@ -795,6 +838,9 @@ def from_config(cls, cp, **kwargs):
# we need the read from config function from the init; to prevent
# circular imports, we import it here
from pycbc.inference.models import read_from_config
# store the original config file, here use deeocopy to avoid later
# changes of cp affect it
kwargs['original_config'] = cp.__deepcopy__(cp)
# get the submodels
kwargs['primary_lbl'] = shlex.split(cp.get('model', 'primary_model'))
kwargs['others_lbls'] = shlex.split(cp.get('model', 'other_models'))
Expand All @@ -805,6 +851,10 @@ def from_config(cls, cp, **kwargs):
sparam_map = map_params(hpiter(cp.options('static_params'),
submodel_lbls))

# get the acceleration label
kwargs['static_margin_params_in_other_models'] = shlex.split(
cp.get('model', 'static_margin_params_in_other_models'))

# we'll need any waveform transforms for the initializing sub-models,
# as the underlying models will receive the output of those transforms

Expand Down Expand Up @@ -856,18 +906,21 @@ def from_config(cls, cp, **kwargs):
cp.get('static_params', param.fullname))

# set the variable params: different from the standard
# hierarchical model, in this multiband model, all sub-models
# has the same variable parameters, so we don't need to worry
# about the unique variable issue. Besides, the primary model
# needs to do marginalization, so we must set variable_params
# and prior section before initializing it.
# hierarchical model, in this JointPrimaryMarginalizedModel model,
# all sub-models has the same variable parameters, so we don't
# need to worry about the unique variable issue. Besides,
# the primary model needs to do marginalization, so we must set
# variable_params and prior section before initializing it.

subcp.add_section('variable_params')
for param in vparam_map[lbl]:
if lbl in kwargs['primary_lbl']:
# set variable_params for the primary model
subcp.set('variable_params', param.subname,
cp.get('variable_params', param.fullname))
else:
# all variable_params in other models will come
# from the primary model during sampling
subcp.set('static_params', param.subname, 'REPLACE')

for section in cp.sections():
Expand Down Expand Up @@ -919,14 +972,13 @@ def from_config(cls, cp, **kwargs):
# it will not be listed in `variable_params` and `prior` sections
primary_model = submodels[kwargs['primary_lbl'][0]]
marginalized_params = primary_model.marginalize_vector_params.copy()
if 'logw_partial' in marginalized_params:
marginalized_params.pop('logw_partial')
marginalized_params = list(marginalized_params.keys())
else:
marginalized_params = []
# this may also include 'f_ref', 'f_lower', 'approximant',
# but doesn't matter
marginalized_params += list(primary_model.static_params.keys())
marginalized_params = list(marginalized_params.keys())
# add distance or phase if they are marginalized
if primary_model.distance_marginalization:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to hardcode this? This will be brittle and break easily to changes to marginalization for example, which we don't want. If you need a list of marginalized parameters, why not add this to the class in inference/tools so that it is kept up to date?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will save it in the tools module and just use it here.

marginalized_params.append('distance')
if primary_model.marginalize_phase:
marginalized_params.append('coa_phase')

for p in primary_model.static_params.keys():
p_full = '%s__%s' % (kwargs['primary_lbl'][0], p)
if p_full not in cp['static_params']:
Expand All @@ -940,6 +992,10 @@ def from_config(cls, cp, **kwargs):
cp['variable_params'].pop(p)
cp.pop(section)

# save the vitual config file to disk for later check
with open('internal_top.ini', 'w', encoding='utf-8') as file:
cp.write(file)

# now load the model
logging.info("Loading joint_primary_marginalized model")
return super(HierarchicalModel, cls).from_config(
Expand Down
3 changes: 2 additions & 1 deletion pycbc/inference/models/relbin.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,11 @@ def _loglr(self):
filt += filter_i
norm += norm_i

loglr = self.marginalize_loglr(filt, norm)
if self.return_sh_hh:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this if statement? Shouldn't the existing flag already used for demarginalization take care of this? E.g. why not use the reconstruct_phase flag?

https://github.com/gwastro/pycbc/blob/master/pycbc/inference/models/tools.py#L241

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

results = (filt, norm)
else:
results = self.marginalize_loglr(filt, norm)
results = loglr
return results

def write_metadata(self, fp, group=None):
Expand Down
Loading