Skip to content

Commit

Permalink
Support call tracking within Q# code (#1791)
Browse files Browse the repository at this point in the history
This change adds callable tracking APIs to the Diagnostics namespace
that allow Q# code to register for tracking of specific callables and
getting the count of times that callable was invoked. This is useful in
katas and auto-grading scenarios to verify that exercise restrictions
are being appropriately followed.

Fixes #1154
  • Loading branch information
swernli authored Aug 15, 2024
1 parent f02df41 commit 22c80c9
Show file tree
Hide file tree
Showing 6 changed files with 363 additions and 5 deletions.
48 changes: 48 additions & 0 deletions compiler/qsc_eval/src/intrinsic/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1502,3 +1502,51 @@ fn two_qubit_rotation_neg_inf_error() {
&expect!["invalid rotation angle: -inf"],
);
}

#[test]
fn stop_counting_operation_before_start_fails() {
check_intrinsic_output(
"",
indoc! {"{
Std.Diagnostics.StopCountingOperation(I);
}"},
&expect!["callable not counted"],
);
}

#[test]
fn stop_counting_function_before_start_fails() {
check_intrinsic_output(
"",
indoc! {"{
function Foo() : Unit {}
Std.Diagnostics.StopCountingFunction(Foo);
}"},
&expect!["callable not counted"],
);
}

#[test]
fn start_counting_operation_called_twice_before_stop_fails() {
check_intrinsic_output(
"",
indoc! {"{
Std.Diagnostics.StartCountingOperation(I);
Std.Diagnostics.StartCountingOperation(I);
}"},
&expect!["callable already counted"],
);
}

#[test]
fn start_counting_function_called_twice_before_stop_fails() {
check_intrinsic_output(
"",
indoc! {"{
function Foo() : Unit {}
Std.Diagnostics.StartCountingFunction(Foo);
Std.Diagnostics.StartCountingFunction(Foo);
}"},
&expect!["callable already counted"],
);
}
81 changes: 80 additions & 1 deletion compiler/qsc_eval/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use qsc_fir::fir::{
use qsc_fir::ty::Ty;
use qsc_lowerer::map_fir_package_to_hir;
use rand::{rngs::StdRng, SeedableRng};
use rustc_hash::FxHashSet;
use rustc_hash::{FxHashMap, FxHashSet};
use std::ops;
use std::{
cell::RefCell,
Expand All @@ -61,6 +61,18 @@ pub enum Error {
#[diagnostic(code("Qsc.Eval.ArrayTooLarge"))]
ArrayTooLarge(#[label("this array has too many items")] PackageSpan),

#[error("callable already counted")]
#[diagnostic(help(
"counting for a given callable must be stopped before it can be started again"
))]
#[diagnostic(code("Qsc.Eval.CallableAlreadyCounted"))]
CallableAlreadyCounted(#[label] PackageSpan),

#[error("callable not counted")]
#[diagnostic(help("counting for a given callable must be started before it can be stopped"))]
#[diagnostic(code("Qsc.Eval.CallableNotCounted"))]
CallableNotCounted(#[label] PackageSpan),

#[error("invalid array length: {0}")]
#[diagnostic(code("Qsc.Eval.InvalidArrayLength"))]
InvalidArrayLength(i64, #[label("cannot be used as a length")] PackageSpan),
Expand Down Expand Up @@ -150,6 +162,8 @@ impl Error {
pub fn span(&self) -> &PackageSpan {
match self {
Error::ArrayTooLarge(span)
| Error::CallableAlreadyCounted(span)
| Error::CallableNotCounted(span)
| Error::DivZero(span)
| Error::EmptyRange(span)
| Error::IndexOutOfRange(_, span)
Expand Down Expand Up @@ -435,6 +449,8 @@ struct Scope {
frame_id: usize,
}

type CallableCountKey = (StoreItemId, bool, bool);

pub struct State {
exec_graph_stack: Vec<ExecGraph>,
idx: u32,
Expand All @@ -446,6 +462,7 @@ pub struct State {
call_stack: CallStack,
current_span: Span,
rng: RefCell<StdRng>,
call_counts: FxHashMap<CallableCountKey, i64>,
}

impl State {
Expand All @@ -466,6 +483,7 @@ impl State {
call_stack: CallStack::default(),
current_span: Span::default(),
rng,
call_counts: FxHashMap::default(),
}
}

Expand Down Expand Up @@ -962,9 +980,20 @@ impl State {

let spec = spec_from_functor_app(functor);
match &callee.implementation {
CallableImpl::Intrinsic if is_counting_call(&callee.name.name) => {
self.push_frame(Vec::new().into(), callee_id, functor);

let val = self.counting_call(&callee.name.name, arg, arg_span)?;

self.set_val_register(val);
self.leave_frame();
Ok(())
}
CallableImpl::Intrinsic => {
self.push_frame(Vec::new().into(), callee_id, functor);

self.increment_call_count(callee_id, functor);

let name = &callee.name.name;
let val = intrinsic::call(
name,
Expand Down Expand Up @@ -995,6 +1024,7 @@ impl State {
.expect("missing specialization should be a compilation error");
self.push_frame(spec_decl.exec_graph.clone(), callee_id, functor);
self.push_scope(env);
self.increment_call_count(callee_id, functor);

self.bind_args_for_spec(
env,
Expand Down Expand Up @@ -1436,6 +1466,41 @@ impl State {
span,
}
}

fn counting_call(&mut self, name: &str, arg: Value, span: PackageSpan) -> Result<Value, Error> {
let callable = if let Value::Closure(closure) = arg {
make_counting_key(closure.id, closure.functor)
} else {
let callable = arg.unwrap_global();
make_counting_key(callable.0, callable.1)
};
match name {
"StartCountingOperation" | "StartCountingFunction" => {
if self.call_counts.insert(callable, 0).is_some() {
Err(Error::CallableAlreadyCounted(span))
} else {
Ok(Value::unit())
}
}
"StopCountingOperation" | "StopCountingFunction" => {
if let Some(count) = self.call_counts.remove(&callable) {
Ok(Value::Int(count))
} else {
Err(Error::CallableNotCounted(span))
}
}
_ => panic!("unknown counting call"),
}
}

fn increment_call_count(&mut self, callee_id: StoreItemId, functor: FunctorApp) {
if let Some(count) = self
.call_counts
.get_mut(&make_counting_key(callee_id, functor))
{
*count += 1;
}
}
}

pub fn are_ctls_unique(ctls: &[Value], tup: &Value) -> bool {
Expand Down Expand Up @@ -1925,3 +1990,17 @@ fn is_updatable_in_place(env: &Env, expr: &Expr) -> (bool, bool) {
_ => (false, false),
}
}

fn is_counting_call(name: &str) -> bool {
matches!(
name,
"StartCountingOperation"
| "StopCountingOperation"
| "StartCountingFunction"
| "StopCountingFunction"
)
}

fn make_counting_key(id: StoreItemId, functor: FunctorApp) -> CallableCountKey {
(id, functor.adjoint, functor.controlled > 0)
}
4 changes: 2 additions & 2 deletions compiler/qsc_eval/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3724,7 +3724,7 @@ fn controlled_operation_with_duplicate_controls_fails() {
1,
),
item: LocalItemId(
124,
128,
),
},
caller: PackageId(
Expand Down Expand Up @@ -3774,7 +3774,7 @@ fn controlled_operation_with_target_in_controls_fails() {
1,
),
item: LocalItemId(
124,
128,
),
},
caller: PackageId(
Expand Down
2 changes: 1 addition & 1 deletion compiler/qsc_fir/src/fir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ pub enum Global<'a> {
}

/// A unique identifier for an item within a package store.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq)]
pub struct StoreItemId {
/// The package ID.
pub package: PackageId,
Expand Down
130 changes: 130 additions & 0 deletions library/src/tests/diagnostics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,133 @@ fn check_operations_are_equal() {
),
);
}

#[test]
fn check_start_stop_counting_operation_called_3_times() {
test_expression(
"{
import Microsoft.Quantum.Diagnostics.StartCountingOperation;
import Microsoft.Quantum.Diagnostics.StopCountingOperation;
operation op1() : Unit {}
operation op2() : Unit { op1(); }
StartCountingOperation(op1);
StartCountingOperation(op2);
op1(); op1(); op2();
(StopCountingOperation(op1), StopCountingOperation(op2))
}",
&Value::Tuple([Value::Int(3), Value::Int(1)].into()),
);
}

#[test]
fn check_start_stop_counting_operation_called_0_times() {
test_expression(
"{
import Microsoft.Quantum.Diagnostics.StartCountingOperation;
import Microsoft.Quantum.Diagnostics.StopCountingOperation;
operation op1() : Unit {}
operation op2() : Unit { op1(); }
StartCountingOperation(op1);
StartCountingOperation(op2);
(StopCountingOperation(op1), StopCountingOperation(op2))
}",
&Value::Tuple([Value::Int(0), Value::Int(0)].into()),
);
}

#[test]
fn check_lambda_counted_separately_from_operation() {
test_expression(
"{
import Microsoft.Quantum.Diagnostics.StartCountingOperation;
import Microsoft.Quantum.Diagnostics.StopCountingOperation;
operation op1() : Unit {}
StartCountingOperation(op1);
let lambda = () => op1();
StartCountingOperation(lambda);
op1();
lambda();
(StopCountingOperation(op1), StopCountingOperation(lambda))
}",
&Value::Tuple([Value::Int(2), Value::Int(1)].into()),
);
}

#[test]
fn check_multiple_controls_counted_together() {
test_expression(
"{
import Microsoft.Quantum.Diagnostics.StartCountingOperation;
import Microsoft.Quantum.Diagnostics.StopCountingOperation;
operation op1() : Unit is Adj + Ctl {}
StartCountingOperation(Controlled op1);
Controlled op1([], ());
Controlled Controlled op1([], ([], ()));
Controlled Controlled Controlled op1([], ([], ([], ())));
(StopCountingOperation(Controlled op1))
}",
&Value::Int(3),
);
}

#[test]
fn check_counting_operation_differentiates_between_body_adj_ctl() {
test_expression(
"{
import Microsoft.Quantum.Diagnostics.StartCountingOperation;
import Microsoft.Quantum.Diagnostics.StopCountingOperation;
operation op1() : Unit is Adj + Ctl {}
StartCountingOperation(op1);
StartCountingOperation(Adjoint op1);
StartCountingOperation(Controlled op1);
StartCountingOperation(Adjoint Controlled op1);
op1();
Adjoint op1(); Adjoint op1();
Controlled op1([], ()); Controlled op1([], ()); Controlled op1([], ());
Adjoint Controlled op1([], ()); Adjoint Controlled op1([], ());
Controlled Adjoint op1([], ()); Controlled Adjoint op1([], ());
(StopCountingOperation(op1), StopCountingOperation(Adjoint op1), StopCountingOperation(Controlled op1), StopCountingOperation(Adjoint Controlled op1))
}",
&Value::Tuple([Value::Int(1), Value::Int(2), Value::Int(3), Value::Int(4)].into()),
);
}

#[test]
fn check_start_stop_counting_function_called_3_times() {
test_expression(
"{
import Microsoft.Quantum.Diagnostics.StartCountingFunction;
import Microsoft.Quantum.Diagnostics.StopCountingFunction;
function f1() : Unit {}
function f2() : Unit { f1(); }
StartCountingFunction(f1);
StartCountingFunction(f2);
f1(); f1(); f2();
(StopCountingFunction(f1), StopCountingFunction(f2))
}",
&Value::Tuple([Value::Int(3), Value::Int(1)].into()),
);
}

#[test]
fn check_start_stop_counting_function_called_0_times() {
test_expression(
"{
import Microsoft.Quantum.Diagnostics.StartCountingFunction;
import Microsoft.Quantum.Diagnostics.StopCountingFunction;
function f1() : Unit {}
function f2() : Unit { f1(); }
StartCountingFunction(f1);
StartCountingFunction(f2);
(StopCountingFunction(f1), StopCountingFunction(f2))
}",
&Value::Tuple([Value::Int(0), Value::Int(0)].into()),
);
}
Loading

0 comments on commit 22c80c9

Please sign in to comment.