Skip to content

Commit

Permalink
simplify logic of logpsi_U wrapper (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipVinc authored Nov 28, 2023
1 parent 0a19a8b commit 71f4c55
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 8 deletions.
8 changes: 3 additions & 5 deletions netket_fidelity/infidelity/overlap/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

import jax.numpy as jnp

import flax

from netket import jax as nkjax
from netket.operator import AbstractOperator, DiscreteJaxOperator
from netket.utils.types import DType
from netket.utils.numbers import is_scalar
from netket.vqs import VariationalState, MCState, FullSumState

from netket_fidelity.utils.sampling_Ustate import _logpsi_U
from netket_fidelity.utils.sampling_Ustate import make_logpsi_U_afun


class InfidelityOperatorStandard(AbstractOperator):
Expand Down Expand Up @@ -72,12 +70,12 @@ def InfidelityUPsi(
"an instance of DiscreteJaxOperator."
)

logpsiU = nkjax.HashablePartial(_logpsi_U, state._apply_fun)
logpsiU, variables_U = make_logpsi_U_afun(state._apply_fun, U, state.variables)
target = MCState(
sampler=state.sampler,
apply_fun=logpsiU,
n_samples=state.n_samples,
variables=flax.core.copy(state.variables, {"unitary": U}),
variables=variables_U,
)

return InfidelityOperatorStandard(target, cv_coeff=cv_coeff, dtype=dtype)
2 changes: 1 addition & 1 deletion netket_fidelity/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .expect import expect_2distr
from .sampling_Ustate import _logpsi_U
from .sampling_Ustate import make_logpsi_U_afun, _logpsi_U_fun

from netket.utils import _hide_submodules

Expand Down
32 changes: 30 additions & 2 deletions netket_fidelity/utils/sampling_Ustate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,37 @@
import jax

import flax

from netket import jax as nkjax


def make_logpsi_U_afun(logpsi_fun, U, variables):
"""Wraps an apply_fun into another one that multiplies it by an
Unitary transformation U.
This wrapper is made such that the Unitary is passed as the model_state
of the new wrapped function, and therefore changes to the angles/coefficients
of the Unitary should not trigger recompilation.
Args:
logpsi_fun: a function that takes as input variables and samples
U: a {class}`nk.operator.JaxDiscreteOperator`
variables: The variables used to call *logpsi_fun*
Returns:
A tuple, where the first element is a new function with the same signature as
the original **logpsi_fun** and a set of new variables to be used to call it.
"""
# wrap apply_fun into logpsi logpsi_U
logpsiU_fun = nkjax.HashablePartial(_logpsi_U_fun, logpsi_fun)

# Insert a new 'model_state' key to store the Unitary. This only works
# if U is a pytree that can be flattened/unflattened.
new_variables = flax.core.copy(variables, {"unitary": U})

return logpsiU_fun, new_variables


def _logpsi_U(apply_fun, variables, x, *args):
def _logpsi_U_fun(apply_fun, variables, x, *args):
"""
This should be used as a wrapper to the original apply function, adding
to the `variables` dictionary (in model_state) a new key `unitary` with
Expand Down

0 comments on commit 71f4c55

Please sign in to comment.