Skip to content

Commit

Permalink
Simplify: remove expect_2distr (#15)
Browse files Browse the repository at this point in the history
@alleSini99 what do you think? Isn't this much simpler ? (We should
check that it is as fast... but I think it is)

cc @lgravina1997
  • Loading branch information
PhilipVinc authored May 18, 2024
1 parent a167b6d commit bb237a3
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 203 deletions.
30 changes: 21 additions & 9 deletions netket_fidelity/infidelity/overlap/expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from netket import jax as nkjax
from netket.utils import mpi

from netket_fidelity.utils import expect_2distr

from .operator import InfidelityOperatorStandard

Expand Down Expand Up @@ -76,7 +75,10 @@ def infidelity_sampling_MCState(
σ_t = sigma_t.reshape(-1, N)

def expect_kernel(params):
def kernel_fun(params, params_t, σ, σ_t):
def kernel_fun(params_all, samples_all):
params, params_t = params_all
σ, σ_t = samples_all

W = {"params": params, **model_state}
W_t = {"params": params_t, **model_state_t}

Expand All @@ -91,14 +93,24 @@ def kernel_fun(params, params_t, σ, σ_t):
lambda params, σ: 2 * afun_t({"params": params, **model_state_t}, σ).real
)

return expect_2distr(
log_pdf,
log_pdf_t,
def log_pdf_joint(params_all, samples_all):
params, params_t = params_all
σ, σ_t = samples_all
log_pdf_vals = log_pdf(params, σ)
log_pdf_t_vals = log_pdf_t(params_t, σ_t)
return log_pdf_vals + log_pdf_t_vals

return nkjax.expect(
log_pdf_joint,
kernel_fun,
params,
params_t,
σ,
σ_t,
(
params,
params_t,
),
(
σ,
σ_t,
),
n_chains=n_chains_t,
)

Expand Down
4 changes: 2 additions & 2 deletions netket_fidelity/infidelity/overlap_U/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def expect_fun(params):
F, F_vjp_fun = nkjax.vjp(expect_fun, params, conjugate=True)

F_grad = F_vjp_fun(jnp.ones_like(F))[0]
F_grad = jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad)
I_grad = jax.tree_map(lambda x: -x, F_grad)
F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad)
I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad)
I_stats = Stats(mean=1 - F, error_of_mean=0.0, variance=0.0)

return I_stats, I_grad
30 changes: 21 additions & 9 deletions netket_fidelity/infidelity/overlap_U/expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from netket.vqs import MCState, expect, expect_and_grad, get_local_kernel_arguments
from netket.utils import mpi

from netket_fidelity.utils import expect_2distr

from .operator import InfidelityOperatorUPsi

Expand Down Expand Up @@ -113,7 +112,10 @@ def infidelity_sampling_MCState(
xp_t_ravel = jnp.vstack(xp_t_splitted)

def expect_kernel(params):
def kernel_fun(params, params_t, σ, σ_t):
def kernel_fun(params_all, samples_all):
params, params_t = params_all
σ, σ_t = samples_all

W = {"params": params, **model_state}
W_t = {"params": params_t, **model_state_t}

Expand All @@ -139,14 +141,24 @@ def kernel_fun(params, params_t, σ, σ_t):
lambda params, σ: 2 * afun_t({"params": params, **model_state_t}, σ).real
)

return expect_2distr(
log_pdf,
log_pdf_t,
def log_pdf_joint(params_all, samples_all):
params, params_t = params_all
σ, σ_t = samples_all
log_pdf_vals = log_pdf(params, σ)
log_pdf_t_vals = log_pdf_t(params_t, σ_t)
return log_pdf_vals + log_pdf_t_vals

return nkjax.expect(
log_pdf_joint,
kernel_fun,
params,
params_t,
σ,
σ_t,
(
params,
params_t,
),
(
σ,
σ_t,
),
n_chains=n_chains_t,
)

Expand Down
1 change: 0 additions & 1 deletion netket_fidelity/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .expect import expect_2distr
from .sampling_Ustate import make_logpsi_U_afun, _logpsi_U_fun

from netket.utils import _hide_submodules
Expand Down
182 changes: 0 additions & 182 deletions netket_fidelity/utils/expect.py

This file was deleted.

0 comments on commit bb237a3

Please sign in to comment.