Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider cases when ExpressionConvertible returns float, int, duration #89

Merged
merged 8 commits into from
Jun 17, 2024
22 changes: 13 additions & 9 deletions oqpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def expr_matches(a: Any, b: Any) -> bool:
class ExpressionConvertible(Protocol):
"""This is the protocol an object can implement in order to be usable as an expression."""

def _to_oqpy_expression(self) -> HasToAst:
def _to_oqpy_expression(self) -> AstConvertible:
... # pragma: no cover


Expand Down Expand Up @@ -381,12 +381,17 @@ class OQPyBinaryExpression(OQPyExpression):

def __init__(
self,
op: ast.BinaryOperator,
op: ast.BinaryOperator | str,
lhs: AstConvertible,
rhs: AstConvertible,
ast_type: ast.ClassicalType | None = None,
):
super().__init__()
if isinstance(op, str):
try:
op = ast.BinaryOperator[op]
except KeyError as e:
raise ValueError(f"Invalid binary operator {op}") from e
self.op = op
self.lhs = lhs
self.rhs = rhs
Expand All @@ -398,7 +403,9 @@ def __init__(
elif isinstance(rhs, OQPyExpression):
ast_type = rhs.type
else:
raise TypeError("Neither lhs nor rhs is an expression?")
raise TypeError(
"Cannot infer ast_type from lhs or rhs. Please provide it if possible."
)
self.type = ast_type

# Adding floats to durations is not allowed. So we promote types as necessary.
Expand Down Expand Up @@ -470,17 +477,14 @@ def to_ast(self, program: Program) -> ast.Expression:
def to_ast(program: Program, item: AstConvertible) -> ast.Expression:
"""Convert an object to an AST node."""
if hasattr(item, "_to_oqpy_expression"):
item = cast(ExpressionConvertible, item)
return item._to_oqpy_expression().to_ast(program)
item = cast(ExpressionConvertible, item)._to_oqpy_expression()
if hasattr(item, "_to_cached_oqpy_expression"):
item = cast(CachedExpressionConvertible, item)
if item._oqpy_cache_key is None:
item._oqpy_cache_key = uuid.uuid1()
if item._oqpy_cache_key not in program.expr_cache:
program.expr_cache[item._oqpy_cache_key] = item._to_cached_oqpy_expression().to_ast(
program
)
return program.expr_cache[item._oqpy_cache_key]
program.expr_cache[item._oqpy_cache_key] = item._to_cached_oqpy_expression()
item = program.expr_cache[item._oqpy_cache_key]
if isinstance(item, (complex, np.complexfloating)):
if item.imag == 0:
return to_ast(program, item.real)
Expand Down
8 changes: 4 additions & 4 deletions oqpy/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,16 @@ def convert_float_to_duration(time: AstConvertible, require_nonnegative: bool =
require_nonnegative: if True, raise an exception if the time value is known to
be negative.
"""
if isinstance(time, (float, int)):
if require_nonnegative and time < 0:
raise ValueError(f"Expected a non-negative duration, but got {time}")
return OQDurationLiteral(time)
if hasattr(time, "_to_oqpy_expression"):
time = cast(ExpressionConvertible, time)
time = time._to_oqpy_expression()
if hasattr(time, "_to_cached_oqpy_expression"):
time = cast(CachedExpressionConvertible, time)
time = time._to_cached_oqpy_expression()
if isinstance(time, (float, int)):
if require_nonnegative and time < 0:
raise ValueError(f"Expected a non-negative duration, but got {time}")
return OQDurationLiteral(time)
if isinstance(time, OQPyExpression):
if isinstance(time.type, (ast.UintType, ast.IntType, ast.FloatType)):
time = time * OQDurationLiteral(1)
Expand Down
24 changes: 20 additions & 4 deletions tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import oqpy
from oqpy import *
from oqpy.base import OQPyExpression, expr_matches, logical_and, logical_or
from oqpy.base import OQPyBinaryExpression, OQPyExpression, expr_matches, logical_and, logical_or
from oqpy.classical_types import OQIndexExpression
from oqpy.quantum_types import PhysicalQubits
from oqpy.timing import OQDurationLiteral
Expand Down Expand Up @@ -421,6 +421,7 @@ def test_binary_expressions():
prog.set(d, 5e-9 - d)
prog.set(d, d + convert_float_to_duration(10e-9))
prog.set(f, d / convert_float_to_duration(1))
prog.set(k, OQPyBinaryExpression("+", 2, k))

with pytest.raises(ValueError):
prog.set(f, "a" * i)
Expand All @@ -436,6 +437,8 @@ def test_binary_expressions():
prog.set(d, 5j / d)
with pytest.raises(TypeError):
prog.set(d, 5j * d)
with pytest.raises(ValueError):
OQPyBinaryExpression(".", d, d)

expected = textwrap.dedent(
"""
Expand Down Expand Up @@ -479,6 +482,7 @@ def test_binary_expressions():
d = 5.0ns - d;
d = d + 10.0ns;
f = d / 1s;
k = 2 + k;
"""
).strip()

Expand Down Expand Up @@ -1080,6 +1084,7 @@ def test_invalid_extern_declaration():
with pytest.raises(Exception, match="Argument.*"):
_ = declare_extern("invalid", [int32])


def test_defcals():
prog = Program()
constant = declare_waveform_generator("constant", [("length", duration), ("iq", complex128)])
Expand Down Expand Up @@ -1574,29 +1579,42 @@ class A:

def _to_oqpy_expression(self):
return DurationVar(1e-7, self.name)

@dataclass
class B:
name: str

def _to_oqpy_expression(self):
return FloatVar(1e-7, self.name)

@dataclass
class C:
def _to_oqpy_expression(self):
return 1e-7

def __rmul__(self, other):
return other * self._to_oqpy_expression()

frame = FrameVar(name="f1")
prog = Program()
prog.set(A("a1"), 2)
prog.set(FloatVar(name="c1"), 3 * C())
prog.delay(A("a2"), frame)
prog.delay(B("b1"), frame)
prog.delay(C(), frame)
expected = textwrap.dedent(
"""
OPENQASM 3.0;
duration a1 = 100.0ns;
float[64] c1;
duration a2 = 100.0ns;
frame f1;
float[64] b1 = 1e-07;
a1 = 2;
c1 = 3e-07;
delay[a2] f1;
delay[b1 * 1s] f1;
delay[100.0ns] f1;
"""
).strip()
assert prog.to_qasm() == expected
Expand Down Expand Up @@ -2606,5 +2624,3 @@ def test_box_with_negative_duration():
with pytest.raises(ValueError, match="Expected a non-negative duration, but got -4e-09"):
with Box(prog, -4e-9):
pass


Loading