Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(arviz): add arviz import and export #32

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

davecwright3
Copy link

@davecwright3 davecwright3 commented Oct 3, 2024

This PR implements a very basic import and export scheme with ArviZ.

Checks for new file types

If it finds netcdf ".nc" or zarr ".zarr" files, it tries to read them with ArviZ.

la_forge/la_forge/core.py

Lines 118 to 126 in 6d23069

# Check if it's in a common arviz format
elif os.path.isfile(chaindir + '/chain.nc') or os.path.isfile(chaindir + '/chain.zarr'):
self.chainpath = chaindir + '/chain.nc' if os.path.isfile(chaindir + '/chain.nc') else chaindir + '/chain.zarr'
extension = self.chainpath.split(".")[-1]
try:
inf_data = az.from_netcdf(self.chainpath) if extension=="nc" else az.from_zarr(self.chainpath)
except:
msg = f"{self.chainpath} is not a valid ArviZ InferenceData object."
raise ValueError(msg)

Converts ArviZ data and metadata into a single "chain" for La Forge to consume

It filters out any parameters that do not correspond to samples in the chain. This means that the number of parameters always equals the number of columns in the chain. It therefore skips the check later on that would add extra PTMCMC parameters, such as lnpost.

In this implementation, you need to have already named your variables in the ArviZ InferenceData object to their desired final names. I would recommend naming them to their usual PTMCMC values for backwards compatibility.

la_forge/la_forge/core.py

Lines 127 to 129 in 6d23069

stacked = az.extract(inf_data) # combines chains
self.chain = stacked.to_array().to_numpy().T # ArviZ uses dimension 1 for samples, we want it to be 0
self.params = [param for param in stacked.variables if param not in ['sample', 'chain', 'draw']]

Adds an arviz cached_property to the Core class

Assuming you have a Core named my_core, calling my_core.arviz will return an ArviZ InferenceData object populated with the Core's data.

la_forge/la_forge/core.py

Lines 687 to 713 in 6d23069

@cached_property
def arviz(self) -> az.InferenceData:
"""Create an arviz.InferenceData object from a Core."""
# Easiest to make a dataframe first
df = pd.DataFrame(data=self.chain, columns=self.params)
# ArviZ wants to see `chain` and `draw` dimensions
df["chain"] = 0
df["draw"] = np.arange(len(df), dtype=int)
df = df.set_index(["chain", "draw"])
# Make an xarray `Dataset` to give ArviZ
xdata = xr.Dataset.from_dataframe(df)
# Store some metadata
xdata.attrs.update(
source="la_forge_core",
created_at=datetime.datetime.now(datetime.timezone.utc)
.replace(microsecond=0)
.isoformat(),
)
# Make the ArviZ object
dataset = az.InferenceData(posterior=xdata)
return dataset

Self contained example

import arviz as az
from la_forge.core import Core
from pathlib import Path

chain_dir = Path("test_chain/")
chain_dir.mkdir(parents=True, exist_ok=True)

inf_data = az.load_arviz_data("regression1d")
inf_data.to_netcdf(chain_dir/"chain.nc")

az_core = Core(chain_dir.as_posix())
>>> print(az_core.chain)
[[ 1.5665029  -1.33202579  1.8831507  -1.34775906  1.21947951]
 [ 1.80178445 -1.18480163  1.8831507  -1.34775906  1.10626457]
 [ 1.84332941 -1.22758633  1.8831507  -1.34775906  1.0078494 ]
 ...
 [ 1.68224315 -1.18911818  1.8831507  -1.34775906  0.91507343]
 [ 2.01824417 -1.34356813  1.8831507  -1.34775906  1.14519821]
 [ 1.94768056 -1.37682251  1.8831507  -1.34775906  0.99698405]]

>>> print(az_core.params)
['slope', 'intercept', 'true_slope', 'true_intercept', 'eps']
>>> print(az_core.arviz)
Inference data with groups:
	> posterior

>>> print(az_core.arviz.posterior)
<xarray.Dataset> Size: 96kB
Dimensions:         (chain: 1, draw: 2000)
Coordinates:
  * chain           (chain) int64 8B 0
  * draw            (draw) int64 16kB 0 1 2 3 4 5 ... 1995 1996 1997 1998 1999
Data variables:
    slope           (chain, draw) float64 16kB 1.567 1.802 1.843 ... 2.018 1.948
    intercept       (chain, draw) float64 16kB -1.332 -1.185 ... -1.344 -1.377
    true_slope      (chain, draw) float64 16kB 1.883 1.883 1.883 ... 1.883 1.883
    true_intercept  (chain, draw) float64 16kB -1.348 -1.348 ... -1.348 -1.348
    eps             (chain, draw) float64 16kB 1.219 1.106 1.008 ... 1.145 0.997
Attributes:
    source:      la_forge_core
    created_at:  2024-10-04T00:00:57+00:00

@davecwright3
Copy link
Author

Tests are failing because the minimum arviz version I specified is too high for these python versions. I'll lower the arviz version bound.

@davecwright3
Copy link
Author

Also I should raise a warning if an inference data object is imported that doesn't have all of the usual PTMCMC fields defined. This is just so users are aware some methods may fail that depend on those fields existing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant