diff --git a/Cargo.lock b/Cargo.lock index 763688914b..1e819c3dda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5242,8 +5242,10 @@ checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" name = "infra_utils" version = "0.0.0" dependencies = [ + "pretty_assertions", "rstest", "tokio", + "tracing", ] [[package]] diff --git a/crates/infra_utils/Cargo.toml b/crates/infra_utils/Cargo.toml index 6b42af283d..09fc7fb752 100644 --- a/crates/infra_utils/Cargo.toml +++ b/crates/infra_utils/Cargo.toml @@ -10,7 +10,10 @@ description = "Infrastructure utility." workspace = true [dependencies] -tokio = { workspace = true, features = ["process"] } +tokio = { workspace = true, features = ["process", "time"] } +tracing.workspace = true [dev-dependencies] +pretty_assertions.workspace = true rstest.workspace = true +tokio = { workspace = true, features = ["macros", "rt"] } diff --git a/crates/infra_utils/src/lib.rs b/crates/infra_utils/src/lib.rs index f744151bf9..0d8ac5d50e 100644 --- a/crates/infra_utils/src/lib.rs +++ b/crates/infra_utils/src/lib.rs @@ -1,2 +1,3 @@ pub mod command; pub mod path; +pub mod run_until; diff --git a/crates/infra_utils/src/run_until.rs b/crates/infra_utils/src/run_until.rs new file mode 100644 index 0000000000..a2db3ebe04 --- /dev/null +++ b/crates/infra_utils/src/run_until.rs @@ -0,0 +1,95 @@ +use tokio::time::{sleep, Duration}; +use tracing::{debug, error, info, trace, warn}; + +#[cfg(test)] +#[path = "run_until_test.rs"] +mod run_until_test; + +/// Struct to hold trace configuration +pub struct TraceConfig { + pub level: LogLevel, + pub message: String, +} + +/// Enum for dynamically setting trace level +#[derive(Clone, Copy)] +pub enum LogLevel { + Trace, + Debug, + Info, + Warn, + Error, +} + +/// Runs an asynchronous function until a condition is met or max attempts are reached. +/// +/// # Arguments +/// - `interval`: Time between each attempt (in milliseconds). +/// - `max_attempts`: Maximum number of attempts. +/// - `executable`: An asynchronous function to execute, which returns a value of type `T`. +/// - `condition`: A closure that takes a value of type `T` and returns `true` if the condition is +/// met. +/// - `trace_config`: Optional trace configuration for logging. +/// +/// # Returns +/// - `Option`: Returns `Some(value)` if the condition is met within the attempts, otherwise +/// `None`. +pub async fn run_until( + interval: u64, + max_attempts: usize, + mut executable: F, + condition: C, + trace_config: Option, +) -> Option +where + T: Clone + Send + std::fmt::Debug + 'static, + F: FnMut() -> T + Send, + C: Fn(&T) -> bool + Send + Sync, +{ + for attempt in 1..=max_attempts { + let result = executable(); + + // Log attempt message. + if let Some(config) = &trace_config { + let attempt_message = format!( + "{}: Attempt {}/{}, Value {:?}", + config.message, attempt, max_attempts, result + ); + log_message(config.level, &attempt_message); + } + + // Check if the condition is met. + if condition(&result) { + if let Some(config) = &trace_config { + let success_message = format!( + "{}: Condition met on attempt {}/{}", + config.message, attempt, max_attempts + ); + log_message(config.level, &success_message); + } + return Some(result); + } + + // Wait for the interval before the next attempt. + sleep(Duration::from_millis(interval)).await; + } + + if let Some(config) = &trace_config { + let failure_message = + format!("{}: Condition not met after {} attempts.", config.message, max_attempts); + log_message(config.level, &failure_message); + } + + None +} + +/// Logs a message at the specified level +fn log_message(level: LogLevel, message: &str) { + match level { + LogLevel::Trace => trace!("{}", message), + LogLevel::Debug => debug!("{}", message), + LogLevel::Info => info!("{}", message), + LogLevel::Warn => warn!("{}", message), + LogLevel::Error => error!("{}", message), + } +} diff --git a/crates/infra_utils/src/run_until_test.rs b/crates/infra_utils/src/run_until_test.rs new file mode 100644 index 0000000000..ac0a735a37 --- /dev/null +++ b/crates/infra_utils/src/run_until_test.rs @@ -0,0 +1,46 @@ +use pretty_assertions::assert_eq; +use rstest::rstest; + +use crate::run_until::run_until; + +#[rstest] +#[tokio::test] +async fn test_run_until_condition_met() { + // Mock executable that increments a counter. + let mut counter = 0; + let mock_executable = || { + counter += 1; + counter + }; + + // Condition: stop when the counter reaches 3. + let condition = |&result: &i32| result >= 3; + + // Run the function with a short interval and a maximum of 5 attempts. + let result = run_until(100, 5, mock_executable, condition, None).await; + + // Assert that the condition was met and the returned value is correct. + assert_eq!(result, Some(3)); + assert_eq!(counter, 3); // Counter should stop at 3 since the condition is met. +} + +#[rstest] +#[tokio::test] +async fn test_run_until_condition_not_met() { + // Mock executable that increments a counter. + let mut counter = 0; + let mock_executable = || { + counter += 1; + counter + }; + + // Condition: stop when the counter reaches 3. + let condition = |&result: &i32| result >= 3; + + // Test that it stops when the maximum attempts are exceeded without meeting the condition. + let failed_result = run_until(100, 2, mock_executable, condition, None).await; + + // The condition is not met within 2 attempts, so the result should be None. + assert_eq!(failed_result, None); + assert_eq!(counter, 2); // Counter should stop at 2 because of max attempts. +}