diff --git a/jormungandr/src/settings/start/config.rs b/jormungandr/src/settings/start/config.rs index da11491000..971c0523b2 100644 --- a/jormungandr/src/settings/start/config.rs +++ b/jormungandr/src/settings/start/config.rs @@ -79,11 +79,14 @@ pub struct Tls { pub struct Cors { /// If none provided, echos request origin #[serde(default)] - pub allowed_origins: Vec, + pub allowed_origins: Vec, /// If none provided, CORS responses won't be cached pub max_age_secs: Option, } +#[derive(Debug, Clone, Default, Serialize, PartialEq, Eq)] +pub struct CorsOrigin(String); + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct P2pConfig { @@ -389,6 +392,49 @@ impl<'de> Deserialize<'de> for InterestLevel { } } +impl<'de> Deserialize<'de> for CorsOrigin { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct CorsOriginVisitor; + impl<'de> Visitor<'de> for CorsOriginVisitor { + type Value = CorsOrigin; + + fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "an origin in format http[s]://example.com[:3000]",) + } + + fn visit_str<'a, E>(self, v: &'a str) -> std::result::Result + where + E: serde::de::Error, + { + use serde::de::Unexpected; + + let uri = warp::http::uri::Uri::from_str(v).map_err(E::custom)?; + if let Some(s) = uri.scheme_str() { + if s != "http" && s != "https" { + return Err(E::invalid_value(Unexpected::Str(v), &self)); + } + } + if let Some(p) = uri.path_and_query() { + if p.as_str() != "/" { + return Err(E::invalid_value(Unexpected::Str(v), &self)); + } + } + Ok(CorsOrigin(v.trim_end_matches('/').to_owned())) + } + } + deserializer.deserialize_str(CorsOriginVisitor) + } +} + +impl AsRef for CorsOrigin { + fn as_ref(&self) -> &str { + &self.0 + } +} + mod filter_level_opt_serde { use super::*;