From 71f4c5595843a2156583df6a249e9d3c48699de9 Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Tue, 28 Nov 2023 08:24:22 +0100 Subject: [PATCH] simplify logic of logpsi_U wrapper (#11) @alleSini99 --- .../infidelity/overlap/operator.py | 8 ++--- netket_fidelity/utils/__init__.py | 2 +- netket_fidelity/utils/sampling_Ustate.py | 32 +++++++++++++++++-- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/netket_fidelity/infidelity/overlap/operator.py b/netket_fidelity/infidelity/overlap/operator.py index 5c00036..6a8709b 100644 --- a/netket_fidelity/infidelity/overlap/operator.py +++ b/netket_fidelity/infidelity/overlap/operator.py @@ -2,15 +2,13 @@ import jax.numpy as jnp -import flax -from netket import jax as nkjax from netket.operator import AbstractOperator, DiscreteJaxOperator from netket.utils.types import DType from netket.utils.numbers import is_scalar from netket.vqs import VariationalState, MCState, FullSumState -from netket_fidelity.utils.sampling_Ustate import _logpsi_U +from netket_fidelity.utils.sampling_Ustate import make_logpsi_U_afun class InfidelityOperatorStandard(AbstractOperator): @@ -72,12 +70,12 @@ def InfidelityUPsi( "an instance of DiscreteJaxOperator." ) - logpsiU = nkjax.HashablePartial(_logpsi_U, state._apply_fun) + logpsiU, variables_U = make_logpsi_U_afun(state._apply_fun, U, state.variables) target = MCState( sampler=state.sampler, apply_fun=logpsiU, n_samples=state.n_samples, - variables=flax.core.copy(state.variables, {"unitary": U}), + variables=variables_U, ) return InfidelityOperatorStandard(target, cv_coeff=cv_coeff, dtype=dtype) diff --git a/netket_fidelity/utils/__init__.py b/netket_fidelity/utils/__init__.py index a599539..ac817ba 100644 --- a/netket_fidelity/utils/__init__.py +++ b/netket_fidelity/utils/__init__.py @@ -1,5 +1,5 @@ from .expect import expect_2distr -from .sampling_Ustate import _logpsi_U +from .sampling_Ustate import make_logpsi_U_afun, _logpsi_U_fun from netket.utils import _hide_submodules diff --git a/netket_fidelity/utils/sampling_Ustate.py b/netket_fidelity/utils/sampling_Ustate.py index 12ad488..f9452d4 100644 --- a/netket_fidelity/utils/sampling_Ustate.py +++ b/netket_fidelity/utils/sampling_Ustate.py @@ -1,9 +1,37 @@ import jax - import flax +from netket import jax as nkjax + + +def make_logpsi_U_afun(logpsi_fun, U, variables): + """Wraps an apply_fun into another one that multiplies it by an + Unitary transformation U. + + This wrapper is made such that the Unitary is passed as the model_state + of the new wrapped function, and therefore changes to the angles/coefficients + of the Unitary should not trigger recompilation. + + Args: + logpsi_fun: a function that takes as input variables and samples + U: a {class}`nk.operator.JaxDiscreteOperator` + variables: The variables used to call *logpsi_fun* + + Returns: + A tuple, where the first element is a new function with the same signature as + the original **logpsi_fun** and a set of new variables to be used to call it. + """ + # wrap apply_fun into logpsi logpsi_U + logpsiU_fun = nkjax.HashablePartial(_logpsi_U_fun, logpsi_fun) + + # Insert a new 'model_state' key to store the Unitary. This only works + # if U is a pytree that can be flattened/unflattened. + new_variables = flax.core.copy(variables, {"unitary": U}) + + return logpsiU_fun, new_variables + -def _logpsi_U(apply_fun, variables, x, *args): +def _logpsi_U_fun(apply_fun, variables, x, *args): """ This should be used as a wrapper to the original apply function, adding to the `variables` dictionary (in model_state) a new key `unitary` with