From 050f6e6312e6fc11f39ec9094ce19eafc00d725e Mon Sep 17 00:00:00 2001 From: zyj Date: Fri, 14 Jun 2024 17:03:37 +0800 Subject: [PATCH] add lookup ip strategy --- Cargo.toml | 2 +- examples/template.toml | 14 ++++---- src/config.rs | 17 +++++---- src/handler.rs | 5 ++- src/handler_config.rs | 19 ++++++---- src/resolver.rs | 78 +++++++++++++++++++++++++++--------------- 6 files changed, 84 insertions(+), 51 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3ab0856..974a715 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ authors = ["zyj"] edition = "2021" name = "yadns" -version = "0.5.0" +version = "0.5.1" [features] default = ["default-doh-rustls"] diff --git a/examples/template.toml b/examples/template.toml index e3532cd..1d5fb6f 100644 --- a/examples/template.toml +++ b/examples/template.toml @@ -7,12 +7,14 @@ bind = "127.0.0.1:5300" # the address that ya-dns listens on [resolver_opts] # Specify the timeout for a request. Defaults to 5 seconds timeout = 5 -# The strategy for the Resolver to use when lookup Ipv4 or Ipv6 addresses. Available lookup ip strategy options: -# Ipv4Only, Only query for A (Ipv4) records -# Ipv6Only, Only query for AAAA (Ipv6) records -# Ipv4AndIpv6, Query for A and AAAA in parallel -# Ipv6thenIpv4, Query for Ipv6 if that fails, query for Ipv4 -# Ipv4thenIpv6, Query for Ipv4 if that fails, query for Ipv6 (default) +# The strategy for the Resolver to use when lookup Ipv4 or Ipv6 addresses. +# Available lookup ip strategy options: +# None, Query records by client query type (default) +# Ipv4Only, Only query for A (Ipv4) records +# Ipv6Only, Only query for AAAA (Ipv6) records +# Ipv4AndIpv6, Query for A and AAAA in parallel +# Ipv6thenIpv4, Query for Ipv6 if that fails, query for Ipv4 +# Ipv4thenIpv6, Query for Ipv4 if that fails, query for Ipv6 strategy = "Ipv4thenIpv6" # Cache size is in number of records (some records can be large) cache_size = 32 diff --git a/src/config.rs b/src/config.rs index 7e17d4a..976155f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -140,7 +140,7 @@ impl ConfigBuilder { #[derive(Debug)] pub struct ResolverOpts { pub timeout: Duration, - pub ip_strategy: LookupIpStrategy, + pub ip_strategy: Option, pub cache_size: usize, } @@ -153,6 +153,8 @@ struct ResolverOptsConfig { #[derive(Debug, Deserialize)] enum StrategyType { + #[serde(rename = "None")] + None, #[serde(rename = "Ipv4Only")] Ipv4Only, #[serde(rename = "Ipv6Only")] @@ -172,13 +174,14 @@ impl ResolverOptsConfig { ip_strategy: self .strategy .map(|s| match s { - StrategyType::Ipv4Only => LookupIpStrategy::Ipv4Only, - StrategyType::Ipv6Only => LookupIpStrategy::Ipv6Only, - StrategyType::Ipv4AndIpv6 => LookupIpStrategy::Ipv4AndIpv6, - StrategyType::Ipv6thenIpv4 => LookupIpStrategy::Ipv6thenIpv4, - StrategyType::Ipv4thenIpv6 => LookupIpStrategy::Ipv4thenIpv6, + StrategyType::Ipv4Only => Some(LookupIpStrategy::Ipv4Only), + StrategyType::Ipv6Only => Some(LookupIpStrategy::Ipv6Only), + StrategyType::Ipv4AndIpv6 => Some(LookupIpStrategy::Ipv4AndIpv6), + StrategyType::Ipv6thenIpv4 => Some(LookupIpStrategy::Ipv6thenIpv4), + StrategyType::Ipv4thenIpv6 => Some(LookupIpStrategy::Ipv4thenIpv6), + StrategyType::None => None, }) - .unwrap_or_default(), + .unwrap_or(None), cache_size: self.cache_size.unwrap_or(32), } } diff --git a/src/handler.rs b/src/handler.rs index 5481f22..72c6764 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -55,13 +55,12 @@ impl Handler { .map(|name| { let domain1 = request.query().name().to_string(); let domain2 = request.query().name().to_string(); - //let query_type = request.query().query_type(); + let query_type = request.query().query_type(); let name1 = name.to_string(); let name2 = name.to_string(); let rs = handler_config().resolvers.get(name); rs.unwrap() - // .resolve(domain1, query_type) - .resolve(domain1) + .resolve(domain1, query_type) .boxed() .map_ok(move |resp| (domain2, name1, resp)) .map_err(move |e| (name2, e)) diff --git a/src/handler_config.rs b/src/handler_config.rs index 67c6704..db6174a 100644 --- a/src/handler_config.rs +++ b/src/handler_config.rs @@ -30,8 +30,9 @@ impl From for HandlerConfig { // debug!(STDERR, "{:#?}", config); let mut opts = ResolverOpts::default(); opts.timeout = config.resolver_opts.timeout; - opts.ip_strategy = config.resolver_opts.ip_strategy; + opts.ip_strategy = config.resolver_opts.ip_strategy.unwrap_or_default(); opts.cache_size = config.resolver_opts.cache_size; + let lookup_ip_only = config.resolver_opts.ip_strategy.is_some(); let resolvers: HashMap<_, _> = config .upstreams .iter() @@ -39,12 +40,14 @@ impl From for HandlerConfig { ( name.to_owned(), match upstream { - Upstream::TcpUpstream { address, proxy } => { - Arc::new(resolver::tcp_resolver(address, opts.to_owned(), proxy)) - } - Upstream::UdpUpstream { address } => { - Arc::new(resolver::udp_resolver(address, opts.to_owned())) - } + Upstream::TcpUpstream { address, proxy } => Arc::new( + resolver::tcp_resolver(address, opts.to_owned(), lookup_ip_only, proxy), + ), + Upstream::UdpUpstream { address } => Arc::new(resolver::udp_resolver( + address, + opts.to_owned(), + lookup_ip_only, + )), #[cfg(feature = "dns-over-tls")] Upstream::TlsUpstream { address, @@ -54,6 +57,7 @@ impl From for HandlerConfig { address, tls_host, opts.to_owned(), + lookup_ip_only, proxy, )), #[cfg(feature = "dns-over-https")] @@ -65,6 +69,7 @@ impl From for HandlerConfig { address, tls_host, opts.to_owned(), + lookup_ip_only, proxy, )), }, diff --git a/src/resolver.rs b/src/resolver.rs index 9e93a5f..78cd5a4 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -1,3 +1,4 @@ +use hickory_proto::rr::RecordType; use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts}; use hickory_resolver::error::ResolveError; use hickory_resolver::lookup::Lookup; @@ -9,6 +10,7 @@ use crate::resolver_runtime_provider::{ProxyConnectionProvider, ProxyRuntimeProv #[derive(Clone, Debug)] pub struct RecursiveResolver { pub resolver: AsyncResolver, + pub lookup_ip_only: bool, } impl RecursiveResolver { @@ -16,40 +18,48 @@ impl RecursiveResolver { resolver_config: ResolverConfig, options: ResolverOpts, provider: ProxyConnectionProvider, + lookup_ip_only: bool, ) -> Self { let resolver = AsyncResolver::new(resolver_config, options, provider); - RecursiveResolver { resolver } - } - /* - pub async fn resolve( - &self, - domain: String, - record_type: RecordType, - ) -> Result { - self.resolver.lookup(domain, record_type).await + RecursiveResolver { + resolver, + lookup_ip_only, } - */ - pub async fn resolve(&self, domain: String) -> Result { - match self.resolver.lookup_ip(domain).await { - Ok(res) => Ok(res.as_lookup().to_owned()), - Err(e) => Err(e), + } + + pub async fn resolve( + &self, + domain: String, + record_type: RecordType, + ) -> Result { + match self.lookup_ip_only { + true => match self.resolver.lookup_ip(domain).await { + Ok(res) => Ok(res.as_lookup().to_owned()), + Err(e) => Err(e), + }, + false => self.resolver.lookup(domain, record_type).await, } } } -pub fn udp_resolver(address: &Vec, options: ResolverOpts) -> RecursiveResolver { +pub fn udp_resolver( + address: &Vec, + options: ResolverOpts, + lookup_ip_only: bool, +) -> RecursiveResolver { let mut resolver_config = ResolverConfig::new(); address.iter().for_each(|addr| { resolver_config.add_name_server(NameServerConfig::new(*addr, Protocol::Udp)); }); let runtime_provider = ProxyRuntimeProvider::new(None); let provider = ProxyConnectionProvider::new(runtime_provider); - RecursiveResolver::new(resolver_config, options, provider) + RecursiveResolver::new(resolver_config, options, provider, lookup_ip_only) } pub fn tcp_resolver( address: &Vec, options: ResolverOpts, + lookup_ip_only: bool, proxy: &Option, ) -> RecursiveResolver { let mut resolver_config = ResolverConfig::new(); @@ -58,7 +68,7 @@ pub fn tcp_resolver( }); let runtime_provider = ProxyRuntimeProvider::new(proxy.to_owned().map(|p| p.parse().unwrap())); let provider = ProxyConnectionProvider::new(runtime_provider); - RecursiveResolver::new(resolver_config, options, provider) + RecursiveResolver::new(resolver_config, options, provider, lookup_ip_only) } #[cfg(feature = "dns-over-tls")] @@ -66,6 +76,7 @@ pub fn tls_resolver( address: &Vec, tls_host: &String, options: ResolverOpts, + lookup_ip_only: bool, proxy: &Option, ) -> RecursiveResolver { let mut resolver_config = ResolverConfig::new(); @@ -76,7 +87,7 @@ pub fn tls_resolver( }); let runtime_provider = ProxyRuntimeProvider::new(proxy.to_owned().map(|p| p.parse().unwrap())); let provider = ProxyConnectionProvider::new(runtime_provider); - RecursiveResolver::new(resolver_config, options, provider) + RecursiveResolver::new(resolver_config, options, provider, lookup_ip_only) } #[cfg(feature = "dns-over-https")] @@ -84,6 +95,7 @@ pub fn https_resolver( address: &Vec, tls_host: &String, options: ResolverOpts, + lookup_ip_only: bool, proxy: &Option, ) -> RecursiveResolver { let mut resolver_config = ResolverConfig::new(); @@ -94,7 +106,7 @@ pub fn https_resolver( }); let runtime_provider = ProxyRuntimeProvider::new(proxy.to_owned().map(|p| p.parse().unwrap())); let provider = ProxyConnectionProvider::new(runtime_provider); - RecursiveResolver::new(resolver_config, options, provider) + RecursiveResolver::new(resolver_config, options, provider, lookup_ip_only) } #[cfg(test)] @@ -107,8 +119,8 @@ mod tests { fn udp_resolver_test() { let dns_addr = "8.8.8.8:53".parse::().unwrap(); let io_loop = Runtime::new().unwrap(); - let resolver = udp_resolver(&vec![dns_addr], ResolverOpts::default()); - let lookup_future = resolver.resolve(String::from("www.example.com")); + let resolver = udp_resolver(&vec![dns_addr], ResolverOpts::default(), false); + let lookup_future = resolver.resolve(String::from("www.example.com"), RecordType::A); let response = io_loop.block_on(lookup_future).unwrap(); let a = response .record_iter() @@ -123,8 +135,8 @@ mod tests { fn tcp_resolver_test() { let dns_addr = "8.8.8.8:53".parse::().unwrap(); let io_loop = Runtime::new().unwrap(); - let resolver = tcp_resolver(&vec![dns_addr], ResolverOpts::default(), &None); - let lookup_future = resolver.resolve(String::from("www.example.com")); + let resolver = tcp_resolver(&vec![dns_addr], ResolverOpts::default(), false, &None); + let lookup_future = resolver.resolve(String::from("www.example.com"), RecordType::A); let response = io_loop.block_on(lookup_future).unwrap(); let a = response .record_iter() @@ -141,8 +153,14 @@ mod tests { let dns_addr = "8.8.8.8:853".parse::().unwrap(); let dns_host = String::from("dns.google"); let io_loop = Runtime::new().unwrap(); - let resolver = tls_resolver(&vec![dns_addr], &dns_host, ResolverOpts::default(), &None); - let lookup_future = resolver.resolve(String::from("www.example.com")); + let resolver = tls_resolver( + &vec![dns_addr], + &dns_host, + ResolverOpts::default(), + false, + &None, + ); + let lookup_future = resolver.resolve(String::from("www.example.com"), RecordType::A); let response = io_loop.block_on(lookup_future).unwrap(); let a = response .record_iter() @@ -159,8 +177,14 @@ mod tests { let dns_addr = "8.8.8.8:443".parse::().unwrap(); let dns_host = String::from("dns.google"); let io_loop = Runtime::new().unwrap(); - let resolver = https_resolver(&vec![dns_addr], &dns_host, ResolverOpts::default(), &None); - let lookup_future = resolver.resolve(String::from("www.example.com")); + let resolver = https_resolver( + &vec![dns_addr], + &dns_host, + ResolverOpts::default(), + false, + &None, + ); + let lookup_future = resolver.resolve(String::from("www.example.com"), RecordType::A); let response = io_loop.block_on(lookup_future).unwrap(); let a = response .record_iter()