diff --git a/la_forge/core.py b/la_forge/core.py index d6bef5a..c6981e8 100644 --- a/la_forge/core.py +++ b/la_forge/core.py @@ -5,8 +5,14 @@ import logging from typing import Type +import datetime +from functools import cached_property + +import arviz as az import h5py import numpy as np +import pandas as pd +import xarray as xr from astropy.io import fits from astropy.table import Table @@ -109,6 +115,18 @@ def __init__(self, chaindir=None, corepath=None, burn=0.25, label=None, self.params = table.colnames self.chain = np.array([table[p] for p in self.params]).T self.chainpath = chaindir + '/chain.fits' + # Check if it's in a common arviz format + elif os.path.isfile(chaindir + '/chain.nc') or os.path.isfile(chaindir + '/chain.zarr'): + self.chainpath = chaindir + '/chain.nc' if os.path.isfile(chaindir + '/chain.nc') else chaindir + '/chain.zarr' + extension = self.chainpath.split(".")[-1] + try: + inf_data = az.from_netcdf(self.chainpath) if extension=="nc" else az.from_zarr(self.chainpath) + except: + msg = f"{self.chainpath} is not a valid ArviZ InferenceData object." + raise ValueError(msg) + stacked = az.extract(inf_data) # combines chains + self.chain = stacked.to_array().to_numpy().T # ArviZ uses dimension 1 for samples, we want it to be 0 + self.params = [param for param in stacked.variables if param not in ['sample', 'chain', 'draw']] else: # Load chain if os.path.isfile(chaindir + '/chain_1.txt'): @@ -666,6 +684,35 @@ def map_params(self): """Return all Maximum a posteri parameters.""" return self.chain[self.burn + self.map_idx, :] + @cached_property + def arviz(self) -> az.InferenceData: + """Create an arviz.InferenceData object from a Core.""" + + # Easiest to make a dataframe first + df = pd.DataFrame(data=self.chain, columns=self.params) + + # ArviZ wants to see `chain` and `draw` dimensions + df["chain"] = 0 + df["draw"] = np.arange(len(df), dtype=int) + df = df.set_index(["chain", "draw"]) + + # Make an xarray `Dataset` to give ArviZ + xdata = xr.Dataset.from_dataframe(df) + + # Store some metadata + xdata.attrs.update( + source="la_forge_core", + created_at=datetime.datetime.now(datetime.timezone.utc) + .replace(microsecond=0) + .isoformat(), + ) + + # Make the ArviZ object + dataset = az.InferenceData(posterior=xdata) + + return dataset + + # --------------------------------------------# # ---------------HyperModel Core--------------# # --------------------------------------------# diff --git a/requirements.txt b/requirements.txt index ea5a734..82731b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ astropy>=3.0 corner six h5py>=3.4.0 +arviz>=0.19.0 diff --git a/setup.py b/setup.py index 0809594..dc6261e 100644 --- a/setup.py +++ b/setup.py @@ -11,15 +11,23 @@ with open('HISTORY.rst') as history_file: history = history_file.read() -requirements = ['numpy>=1.16', - 'scipy>=1.0.0', - 'matplotlib>=2.0.0', - 'corner', - 'h5py>=3.4.0', - 'astropy>=3.0', - 'six', - - ] +requirements = [ + "numpy>=1.16", + "scipy>=1.0.0", + "matplotlib>=2.0.0", + "corner", + "h5py>=3.4.0", + "astropy>=3.0", + "six", + "arviz>=0.19.0", + "zarr>=2.5.0,<3", + "netcdf4", + "xarray-datatree", + "dm-tree", + "contourpy", + "bokeh>=3", + +] test_requirements = ['pytest', ]