Skip to content

Commit

Permalink
New exp plots 1 (#58)
Browse files Browse the repository at this point in the history
* n_outliers

* more comparisons

* add more plots to exp1

* new seed

* new plots on exp2

* forgot to add prior to notebook
  • Loading branch information
ismael-mendoza authored Dec 2, 2024
1 parent 45abd58 commit 2ada016
Show file tree
Hide file tree
Showing 17 changed files with 96 additions and 37 deletions.
Binary file modified experiments/exp1/figs/calibration.pdf
Binary file not shown.
Binary file modified experiments/exp1/figs/contours.pdf
Binary file not shown.
Binary file added experiments/exp1/figs/mean_std_hist.pdf
Binary file not shown.
Binary file modified experiments/exp1/figs/multiplicative_bias_hist.pdf
Binary file not shown.
Binary file modified experiments/exp1/figs/traces.pdf
Binary file not shown.
2 changes: 1 addition & 1 deletion experiments/exp1/get_figures.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
#!/bin/bash
./make_figures.py
./make_figures.py 43
2 changes: 1 addition & 1 deletion experiments/exp1/get_posteriors.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
#!/bin/bash
../../scripts/slurm_toy_shear_vectorized.py 42 toy_shear_42
../../scripts/slurm_toy_shear_vectorized.py 43 toy_shear_43
39 changes: 34 additions & 5 deletions experiments/exp1/make_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["JAX_ENABLE_X64"] = "True"

from math import sqrt

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import typer
from jax import Array
from matplotlib.backends.backend_pdf import PdfPages
from tqdm import tqdm
Expand All @@ -17,7 +20,7 @@
from bpd.diagnostics import get_contour_plot, get_gauss_pc_fig, get_pc_fig


def make_trace_plots(g_samples: Array, n_examples: int = 10) -> None:
def make_trace_plots(g_samples: Array, n_examples: int = 25) -> None:
"""Make example figure showing example trace plots of shear posteriors."""
# by default, we choose 10 random traces to plot in 1 PDF file.
fname = "figs/traces.pdf"
Expand All @@ -39,7 +42,7 @@ def make_trace_plots(g_samples: Array, n_examples: int = 10) -> None:
plt.close(fig)


def make_contour_plots(g_samples: Array, n_examples=10) -> None:
def make_contour_plots(g_samples: Array, n_examples: int = 25) -> None:
"""Make example figure showing example contour plots of shear posterios"""
# by default, we choose 10 random contours to plot in 1 PDF file.
fname = "figs/contours.pdf"
Expand Down Expand Up @@ -80,16 +83,41 @@ def make_histogram_mbias(g_samples: Array) -> None:
fname = "figs/multiplicative_bias_hist.pdf"
with PdfPages(fname) as pdf:
g1 = g_samples[:, :, 0]

mbias = (g1.mean(axis=1) - 0.02) / 0.02
fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.hist(mbias, bins=31, histtype="step")
pdf.savefig(fig)
plt.close(fig)


def make_histogram_means_and_stds(g_samples: Array) -> None:
fname = "figs/mean_std_hist.pdf"
with PdfPages(fname) as pdf:
g1 = g_samples[:, :, 0]

means = g1.mean(axis=1)
stds = g1.std(axis=1)

fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.hist(means, bins=31, histtype="step")
ax.set_title(f"Std: {means.std():.5g}")
ax.axvline(means.mean(), linestyle="--", color="k", label="mean")
ax.legend()
pdf.savefig(fig)
plt.close(fig)

fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.hist(stds, bins=31, histtype="step")
ax.axvline(1e-3 / sqrt(1000), linestyle="--", color="k")
ax.set_title(f"Std_correct: {1e-3 / sqrt(1000) / sqrt(2):.5g}")
pdf.savefig(fig)
plt.close(fig)


def main():
pdir = DATA_DIR / "cache_chains" / "toy_shear_42"
def main(seed: int = 43):
np.random.seed(seed)
pdir = DATA_DIR / "cache_chains" / f"toy_shear_{seed}"
assert pdir.exists()
all_g_samples = []
for fpath in pdir.iterdir():
Expand All @@ -104,7 +132,8 @@ def main():
make_contour_plots(g_samples)
make_posterior_calibration(g_samples)
make_histogram_mbias(g_samples)
make_histogram_means_and_stds(g_samples)


if __name__ == "__main__":
main()
typer.run(main)
Binary file modified experiments/exp2/figs/contours.pdf
Binary file not shown.
Binary file modified experiments/exp2/figs/convergence_hist.pdf
Binary file not shown.
6 changes: 6 additions & 0 deletions experiments/exp2/figs/outliers.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Number of R-hat outliers for e1: 1
Number of R-hat outliers for e2: 2
Number of R-hat outliers for hlr: 0
Number of R-hat outliers for lf: 0
Number of R-hat outliers for x: 2
Number of R-hat outliers for y: 3
Binary file modified experiments/exp2/figs/timing.pdf
Binary file not shown.
Binary file modified experiments/exp2/figs/traces.pdf
Binary file not shown.
Binary file added experiments/exp2/figs/traces_adapt.pdf
Binary file not shown.
31 changes: 23 additions & 8 deletions experiments/exp2/make_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["JAX_ENABLE_X64"] = "True"

from pathlib import Path

import cycler
import jax.numpy as jnp
import matplotlib.pyplot as plt
Expand All @@ -20,14 +22,16 @@


def make_trace_plots(
samples_dict: dict[str, Array], truth: dict[str, Array], n_examples: int = 25
samples_dict: dict[str, Array],
truth: dict[str, Array],
fpath: str,
n_examples: int = 25,
) -> None:
"""Make example figure showing example trace plots for each parameter."""
# by default, we choose 10 random traces to plot in 1 PDF file.
fname = "figs/traces.pdf"
n_gals, _, _ = samples_dict["lf"].shape

with PdfPages(fname) as pdf:
with PdfPages(fpath) as pdf:
for _ in tqdm(range(n_examples), desc="Making traces"):
idx = np.random.choice(np.arange(0, n_gals)).item()
chains = {k: v[idx] for k, v in samples_dict.items()}
Expand All @@ -50,7 +54,7 @@ def make_trace_plots(


def make_contour_plots(
samples_dict: dict[str, Array], truth: dict[str, Array], n_examples: int = 10
samples_dict: dict[str, Array], truth: dict[str, Array], n_examples: int = 25
) -> None:
"""Make example figure showing example contour plots of galaxy properties"""
fname = "figs/contours.pdf"
Expand All @@ -74,6 +78,9 @@ def make_convergence_histograms(samples_dict: dict[str, Array]) -> None:
fname = "figs/convergence_hist.pdf"
n_gals, n_chains_per_gal, n_samples = samples_dict["lf"].shape

if Path("figs/outliers.txt").exists():
os.remove("figs/outliers.txt")

# compute convergence metrics
rhats = {p: [] for p in samples_dict}
ess = {p: [] for p in samples_dict}
Expand All @@ -88,7 +95,10 @@ def make_convergence_histograms(samples_dict: dict[str, Array]) -> None:
# print rhat outliers
for p in rhats:
rhat = np.array(rhats[p])
print(f"Number of R-hat outliers for {p}:", sum((rhat < 0.98) | (rhat > 1.1)))
_ess = np.array(ess[p])
n_outliers = sum((rhat < 0.99) | (rhat > 1.05) | (_ess < 0.25))
with open("figs/outliers.txt", "a", encoding="utf-8") as f:
print(f"Number of R-hat outliers for {p}: {n_outliers}", file=f)

with PdfPages(fname) as pdf:
for p in samples_dict:
Expand Down Expand Up @@ -160,15 +170,20 @@ def main():
)
assert fpath.exists()
results = jnp.load(fpath, allow_pickle=True).item()
samples = results[250]["samples"]
truth = results[250]["truth"]
max_n_gal = max(results.keys())
samples = results[max_n_gal]["samples"]
truth = results[max_n_gal]["truth"]

# make plots
make_trace_plots(samples, truth)
make_trace_plots(samples, truth, fpath="figs/traces.pdf")
make_contour_plots(samples, truth)
make_convergence_histograms(samples)
make_timing_plots(results)

# on adaption too
adapt_states = results[max_n_gal]["adapt_info"].state.position
make_trace_plots(adapt_states, truth, fpath="figs/traces_adapt.pdf", n_examples=50)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions experiments/exp2/run_inference_galaxy_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _sample_prior_init(rng_key: PRNGKeyArray):

def main(
seed: int,
n_samples: int = 1000,
n_samples: int = 500,
shape_noise: float = 0.3,
sigma_e_int: float = 0.5,
slen: int = 53,
Expand Down Expand Up @@ -100,7 +100,7 @@ def main(
_run_sampling = vmap(vmap(jjit(_run_sampling1), in_axes=(0, 0, 0, None)))

results = {}
for n_gals in (1, 1, 5, 10, 20, 25, 50, 100, 250): # repeat 1 == compilation
for n_gals in (1, 1, 5, 10, 20, 25, 50, 100, 250, 500): # repeat 1 == compilation
print("n_gals:", n_gals)

# generate data and parameters
Expand Down
49 changes: 29 additions & 20 deletions notebooks/test-shear-inf-with-prior1.ipynb

Large diffs are not rendered by default.

0 comments on commit 2ada016

Please sign in to comment.