From 8a2d95a30aaac1096cc49bf24b5f50444805dcf0 Mon Sep 17 00:00:00 2001 From: Alastair Holmes <42404303+AlastairHolmes@users.noreply.github.com> Date: Fri, 29 Sep 2023 17:00:33 +0200 Subject: [PATCH] fix: loop_select conditions (PRO-587) (#4061) * fix: loop_select conditions * review improvements * clippy * fixes --- .../src/witness/common/chain_source/shared.rs | 5 +- .../chunked_by_vault/deposit_addresses.rs | 7 +- .../chunked_by_vault/egress_items.rs | 8 +- utilities/src/with_std/loop_select.rs | 179 ++++++++++++++++-- 4 files changed, 172 insertions(+), 27 deletions(-) diff --git a/engine/src/witness/common/chain_source/shared.rs b/engine/src/witness/common/chain_source/shared.rs index 5aa07fa66f..11aa359542 100644 --- a/engine/src/witness/common/chain_source/shared.rs +++ b/engine/src/witness/common/chain_source/shared.rs @@ -3,7 +3,6 @@ use tokio::sync::oneshot; use utilities::{ loop_select, spmc, task_scope::{Scope, OR_CANCEL}, - UnendingStream, }; use crate::witness::common::ExternalChainSource; @@ -54,8 +53,8 @@ where if let Some(response_sender) = request_receiver.next() => { let receiver = sender.receiver(); let _result = response_sender.send((receiver, inner_client.clone())); - }, - let item = inner_stream.next_or_pending() => { + } else disable, + if let Some(item) = inner_stream.next() => { // This branch failing causes `sender` to be dropped, this causes the proxy/duplicate streams to also end. let _result = sender.send(item).await; }, let _ = sender.closed() => { break }, diff --git a/engine/src/witness/common/chunked_chain_source/chunked_by_vault/deposit_addresses.rs b/engine/src/witness/common/chunked_chain_source/chunked_by_vault/deposit_addresses.rs index 7c1829c697..b088af677c 100644 --- a/engine/src/witness/common/chunked_chain_source/chunked_by_vault/deposit_addresses.rs +++ b/engine/src/witness/common/chunked_chain_source/chunked_by_vault/deposit_addresses.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use cf_chains::ChainState; use frame_support::CloneNoBound; use futures::FutureExt; -use futures_core::FusedStream; use futures_util::{stream, StreamExt}; use pallet_cf_ingress_egress::DepositChannelDetails; use state_chain_runtime::PalletInstanceAlias; @@ -11,7 +10,6 @@ use tokio::sync::watch; use utilities::{ loop_select, task_scope::{Scope, OR_CANCEL}, - UnendingStream, }; use crate::{ @@ -206,10 +204,9 @@ where |(mut chain_stream, mut state)| async move { loop_select!( if !state.ready_headers.is_empty() => break Some((state.ready_headers.pop().unwrap(), (chain_stream, state))), - if chain_stream.is_terminated() && state.pending_headers.is_empty() => break None, - let header = chain_stream.next_or_pending() => { + if let Some(header) = chain_stream.next() => { state.add_headers(std::iter::once(header)); - }, + } else disable then if state.pending_headers.is_empty() => break None, let _ = state.receiver.changed().map(|result| result.expect(OR_CANCEL)) => { let pending_headers = std::mem::take(&mut state.pending_headers); state.add_headers(pending_headers); diff --git a/engine/src/witness/common/chunked_chain_source/chunked_by_vault/egress_items.rs b/engine/src/witness/common/chunked_chain_source/chunked_by_vault/egress_items.rs index 69770a4f49..d54e6f486a 100644 --- a/engine/src/witness/common/chunked_chain_source/chunked_by_vault/egress_items.rs +++ b/engine/src/witness/common/chunked_chain_source/chunked_by_vault/egress_items.rs @@ -3,10 +3,9 @@ use std::sync::Arc; use crate::witness::common::chain_source::{ChainClient, ChainStream}; use cf_chains::{Chain, ChainCrypto}; use frame_support::CloneNoBound; -use futures_core::FusedStream; use futures_util::{stream, StreamExt}; use state_chain_runtime::PalletInstanceAlias; -use utilities::{loop_select, task_scope::Scope, UnendingStream}; +use utilities::{loop_select, task_scope::Scope}; use crate::{ state_chain_observer::client::{storage_api::StorageApi, StateChainStreamApi}, @@ -138,15 +137,14 @@ where (chain_stream.fuse(), self.receiver.clone()), |(mut chain_stream, receiver)| async move { loop_select!( - if chain_stream.is_terminated() => break None, - let header = chain_stream.next_or_pending() => { + if let Some(header) = chain_stream.next() => { // Always get the latest tx out ids. // NB: There is a race condition here. If we're not watching for a particular egress id (because our state chain is slow for some reason) at the time // it arrives on external chain, we won't witness it. This is pretty unlikely since the time between the egress id being set on the SC and the tx // being confirmed on the external chain is quite large. We should fix this eventually though. PRO-689 let tx_out_ids = receiver.borrow().clone(); break Some((header.map_data(|header| (header.data, tx_out_ids)), (chain_stream, receiver))) - }, + } else break None, ) }, ) diff --git a/utilities/src/with_std/loop_select.rs b/utilities/src/with_std/loop_select.rs index ef74faa782..0bcd809a79 100644 --- a/utilities/src/with_std/loop_select.rs +++ b/utilities/src/with_std/loop_select.rs @@ -3,10 +3,17 @@ pub use futures::future::ready as internal_ready; #[doc(hidden)] pub use tokio::select as internal_tokio_select; +#[doc(hidden)] +pub fn is_bit_set(mask: u64, bit: u64) -> bool { + mask & (1u64 << bit) == (1u64 << bit) +} + #[macro_export] macro_rules! inner_loop_select { - ({ $($processed:tt)* } let $pattern:pat = $expression:expr => $body:block, $($unprocessed:tt)*) => { + ($disabled_mask:ident, $count:expr, { $($processed:tt)* } let $pattern:pat = $expression:expr => $body:block, $($unprocessed:tt)*) => { $crate::inner_loop_select!( + $disabled_mask, + $count + 1u64, { $($processed)* x = $expression => { @@ -17,8 +24,10 @@ macro_rules! inner_loop_select { $($unprocessed)* ) }; - ({ $($processed:tt)* } if let $pattern:pat = $expression:expr => $body:block, $($unprocessed:tt)*) => { + ($disabled_mask:ident, $count:expr, { $($processed:tt)* } if let $pattern:pat = $expression:expr => $body:block, $($unprocessed:tt)*) => { $crate::inner_loop_select!( + $disabled_mask, + $count + 1u64, { $($processed)* x = $expression => { @@ -30,8 +39,10 @@ macro_rules! inner_loop_select { $($unprocessed)* ) }; - ({ $($processed:tt)* } if let $pattern:pat = $expression:expr => $body:block else break $extra:expr, $($unprocessed:tt)*) => { + ($disabled_mask:ident, $count:expr, { $($processed:tt)* } if let $pattern:pat = $expression:expr => $body:block else break $extra:expr, $($unprocessed:tt)*) => { $crate::inner_loop_select!( + $disabled_mask, + $count + 1u64, { $($processed)* x = $expression => { @@ -43,11 +54,35 @@ macro_rules! inner_loop_select { $($unprocessed)* ) }; - ({ $($processed:tt)* } if $enable_expression:expr => let $pattern:pat = $expression:expr => $body:block, $($unprocessed:tt)*) => { + ($disabled_mask:ident, $count:expr, { $($processed:tt)* } if let $pattern:pat = $expression:expr => $body:block else disable $(then if $disable_break_expression:expr => break $($extra:expr)?)?, $($unprocessed:tt)*) => { + $crate::inner_loop_select!( + $disabled_mask, + $count + 1u64, + { + $($processed)* + x = async { $expression.await } /* async await block ensures $expression is evaluated after condition */, if !$crate::loop_select::is_bit_set($disabled_mask, $count) => { + if let $pattern = x { + $body + } else { + $disabled_mask |= 1u64 << $count; + } + }, + $( + _ = $crate::loop_select::internal_ready(()), if $crate::loop_select::is_bit_set($disabled_mask, $count) && $disable_break_expression => { + break $($extra)? + }, + )? + } + $($unprocessed)* + ) + }; + ($disabled_mask:ident, $count:expr, { $($processed:tt)* } if $enable_expression:expr => let $pattern:pat = $expression:expr => $body:block, $($unprocessed:tt)*) => { $crate::inner_loop_select!( + $disabled_mask, + $count + 1u64, { $($processed)* - x = async { $expression.await }, if $enable_expression => { + x = async { $expression.await } /* async await block ensures $expression is evaluated after condition */, if $enable_expression => { let $pattern = x; $body }, @@ -55,11 +90,13 @@ macro_rules! inner_loop_select { $($unprocessed)* ) }; - ({ $($processed:tt)* } if $enable_expression:expr => if let $pattern:pat = $expression:expr => $body:block, $($unprocessed:tt)*) => { + ($disabled_mask:ident, $count:expr, { $($processed:tt)* } if $enable_expression:expr => if let $pattern:pat = $expression:expr => $body:block, $($unprocessed:tt)*) => { $crate::inner_loop_select!( + $disabled_mask, + $count + 1u64, { $($processed)* - x = async { $expression.await }, if $enable_expression => { + x = async { $expression.await } /* async await block ensures $expression is evaluated after condition */, if $enable_expression => { if let $pattern = x { $body } else { break } @@ -68,11 +105,13 @@ macro_rules! inner_loop_select { $($unprocessed)* ) }; - ({ $($processed:tt)* } if $enable_expression:expr => if let $pattern:pat = $expression:expr => $body:block else break $extra:expr, $($unprocessed:tt)*) => { + ($disabled_mask:ident, $count:expr, { $($processed:tt)* } if $enable_expression:expr => if let $pattern:pat = $expression:expr => $body:block else break $extra:expr, $($unprocessed:tt)*) => { $crate::inner_loop_select!( + $disabled_mask, + $count + 1u64, { $($processed)* - x = async { $expression.await }, if $enable_expression => { + x = async { $expression.await } /* async await block ensures $expression is evaluated after condition */, if $enable_expression => { if let $pattern = x { $body } else { break $extra } @@ -81,8 +120,10 @@ macro_rules! inner_loop_select { $($unprocessed)* ) }; - ({ $($processed:tt)* } if $expression:expr => break $($extra:expr)?, $($unprocessed:tt)*) => { + ($disabled_mask:ident, $count:expr, { $($processed:tt)* } if $expression:expr => break $($extra:expr)?, $($unprocessed:tt)*) => { $crate::inner_loop_select!( + $disabled_mask, + $count + 1u64, { $($processed)* _ = $crate::loop_select::internal_ready(()), if $expression => { @@ -92,7 +133,7 @@ macro_rules! inner_loop_select { $($unprocessed)* ) }; - ({ $($processed:tt)+ }) => { + ($disabled_mask:ident, $count:expr, { $($processed:tt)+ }) => { loop { $crate::loop_select::internal_tokio_select!( $($processed)+ @@ -103,9 +144,11 @@ macro_rules! inner_loop_select { #[macro_export] macro_rules! loop_select { - ($($cases:tt)+) => { - $crate::inner_loop_select!({} $($cases)+) - } + ($($cases:tt)+) => {{ + #[allow(unused, unused_mut)] + let mut disabled_mask = 0u64; + $crate::inner_loop_select!(disabled_mask, 0u64, {} $($cases)+) + }} } #[cfg(test)] @@ -438,4 +481,112 @@ mod test_loop_select { } ); } + + #[allow(clippy::unit_cmp)] + #[tokio::test] + async fn disabled_branches() { + // Break condition works + + assert_eq!( + 'c', + loop_select!( + if let 'a' = futures::future::ready('b') => { panic!() } else disable then if true => break 'c', + ) + ); + assert_eq!( + (), + loop_select!( + if let 'a' = futures::future::ready('b') => { panic!() } else disable then if true => break, + ) + ); + + // Disabled conditions don't run + + { + let mut condition_run = false; + + assert_eq!( + (), + loop_select!( + if let 2 = futures::future::ready({ + if !condition_run { + condition_run = true; + 1 + } else { + panic!() + } + }) => { + panic!() + } else disable, + if condition_run => break, + ) + ); + } + { + let mut condition_run = false; + + assert_eq!( + (), + loop_select!( + if let 2 = futures::future::ready({ + if !condition_run { + condition_run = true; + 1 + } else { + panic!() + } + }) => { + panic!() + } else disable then if condition_run => break, + ) + ); + } + + // Disabled branches don't run + + { + let mut condition_run = false; + + assert_eq!( + (), + loop_select!( + if let false = futures::future::ready(condition_run) => { + if condition_run { + panic!() + } else { + condition_run = true; + } + } else disable then if condition_run => break, + ) + ); + } + + // Branches run until disabled + + { + let mut i = 0; + assert_eq!( + (), + loop_select!( + if let 0..=10 = futures::future::ready(i) => { + i += 1; + } else disable then if true => break, + ) + ); + assert_eq!(i, 11); + } + { + let mut i = 0; + assert_eq!( + (), + loop_select!( + if let 0..=10 = futures::future::ready(i) => { + i += 1; + } else disable, + if i == 11 => break, + ) + ); + assert_eq!(i, 11); + } + } }