-
Notifications
You must be signed in to change notification settings - Fork 0
/
modular-exponential-gate
101 lines (88 loc) · 3.79 KB
/
modular-exponential-gate
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Defines the modular exponential gate used in Shor's algorithm."""
class ModularExp(cirq.ArithmeticGate):
"""Quantum modular exponentiation.
This class represents the unitary which multiplies base raised to exponent
into the target modulo the given modulus. More precisely, it represents the
unitary V which computes modular exponentiation x**e mod n:
V|y⟩|e⟩ = |y * x**e mod n⟩ |e⟩ 0 <= y < n
V|y⟩|e⟩ = |y⟩ |e⟩ n <= y
where y is the target register, e is the exponent register, x is the base
and n is the modulus. Consequently,
V|y⟩|e⟩ = (U**e|y)|e⟩
where U is the unitary defined as
U|y⟩ = |y * x mod n⟩ 0 <= y < n
U|y⟩ = |y⟩ n <= y
"""
def __init__(
self,
target: Sequence[int],
exponent: Union[int, Sequence[int]],
base: int,
modulus: int
) -> None:
if len(target) < modulus.bit_length():
raise ValueError(
f'Register with {len(target)} qubits is too small for modulus'
f' {modulus}'
)
self.target = target
self.exponent = exponent
self.base = base
self.modulus = modulus
def registers(self) -> Sequence[Union[int, Sequence[int]]]:
return self.target, self.exponent, self.base, self.modulus
def with_registers(
self, *new_registers: Union[int, Sequence[int]]
) -> 'ModularExp':
"""Returns a new ModularExp object with new registers."""
if len(new_registers) != 4:
raise ValueError(
f'Expected 4 registers (target, exponent, base, '
f'modulus), but got {len(new_registers)}'
)
target, exponent, base, modulus = new_registers
if not isinstance(target, Sequence):
raise ValueError(
f'Target must be a qubit register, got {type(target)}'
)
if not isinstance(base, int):
raise ValueError(
f'Base must be a classical constant, got {type(base)}'
)
if not isinstance(modulus, int):
raise ValueError(
f'Modulus must be a classical constant, got {type(modulus)}'
)
return ModularExp(target, exponent, base, modulus)
def apply(self, *register_values: int) -> int:
"""Applies modular exponentiation to the registers.
Four values should be passed in. They are, in order:
- the target
- the exponent
- the base
- the modulus
Note that the target and exponent should be qubit
registers, while the base and modulus should be
constant parameters that control the resulting unitary.
"""
assert len(register_values) == 4
target, exponent, base, modulus = register_values
if target >= modulus:
return target
return (target * base**exponent) % modulus
def _circuit_diagram_info_(
self, args: cirq.CircuitDiagramInfoArgs
) -> cirq.CircuitDiagramInfo:
"""Returns a 'CircuitDiagramInfo' object for printing circuits.
This function just returns information on how to print this operation
out in a circuit diagram so that the registers are labeled
appropriately as exponent ('e') and target ('t').
"""
assert args.known_qubits is not None
wire_symbols = [f't{i}' for i in range(len(self.target))]
e_str = str(self.exponent)
if isinstance(self.exponent, Sequence):
e_str = 'e'
wire_symbols += [f'e{i}' for i in range(len(self.exponent))]
wire_symbols[0] = f'ModularExp(t*{self.base}**{e_str} % {self.modulus})'
return cirq.CircuitDiagramInfo(wire_symbols=tuple(wire_symbols))