Skip to content

Commit

Permalink
Add has_rule()
Browse files Browse the repository at this point in the history
  • Loading branch information
amaury1093 committed Oct 11, 2023
1 parent 105aa8a commit b7c7876
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 21 deletions.
37 changes: 30 additions & 7 deletions core/src/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,47 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub enum Rules {
pub enum Rule {
/// Don't perform catch-all check.
SkipCatchAll,
}

#[derive(Debug, Deserialize, Serialize)]
pub struct RulesByDomain {
pub rules: Vec<Rules>,
struct RulesByDomain {
rules: Vec<Rule>,
}

#[derive(Debug, Deserialize, Serialize)]
pub struct AllRules {
struct AllRules {
/// Apply rules by domain name, i.e. after the @ symbol.
pub by_domain: HashMap<String, RulesByDomain>,
by_domain: HashMap<String, RulesByDomain>,
/// Apply rules by the MX host. Since each domain potentially has multiple
/// MX records, we match by their suffix.
pub by_mx_suffix: HashMap<String, RulesByDomain>,
by_mx_suffix: HashMap<String, RulesByDomain>,
}

pub(crate) static ALL_RULES: Lazy<AllRules> =
static ALL_RULES: Lazy<AllRules> =
Lazy::new(|| serde_json::from_str::<AllRules>(include_str!("rules.json")).unwrap());

fn does_domain_have_rule(domain: &str, rule: &Rule) -> bool {
if let Some(v) = ALL_RULES.by_domain.get(domain) {
return v.rules.contains(rule);
}

false
}

fn does_mx_have_rule(host: &str, rule: &Rule) -> bool {
for (k, v) in ALL_RULES.by_mx_suffix.iter() {
if host.ends_with(k) {
return v.rules.contains(rule);
}
}

false
}

/// Check if either the domain or the MX host has any given rule.
pub fn has_rule(domain: &str, host: &str, rule: &Rule) -> bool {
does_domain_have_rule(domain, rule) || does_mx_have_rule(host, rule)
}
17 changes: 3 additions & 14 deletions core/src/smtp/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ use std::iter;
use std::str::FromStr;
use std::time::Duration;

use super::{gmail::is_gmail, outlook::is_hotmail, parser, yahoo::is_yahoo};
use super::parser;
use super::{SmtpDetails, SmtpError};
use crate::{
rules::{Rules, ALL_RULES},
rules::{has_rule, Rule},
util::{constants::LOG_TARGET, input_output::CheckEmailInput},
};

Expand Down Expand Up @@ -223,18 +223,7 @@ async fn smtp_is_catch_all(
host: &str,
) -> Result<bool, SmtpError> {
// Skip catch-all check for known providers.
if let Some(d) = ALL_RULES.by_domain.get(domain) {
if d.rules.contains(&Rules::SkipCatchAll) {
return Ok(false);
}
}
for (key, val) in ALL_RULES.by_mx_suffix.iter() {
if host.ends_with(key) && val.rules.contains(&Rules::SkipCatchAll) {
return Ok(false);
}
}

if is_gmail(&host) || is_hotmail(&host) || is_yahoo(&host) {
if has_rule(domain, host, &Rule::SkipCatchAll) {
return Ok(false);
}

Expand Down

0 comments on commit b7c7876

Please sign in to comment.