Skip to content

Commit

Permalink
Now we have a good implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
arsalan-motamedi committed Nov 1, 2024
1 parent a8eab48 commit 3b2844b
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 16 deletions.
4 changes: 3 additions & 1 deletion mrmustard/lab_dev/states/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
)

from ..circuit_components import CircuitComponent
from ..circuit_components_utils import BtoPS


__all__ = ["State"]

Expand Down Expand Up @@ -333,6 +333,8 @@ def phase_space(self, s: float) -> tuple:
Returns:
The covariance matrix, the mean vector and the coefficient of the state in s-parametrized phase space.
"""
from ..circuit_components_utils import BtoPS

if not isinstance(self.ansatz, PolyExpAnsatz):
raise ValueError("Can calculate phase space only for Bargmann states.")

Expand Down
11 changes: 0 additions & 11 deletions mrmustard/lab_dev/states/dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,17 +413,6 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent:
Returns a ``DM`` when the wires of the resulting components are compatible with
those of a ``DM``, a ``CircuitComponent`` otherwise, and a scalar if there are no wires left.
"""
from ..transformations.phasenoise import PhaseNoise

if isinstance(other, PhaseNoise):
array = self.to_fock().representation.array[0]
cutoff = array.shape[-1]
for i in range(cutoff):
for j in range(cutoff):
array[i, j] = array[i, j] * math.exp(
-0.5 * (i - j) ** 2 * other.phase_stdev.value**2
)
return DM.from_fock(self.modes, array, self.name)

result = super().__rshift__(other)
if not isinstance(result, CircuitComponent):
Expand Down
28 changes: 24 additions & 4 deletions mrmustard/lab_dev/transformations/phasenoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from typing import Sequence
from mrmustard import math, settings
from .base import Channel
from .attenuator import Attenuator
from ...physics.representations import Fock
from ..states import Ket, DM
from ..utils import make_parameter
import numpy as np

__all__ = ["PhaseNoise"]

Expand All @@ -44,6 +44,7 @@ class PhaseNoise(Channel):
"""

short_name = "P~"

# randomized : bool
def __init__(
self,
Expand All @@ -56,10 +57,29 @@ def __init__(
self._add_parameter(
make_parameter(phase_stdev_trainable, phase_stdev, "phase_stdev", phase_stdev_bounds)
)
self._representation = self.from_ansatz(
modes_in=modes, modes_out=modes, ansatz=None
).representation

def __custom_rrshift__(self, other):
r"""
Custom rrshift
"""
# check if Ket or DM: do the specific matmul
# raise exception if not
if isinstance(other, Ket):
other = other.dm()
if isinstance(other, DM):
array = other.fock_array()
for mode in self.modes:
for count, _ in enumerate(np.nditer(array)):
idx = math.zeros(len(array.shape))
temp = count
for l in range(len(idx)):
idx[-1 - l] = temp % array.shape[-1 - l]
temp = temp // array.shape[-1 - l]
array_index = tuple(idx.astype(int))
array[array_index] *= math.exp(
-0.5
* (idx[mode] - idx[other.n_modes + mode]) ** 2
* self.phase_stdev.value**2
)
return DM.from_fock(other.modes, array, self.name)

0 comments on commit 3b2844b

Please sign in to comment.