Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Python] Enhancements to the custom operations feature #1968

Merged
merged 3 commits into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions python/cudaq/kernel/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,23 +188,40 @@ def supportCommonCast(mlirType, otherTy, arg, FromType, ToType, PyType):

def __generalCustomOperation(self, opName, *args):
"""
Utility function for adding a generic quantum operation to the MLIR representation for the PyKernel.
Utility function for adding a generic quantum operation to the MLIR
representation for the PyKernel.

A controlled version can be invoked by passing additional arguments
to the operation. For an N-qubit operation, the last N arguments are
treated as `targets` and excess arguments as `controls`.
"""

global globalRegisteredOperations
unitary = globalRegisteredOperations[opName]

numTargets = int(np.log2(np.sqrt(unitary.size)))

targets = []
qubits = []
with self.insertPoint, self.loc:
for arg in args:
if isinstance(arg, QuakeValue):
targets.append(arg.mlirValue)
qubits.append(arg.mlirValue)
else:
emitFatalError(f"invalid argument type passed to {opName}.")

assert (numTargets == len(targets))
targets = []
controls = []

if numTargets == len(qubits):
targets = qubits
elif numTargets < len(qubits):
numControls = len(qubits) - numTargets
targets = qubits[-numTargets:]
controls = qubits[:numControls]
else:
emitFatalError(
f"too few arguments passed to {opName}, expected ({numTargets})"
)
khalatepradnya marked this conversation as resolved.
Show resolved Hide resolved

globalName = f'{nvqppPrefix}{opName}_generator_{numTargets}.rodata'
currentST = SymbolTable(self.module.operation)
Expand All @@ -216,7 +233,7 @@ def __generalCustomOperation(self, opName, *args):
quake.CustomUnitarySymbolOp([],
generator=FlatSymbolRefAttr.get(globalName),
parameters=[],
controls=[],
controls=controls,
targets=targets,
is_adj=False)
return
Expand Down Expand Up @@ -1520,6 +1537,11 @@ def getListType(eleType: type):

cudaq_runtime.pyAltLaunchKernel(self.name, self.module, *processedArgs)

def __getattr__(self, attr_name):
if hasattr(self, attr_name):
return getattr(self, attr_name)
raise AttributeError(f"'{attr_name}' is not supported on PyKernel")


setattr(PyKernel, 'h', partialmethod(__singleTargetOperation, 'h'))
setattr(PyKernel, 'x', partialmethod(__singleTargetOperation, 'x'))
Expand Down
16 changes: 7 additions & 9 deletions python/cudaq/kernel/register_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,17 @@ def kernel():
if isinstance(unitary, Callable):
raise RuntimeError("parameterized custom operations not yet supported.")

if isinstance(unitary, np.ndarray):
if (len(unitary.shape) != unitary.ndim):
raise RuntimeError(
"provide a 1D array for the matrix representation in row-major format."
)
matrix = unitary
elif isinstance(unitary, List):
if isinstance(unitary, np.matrix) or isinstance(unitary, List):
matrix = np.array(unitary)
elif isinstance(unitary, np.ndarray):
matrix = unitary
else:
raise RuntimeError("unknown type of unitary.")

# TODO: Flatten the matrix if not flattened
assert (matrix.ndim == len(matrix.shape))
matrix = matrix.flatten()
assert (
matrix.ndim == len(matrix.shape),
"provide a 1D array for the matrix representation in row-major format.")

# Size must be a power of 2
assert (matrix.size != 0)
Expand Down
15 changes: 14 additions & 1 deletion python/tests/custom/test_custom_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def kernel():


def test_builder_mode():
"""Builder-mode API """
"""Builder-mode API"""

kernel = cudaq.make_kernel()
cudaq.register_operation("custom_h",
Expand All @@ -192,6 +192,19 @@ def test_builder_mode():
check_bell(kernel)


def test_builder_mode_control():
"""Controlled operation in builder-mode"""

kernel = cudaq.make_kernel()
cudaq.register_operation("custom_x", np.array([0, 1, 1, 0]))

qubits = kernel.qalloc(2)
kernel.h(qubits[0])
kernel.custom_x(qubits[0], qubits[1])

check_bell(kernel)


def test_invalid_ctrl():
cudaq.register_operation("custom_x", np.array([0, 1, 1, 0]))

Expand Down
85 changes: 85 additions & 0 deletions python/tests/mlir/custom_op_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# ============================================================================ #
# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

# RUN: PYTHONPATH=../../ pytest -rP %s | FileCheck %s

import numpy as np
import cudaq


def test_builder_look_up():
"""A custom operation can be looked up by its name in builder mode"""

base_name = 'foo'
op_count = 3

def register_custom_operations(matrix):
prev = np.identity(2)
for t in range(op_count):
new = prev @ matrix
cudaq.register_operation(f'{base_name}_{t}', new)
prev = new

register_custom_operations(
np.array([[1, 0], [0, np.exp(np.pi * 1j * 1 / 3)]]))

kernel = cudaq.make_kernel()

qubit = kernel.qalloc(1)
ancilla = kernel.qalloc(2)

kernel.x(qubit)
kernel.h(ancilla)

for i in range(op_count):
kernel.__getattr__(f'{base_name}_{i}')(ancilla, qubit)

print(kernel)
counts = cudaq.sample(kernel)


# CHECK-LABEL: func.func @__nvqpp__mlirgen____nvqppBuilderKernel_{{.*}}() attributes {"cudaq-entrypoint"} {
# CHECK: %[[VAL_0:.*]] = arith.constant 2 : i64
# CHECK: %[[VAL_1:.*]] = arith.constant 1 : i64
# CHECK: %[[VAL_2:.*]] = arith.constant 0 : i64
# CHECK: %[[VAL_3:.*]] = quake.alloca !quake.veq<1>
# CHECK: %[[VAL_4:.*]] = quake.alloca !quake.veq<2>
# CHECK: %[[VAL_5:.*]] = cc.loop while ((%[[VAL_6:.*]] = %[[VAL_2]]) -> (i64)) {
# CHECK: %[[VAL_7:.*]] = arith.cmpi slt, %[[VAL_6]], %[[VAL_1]] : i64
# CHECK: cc.condition %[[VAL_7]](%[[VAL_6]] : i64)
# CHECK: } do {
# CHECK: ^bb0(%[[VAL_8:.*]]: i64):
# CHECK: %[[VAL_9:.*]] = quake.extract_ref %[[VAL_3]]{{\[}}%[[VAL_8]]] : (!quake.veq<1>, i64) -> !quake.ref
# CHECK: quake.x %[[VAL_9]] : (!quake.ref) -> ()
# CHECK: cc.continue %[[VAL_8]] : i64
# CHECK: } step {
# CHECK: ^bb0(%[[VAL_10:.*]]: i64):
# CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_1]] : i64
# CHECK: cc.continue %[[VAL_11]] : i64
# CHECK: } {invariant}
# CHECK: %[[VAL_12:.*]] = cc.loop while ((%[[VAL_13:.*]] = %[[VAL_2]]) -> (i64)) {
# CHECK: %[[VAL_14:.*]] = arith.cmpi slt, %[[VAL_13]], %[[VAL_0]] : i64
# CHECK: cc.condition %[[VAL_14]](%[[VAL_13]] : i64)
# CHECK: } do {
# CHECK: ^bb0(%[[VAL_15:.*]]: i64):
# CHECK: %[[VAL_16:.*]] = quake.extract_ref %[[VAL_4]]{{\[}}%[[VAL_15]]] : (!quake.veq<2>, i64) -> !quake.ref
# CHECK: quake.h %[[VAL_16]] : (!quake.ref) -> ()
# CHECK: cc.continue %[[VAL_15]] : i64
# CHECK: } step {
# CHECK: ^bb0(%[[VAL_17:.*]]: i64):
# CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_1]] : i64
# CHECK: cc.continue %[[VAL_18]] : i64
# CHECK: } {invariant}
# CHECK: quake.custom_op @__nvqpp__mlirgen__foo_0_generator_1.rodata {{\[}}%[[VAL_4]]] %[[VAL_3]] : (!quake.veq<2>, !quake.veq<1>) -> ()
# CHECK: quake.custom_op @__nvqpp__mlirgen__foo_1_generator_1.rodata {{\[}}%[[VAL_4]]] %[[VAL_3]] : (!quake.veq<2>, !quake.veq<1>) -> ()
# CHECK: quake.custom_op @__nvqpp__mlirgen__foo_2_generator_1.rodata {{\[}}%[[VAL_4]]] %[[VAL_3]] : (!quake.veq<2>, !quake.veq<1>) -> ()
# CHECK: return
# CHECK: }
# CHECK-DAG: cc.global constant @__nvqpp__mlirgen__foo_0_generator_1.rodata (dense<[{{.*}}]> : tensor<4xcomplex<f64>>) : !cc.array<complex<f64> x 4>
# CHECK-DAG: cc.global constant @__nvqpp__mlirgen__foo_1_generator_1.rodata (dense<[{{.*}}]> : tensor<4xcomplex<f64>>) : !cc.array<complex<f64> x 4>
# CHECK-DAG: cc.global constant @__nvqpp__mlirgen__foo_2_generator_1.rodata (dense<[{{.*}}]> : tensor<4xcomplex<f64>>) : !cc.array<complex<f64> x 4>
Loading