Skip to content

Commit

Permalink
light refactoring (#50)
Browse files Browse the repository at this point in the history
* add numpyro

* small corrections

* typing

* typing

* typing

* snr doc

* typing

* minor refactor

* ruff
  • Loading branch information
ismael-mendoza authored Nov 30, 2024
1 parent 68e15c8 commit 238d4b9
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 31 deletions.
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ Bayesian Pixel Domain shear estimation based on automatically differentiable cel

This repository contains functions to run HMC (Hamiltonian Monte Carlo) using [JAX-Galsim](https://github.com/GalSim-developers/JAX-GalSim) as a forward model to perform shear inference.


## Installation

```bash
Expand All @@ -13,15 +12,15 @@ pip install --upgrade pip
conda create -n bpd python=3.12
conda activate bpd

# Install JAX (cuda)
pip install -U "jax[cuda12]"
# Install JAX
pip install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html || pip install -U "jax[cpu]"

# Install JAX-Galsim
pip install git+https://github.com/GalSim-developers/JAX-GalSim.git

# Install package and depedencies
git clone git@github.com:LSSTDESC/BPD.git
cd BPD
python -m pip install . -e
python -m pip install .[dev]
pip install -e .
pip install -e ".[dev]"
```
12 changes: 7 additions & 5 deletions bpd/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

import blackjax
import jax
from jax import random
from jax import Array, random
from jax._src.prng import PRNGKeyArray
from jax.typing import ArrayLike


def inference_loop(rng_key, initial_state, kernel, n_samples: int):
"""Function to run a single chain with a given kernel and obtain `n_samples`."""
def inference_loop(
rng_key: PRNGKeyArray, initial_state: ArrayLike, kernel: Callable, n_samples: int
):
"""Function to run a single chain with a given kernel and obtain samples"""

def one_step(state, rng_key):
state, info = kernel(rng_key, state)
Expand All @@ -32,7 +34,7 @@ def run_warmup_nuts(
n_warmup_steps: int = 500,
is_mass_matrix_diagonal: bool = True,
target_acceptance_rate: float = 0.8,
):
) -> tuple[ArrayLike, dict, dict]:
_logtarget = partial(logtarget, data=data)
warmup = blackjax.window_adaptation(
blackjax.nuts,
Expand Down Expand Up @@ -82,7 +84,7 @@ def run_inference_nuts(
n_warmup_steps: int = 500,
target_acceptance_rate: float = 0.80,
is_mass_matrix_diagonal: bool = True,
):
) -> Array | dict[str, Array]:
key1, key2 = random.split(rng_key)

_logtarget = partial(logtarget, data=data)
Expand Down
3 changes: 1 addition & 2 deletions bpd/diagnostics.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import numpy as np
import pandas as pd
from chainconsumer import Chain, ChainConsumer, Truth
from jax import Array
from matplotlib.figure import Figure
from matplotlib.pyplot import Axes
from numpyro.diagnostics import hpdi
from scipy import stats


def get_contour_plot(
samples_list: list[dict[str, Array]],
samples_list: list[dict[str, np.ndarray]],
names: list[str],
truth: dict[str, float],
figsize: tuple[float, float] = (7, 7),
Expand Down
4 changes: 2 additions & 2 deletions bpd/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jax.numpy as jnp
import jax.scipy as jsp
from jax import grad, vmap
from jax import Array, grad, vmap
from jax.numpy.linalg import norm
from jax.typing import ArrayLike

Expand All @@ -13,7 +13,7 @@


def shear_loglikelihood_unreduced(
g: tuple[float, float], e_post, prior: Callable, interim_prior: Callable
g: tuple[float, float], e_post: Array, prior: Callable, interim_prior: Callable
) -> ArrayLike:
# Given by the inference procedure in Schneider et al. 2014
# assume single shear g
Expand Down
11 changes: 7 additions & 4 deletions bpd/measure.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import numpy as np
from jax.typing import ArrayLike
import jax.numpy as jnp
from jaxtyping import ArrayLike


def get_snr(im: ArrayLike, background: float) -> float:
"""Calculate the signal-to-noise ratio of an image.
Args:
im: Image array with no background.
im: 2D image array with no background.
background: Background level.
Returns:
float: The signal-to-noise ratio.
"""
assert im.ndim == 2
assert isinstance(background, float) or background.shape == ()
return np.sqrt(np.sum(im * im / (background + im)))
return jnp.sqrt(jnp.sum(im * im / (background + im)))
10 changes: 5 additions & 5 deletions bpd/pipelines/image_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,16 @@ def loglikelihood(
background: float,
free_flux: bool = True,
):
# NOTE: draw_fnc should already contain `f` and `hlr` as constant arguments.
_draw_params = {**{"g1": 0.0, "g2": 0.0}, **params} # function is more general
# NOTE: draw_fnc should already contain `f` and `hlr` as constant arguments if fixed
_draw_params = {**{"g1": 0.0, "g2": 0.0}, **params}

# Convert log-flux to flux if provided
if free_flux:
_draw_params["f"] = 10 ** _draw_params.pop("lf")
model = draw_fnc(**_draw_params)

model = draw_fnc(**_draw_params)
likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(background))
likelihood = jnp.sum(likelihood_pp)
return likelihood
return jnp.sum(likelihood_pp)


def logtarget(
Expand Down
14 changes: 9 additions & 5 deletions bpd/prior.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import jax.numpy as jnp
from jax import Array, random
from jax._src.prng import PRNGKeyArray
from jax.numpy.linalg import norm
from jaxtyping import ArrayLike


def ellip_mag_prior(e, sigma: float):
def ellip_mag_prior(e: ArrayLike, sigma: float):
"""Unnormalized Prior for the magnitude of the ellipticity, domain is (0, 1)
This distribution is taken from Gary's 2013 paper on Bayesian shear inference.
Expand All @@ -15,7 +17,9 @@ def ellip_mag_prior(e, sigma: float):
return (1 - e**2) ** 2 * jnp.exp(-(e**2) / (2 * sigma**2))


def sample_mag_ellip_prior(rng_key, sigma: float, n: int = 1, n_bins: int = 1_000_000):
def sample_mag_ellip_prior(
rng_key: PRNGKeyArray, sigma: float, n: int = 1, n_bins: int = 1_000_000
):
"""Sample n points from Gary's ellipticity magnitude prior."""
# this part could be cached
e_array = jnp.linspace(0, 1, n_bins)
Expand All @@ -25,7 +29,7 @@ def sample_mag_ellip_prior(rng_key, sigma: float, n: int = 1, n_bins: int = 1_00
return random.choice(rng_key, e_array, shape=(n,), p=p_array)


def sample_ellip_prior(rng_key, sigma: float, n: int = 1):
def sample_ellip_prior(rng_key: PRNGKeyArray, sigma: float, n: int = 1):
"""Sample n ellipticities isotropic components with Gary's prior from magnitude."""
key1, key2 = random.split(rng_key, 2)
e_mag = sample_mag_ellip_prior(key1, sigma=sigma, n=n)
Expand Down Expand Up @@ -98,7 +102,7 @@ def inv_shear_transformation(e: Array, g: tuple[float, float]):

# get synthetic measured sheared ellipticities
def sample_synthetic_sheared_ellips_unclipped(
rng_key,
rng_key: PRNGKeyArray,
g: tuple[float, float],
n: int,
sigma_m: float,
Expand All @@ -114,7 +118,7 @@ def sample_synthetic_sheared_ellips_unclipped(


def sample_synthetic_sheared_ellips_clipped(
rng_key,
rng_key: PRNGKeyArray,
g: tuple[float, float],
sigma_m: float,
sigma_e: float,
Expand Down
13 changes: 10 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@ description = "Bayesian Pixel Domain method for shear inference."
version = "0.0.1"
license = { file = "LICENSE" }
readme = "README.md"
dependencies = ["numpy >=1.18.0", "galsim >=2.3.0", "jax >=0.4.30", "jaxlib", "blackjax >=1.2.0"]
dependencies = [
"numpy >=1.18.0",
"galsim >=2.3.0",
"jax >=0.4.30",
"jaxlib",
"blackjax >=1.2.0",
"numpyro >=0.13.0",
]


[project.optional-dependencies]
Expand Down Expand Up @@ -58,8 +65,8 @@ exclude = [
line-length = 88
indent-width = 4

# Assume Python 3.8
target-version = "py310"
# Assume Python 3.12
target-version = "py312"

[tool.ruff.format]
# Like Black, use double quotes for strings.
Expand Down

0 comments on commit 238d4b9

Please sign in to comment.