diff --git a/lints/missing_owner_check/src/lib.rs b/lints/missing_owner_check/src/lib.rs index ba0f62b..1bdfa48 100644 --- a/lints/missing_owner_check/src/lib.rs +++ b/lints/missing_owner_check/src/lib.rs @@ -181,24 +181,6 @@ fn is_safe_to_account_info<'tcx>(cx: &LateContext<'tcx>, expr: &'tcx Expr<'tcx>) } } -/// if `expr` is a method call of `def_path` return the receiver else None -fn is_expr_method_call<'tcx>( - cx: &LateContext<'tcx>, - expr: &Expr<'tcx>, - def_path: &[&str], -) -> Option<&'tcx Expr<'tcx>> { - if_chain! { - if let ExprKind::MethodCall(_, recv, _, _) = expr.kind; - if let Some(def_id) = cx.typeck_results().type_dependent_def_id(expr.hir_id); - if match_def_path(cx, def_id, def_path); - then { - Some(recv) - } else { - None - } - } -} - fn contains_owner_use<'tcx>( cx: &LateContext<'tcx>, body: &'tcx Body<'tcx>, @@ -209,19 +191,26 @@ fn contains_owner_use<'tcx>( }) } -/// Checks if `expr` is references `field` on `account_expr` -fn uses_given_field<'tcx>( +fn contains_key_check<'tcx>( + cx: &LateContext<'tcx>, + body: &'tcx Body<'tcx>, + account_expr: &Expr<'tcx>, +) -> bool { + visit_expr_no_bodies(body.value, |expr| compares_key(cx, expr, account_expr)) +} + +fn compares_key<'tcx>( cx: &LateContext<'tcx>, expr: &Expr<'tcx>, account_expr: &Expr<'tcx>, - field: &str, ) -> bool { if_chain! { - if let ExprKind::Field(object, field_name) = expr.kind; - // TODO: add check for key, is_signer - if field_name.as_str() == field; - let mut spanless_eq = SpanlessEq::new(cx); - if spanless_eq.eq_expr(account_expr, object); + // check if the expr is comparison expression + if let ExprKind::Binary(op, lhs, rhs) = expr.kind; + // == or != + if matches!(op.node, BinOpKind::Eq | BinOpKind::Ne); + // check if lhs or rhs accesses key of `account_expr` + if expr_accesses_key(cx, lhs, account_expr) || expr_accesses_key(cx, rhs, account_expr); then { true } else { @@ -230,6 +219,17 @@ fn uses_given_field<'tcx>( } } +// Return true if the expr access key of account_expr(AccountInfo) +fn expr_accesses_key<'tcx>( + cx: &LateContext<'tcx>, + expr: &Expr<'tcx>, + account_expr: &Expr<'tcx>, +) -> bool { + // Anchor AccountInfo: `.key()` and Solana AccountInfo: `.key` field. + calls_method_on_expr(cx, expr, account_expr, &paths::ANCHOR_LANG_KEY) + || uses_given_field(cx, expr, account_expr, "key") +} + /// Checks if `expr` is a method call of `path` on `account_expr` fn calls_method_on_expr<'tcx>( cx: &LateContext<'tcx>, @@ -251,41 +251,41 @@ fn calls_method_on_expr<'tcx>( } } -// Return true if the expr access key of account_expr(AccountInfo) -fn expr_accesses_key<'tcx>( +/// Checks if `expr` is references `field` on `account_expr` +fn uses_given_field<'tcx>( cx: &LateContext<'tcx>, expr: &Expr<'tcx>, account_expr: &Expr<'tcx>, + field: &str, ) -> bool { - // Anchor AccountInfo: `.key()` and Solana AccountInfo: `.key` field. - calls_method_on_expr(cx, expr, account_expr, &paths::ANCHOR_LANG_KEY) - || uses_given_field(cx, expr, account_expr, "key") -} - -fn contains_key_check<'tcx>( - cx: &LateContext<'tcx>, - body: &'tcx Body<'tcx>, - account_expr: &Expr<'tcx>, -) -> bool { - visit_expr_no_bodies(body.value, |expr| compares_key(cx, expr, account_expr)) + if_chain! { + if let ExprKind::Field(object, field_name) = expr.kind; + // TODO: add check for key, is_signer + if field_name.as_str() == field; + let mut spanless_eq = SpanlessEq::new(cx); + if spanless_eq.eq_expr(account_expr, object); + then { + true + } else { + false + } + } } -fn compares_key<'tcx>( +/// if `expr` is a method call of `def_path` return the receiver else None +fn is_expr_method_call<'tcx>( cx: &LateContext<'tcx>, expr: &Expr<'tcx>, - account_expr: &Expr<'tcx>, -) -> bool { + def_path: &[&str], +) -> Option<&'tcx Expr<'tcx>> { if_chain! { - // check if the expr is comparison expression - if let ExprKind::Binary(op, lhs, rhs) = expr.kind; - // == or != - if matches!(op.node, BinOpKind::Eq | BinOpKind::Ne); - // check if lhs or rhs accesses key of `account_expr` - if expr_accesses_key(cx, lhs, account_expr) || expr_accesses_key(cx, rhs, account_expr); + if let ExprKind::MethodCall(_, recv, _, _) = expr.kind; + if let Some(def_id) = cx.typeck_results().type_dependent_def_id(expr.hir_id); + if match_def_path(cx, def_id, def_path); then { - true + Some(recv) } else { - false + None } } }