diff --git a/sacc/tracers.py b/sacc/tracers.py index 161c290..5f501a8 100644 --- a/sacc/tracers.py +++ b/sacc/tracers.py @@ -652,7 +652,7 @@ class QPNZTracer(BaseTracer, tracer_type='QPNZ'): """ A Tracer type for tomographic n(z) data represented as a `qp.Ensemble` - Takes a `qp.Ensemble` + Takes a `qp.Ensemble` and optionally a redshift array. Requires the `qp` and `tables_io` packages to be installed. @@ -666,7 +666,7 @@ class QPNZTracer(BaseTracer, tracer_type='QPNZ'): The qp.ensemble in questions """ - def __init__(self, name, ens, **kwargs): + def __init__(self, name, ens, z=None, **kwargs): """ Create a tracer corresponding to a distribution in redshift n(z), for example of galaxies. @@ -680,6 +680,11 @@ def __init__(self, name, ens, **kwargs): ensemble: qp.Ensemble The qp.ensemble in questions + z: array + Optional grid of redshift values at which to evaluate the ensemble. + If left as None then the ensemble metadata is checked for a grid. + If that is not present then no redshift grid is saved. + Returns ------- instance: NZTracer object @@ -687,7 +692,16 @@ def __init__(self, name, ens, **kwargs): """ super().__init__(name, **kwargs) self.ensemble = ens - + if z is None: + ens_meta = ens.metadata() + if 'bins' in list(ens_meta.keys()): + z = ens_meta['bins'][0] + self.z = z + if z is None: + self.nz = None + else: + self.nz = np.mean(ens.pdf(self.z),axis=0) + @classmethod def to_tables(cls, instance_list): """Convert a list of NZTracers to a list of astropy tables @@ -710,6 +724,16 @@ def to_tables(cls, instance_list): tables = [] for tracer in instance_list: + if tracer.z is not None: + names = ['z', 'nz'] + cols = [tracer.z, tracer.nz] + fid_table = Table(data=cols, names=names) + fid_table.meta['SACCTYPE'] = 'tracer' + fid_table.meta['SACCCLSS'] = cls.tracer_type + fid_table.meta['SACCNAME'] = tracer.name + fid_table.meta['SACCQTTY'] = tracer.quantity + fid_table.meta['EXTNAME'] = f'tracer:{cls.tracer_type}:{tracer.name}:fid' + table_dict = tracer.ensemble.build_tables() ap_tables = convertToApTables(table_dict) data_table = ap_tables['data'] @@ -731,6 +755,8 @@ def to_tables(cls, instance_list): meta_table.meta['META_'+kk] = vv tables.append(data_table) tables.append(meta_table) + if tracer.z is not None: + tables.append(fid_table) if ancil_table: ancil_table.meta['SACCTYPE'] = 'tracer' ancil_table.meta['SACCCLSS'] = cls.tracer_type @@ -775,6 +801,10 @@ def from_tables(cls, table_list): for val in sorted_dict.values(): meta_table = val['meta'] + if 'fid' in val: + z = val['fid']['z'] + else: + z = None ensemble = qp.from_tables(val) name = meta_table.meta['SACCNAME'] quantity = meta_table.meta.get('SACCQTTY', 'generic') @@ -783,7 +813,7 @@ def from_tables(cls, table_list): for key, value in meta_table.meta.items(): if key.startswith("META_"): metadata[key[5:]] = value - tracers[name] = cls(name, ensemble, + tracers[name] = cls(name, ensemble, z=z, quantity=quantity, metadata=metadata) return tracers diff --git a/test/test_sacc2.py b/test/test_sacc2.py index 54fb09f..a50379a 100644 --- a/test/test_sacc2.py +++ b/test/test_sacc2.py @@ -833,10 +833,10 @@ def test_qpnz_tracer(): nz_qp_interp = qp.Ensemble(qp.interp, data=dict(xvals=z, yvals=np.ones(shape=(1, 101)))) nz_qp_hist = qp.Ensemble(qp.hist, data=dict(bins=z, pdfs=np.ones(shape=(1, 100)))) - T1 = sacc.BaseTracer.make('QPNZ', 'tracer1', nz_qp_interp, + T1 = sacc.BaseTracer.make('QPNZ', 'tracer1', nz_qp_interp, z, quantity='galaxy_density', metadata=md1) - T2 = sacc.BaseTracer.make('QPNZ', 'tracer2', nz_qp_hist, + T2 = sacc.BaseTracer.make('QPNZ', 'tracer2', nz_qp_hist, z, quantity='galaxy_shear', metadata=md2) assert T1.metadata == md1 @@ -850,6 +850,15 @@ def test_qpnz_tracer(): assert T1a.metadata == md1 assert T2a.metadata == md2 + # test version without saved z + T3 = sacc.BaseTracer.make('QPNZ', 'tracer3', nz_qp_interp, + quantity='galaxy_density', + metadata=md1) + tables = sacc.BaseTracer.to_tables([T3]) + D = sacc.BaseTracer.from_tables(tables) + assert D['tracer3'].z is None + + def test_io_qp(): s = sacc.Sacc() @@ -859,7 +868,7 @@ def test_io_qp(): nz = np.expand_dims((z-0.5)**2/0.1**2, 0) ens = qp.Ensemble(qp.interp, data=dict(xvals=z, yvals=nz)) ens.set_ancil(dict(modes = ens.mode(z))) - s.add_tracer('QpnZ', 'source_0', ens) + s.add_tracer('QpnZ', 'source_0', ens, z) for i in range(20): ee = 0.1 * i