diff --git a/CHANGELOG.md b/CHANGELOG.md index 85f676c..624a496 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,15 @@ +# 0.3.1 + +* Optimize memory usage + # 0.2.0 * add response soa record -* fixed bugs +* fix bugs # 0.1.1 -* fixed mips build error +* fix mips build error # 0.1.0 diff --git a/Cargo.lock b/Cargo.lock index be19657..6992501 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -557,9 +557,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.0.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +checksum = "ad227c3af19d4914570ad36d30409928b75967c298feb9ea1969db3a610bb14e" dependencies = [ "equivalent", "hashbrown 0.14.0", @@ -1309,7 +1309,7 @@ version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ - "indexmap 2.0.0", + "indexmap 2.0.1", "serde", "serde_spanned", "toml_datetime", @@ -1659,7 +1659,7 @@ dependencies = [ [[package]] name = "yadns" -version = "0.2.0" +version = "0.3.1" dependencies = [ "anyhow", "async-recursion", @@ -1670,6 +1670,7 @@ dependencies = [ "ipnet", "iprange", "lazy_static", + "once_cell", "openssl", "publicsuffix", "regex", diff --git a/Cargo.toml b/Cargo.toml index a710719..90eb861 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "yadns" -version = "0.2.0" +version = "0.3.1" authors = ["zyj"] edition = "2021" @@ -36,6 +36,7 @@ anyhow = "1.0.75" futures = "0.3.28" async-recursion = "1.0.5" publicsuffix = "2.2.3" +once_cell = "1.18.0" [target.'cfg(any(target_arch = "mips", target_arch = "mips64"))'.dependencies] trust-dns-resolver = { version = "0.23.0", default-features = false, features = ["dns-over-openssl"], optional = true } diff --git a/src/config.rs b/src/config.rs index 23676e2..88a7961 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,5 +1,4 @@ use crate::ip::IpRange; -use crate::Transpose; use failure::{err_msg, Error}; use ipnet::IpNet; use serde_derive::Deserialize; @@ -315,3 +314,20 @@ pub enum RuleAction { #[serde(rename = "drop")] Drop, } + +trait Transpose { + type Output; + fn transpose(self) -> Self::Output; +} + +impl Transpose for Option> { + type Output = Result, E>; + + fn transpose(self) -> Self::Output { + match self { + Some(Ok(x)) => Ok(Some(x)), + Some(Err(e)) => Err(e), + None => Ok(None), + } + } +} diff --git a/src/filter.rs b/src/filter.rs index 9d146e0..accc857 100644 --- a/src/filter.rs +++ b/src/filter.rs @@ -1,12 +1,18 @@ use crate::{ config::{RequestRule, ResponseRule, RuleAction}, - CONFIG, STDERR, + handler_config::HandlerConfig, + loger::STDERR, }; use slog::debug; use trust_dns_proto::{op::LowerQuery, rr::RecordType}; use trust_dns_resolver::lookup::Lookup; -pub fn check_response(domain: &str, upstream_name: &str, resp: &Lookup) -> RuleAction { +pub fn check_response( + cfg: &HandlerConfig, + domain: &str, + upstream_name: &str, + resp: &Lookup, +) -> RuleAction { let answers = resp.records(); // Drop empty response @@ -31,7 +37,7 @@ pub fn check_response(domain: &str, upstream_name: &str, resp: &Lookup) -> RuleA let toggle = (range_pattern.len() - range_name.len()) % 2 == 1; // See if the range contains the IP - let range = CONFIG.app_config.ranges.get(range_name); + let range = cfg.ranges.get(range_name); range .map(|range| { answers @@ -57,18 +63,16 @@ pub fn check_response(domain: &str, upstream_name: &str, resp: &Lookup) -> RuleA .unwrap_or(true) // No ranges field means matching all ranges }; - CONFIG - .app_config - .response_rules + cfg.response_rules .iter() .find(|rule| { - check_upstream(rule) && check_ranges(rule) && check_domains(domain, &rule.domains) + check_upstream(rule) && check_ranges(rule) && check_domains(cfg, domain, &rule.domains) }) .map(|rule| rule.action) .unwrap_or(RuleAction::Accept) } -pub fn resolvers(query: &LowerQuery) -> Vec<&str> { +pub fn resolvers<'a>(cfg: &'a HandlerConfig, query: &LowerQuery) -> Vec<&'a str> { let name = query.name().to_string(); let check_type = |rule: &RequestRule| { @@ -78,11 +82,10 @@ pub fn resolvers(query: &LowerQuery) -> Vec<&str> { .unwrap_or(true) }; - let rule = CONFIG - .app_config + let rule = cfg .request_rules .iter() - .find(|r| check_domains(&name, &r.domains) && check_type(r)); + .find(|r| check_domains(cfg, &name, &r.domains) && check_type(r)); if let Some(rule) = rule { debug!(STDERR, "Query {} matches rule {:?}", name, rule); @@ -90,16 +93,11 @@ pub fn resolvers(query: &LowerQuery) -> Vec<&str> { } else { debug!(STDERR, "No rule matches for {}. Use defaults.", name); // If no rule matches, use defaults - CONFIG - .app_config - .defaults - .iter() - .map(String::as_str) - .collect() + cfg.defaults.iter().map(String::as_str).collect() } } -fn check_domains(domain: &str, domains: &Option>) -> bool { +fn check_domains(cfg: &HandlerConfig, domain: &str, domains: &Option>) -> bool { let name = domain.trim_end_matches("."); domains .as_ref() @@ -108,7 +106,7 @@ fn check_domains(domain: &str, domains: &Option>) -> bool { // Process the leading `!` let domains_tag = domains_pattern.trim_start_matches('!'); let toggle = (domains_pattern.len() - domains_tag.len()) % 2 == 1; - let domains = CONFIG.app_config.domains.get(domains_tag); + let domains = cfg.domains.get(domains_tag); domains .map(|domains| { (domains.regex_set.is_match(&name) || domains.suffix.contains(&name)) diff --git a/src/handler.rs b/src/handler.rs index f222532..6f675b3 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,10 +1,11 @@ -use crate::{config::RuleAction, filter, CONFIG, STDERR}; +use crate::{config::RuleAction, filter, handler_config::HandlerConfig, loger::STDERR}; use anyhow::Error; use async_recursion::async_recursion; use futures::{ future::{self, MapErr, MapOk}, Future, FutureExt, TryFutureExt, }; +use once_cell::sync::OnceCell; use slog::{debug, error}; use std::pin::Pin; use trust_dns_proto::op::Query; @@ -18,17 +19,27 @@ use trust_dns_server::{ server::{Request, RequestHandler, ResponseHandler, ResponseInfo}, }; +static HANDLER_CONFIG: OnceCell = OnceCell::new(); + +fn handler_config() -> &'static HandlerConfig { + HANDLER_CONFIG + .get() + .expect("HandlerConfig is not initialized") +} + /// DNS Request Handler -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Handler { //pub counter: Arc, } impl Handler { - /// Create default handler. - pub fn default() -> Self { - Handler { - // counter: Arc::new(AtomicU64::new(0)), + /// Create handler from app config. + pub fn new(cfg: HandlerConfig) -> Self { + match HANDLER_CONFIG.set(cfg) { + _ => Handler { + // counter: Arc::new(AtomicU64::new(0)), + }, } } @@ -39,7 +50,7 @@ impl Handler { mut responder: R, ) -> Result { //self.counter.fetch_add(1, Ordering::SeqCst); - let resolvers = filter::resolvers(request.query()); + let resolvers = filter::resolvers(handler_config(), request.query()); let tasks: Vec<_> = resolvers .into_iter() .map(|name| { @@ -48,7 +59,7 @@ impl Handler { let query_type = request.query().query_type().to_owned(); let name1 = name.to_owned(); let name2 = name.to_owned(); - let rs = CONFIG.app_config.resolvers.get(name); + let rs = handler_config().resolvers.get(name); rs.unwrap() .resolve(domain1, query_type) .boxed() @@ -89,7 +100,7 @@ impl Handler { match future::select_all(tasks).await { (Ok((domain, name, resp)), _index, remaining) => { //debug!(STDERR, "DNS {} result {:?}", name, resp); - match filter::check_response(&domain, &name, &resp) { + match filter::check_response(handler_config(), &domain, &name, &resp) { RuleAction::Accept => { // Ignore the remaining future tokio::spawn(future::join_all(remaining).map(|_| ())); diff --git a/src/app_config.rs b/src/handler_config.rs similarity index 95% rename from src/app_config.rs rename to src/handler_config.rs index cfdb713..5d07442 100644 --- a/src/app_config.rs +++ b/src/handler_config.rs @@ -9,7 +9,7 @@ use std::collections::HashMap; use std::sync::Arc; #[derive(Clone, Debug)] -pub struct AppConfig { +pub struct HandlerConfig { pub defaults: Arc>, pub resolvers: Arc>>, pub domains: Arc>, @@ -24,8 +24,8 @@ pub struct Domains { pub suffix: DomainSuffix, } -impl AppConfig { - pub fn new(config: Config) -> Self { +impl From for HandlerConfig { + fn from(config: Config) -> Self { // debug!(STDERR, "{:#?}", config); let resolvers: HashMap<_, _> = config .upstreams @@ -67,7 +67,7 @@ impl AppConfig { }) .collect(); - AppConfig { + HandlerConfig { defaults: Arc::new(config.default_upstreams), resolvers: Arc::new(resolvers), domains: Arc::new(domains), diff --git a/src/loger.rs b/src/loger.rs new file mode 100644 index 0000000..421cd8b --- /dev/null +++ b/src/loger.rs @@ -0,0 +1,23 @@ +use lazy_static::lazy_static; +use slog::{o, Drain, Logger}; + +lazy_static! { + pub static ref STDOUT: Logger = stdout_logger(); + pub static ref STDERR: Logger = stderr_logger(); +} + +fn stdout_logger() -> Logger { + let decorator = slog_term::TermDecorator::new().build(); + let drain = slog_term::CompactFormat::new(decorator).build().fuse(); + let drain = slog_async::Async::new(drain).build().fuse(); + + Logger::root(drain, o!()) +} + +fn stderr_logger() -> Logger { + let decorator = slog_term::TermDecorator::new().build(); + let drain = slog_term::CompactFormat::new(decorator).build(); + let drain = std::sync::Mutex::new(drain).fuse(); + + Logger::root(drain, o!()) +} diff --git a/src/main.rs b/src/main.rs index a9c24d0..a221bec 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,53 +1,35 @@ -use crate::app_config::AppConfig; use crate::config::{Config, ConfigBuilder}; use crate::handler::Handler; +use crate::loger::{STDERR, STDOUT}; use anyhow::Result; use clap::Parser; use failure::Error; -use lazy_static::lazy_static; use option::Args; use slog::{crit, debug, info}; -use slog::{o, Drain, Logger}; use std::fmt::Display; use std::fs::File; use std::io::prelude::*; -use std::net::SocketAddr; use std::process::exit; use std::time::Duration; use tokio; use tokio::net::{TcpListener, UdpSocket}; use trust_dns_server::ServerFuture; -mod app_config; mod config; mod domain; mod filter; mod handler; +mod handler_config; mod ip; +mod loger; mod option; mod resolver; -struct MainConfig { - pub bind: SocketAddr, - pub app_config: AppConfig, -} - -lazy_static! { - static ref STDOUT: Logger = stdout_logger(); - static ref STDERR: Logger = stderr_logger(); - static ref CONFIG: MainConfig = { - let config = config().unwrap_or_log(); - MainConfig { - bind: config.bind, - app_config: AppConfig::new(config), - } - }; -} - #[tokio::main] async fn main() -> Result<()> { - let bind_socket = CONFIG.bind; - let mut server = ServerFuture::new(Handler::default()); + let config = config().unwrap_or_log(); + let bind_socket = config.bind; + let mut server = ServerFuture::new(Handler::new(config.into())); let bind = UdpSocket::bind(bind_socket); info!(STDOUT, "Listening on UDP: {}", bind_socket); @@ -73,22 +55,6 @@ fn config() -> Result { builder.build() } -fn stdout_logger() -> Logger { - let decorator = slog_term::TermDecorator::new().build(); - let drain = slog_term::CompactFormat::new(decorator).build().fuse(); - let drain = slog_async::Async::new(drain).build().fuse(); - - Logger::root(drain, o!()) -} - -fn stderr_logger() -> Logger { - let decorator = slog_term::TermDecorator::new().build(); - let drain = slog_term::CompactFormat::new(decorator).build(); - let drain = std::sync::Mutex::new(drain).fuse(); - - Logger::root(drain, o!()) -} - trait ShouldSuccess { type Item; @@ -121,20 +87,3 @@ where }) } } - -trait Transpose { - type Output; - fn transpose(self) -> Self::Output; -} - -impl Transpose for Option> { - type Output = Result, E>; - - fn transpose(self) -> Self::Output { - match self { - Some(Ok(x)) => Ok(Some(x)), - Some(Err(e)) => Err(e), - None => Ok(None), - } - } -}