Skip to content

Commit

Permalink
rename exactstate
Browse files Browse the repository at this point in the history
improve driver


black
  • Loading branch information
PhilipVinc committed May 22, 2023
1 parent b53a6a5 commit f42071d
Show file tree
Hide file tree
Showing 14 changed files with 217 additions and 62 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ I_op = nkf.infidelity.InfidelityOperator(phi, U=U, U_dagger=U, is_unitary=True,
# Create the driver
optimizer = nk.optimizer.Sgd(learning_rate=0.01)
te = nkf.driver.infidelity_optimizer.InfidelityOptimizer(phi, U, psi, optimizer, U_dagger=U, is_unitary=True, cv_coeff=-0.5)
te = nkf.driver.infidelity_optimizer.InfidelityOptimizer(phi, optimizer, U=U, U_dagger=U, variational_state=psi, is_unitary=True, cv_coeff=-0.5)
# Run the driver
te.run(n_iter=100)
Expand Down
161 changes: 140 additions & 21 deletions netket_fidelity/driver/infidelity_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,140 @@
from typing import Optional
from netket.stats import Stats

from netket.driver.abstract_variational_driver import AbstractVariationalDriver

from .infidelity_optimizer_common import info
from netket.optimizer import (
identity_preconditioner,
PreconditionerT,
)

from netket_fidelity.infidelity import InfidelityOperator

from .infidelity_optimizer_common import info


class InfidelityOptimizer(AbstractVariationalDriver):
def __init__(
self,
target_state,
U,
vstate,
optimizer,
*,
variational_state,
U=None,
U_dagger=None,
sr=None,
preconditioner: PreconditionerT = identity_preconditioner,
is_unitary=False,
cv_coeff=None,
cv_coeff=-0.5,
):
super().__init__(vstate, optimizer, minimized_quantity_name="Infidelity")
r"""
Constructs a driver training the state to match the target state.
The target state is either `math`:\ket{\psi}` or `math`:\hat{U}\ket{\psi}`
depending on the provided inputs.
Operator I_op computing the infidelity I among two variational states |ψ⟩ and |Φ⟩ as:
.. math::
I = 1 - |⟨ψ|Φ⟩|^2 / ⟨ψ|ψ⟩ ⟨Φ|Φ⟩ = 1 - ⟨ψ|I_op|ψ⟩ / ⟨ψ|ψ⟩
where:
.. math::
I_op = |Φ⟩⟨Φ| / ⟨Φ|Φ⟩
The state |Φ⟩ can be an autonomous state |Φ⟩ =|ϕ⟩ or an operator U applied to it, namely
|Φ⟩ = U|ϕ⟩. I_op is defined by the state |ϕ⟩ (called target) and, possibly, by the operator U.
If U is not passed, it is assumed |Φ⟩ =|ϕ⟩.
The Monte Carlo estimator of I is:
..math::
I = \mathbb{E}_{χ}[ I_loc(σ,η) ] = \mathbb{E}_{χ}[ ⟨σ|Φ⟩ ⟨η|ψ⟩ / ⟨σ|ψ⟩ ⟨η|Φ⟩ ]
where χ(σ, η) = |Ψ(σ)|^2 |Φ(η)|^2 / ⟨ψ|ψ⟩ ⟨Φ|Φ⟩. In practice, since I is a real quantity, Re{I_loc(σ,η)}
is used. This estimator can be utilized both when |Φ⟩ =|ϕ⟩ and when |Φ⟩ = U|ϕ⟩, with U a (unitary or
non-unitary) operator. In the second case, we have to sample from U|ϕ⟩ and this is implemented in
the function :ref:`jax.:ref:`InfidelityUPsi`. This works only with the operators provdided in the package.
We remark that sampling from U|ϕ⟩ requires to compute connected elements of U and so is more expensive
than sampling from an autonomous state. The choice of this estimator is specified by passing
`sample_Upsi=True`, while the flag argument `is_unitary` indicates whether U is unitary or not.
If U is unitary, the following alternative estimator can be used:
..math::
I = \mathbb{E}_{χ'}[ I_loc(σ,η) ] = \mathbb{E}_{χ}[ ⟨σ|U|ϕ⟩ ⟨η|ψ⟩ / ⟨σ|U^{\dagger}|ψ⟩ ⟨η|ϕ⟩ ].
where χ'(σ, η) = |Ψ(σ)|^2 |ϕ(η)|^2 / ⟨ψ|ψ⟩ ⟨ϕ|ϕ⟩. This estimator is more efficient since it does not
require to sample from U|ϕ⟩, but only from |ϕ⟩. This choice of the estimator is the default and it works only
with `is_unitary==True` (besides `sample_Upsi=False`). When |Φ⟩ = |ϕ⟩ the two estimators coincides.
To reduce the variance of the estimator, the Control Variates (CV) method can be applied. This consists
in modifying the estimator into:
..math::
I_loc^{CV} = Re{I_loc(σ,η)} - c (|1 - I_loc(σ,η)^2| - 1)
where c ∈ \mathbb{R}. The constant c is chosen to minimize the variance of I_loc^{CV} as:
..math::
c* = Cov_{χ}[ |1-I_loc|^2, Re{1-I_loc}] / Var_{χ}[ |1-I_loc|^2 ],
where Cov[..., ...] indicates the covariance and Var[...] the variance. In the relevant limit
|Ψ⟩ →|Φ⟩, we have c*→-1/2. The value -1/2 is adopted as default value for c in the infidelity
estimator. To not apply CV, set c=0.
Args:
target_state: target variational state |ϕ⟩.
optimizer: the optimizer to use to use (from optax)
variational_state: the variational state to train
U: operator U.
U_dagger: dagger operator U^{\dagger}.
cv_coeff: Control Variates coefficient c.
is_unitary: flag specifiying the unitarity of U. If True with `sample_Upsi=False`, the second estimator is used.
dtype: The dtype of the output of expectation value and gradient.
sample_Upsi: flag specifiying whether to sample from |ϕ⟩ or from U|ϕ⟩. If False with `is_unitary=False`, an error occurs.
preconditioner: Determines which preconditioner to use for the loss gradient.
This must be a tuple of `(object, solver)` as documented in the section
`preconditioners` in the documentation. The standard preconditioner
included with NetKet is Stochastic Reconfiguration. By default, no
preconditioner is used and the bare gradient is passed to the optimizer.
"""
super().__init__(
variational_state, optimizer, minimized_quantity_name="Infidelity"
)

self._cv = cv_coeff

self.preconditioner = preconditioner

self.sr = sr
self._I_op = InfidelityOperator(
target_state, U=U, U_dagger=U, is_unitary=True, cv_coeff=-1 / 2
target_state, U=U, U_dagger=U, is_unitary=True, cv_coeff=cv_coeff
)

def _forward_and_backward(self):
self.state.reset()
self._I_op.target.reset()

I_stats, I_grad = self.state.expect_and_grad(self._I_op)

# TODO
self._loss_stats = I_stats
self._loss_grad = I_grad
self._loss_stats, self._loss_grad = self.state.expect_and_grad(self._I_op)

if self.sr is not None:
self._S = self.state.quantum_geometric_tensor(self.sr)
self._dp = self._S(self._loss_grad)
else:
self._dp = self._loss_grad
# if it's the identity it does
self._dp = self.preconditioner(self.state, self._loss_grad, self.step_count)

return self._dp

@property
def cv(self) -> Optional[float]:
"""
Return the coefficient for the Control Variates
"""
return self._cv

@property
def infidelity(self) -> Stats:
"""
Expand All @@ -51,6 +143,36 @@ def infidelity(self) -> Stats:
"""
return self._loss_stats

@property
def preconditioner(self):
"""
The preconditioner used to modify the gradient.
This is a function with the following signature
.. code-block:: python
precondtioner(vstate: VariationalState,
grad: PyTree,
step: Optional[Scalar] = None)
Where the first argument is a variational state, the second argument
is the PyTree of the gradient to precondition and the last optional
argument is the step, used to change some parameters along the
optimisation.
Often, this is taken to be :func:`nk.optimizer.SR`. If it is set to
`None`, then the identity is used.
"""
return self._preconditioner

@preconditioner.setter
def preconditioner(self, val: Optional[PreconditionerT]):
if val is None:
val = identity_preconditioner

self._preconditioner = val

def __repr__(self):
return (
"InfidelityOptimiser("
Expand All @@ -69,6 +191,3 @@ def info(self, depth=0):
]
]
return "\n{}".format(" " * 3 * (depth + 1)).join([str(self)] + lines)

def info(self):
pass
11 changes: 9 additions & 2 deletions netket_fidelity/infidelity/logic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from typing import Optional

from netket.operator import AbstractOperator, Adjoint
from netket.vqs import VariationalState, ExactState
from netket.vqs import VariationalState
from netket.utils.types import DType

import netket

if hasattr(netket.vqs, "FullSumState"):
from netket.vqs import FullSumState
else:
from netket.vqs import ExactState as FullSumState

from .overlap import InfidelityOperatorStandard, InfidelityUPsi
from .overlap_U import InfidelityOperatorUPsi

Expand Down Expand Up @@ -120,7 +127,7 @@ def InfidelityOperator(
"use operators coming from `netket_fidelity`."
)

if isinstance(target, ExactState):
if isinstance(target, FullSumState):
return InfidelityOperatorUPsi(
U,
target,
Expand Down
26 changes: 17 additions & 9 deletions netket_fidelity/infidelity/overlap/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,29 @@

from netket import jax as nkjax
from netket.utils.dispatch import TrueT
from netket.vqs import ExactState, expect, expect_and_grad
from netket.vqs import expect, expect_and_grad
from netket.utils import mpi
from netket.stats import Stats

# support future netket
import netket

if hasattr(netket.vqs, "FullSumState"):
from netket.vqs import FullSumState
else:
from netket.vqs import ExactState as FullSumState

from .operator import InfidelityOperatorStandard


@expect.dispatch
def infidelity(vstate: ExactState, op: InfidelityOperatorStandard):
def infidelity(vstate: FullSumState, op: InfidelityOperatorStandard):
if op.hilbert != vstate.hilbert:
raise TypeError("Hilbert spaces should match")
if not isinstance(op.target, ExactState):
if not isinstance(op.target, FullSumState):
raise TypeError("Can only compute infidelity of exact states.")

return infidelity_sampling_ExactState(
return infidelity_sampling_FullSumState(
vstate._apply_fun,
vstate.parameters,
vstate.model_state,
Expand All @@ -30,19 +38,19 @@ def infidelity(vstate: ExactState, op: InfidelityOperatorStandard):


@expect_and_grad.dispatch
def infidelity(
vstate: ExactState,
def infidelity( # noqa: F811
vstate: FullSumState,
op: InfidelityOperatorStandard,
use_covariance: TrueT,
*,
mutable,
):
if op.hilbert != vstate.hilbert:
raise TypeError("Hilbert spaces should match")
if not isinstance(op.target, ExactState):
if not isinstance(op.target, FullSumState):
raise TypeError("Can only compute infidelity of exact states.")

return infidelity_sampling_ExactState(
return infidelity_sampling_FullSumState(
vstate._apply_fun,
vstate.parameters,
vstate.model_state,
Expand All @@ -53,7 +61,7 @@ def infidelity(


@partial(jax.jit, static_argnames=("afun", "return_grad"))
def infidelity_sampling_ExactState(
def infidelity_sampling_FullSumState(
afun,
params,
model_state,
Expand Down
2 changes: 1 addition & 1 deletion netket_fidelity/infidelity/overlap/expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def infidelity(vstate: MCState, op: InfidelityOperatorStandard):


@expect_and_grad.dispatch
def infidelity(
def infidelity( # noqa: F811
vstate: MCState,
op: InfidelityOperatorStandard,
use_covariance: TrueT,
Expand Down
13 changes: 11 additions & 2 deletions netket_fidelity/infidelity/overlap/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
from netket.operator import AbstractOperator
from netket.utils.types import DType
from netket.utils.numbers import is_scalar
from netket.vqs import VariationalState, ExactState, MCState
from netket.vqs import VariationalState, MCState

# support future netket
import netket

if hasattr(netket.vqs, "FullSumState"):
from netket.vqs import FullSumState
else:
from netket.vqs import ExactState as FullSumState


from netket_fidelity.utils.sampling_Ustate import _logpsi_U

Expand All @@ -29,7 +38,7 @@ def __init__(
if (not is_scalar(cv_coeff)) or jnp.iscomplex(cv_coeff):
raise TypeError("`cv_coeff` should be a real scalar number or None.")

if isinstance(target, ExactState):
if isinstance(target, FullSumState):
cv_coeff = None

self._target = target
Expand Down
Loading

0 comments on commit f42071d

Please sign in to comment.