From b7c7876b6fc4f92fb030932eb0d9db5c3d65b274 Mon Sep 17 00:00:00 2001 From: Amaury <1293565+amaury1729@users.noreply.github.com> Date: Wed, 11 Oct 2023 15:47:29 +0200 Subject: [PATCH] Add has_rule() --- core/src/rules.rs | 37 ++++++++++++++++++++++++++++++------- core/src/smtp/connect.rs | 17 +++-------------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/core/src/rules.rs b/core/src/rules.rs index 9d9b8fc01..8a521a1a2 100644 --- a/core/src/rules.rs +++ b/core/src/rules.rs @@ -25,24 +25,47 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; #[derive(Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] -pub enum Rules { +pub enum Rule { /// Don't perform catch-all check. SkipCatchAll, } #[derive(Debug, Deserialize, Serialize)] -pub struct RulesByDomain { - pub rules: Vec, +struct RulesByDomain { + rules: Vec, } #[derive(Debug, Deserialize, Serialize)] -pub struct AllRules { +struct AllRules { /// Apply rules by domain name, i.e. after the @ symbol. - pub by_domain: HashMap, + by_domain: HashMap, /// Apply rules by the MX host. Since each domain potentially has multiple /// MX records, we match by their suffix. - pub by_mx_suffix: HashMap, + by_mx_suffix: HashMap, } -pub(crate) static ALL_RULES: Lazy = +static ALL_RULES: Lazy = Lazy::new(|| serde_json::from_str::(include_str!("rules.json")).unwrap()); + +fn does_domain_have_rule(domain: &str, rule: &Rule) -> bool { + if let Some(v) = ALL_RULES.by_domain.get(domain) { + return v.rules.contains(rule); + } + + false +} + +fn does_mx_have_rule(host: &str, rule: &Rule) -> bool { + for (k, v) in ALL_RULES.by_mx_suffix.iter() { + if host.ends_with(k) { + return v.rules.contains(rule); + } + } + + false +} + +/// Check if either the domain or the MX host has any given rule. +pub fn has_rule(domain: &str, host: &str, rule: &Rule) -> bool { + does_domain_have_rule(domain, rule) || does_mx_have_rule(host, rule) +} diff --git a/core/src/smtp/connect.rs b/core/src/smtp/connect.rs index 495c166ec..5a1524384 100644 --- a/core/src/smtp/connect.rs +++ b/core/src/smtp/connect.rs @@ -27,10 +27,10 @@ use std::iter; use std::str::FromStr; use std::time::Duration; -use super::{gmail::is_gmail, outlook::is_hotmail, parser, yahoo::is_yahoo}; +use super::parser; use super::{SmtpDetails, SmtpError}; use crate::{ - rules::{Rules, ALL_RULES}, + rules::{has_rule, Rule}, util::{constants::LOG_TARGET, input_output::CheckEmailInput}, }; @@ -223,18 +223,7 @@ async fn smtp_is_catch_all( host: &str, ) -> Result { // Skip catch-all check for known providers. - if let Some(d) = ALL_RULES.by_domain.get(domain) { - if d.rules.contains(&Rules::SkipCatchAll) { - return Ok(false); - } - } - for (key, val) in ALL_RULES.by_mx_suffix.iter() { - if host.ends_with(key) && val.rules.contains(&Rules::SkipCatchAll) { - return Ok(false); - } - } - - if is_gmail(&host) || is_hotmail(&host) || is_yahoo(&host) { + if has_rule(domain, host, &Rule::SkipCatchAll) { return Ok(false); }