Skip to content

Commit

Permalink
add lookup ip strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
zyj committed Jun 14, 2024
1 parent a4715f4 commit 050f6e6
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 51 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
authors = ["zyj"]
edition = "2021"
name = "yadns"
version = "0.5.0"
version = "0.5.1"

[features]
default = ["default-doh-rustls"]
Expand Down
14 changes: 8 additions & 6 deletions examples/template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ bind = "127.0.0.1:5300" # the address that ya-dns listens on
[resolver_opts]
# Specify the timeout for a request. Defaults to 5 seconds
timeout = 5
# The strategy for the Resolver to use when lookup Ipv4 or Ipv6 addresses. Available lookup ip strategy options:
# Ipv4Only, Only query for A (Ipv4) records
# Ipv6Only, Only query for AAAA (Ipv6) records
# Ipv4AndIpv6, Query for A and AAAA in parallel
# Ipv6thenIpv4, Query for Ipv6 if that fails, query for Ipv4
# Ipv4thenIpv6, Query for Ipv4 if that fails, query for Ipv6 (default)
# The strategy for the Resolver to use when lookup Ipv4 or Ipv6 addresses.
# Available lookup ip strategy options:
# None, Query records by client query type (default)
# Ipv4Only, Only query for A (Ipv4) records
# Ipv6Only, Only query for AAAA (Ipv6) records
# Ipv4AndIpv6, Query for A and AAAA in parallel
# Ipv6thenIpv4, Query for Ipv6 if that fails, query for Ipv4
# Ipv4thenIpv6, Query for Ipv4 if that fails, query for Ipv6
strategy = "Ipv4thenIpv6"
# Cache size is in number of records (some records can be large)
cache_size = 32
Expand Down
17 changes: 10 additions & 7 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl ConfigBuilder {
#[derive(Debug)]
pub struct ResolverOpts {
pub timeout: Duration,
pub ip_strategy: LookupIpStrategy,
pub ip_strategy: Option<LookupIpStrategy>,
pub cache_size: usize,
}

Expand All @@ -153,6 +153,8 @@ struct ResolverOptsConfig {

#[derive(Debug, Deserialize)]
enum StrategyType {
#[serde(rename = "None")]
None,
#[serde(rename = "Ipv4Only")]
Ipv4Only,
#[serde(rename = "Ipv6Only")]
Expand All @@ -172,13 +174,14 @@ impl ResolverOptsConfig {
ip_strategy: self
.strategy
.map(|s| match s {
StrategyType::Ipv4Only => LookupIpStrategy::Ipv4Only,
StrategyType::Ipv6Only => LookupIpStrategy::Ipv6Only,
StrategyType::Ipv4AndIpv6 => LookupIpStrategy::Ipv4AndIpv6,
StrategyType::Ipv6thenIpv4 => LookupIpStrategy::Ipv6thenIpv4,
StrategyType::Ipv4thenIpv6 => LookupIpStrategy::Ipv4thenIpv6,
StrategyType::Ipv4Only => Some(LookupIpStrategy::Ipv4Only),
StrategyType::Ipv6Only => Some(LookupIpStrategy::Ipv6Only),
StrategyType::Ipv4AndIpv6 => Some(LookupIpStrategy::Ipv4AndIpv6),
StrategyType::Ipv6thenIpv4 => Some(LookupIpStrategy::Ipv6thenIpv4),
StrategyType::Ipv4thenIpv6 => Some(LookupIpStrategy::Ipv4thenIpv6),
StrategyType::None => None,
})
.unwrap_or_default(),
.unwrap_or(None),
cache_size: self.cache_size.unwrap_or(32),
}
}
Expand Down
5 changes: 2 additions & 3 deletions src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,12 @@ impl Handler {
.map(|name| {
let domain1 = request.query().name().to_string();
let domain2 = request.query().name().to_string();
//let query_type = request.query().query_type();
let query_type = request.query().query_type();
let name1 = name.to_string();
let name2 = name.to_string();
let rs = handler_config().resolvers.get(name);
rs.unwrap()
// .resolve(domain1, query_type)
.resolve(domain1)
.resolve(domain1, query_type)
.boxed()
.map_ok(move |resp| (domain2, name1, resp))
.map_err(move |e| (name2, e))
Expand Down
19 changes: 12 additions & 7 deletions src/handler_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,24 @@ impl From<Config> for HandlerConfig {
// debug!(STDERR, "{:#?}", config);
let mut opts = ResolverOpts::default();
opts.timeout = config.resolver_opts.timeout;
opts.ip_strategy = config.resolver_opts.ip_strategy;
opts.ip_strategy = config.resolver_opts.ip_strategy.unwrap_or_default();
opts.cache_size = config.resolver_opts.cache_size;
let lookup_ip_only = config.resolver_opts.ip_strategy.is_some();
let resolvers: HashMap<_, _> = config
.upstreams
.iter()
.map(|(name, upstream)| {
(
name.to_owned(),
match upstream {
Upstream::TcpUpstream { address, proxy } => {
Arc::new(resolver::tcp_resolver(address, opts.to_owned(), proxy))
}
Upstream::UdpUpstream { address } => {
Arc::new(resolver::udp_resolver(address, opts.to_owned()))
}
Upstream::TcpUpstream { address, proxy } => Arc::new(
resolver::tcp_resolver(address, opts.to_owned(), lookup_ip_only, proxy),
),
Upstream::UdpUpstream { address } => Arc::new(resolver::udp_resolver(
address,
opts.to_owned(),
lookup_ip_only,
)),
#[cfg(feature = "dns-over-tls")]
Upstream::TlsUpstream {
address,
Expand All @@ -54,6 +57,7 @@ impl From<Config> for HandlerConfig {
address,
tls_host,
opts.to_owned(),
lookup_ip_only,
proxy,
)),
#[cfg(feature = "dns-over-https")]
Expand All @@ -65,6 +69,7 @@ impl From<Config> for HandlerConfig {
address,
tls_host,
opts.to_owned(),
lookup_ip_only,
proxy,
)),
},
Expand Down
78 changes: 51 additions & 27 deletions src/resolver.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use hickory_proto::rr::RecordType;
use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
use hickory_resolver::error::ResolveError;
use hickory_resolver::lookup::Lookup;
Expand All @@ -9,47 +10,56 @@ use crate::resolver_runtime_provider::{ProxyConnectionProvider, ProxyRuntimeProv
#[derive(Clone, Debug)]
pub struct RecursiveResolver {
pub resolver: AsyncResolver<ProxyConnectionProvider>,
pub lookup_ip_only: bool,
}

impl RecursiveResolver {
pub fn new(
resolver_config: ResolverConfig,
options: ResolverOpts,
provider: ProxyConnectionProvider,
lookup_ip_only: bool,
) -> Self {
let resolver = AsyncResolver::new(resolver_config, options, provider);
RecursiveResolver { resolver }
}
/*
pub async fn resolve(
&self,
domain: String,
record_type: RecordType,
) -> Result<Lookup, ResolveError> {
self.resolver.lookup(domain, record_type).await
RecursiveResolver {
resolver,
lookup_ip_only,
}
*/
pub async fn resolve(&self, domain: String) -> Result<Lookup, ResolveError> {
match self.resolver.lookup_ip(domain).await {
Ok(res) => Ok(res.as_lookup().to_owned()),
Err(e) => Err(e),
}

pub async fn resolve(
&self,
domain: String,
record_type: RecordType,
) -> Result<Lookup, ResolveError> {
match self.lookup_ip_only {
true => match self.resolver.lookup_ip(domain).await {
Ok(res) => Ok(res.as_lookup().to_owned()),
Err(e) => Err(e),
},
false => self.resolver.lookup(domain, record_type).await,
}
}
}

pub fn udp_resolver(address: &Vec<SocketAddr>, options: ResolverOpts) -> RecursiveResolver {
pub fn udp_resolver(
address: &Vec<SocketAddr>,
options: ResolverOpts,
lookup_ip_only: bool,
) -> RecursiveResolver {
let mut resolver_config = ResolverConfig::new();
address.iter().for_each(|addr| {
resolver_config.add_name_server(NameServerConfig::new(*addr, Protocol::Udp));
});
let runtime_provider = ProxyRuntimeProvider::new(None);
let provider = ProxyConnectionProvider::new(runtime_provider);
RecursiveResolver::new(resolver_config, options, provider)
RecursiveResolver::new(resolver_config, options, provider, lookup_ip_only)
}

pub fn tcp_resolver(
address: &Vec<SocketAddr>,
options: ResolverOpts,
lookup_ip_only: bool,
proxy: &Option<String>,
) -> RecursiveResolver {
let mut resolver_config = ResolverConfig::new();
Expand All @@ -58,14 +68,15 @@ pub fn tcp_resolver(
});
let runtime_provider = ProxyRuntimeProvider::new(proxy.to_owned().map(|p| p.parse().unwrap()));
let provider = ProxyConnectionProvider::new(runtime_provider);
RecursiveResolver::new(resolver_config, options, provider)
RecursiveResolver::new(resolver_config, options, provider, lookup_ip_only)
}

#[cfg(feature = "dns-over-tls")]
pub fn tls_resolver(
address: &Vec<SocketAddr>,
tls_host: &String,
options: ResolverOpts,
lookup_ip_only: bool,
proxy: &Option<String>,
) -> RecursiveResolver {
let mut resolver_config = ResolverConfig::new();
Expand All @@ -76,14 +87,15 @@ pub fn tls_resolver(
});
let runtime_provider = ProxyRuntimeProvider::new(proxy.to_owned().map(|p| p.parse().unwrap()));
let provider = ProxyConnectionProvider::new(runtime_provider);
RecursiveResolver::new(resolver_config, options, provider)
RecursiveResolver::new(resolver_config, options, provider, lookup_ip_only)
}

#[cfg(feature = "dns-over-https")]
pub fn https_resolver(
address: &Vec<SocketAddr>,
tls_host: &String,
options: ResolverOpts,
lookup_ip_only: bool,
proxy: &Option<String>,
) -> RecursiveResolver {
let mut resolver_config = ResolverConfig::new();
Expand All @@ -94,7 +106,7 @@ pub fn https_resolver(
});
let runtime_provider = ProxyRuntimeProvider::new(proxy.to_owned().map(|p| p.parse().unwrap()));
let provider = ProxyConnectionProvider::new(runtime_provider);
RecursiveResolver::new(resolver_config, options, provider)
RecursiveResolver::new(resolver_config, options, provider, lookup_ip_only)
}

#[cfg(test)]
Expand All @@ -107,8 +119,8 @@ mod tests {
fn udp_resolver_test() {
let dns_addr = "8.8.8.8:53".parse::<SocketAddr>().unwrap();
let io_loop = Runtime::new().unwrap();
let resolver = udp_resolver(&vec![dns_addr], ResolverOpts::default());
let lookup_future = resolver.resolve(String::from("www.example.com"));
let resolver = udp_resolver(&vec![dns_addr], ResolverOpts::default(), false);
let lookup_future = resolver.resolve(String::from("www.example.com"), RecordType::A);
let response = io_loop.block_on(lookup_future).unwrap();
let a = response
.record_iter()
Expand All @@ -123,8 +135,8 @@ mod tests {
fn tcp_resolver_test() {
let dns_addr = "8.8.8.8:53".parse::<SocketAddr>().unwrap();
let io_loop = Runtime::new().unwrap();
let resolver = tcp_resolver(&vec![dns_addr], ResolverOpts::default(), &None);
let lookup_future = resolver.resolve(String::from("www.example.com"));
let resolver = tcp_resolver(&vec![dns_addr], ResolverOpts::default(), false, &None);
let lookup_future = resolver.resolve(String::from("www.example.com"), RecordType::A);
let response = io_loop.block_on(lookup_future).unwrap();
let a = response
.record_iter()
Expand All @@ -141,8 +153,14 @@ mod tests {
let dns_addr = "8.8.8.8:853".parse::<SocketAddr>().unwrap();
let dns_host = String::from("dns.google");
let io_loop = Runtime::new().unwrap();
let resolver = tls_resolver(&vec![dns_addr], &dns_host, ResolverOpts::default(), &None);
let lookup_future = resolver.resolve(String::from("www.example.com"));
let resolver = tls_resolver(
&vec![dns_addr],
&dns_host,
ResolverOpts::default(),
false,
&None,
);
let lookup_future = resolver.resolve(String::from("www.example.com"), RecordType::A);
let response = io_loop.block_on(lookup_future).unwrap();
let a = response
.record_iter()
Expand All @@ -159,8 +177,14 @@ mod tests {
let dns_addr = "8.8.8.8:443".parse::<SocketAddr>().unwrap();
let dns_host = String::from("dns.google");
let io_loop = Runtime::new().unwrap();
let resolver = https_resolver(&vec![dns_addr], &dns_host, ResolverOpts::default(), &None);
let lookup_future = resolver.resolve(String::from("www.example.com"));
let resolver = https_resolver(
&vec![dns_addr],
&dns_host,
ResolverOpts::default(),
false,
&None,
);
let lookup_future = resolver.resolve(String::from("www.example.com"), RecordType::A);
let response = io_loop.block_on(lookup_future).unwrap();
let a = response
.record_iter()
Expand Down

0 comments on commit 050f6e6

Please sign in to comment.