Diagnostic a pathological behavior: solution? #28
Replies: 4 comments 5 replies
-
Dear @Joshuaalbert, |
Beta Was this translation helpful? Give feedback.
-
Hello @Joshuaalbert, Here is my Nympyro code def model():
# Cosmological params (test13 JEC)
Omega_c = numpyro.sample('Omega_c', dist.Uniform(0.1, 0.4))
sigma8 = numpyro.sample('sigma8', dist.Uniform(0.5, 1.2))
Omega_b = numpyro.sample('Omega_b', dist.Uniform(0.01, 0.09))
h = numpyro.sample('h', dist.Uniform(0.4, 1.0))
n_s = numpyro.sample('n_s', dist.Uniform(0.5, 1.25))
w0 = numpyro.sample('w0', dist.Uniform(-2.0, -0.0001))
# Astrophysical params
A = numpyro.sample('A', dist.Uniform(0., 2.5))
eta = numpyro.sample('eta', dist.Uniform(0., 6.))
# parameters for systematics
m = [numpyro.sample('m%d'%i, dist.Normal(0.012, 0.023))
for i in range(1,5)]
dz1 = numpyro.sample('dz1', dist.Normal(0.001, 0.016))
dz2 = numpyro.sample('dz2', dist.Normal(-0.019, 0.013))
dz3 = numpyro.sample('dz3', dist.Normal(0.009, 0.011))
dz4 = numpyro.sample('dz4', dist.Normal(-0.018, 0.022))
# Now that params are defined, here is the forward model
cosmo = FiducialCosmo(Omega_c=Omega_c, sigma8=sigma8, Omega_b=Omega_b,
h=h, n_s=n_s, w0=w0)
signal = model_fn(get_params_vec(cosmo, m, [dz1, dz2, dz3, dz4], [A, eta]))
# And here we define the likelihood
numpyro.sample('cl_wl', dist.MultivariateNormal(signal, C), obs=cl_obs) Here is my tentative to setup a JaxNS code: #@jax.jit
#code from https://github.com/google/jax/issues/2314
def multi_gauss_logpdf(x, mean, cov):
""" Calculate the probability density of a
sample from the multivariate normal. """
D = mean.shape[0]
(sign, logdet) = np.linalg.slogdet(cov)
p1 = D*np.log(2*np.pi) + logdet
p2 = (x-mean).T @ np.linalg.inv(cov) @ (x-mean)
return -1./2 * (p1 + p2)
def solve(cov,cl_obs):
def log_lik(Omega_c, sigma8, Omega_b,h, n_s, w0, A, eta,
m1, m2,m3, m4,
dz1, dz2, dz3, dz4,
**kwargs):
cosmo = FiducialCosmo(Omega_c=Omega_c, sigma8=sigma8, Omega_b=Omega_b,
h=h, n_s=n_s, w0=w0)
signal = model_fn(get_params_vec(cosmo, [m1, m2,m3, m4], [dz1, dz2, dz3, dz4], [A, eta]))
### Comment faire numpyro.sample('cl_wl', dist.MultivariateNormal(signal, cov), obs='cl_obs')
return multi_gauss_logpdf(cl_obs,signal,cov)
prior_chain = PriorChain(UniformPrior('Omega_c',0.1, 0.4),
UniformPrior('sigma8',0.5, 1.2),
UniformPrior('Omega_b',0.01, 0.09),
UniformPrior('h',0.4, 1.0),
UniformPrior('n_s',0.5, 1.25),
UniformPrior('w0',-2.0, -0.0001),
UniformPrior('A',0.,2.5),
UniformPrior('eta',0.,6.),
NormalPrior('m1',0.012, 0.023),
NormalPrior('m2',0.012, 0.023),
NormalPrior('m3',0.012, 0.023),
NormalPrior('m4',0.012, 0.023),
NormalPrior('dz1',0.001, 0.016),
NormalPrior('dz2',-0.019, 0.013),
NormalPrior('dz3',0.009, 0.011),
NormalPrior('dz4',-0.018, 0.022)
)
print('num_live_points:',prior_chain.U_ndims*500)
ns = jaxns.nested_sampling.NestedSampler(log_lik, prior_chain,
num_live_points=prior_chain.U_ndims*500)
print('Go...')
results = ns(jax.random.PRNGKey(32564))
return results
## Go.
results = solve(C,data) Now I try to run this code on GPU...
2021-10-20 15:11:15.189842: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:461] Allocator (GPU_0_bfc) ran out of memory trying to allocate 933.17GiB (rounded to 1001989427968)requested by op The stack trace below excludes JAX-internal frames. The above exception was the direct cause of the following exception: Traceback (most recent call last):
|
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.
-
@jecampagne does this example require any data? If not, I could add it to the examples section. |
Beta Was this translation helpful? Give feedback.
-
Dear experts,
I have posted this message to
Numpyro
forum, and they invite me to ask the question here as theirNestedSampler
class is a wrapper toJaxNS
(see here) . Also I must confess that I 'm not an experienced user of Nested SAmpler so my question may be not well formulated.So, I have a
model
which I sample with theNUTS
sampler. Schematically, I do the following (sorry this is in the Numpyro language)numpyro.sample
statements to define priors (which are for this exemple all Gaussian distributions) and likelihoodfix_cond_model = numpyro.handlers.condition(model, <parmeters defult values>
to generate somedata
thanks toThen, I proceed to the MCMC run et finally get the samples:
So far so good, now I wander if I can use the
NestedSampler
? I have triedFrom the function calls point of view, it seems ok from
Numpyro
developers, but the sampling of the variables is clearly pathologic,And here are the arviz kde plots
I certainly miss something. Any idea are welcome. Thanks
Beta Was this translation helpful? Give feedback.
All reactions