Skip to content

Commit

Permalink
Update tensors.py
Browse files Browse the repository at this point in the history
  • Loading branch information
JakeKitchen authored Oct 15, 2024
1 parent 987feac commit d53e625
Showing 1 changed file with 41 additions and 28 deletions.
69 changes: 41 additions & 28 deletions mrmustard/math/tensor_networks/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ class Wire:
is_input: bool
is_ket: bool

_contraction_id: int = field(default_factory=random_int, init=False)
_dim: int | None = field(default=None, init=False)
_is_connected: bool = field(default=False, init=False)
def __post_init__(self):
self._contraction_id: int = random_int()
self._dim = None
self._is_connected = False

@property
def contraction_id(self) -> int:
Expand All @@ -79,7 +80,7 @@ def dim(self):

@dim.setter
def dim(self, value: int):
if self._dim is not None:
if self._dim:
raise ValueError("Cannot change the dimension of wire with specified dimension.")
self._dim = value

Expand Down Expand Up @@ -186,16 +187,18 @@ def _update_modes(
self._modes_in_bra = modes_in_bra if modes_in_bra else []
self._modes_out_bra = modes_out_bra if modes_out_bra else []

# initialize ket and bra wire dicts using dictionary comprehensions for better performance
self._input = WireGroup(
ket={mode: Wire(random_int(), mode, True, True) for mode in self._modes_in_ket},
bra={mode: Wire(random_int(), mode, True, False) for mode in self._modes_in_bra},
)
# initialize ket and bra wire dicts
self._input = WireGroup()
for mode in self._modes_in_ket:
self._input.ket |= {mode: Wire(random_int(), mode, True, True)}
for mode in self._modes_in_bra:
self._input.bra |= {mode: Wire(random_int(), mode, True, False)}

self._output = WireGroup(
ket={mode: Wire(random_int(), mode, False, True) for mode in self._modes_out_ket},
bra={mode: Wire(random_int(), mode, False, False) for mode in self._modes_out_bra},
)
self._output = WireGroup()
for mode in self._modes_out_ket:
self._output.ket |= {mode: Wire(random_int(), mode, False, True)}
for mode in self._modes_out_bra:
self._output.bra |= {mode: Wire(random_int(), mode, False, False)}

@property
def adjoint(self) -> AdjointView:
Expand Down Expand Up @@ -351,25 +354,30 @@ def shape(self, default_dim: int | None = None, out_in=False):
Returns the shape of the underlying tensor, as inferred from the dimensions of the individual
wires.
If ``out_in`` is ``False``, the shape returned is in the order ``(in_ket, in_bra, out_ket, out_bra)``
If ``out_in`` is ``False``, the shape returned is in the order ``(in_ket, in_bra, out_ket, out_bra)``.
Otherwise, it is in the order ``(out_ket, out_bra, in_ket, in_bra)``.
Args:
default_dim: The default dimension of wires with unspecified dimension.
out_in: Whether to return output shapes followed by input shapes or viceversa.
"""

def _sort_shapes(*args):
for arg in args:
if arg:
yield arg

shape_in_ket = [w.dim if w.dim else default_dim for w in self.input.ket.values()]
shape_out_ket = [w.dim if w.dim else default_dim for w in self.output.ket.values()]
shape_in_bra = [w.dim if w.dim else default_dim for w in self.input.bra.values()]
shape_out_bra = [w.dim if w.dim else default_dim for w in self.output.bra.values()]

if out_in:
combined_shape = shape_out_ket + shape_out_bra + shape_in_ket + shape_in_bra
else:
combined_shape = shape_in_ket + shape_in_bra + shape_out_ket + shape_out_bra
ret = _sort_shapes(shape_out_ket, shape_out_bra, shape_in_ket, shape_in_bra)
ret = _sort_shapes(shape_in_ket, shape_in_bra, shape_out_ket, shape_out_bra)

return tuple(combined_shape)
# pylint: disable=consider-using-generator
return tuple([item for sublist in ret for item in sublist])


class AdjointView(Tensor):
Expand All @@ -381,10 +389,10 @@ def __init__(self, tensor):
self._original = tensor
super().__init__(
name=self._original.name,
modes_in_ket=list(self._original.input.bra.keys()),
modes_out_ket=list(self._original.output.bra.keys()),
modes_in_bra=list(self._original.input.ket.keys()),
modes_out_bra=list(self._original.output.ket.keys()),
modes_in_ket=self._original.input.bra.keys(),
modes_out_ket=self._original.output.bra.keys(),
modes_in_bra=self._original.input.ket.keys(),
modes_out_bra=self._original.output.ket.keys(),
)

def value(self, shape: tuple[int]):
Expand All @@ -397,7 +405,12 @@ def value(self, shape: tuple[int]):
ComplexTensor: the unitary matrix in Fock representation
"""
# converting the given shape into a shape for the original tensor
shape_in_ket, shape_out_ket, shape_in_bra, shape_out_bra = self._original.unpack_shape(shape)
(
shape_in_ket,
shape_out_ket,
shape_in_bra,
shape_out_bra,
) = self._original.unpack_shape(shape)
shape_ret = shape_in_bra + shape_out_bra + shape_in_ket + shape_out_ket

ret = math.conj(math.astensor(self._original.value(shape_ret)))
Expand All @@ -413,10 +426,10 @@ def __init__(self, tensor):
self._original = tensor
super().__init__(
name=self._original.name,
modes_in_ket=list(self._original.output.ket.keys()),
modes_out_ket=list(self._original.input.ket.keys()),
modes_in_bra=list(self._original.output.bra.keys()),
modes_out_bra=list(self._original.input.bra.keys()),
modes_in_ket=self._original.output.ket.keys(),
modes_out_ket=self._original.input.ket.keys(),
modes_in_bra=self._original.output.bra.keys(),
modes_out_bra=self._original.input.bra.keys(),
)

def value(self, shape: tuple[int]):
Expand All @@ -430,6 +443,6 @@ def value(self, shape: tuple[int]):
"""
# converting the given shape into a shape for the original tensor
shape_in_ket, shape_out_ket, shape_in_bra, shape_out_bra = self.unpack_shape(shape)
shape_ret = shape_out_ket + shape_in_ket + shape_out_bra + shape_in_bra
shape_ret = shape_out_ket + shape_in_ket + shape_out_bra, shape_in_bra

return math.conj(self._original.value(shape_ret))

0 comments on commit d53e625

Please sign in to comment.