diff --git a/src/duration.rs b/src/duration.rs deleted file mode 100644 index 107780e45..000000000 --- a/src/duration.rs +++ /dev/null @@ -1,393 +0,0 @@ -//! This module contains a common [`Duration`] struct which is able to parse -//! human-readable duration formats, like `2y 2h 20m 42s` or`15d 2m 2s`. It -//! additionally implements many required traits, like [`Derivative`], -//! [`JsonSchema`], [`Deserialize`], and [`Serialize`]. -//! -//! Furthermore, it implements [`Deref`], which enables us to use all associated -//! functions of [`std::time::Duration`] without re-implementing the public -//! functions on our own type. -//! -//! All operators should opt for [`Duration`] instead of the plain -//! [`std::time::Duration`] when dealing with durations of any form, like -//! timeouts or retries. - -use std::{ - fmt::{Display, Write}, - num::ParseIntError, - ops::{Add, AddAssign, Deref, Sub, SubAssign}, - str::FromStr, -}; - -use derivative::Derivative; -use schemars::{ - gen::SchemaGenerator, - schema::{InstanceType, Schema, SchemaObject}, - JsonSchema, -}; -use serde::{de::Visitor, Deserialize, Serialize}; -use strum::Display; -use thiserror::Error; - -const DAYS_IN_MONTH: f64 = 30.436875; -const DAYS_IN_YEAR: f64 = 365.2425; - -const YEARS_FACTOR: u64 = (DAYS_FACTOR as f64 * DAYS_IN_YEAR) as u64; -const MONTHS_FACTOR: u64 = (DAYS_FACTOR as f64 * DAYS_IN_MONTH) as u64; -const WEEKS_FACTOR: u64 = DAYS_FACTOR * 7; -const DAYS_FACTOR: u64 = HOURS_FACTOR * 24; -const HOURS_FACTOR: u64 = MINUTES_FACTOR * 60; -const MINUTES_FACTOR: u64 = SECONDS_FACTOR * 60; -const SECONDS_FACTOR: u64 = 1; - -#[derive(Debug, Error, PartialEq)] -pub enum DurationParseError { - #[error("failed to parse string as number")] - ParseIntError(#[from] ParseIntError), - - #[error("expected a character, found number")] - ExpectedChar, - - #[error("found invalid character")] - InvalidInput, - - #[error("found invalid unit")] - InvalidUnit, - - #[error("number overflow")] - NumberOverflow, -} - -/// A [`Duration`] which is capable of parsing human-readable duration formats, -/// like `2y 2h 20m 42s` or `15d 2m 2s`. It additionally provides many required -/// trait implementations, which makes it suited for use in CRDs for example. -/// -/// The maximum granularity currently supported is **seconds**. Support for -/// milliseconds can be added, when there is the need for it. -#[derive(Clone, Copy, Debug, Derivative, Hash, PartialEq, PartialOrd)] -pub struct Duration(std::time::Duration); - -#[derive(Copy, Clone, Debug, Display)] -enum DurationParseState { - Value, - Space, - Init, - Unit, - End, -} - -impl FromStr for Duration { - type Err = DurationParseError; - - fn from_str(input: &str) -> Result { - let input = input.trim(); - if input.is_empty() || !input.is_ascii() { - return Err(DurationParseError::InvalidInput); - } - - let mut state = DurationParseState::Init; - let mut buffer = String::new(); - let mut iter = input.chars(); - let mut is_unit = false; - let mut cur = 0 as char; - - let mut dur = std::time::Duration::from_secs(0); - let mut val = 0; - - loop { - state = match state { - DurationParseState::Init => match iter.next() { - Some(c) => { - cur = c; - - match c { - '0'..='9' => DurationParseState::Value, - 'a'..='z' => DurationParseState::Unit, - ' ' => DurationParseState::Space, - _ => return Err(DurationParseError::InvalidInput), - } - } - None => DurationParseState::End, - }, - DurationParseState::Value => { - if is_unit { - return Err(DurationParseError::ExpectedChar); - } - - buffer.push(cur); - DurationParseState::Init - } - DurationParseState::Unit => { - if !is_unit { - is_unit = true; - - val = buffer.parse::()?; - buffer.clear(); - } - - buffer.push(cur); - DurationParseState::Init - } - DurationParseState::Space => { - if !is_unit { - return Err(DurationParseError::ExpectedChar); - } - - let factor = parse_unit(&buffer)?; - - dur = dur - .checked_add(std::time::Duration::from_secs(val * factor)) - .ok_or(DurationParseError::NumberOverflow)?; - - is_unit = false; - buffer.clear(); - - DurationParseState::Init - } - DurationParseState::End => { - if !is_unit { - return Err(DurationParseError::ExpectedChar); - } - - let factor = parse_unit(&buffer)?; - - dur = dur - .checked_add(std::time::Duration::from_secs(val * factor)) - .ok_or(DurationParseError::NumberOverflow)?; - - break; - } - } - } - - Ok(Duration(dur)) - } -} - -fn parse_unit(buffer: &str) -> Result { - let factor = match buffer { - "seconds" | "second" | "secs" | "sec" | "s" => SECONDS_FACTOR, - "minutes" | "minute" | "mins" | "min" | "m" => MINUTES_FACTOR, - "hours" | "hour" | "hrs" | "hr" | "h" => HOURS_FACTOR, - "days" | "day" | "d" => DAYS_FACTOR, - "weeks" | "week" | "w" => WEEKS_FACTOR, - "months" | "month" | "M" => MONTHS_FACTOR, - "years" | "year" | "y" => YEARS_FACTOR, - _ => return Err(DurationParseError::InvalidUnit), - }; - - Ok(factor) -} - -struct DurationVisitor; - -impl<'de> Visitor<'de> for DurationVisitor { - type Value = Duration; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a string in any of the supported formats") - } - - fn visit_str(self, v: &str) -> Result - where - E: serde::de::Error, - { - let dur = v.parse::().map_err(serde::de::Error::custom)?; - Ok(dur) - } -} - -impl<'de> Deserialize<'de> for Duration { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(DurationVisitor) - } -} - -impl Serialize for Duration { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_str(&self.to_string()) - } -} - -impl JsonSchema for Duration { - fn schema_name() -> String { - "Duration".into() - } - - fn json_schema(_: &mut SchemaGenerator) -> Schema { - SchemaObject { - instance_type: Some(InstanceType::String.into()), - ..Default::default() - } - .into() - } -} - -impl Display for Duration { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut secs = self.0.as_secs(); - let mut formatted = String::new(); - - for (factor, unit) in [ - (YEARS_FACTOR, "y"), - (MONTHS_FACTOR, "M"), - (DAYS_FACTOR, "d"), - (HOURS_FACTOR, "h"), - (MINUTES_FACTOR, "m"), - (SECONDS_FACTOR, "s"), - ] { - let whole = secs / factor; - let rest = secs % factor; - - if whole > 0 { - write!(formatted, "{}{} ", whole, unit)?; - } - - secs = rest; - } - - write!(f, "{}", formatted.trim_end()) - } -} - -impl Deref for Duration { - type Target = std::time::Duration; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl From for Duration { - fn from(value: std::time::Duration) -> Self { - Self(value) - } -} - -impl Add for Duration { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - Self::from(self.0 + rhs.0) - } -} - -impl Sub for Duration { - type Output = Duration; - - fn sub(self, rhs: Self) -> Self::Output { - Self::from(self.0 - rhs.0) - } -} - -impl AddAssign for Duration { - fn add_assign(&mut self, rhs: Self) { - self.0.add_assign(rhs.0) - } -} - -impl SubAssign for Duration { - fn sub_assign(&mut self, rhs: Self) { - self.0.sub_assign(rhs.0) - } -} - -impl Duration { - /// Creates a new [`Duration`] from the specified number of whole seconds. - pub const fn from_secs(secs: u64) -> Self { - Self(std::time::Duration::from_secs(secs)) - } -} - -#[cfg(test)] -mod test { - use super::*; - use rstest::rstest; - - #[rstest] - #[case("2y 2h 20m 42s", 63122346)] - #[case("15d 2m 2s", 1296122)] - #[case("1h", 3600)] - #[case("1m", 60)] - #[case("1s", 1)] - fn parse(#[case] input: &str, #[case] output: u64) { - let dur: Duration = input.parse().unwrap(); - assert_eq!(dur.as_secs(), output); - } - - #[rstest] - #[case("2y2", DurationParseError::ExpectedChar)] - #[case("-1y", DurationParseError::InvalidInput)] - #[case("1Y", DurationParseError::InvalidInput)] - #[case("1ä", DurationParseError::InvalidInput)] - #[case("1q", DurationParseError::InvalidUnit)] - #[case(" ", DurationParseError::InvalidInput)] - fn parse_invalid(#[case] input: &str, #[case] expected_err: DurationParseError) { - let err = Duration::from_str(input).unwrap_err(); - assert_eq!(err, expected_err) - } - - #[rstest] - #[case("2y 2h 20m 42s")] - #[case("15d 2m 2s")] - #[case("1h")] - #[case("1m")] - #[case("1s")] - fn to_string(#[case] duration: &str) { - let dur: Duration = duration.parse().unwrap(); - assert_eq!(dur.to_string(), duration); - } - - #[test] - fn deserialize() { - #[derive(Deserialize)] - struct S { - dur: Duration, - } - - let s: S = serde_yaml::from_str("dur: 15d 2m 2s").unwrap(); - assert_eq!(s.dur.as_secs(), 1296122); - } - - #[test] - fn serialize() { - #[derive(Serialize)] - struct S { - dur: Duration, - } - - let s = S { - dur: "15d 2m 2s".parse().unwrap(), - }; - assert_eq!(serde_yaml::to_string(&s).unwrap(), "dur: 15d 2m 2s\n"); - } - - #[test] - fn add_ops() { - let mut dur1 = Duration::from_secs(20); - let dur2 = Duration::from_secs(10); - - let dur = dur1 + dur2; - assert_eq!(dur.as_secs(), 30); - - dur1 += dur2; - assert_eq!(dur1.as_secs(), 30); - } - - #[test] - fn sub_ops() { - let mut dur1 = Duration::from_secs(20); - let dur2 = Duration::from_secs(10); - - let dur = dur1 - dur2; - assert_eq!(dur.as_secs(), 10); - - dur1 -= dur2; - assert_eq!(dur1.as_secs(), 10); - } -} diff --git a/src/duration/mod.rs b/src/duration/mod.rs new file mode 100644 index 000000000..d667781e7 --- /dev/null +++ b/src/duration/mod.rs @@ -0,0 +1,335 @@ +//! This module contains a common [`Duration`] struct which is able to parse +//! human-readable duration formats, like `2y 2h 20m 42s` or`15d 2m 2s`. It +//! additionally implements many required traits, like [`Derivative`], +//! [`JsonSchema`], [`Deserialize`], and [`Serialize`]. +//! +//! Furthermore, it implements [`Deref`], which enables us to use all associated +//! functions of [`std::time::Duration`] without re-implementing the public +//! functions on our own type. +//! +//! All operators should opt for [`Duration`] instead of the plain +//! [`std::time::Duration`] when dealing with durations of any form, like +//! timeouts or retries. + +use std::{ + fmt::{Display, Write}, + num::ParseIntError, + ops::{Deref, DerefMut}, + str::FromStr, +}; + +use derivative::Derivative; +use schemars::JsonSchema; +use strum::IntoEnumIterator; +use thiserror::Error; + +mod serde_impl; +pub use serde_impl::*; + +#[derive(Debug, Error, PartialEq)] +pub enum DurationParseError { + #[error("invalid input, either empty or contains non-ascii characters")] + InvalidInput, + + #[error("failed to parse duration fragment")] + FragmentError(#[from] DurationFragmentParseError), +} + +#[derive(Clone, Copy, Debug, Derivative, Hash, PartialEq, PartialOrd, JsonSchema)] +pub struct Duration(std::time::Duration); + +impl FromStr for Duration { + type Err = DurationParseError; + + fn from_str(s: &str) -> Result { + let input = s.trim(); + + // An empty or non-ascii input is invalid + if input.is_empty() || !input.is_ascii() { + return Err(DurationParseError::InvalidInput); + } + + let parts: Vec<&str> = input.split(' ').collect(); + + let values: Vec = parts + .iter() + .map(|p| p.parse::()) + .map(|r| r.map(|f| f.millis())) + .collect::, DurationFragmentParseError>>()?; + + // NOTE (Techassi): This derefernce is super weird, but + // Duration::from_millis doesn't accept a u128, but returns a u128 + // when as_millis is called. + Ok(Self(std::time::Duration::from_millis( + values.iter().fold(0, |acc, v| acc + (*v as u64)), + ))) + } +} + +impl Display for Duration { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.0.is_zero() { + return write!(f, "0{}", DurationUnit::Milliseconds); + } + + let mut millis = self.0.as_millis(); + let mut formatted = String::new(); + + for unit in DurationUnit::iter() { + let whole = millis / unit.millis(); + let rest = millis % unit.millis(); + + if whole > 0 { + write!(formatted, "{}{} ", whole, unit)?; + } + + millis = rest; + } + + write!(f, "{}", formatted.trim_end()) + } +} + +impl Deref for Duration { + type Target = std::time::Duration; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Duration { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From for Duration { + fn from(value: std::time::Duration) -> Self { + Self(value) + } +} + +impl Duration { + /// Creates a new [`Duration`] from the specified number of whole seconds. + pub const fn from_secs(secs: u64) -> Self { + Self(std::time::Duration::from_secs(secs)) + } +} + +#[derive(Debug, strum::EnumString, strum::Display, strum::AsRefStr, strum::EnumIter)] +pub enum DurationUnit { + #[strum(serialize = "d")] + Days, + + #[strum(serialize = "h")] + Hours, + + #[strum(serialize = "m")] + Minutes, + + #[strum(serialize = "s")] + Seconds, + + #[strum(serialize = "ms")] + Milliseconds, +} + +impl DurationUnit { + /// Returns the number of whole milliseconds for each supported + /// [`DurationUnit`]. + pub fn millis(&self) -> u128 { + use DurationUnit::*; + + match self { + Days => 24 * Hours.millis(), + Hours => 60 * Minutes.millis(), + Minutes => 60 * Seconds.millis(), + Seconds => 1000, + Milliseconds => 1, + } + } +} + +#[derive(Debug, Error, PartialEq)] +pub enum DurationFragmentParseError { + #[error("invalid input, either empty or contains non-ascii characters")] + InvalidInput, + + #[error("expected number, the duration fragment must start with a numeric character")] + ExpectedNumber, + + #[error("expected character, the duration fragments must end with an alphabetic character")] + ExpectedCharacter, + + #[error("failed to parse fragment value as integer")] + ParseIntError(#[from] ParseIntError), + + #[error("failed to parse fragment unit")] + UnitParseError, +} + +#[derive(Debug)] +pub struct DurationFragment { + value: u128, + unit: DurationUnit, +} + +impl FromStr for DurationFragment { + type Err = DurationFragmentParseError; + + fn from_str(s: &str) -> Result { + let input = s.trim(); + + // An empty is invalid, non-ascii characters are already ruled out by + // the Duration impl + if input.is_empty() { + return Err(DurationFragmentParseError::InvalidInput); + } + + let mut chars = input.char_indices().peekable(); + let mut end_index = 0; + + // First loop through all numeric characters + while let Some((i, _)) = chars.next_if(|(_, c)| char::is_numeric(*c)) { + end_index = i + 1; + } + + // Parse the numeric characters as a u128 + let value = if end_index != 0 { + s[0..end_index].parse::()? + } else { + return Err(DurationFragmentParseError::ExpectedNumber); + }; + + // Loop through all alphabetic characters + let start_index = end_index; + while let Some((i, _)) = chars.next_if(|(_, c)| char::is_alphabetic(*c)) { + end_index = i + 1; + } + + // Parse the alphabetic characters as a supported duration unit + let unit = if end_index != 0 { + s[start_index..end_index] + .parse::() + .map_err(|_| DurationFragmentParseError::UnitParseError)? + } else { + return Err(DurationFragmentParseError::ExpectedCharacter); + }; + + // If there are characters left which are not alphabetic, we return an + // error + if chars.peek().is_some() { + return Err(DurationFragmentParseError::ExpectedCharacter); + } + + Ok(Self { value, unit }) + } +} + +impl Display for DurationFragment { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}{}", self.value, self.unit) + } +} + +impl DurationFragment { + pub fn millis(&self) -> u128 { + self.value * self.unit.millis() + } +} + +#[cfg(test)] +mod test { + use super::*; + use rstest::rstest; + use serde::{Deserialize, Serialize}; + + #[rstest] + #[case("15d 2m 2s", 1296122)] + #[case("1h", 3600)] + #[case("1m", 60)] + #[case("1s", 1)] + fn parse(#[case] input: &str, #[case] output: u64) { + let dur: Duration = input.parse().unwrap(); + assert_eq!(dur.as_secs(), output); + } + + #[rstest] + #[case( + "2d2", + DurationParseError::FragmentError(DurationFragmentParseError::ExpectedCharacter) + )] + #[case( + "-1y", + DurationParseError::FragmentError(DurationFragmentParseError::ExpectedNumber) + )] + #[case( + "1D", + DurationParseError::FragmentError(DurationFragmentParseError::UnitParseError) + )] + #[case("1ä", DurationParseError::InvalidInput)] + #[case(" ", DurationParseError::InvalidInput)] + fn parse_invalid(#[case] input: &str, #[case] expected_err: DurationParseError) { + let err = Duration::from_str(input).unwrap_err(); + assert_eq!(err, expected_err) + } + + #[rstest] + #[case("15d 2m 2s")] + #[case("1h 20m")] + #[case("1m")] + #[case("1s")] + fn to_string(#[case] duration: &str) { + let dur: Duration = duration.parse().unwrap(); + assert_eq!(dur.to_string(), duration); + } + + #[test] + fn deserialize() { + #[derive(Deserialize)] + struct S { + dur: Duration, + } + + let s: S = serde_yaml::from_str("dur: 15d 2m 2s").unwrap(); + assert_eq!(s.dur.as_secs(), 1296122); + } + + #[test] + fn serialize() { + #[derive(Serialize)] + struct S { + dur: Duration, + } + + let s = S { + dur: "15d 2m 2s".parse().unwrap(), + }; + assert_eq!(serde_yaml::to_string(&s).unwrap(), "dur: 15d 2m 2s\n"); + } + + // #[test] + // fn add_ops() { + // let mut dur1 = Duration::from_secs(20); + // let dur2 = Duration::from_secs(10); + + // let dur = dur1 + dur2; + // assert_eq!(dur.as_secs(), 30); + + // dur1 += dur2; + // assert_eq!(dur1.as_secs(), 30); + // } + + // #[test] + // fn sub_ops() { + // let mut dur1 = Duration::from_secs(20); + // let dur2 = Duration::from_secs(10); + + // let dur = dur1 - dur2; + // assert_eq!(dur.as_secs(), 10); + + // dur1 -= dur2; + // assert_eq!(dur1.as_secs(), 10); + // } +} diff --git a/src/duration/serde_impl.rs b/src/duration/serde_impl.rs new file mode 100644 index 000000000..04f1b43f5 --- /dev/null +++ b/src/duration/serde_impl.rs @@ -0,0 +1,39 @@ +use serde::{de::Visitor, Deserialize, Serialize}; + +use crate::duration::Duration; + +struct DurationVisitor; + +impl<'de> Visitor<'de> for DurationVisitor { + type Value = Duration; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string in any of the supported formats") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + let dur = v.parse::().map_err(serde::de::Error::custom)?; + Ok(dur) + } +} + +impl<'de> Deserialize<'de> for Duration { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_str(DurationVisitor) + } +} + +impl Serialize for Duration { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&self.to_string()) + } +}