Skip to content

Commit

Permalink
add configuration for the resolver.Closes #26
Browse files Browse the repository at this point in the history
  • Loading branch information
zyj committed Jun 13, 2024
1 parent a468d7a commit a4715f4
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 31 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# 0.5.0

- add configuration for the resolver

# 0.4.0

- add socks5 and http proxy support
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
authors = ["zyj"]
edition = "2021"
name = "yadns"
version = "0.4.2"
version = "0.5.0"

[features]
default = ["default-doh-rustls"]
Expand Down
2 changes: 1 addition & 1 deletion examples/chinadns-domain.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ files = ["direct-list.txt"]

[ranges]
[ranges.cn]
# download from https://www.ipdeny.com/ipblocks/data/aggregated/cn.zone
# download from https://www.ipdeny.com/ipblocks/data/countries/cn.zone
files = ["cn.txt"]

[[requests]]
Expand Down
2 changes: 1 addition & 1 deletion examples/chinadns.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ tls-host = "cloudflare-dns.com"

[ranges]
[ranges.cn]
# download from https://www.ipdeny.com/ipblocks/data/aggregated/cn.zone
# download from https://www.ipdeny.com/ipblocks/data/countries/cn.zone
files = ["cn.txt"]

[[responses]]
Expand Down
14 changes: 14 additions & 0 deletions examples/template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@
# Root privilege may be required if you specify a port below 1024.
bind = "127.0.0.1:5300" # the address that ya-dns listens on

# Configuration for the Resolver
[resolver_opts]
# Specify the timeout for a request. Defaults to 5 seconds
timeout = 5
# The strategy for the Resolver to use when lookup Ipv4 or Ipv6 addresses. Available lookup ip strategy options:
# Ipv4Only, Only query for A (Ipv4) records
# Ipv6Only, Only query for AAAA (Ipv6) records
# Ipv4AndIpv6, Query for A and AAAA in parallel
# Ipv6thenIpv4, Query for Ipv6 if that fails, query for Ipv4
# Ipv4thenIpv6, Query for Ipv4 if that fails, query for Ipv6 (default)
strategy = "Ipv4thenIpv6"
# Cache size is in number of records (some records can be large)
cache_size = 32

# DNS requests will be forwarded to all the upstream servers set up here
# except those with `default = false`.
[upstreams]
Expand Down
61 changes: 61 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::ip::IpRange;
use hickory_proto::rr::RecordType;
use hickory_resolver::config::LookupIpStrategy;
use ipnet::AddrParseError;
use ipnet::IpNet;
use serde_derive::Deserialize;
Expand All @@ -10,6 +11,7 @@ use std::io::BufReader;
use std::net::IpAddr;
use std::net::SocketAddr;
use std::str::FromStr;
use std::time::Duration;
use thiserror::Error;

#[derive(Error, Debug)]
Expand All @@ -27,6 +29,7 @@ pub enum ConfigError {
pub struct Config {
pub bind: SocketAddr,
pub default_upstreams: Vec<String>,
pub resolver_opts: ResolverOpts,
pub upstreams: HashMap<String, Upstream>,
pub domains: HashMap<String, Domains>,
pub ranges: HashMap<String, IpRange>,
Expand All @@ -37,6 +40,7 @@ pub struct Config {
#[derive(Debug, Deserialize)]
pub struct ConfigBuilder {
bind: SocketAddr,
resolver_opts: Option<ResolverOptsConfig>,
upstreams: HashMap<String, UpstreamConfig>,
domains: Option<HashMap<String, DomainsConf>>,
ranges: Option<HashMap<String, IpRangeConf>>,
Expand Down Expand Up @@ -71,6 +75,15 @@ impl ConfigBuilder {
pub fn build(self) -> Result<Config, ConfigError> {
let mut default_upstreams = Vec::new();

let resolver_opts = self
.resolver_opts
.unwrap_or(ResolverOptsConfig {
timeout: None,
strategy: None,
cache_size: None,
})
.build();

let upstreams = self
.upstreams
.into_iter()
Expand Down Expand Up @@ -114,6 +127,7 @@ impl ConfigBuilder {
Ok(Config {
bind: self.bind,
default_upstreams,
resolver_opts,
upstreams,
domains,
ranges,
Expand All @@ -123,6 +137,53 @@ impl ConfigBuilder {
}
}

#[derive(Debug)]
pub struct ResolverOpts {
pub timeout: Duration,
pub ip_strategy: LookupIpStrategy,
pub cache_size: usize,
}

#[derive(Debug, Deserialize)]
struct ResolverOptsConfig {
timeout: Option<u64>,
strategy: Option<StrategyType>,
cache_size: Option<usize>,
}

#[derive(Debug, Deserialize)]
enum StrategyType {
#[serde(rename = "Ipv4Only")]
Ipv4Only,
#[serde(rename = "Ipv6Only")]
Ipv6Only,
#[serde(rename = "Ipv4AndIpv6")]
Ipv4AndIpv6,
#[serde(rename = "Ipv6thenIpv4")]
Ipv6thenIpv4,
#[serde(rename = "Ipv4thenIpv6")]
Ipv4thenIpv6,
}

impl ResolverOptsConfig {
fn build(self) -> ResolverOpts {
ResolverOpts {
timeout: Duration::from_secs(self.timeout.unwrap_or(5)),
ip_strategy: self
.strategy
.map(|s| match s {
StrategyType::Ipv4Only => LookupIpStrategy::Ipv4Only,
StrategyType::Ipv6Only => LookupIpStrategy::Ipv6Only,
StrategyType::Ipv4AndIpv6 => LookupIpStrategy::Ipv4AndIpv6,
StrategyType::Ipv6thenIpv4 => LookupIpStrategy::Ipv6thenIpv4,
StrategyType::Ipv4thenIpv6 => LookupIpStrategy::Ipv4thenIpv6,
})
.unwrap_or_default(),
cache_size: self.cache_size.unwrap_or(32),
}
}
}

#[derive(Debug, Deserialize)]
pub struct UpstreamConfig {
address: Vec<String>,
Expand Down
5 changes: 3 additions & 2 deletions src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ impl Handler {
.map(|name| {
let domain1 = request.query().name().to_string();
let domain2 = request.query().name().to_string();
let query_type = request.query().query_type();
//let query_type = request.query().query_type();
let name1 = name.to_string();
let name2 = name.to_string();
let rs = handler_config().resolvers.get(name);
rs.unwrap()
.resolve(domain1, query_type)
// .resolve(domain1, query_type)
.resolve(domain1)
.boxed()
.map_ok(move |resp| (domain2, name1, resp))
.map_err(move |e| (name2, e))
Expand Down
23 changes: 19 additions & 4 deletions src/handler_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::domain::DomainSuffix;
use crate::ip::IpRange;
use crate::resolver;
use crate::resolver::RecursiveResolver;
use hickory_resolver::config::ResolverOpts;
use regex::RegexSet;
use std::collections::HashMap;
use std::sync::Arc;
Expand All @@ -27,6 +28,10 @@ pub struct Domains {
impl From<Config> for HandlerConfig {
fn from(config: Config) -> Self {
// debug!(STDERR, "{:#?}", config);
let mut opts = ResolverOpts::default();
opts.timeout = config.resolver_opts.timeout;
opts.ip_strategy = config.resolver_opts.ip_strategy;
opts.cache_size = config.resolver_opts.cache_size;
let resolvers: HashMap<_, _> = config
.upstreams
.iter()
Expand All @@ -35,23 +40,33 @@ impl From<Config> for HandlerConfig {
name.to_owned(),
match upstream {
Upstream::TcpUpstream { address, proxy } => {
Arc::new(resolver::tcp_resolver(address, proxy))
Arc::new(resolver::tcp_resolver(address, opts.to_owned(), proxy))
}
Upstream::UdpUpstream { address } => {
Arc::new(resolver::udp_resolver(address))
Arc::new(resolver::udp_resolver(address, opts.to_owned()))
}
#[cfg(feature = "dns-over-tls")]
Upstream::TlsUpstream {
address,
tls_host,
proxy,
} => Arc::new(resolver::tls_resolver(address, tls_host, proxy)),
} => Arc::new(resolver::tls_resolver(
address,
tls_host,
opts.to_owned(),
proxy,
)),
#[cfg(feature = "dns-over-https")]
Upstream::HttpsUpstream {
address,
tls_host,
proxy,
} => Arc::new(resolver::https_resolver(address, tls_host, proxy)),
} => Arc::new(resolver::https_resolver(
address,
tls_host,
opts.to_owned(),
proxy,
)),
},
)
})
Expand Down
56 changes: 34 additions & 22 deletions src/resolver.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use hickory_proto::rr::RecordType;
use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
use hickory_resolver::error::ResolveError;
use hickory_resolver::lookup::Lookup;
Expand All @@ -21,40 +20,52 @@ impl RecursiveResolver {
let resolver = AsyncResolver::new(resolver_config, options, provider);
RecursiveResolver { resolver }
}

pub async fn resolve(
&self,
domain: String,
record_type: RecordType,
) -> Result<Lookup, ResolveError> {
self.resolver.lookup(domain, record_type).await
/*
pub async fn resolve(
&self,
domain: String,
record_type: RecordType,
) -> Result<Lookup, ResolveError> {
self.resolver.lookup(domain, record_type).await
}
*/
pub async fn resolve(&self, domain: String) -> Result<Lookup, ResolveError> {
match self.resolver.lookup_ip(domain).await {
Ok(res) => Ok(res.as_lookup().to_owned()),
Err(e) => Err(e),
}
}
}

pub fn udp_resolver(address: &Vec<SocketAddr>) -> RecursiveResolver {
pub fn udp_resolver(address: &Vec<SocketAddr>, options: ResolverOpts) -> RecursiveResolver {
let mut resolver_config = ResolverConfig::new();
address.iter().for_each(|addr| {
resolver_config.add_name_server(NameServerConfig::new(*addr, Protocol::Udp));
});
let runtime_provider = ProxyRuntimeProvider::new(None);
let provider = ProxyConnectionProvider::new(runtime_provider);
RecursiveResolver::new(resolver_config, ResolverOpts::default(), provider)
RecursiveResolver::new(resolver_config, options, provider)
}

pub fn tcp_resolver(address: &Vec<SocketAddr>, proxy: &Option<String>) -> RecursiveResolver {
pub fn tcp_resolver(
address: &Vec<SocketAddr>,
options: ResolverOpts,
proxy: &Option<String>,
) -> RecursiveResolver {
let mut resolver_config = ResolverConfig::new();
address.iter().for_each(|addr| {
resolver_config.add_name_server(NameServerConfig::new(*addr, Protocol::Tcp));
});
let runtime_provider = ProxyRuntimeProvider::new(proxy.to_owned().map(|p| p.parse().unwrap()));
let provider = ProxyConnectionProvider::new(runtime_provider);
RecursiveResolver::new(resolver_config, ResolverOpts::default(), provider)
RecursiveResolver::new(resolver_config, options, provider)
}

#[cfg(feature = "dns-over-tls")]
pub fn tls_resolver(
address: &Vec<SocketAddr>,
tls_host: &String,
options: ResolverOpts,
proxy: &Option<String>,
) -> RecursiveResolver {
let mut resolver_config = ResolverConfig::new();
Expand All @@ -65,13 +76,14 @@ pub fn tls_resolver(
});
let runtime_provider = ProxyRuntimeProvider::new(proxy.to_owned().map(|p| p.parse().unwrap()));
let provider = ProxyConnectionProvider::new(runtime_provider);
RecursiveResolver::new(resolver_config, ResolverOpts::default(), provider)
RecursiveResolver::new(resolver_config, options, provider)
}

#[cfg(feature = "dns-over-https")]
pub fn https_resolver(
address: &Vec<SocketAddr>,
tls_host: &String,
options: ResolverOpts,
proxy: &Option<String>,
) -> RecursiveResolver {
let mut resolver_config = ResolverConfig::new();
Expand All @@ -82,7 +94,7 @@ pub fn https_resolver(
});
let runtime_provider = ProxyRuntimeProvider::new(proxy.to_owned().map(|p| p.parse().unwrap()));
let provider = ProxyConnectionProvider::new(runtime_provider);
RecursiveResolver::new(resolver_config, ResolverOpts::default(), provider)
RecursiveResolver::new(resolver_config, options, provider)
}

#[cfg(test)]
Expand All @@ -95,8 +107,8 @@ mod tests {
fn udp_resolver_test() {
let dns_addr = "8.8.8.8:53".parse::<SocketAddr>().unwrap();
let io_loop = Runtime::new().unwrap();
let resolver = udp_resolver(&vec![dns_addr]);
let lookup_future = resolver.resolve(String::from("www.example.com"), RecordType::A);
let resolver = udp_resolver(&vec![dns_addr], ResolverOpts::default());
let lookup_future = resolver.resolve(String::from("www.example.com"));
let response = io_loop.block_on(lookup_future).unwrap();
let a = response
.record_iter()
Expand All @@ -111,8 +123,8 @@ mod tests {
fn tcp_resolver_test() {
let dns_addr = "8.8.8.8:53".parse::<SocketAddr>().unwrap();
let io_loop = Runtime::new().unwrap();
let resolver = tcp_resolver(&vec![dns_addr], &None);
let lookup_future = resolver.resolve(String::from("www.example.com"), RecordType::A);
let resolver = tcp_resolver(&vec![dns_addr], ResolverOpts::default(), &None);
let lookup_future = resolver.resolve(String::from("www.example.com"));
let response = io_loop.block_on(lookup_future).unwrap();
let a = response
.record_iter()
Expand All @@ -129,8 +141,8 @@ mod tests {
let dns_addr = "8.8.8.8:853".parse::<SocketAddr>().unwrap();
let dns_host = String::from("dns.google");
let io_loop = Runtime::new().unwrap();
let resolver = tls_resolver(&vec![dns_addr], &dns_host, &None);
let lookup_future = resolver.resolve(String::from("www.example.com"), RecordType::A);
let resolver = tls_resolver(&vec![dns_addr], &dns_host, ResolverOpts::default(), &None);
let lookup_future = resolver.resolve(String::from("www.example.com"));
let response = io_loop.block_on(lookup_future).unwrap();
let a = response
.record_iter()
Expand All @@ -147,8 +159,8 @@ mod tests {
let dns_addr = "8.8.8.8:443".parse::<SocketAddr>().unwrap();
let dns_host = String::from("dns.google");
let io_loop = Runtime::new().unwrap();
let resolver = https_resolver(&vec![dns_addr], &dns_host, &None);
let lookup_future = resolver.resolve(String::from("www.example.com"), RecordType::A);
let resolver = https_resolver(&vec![dns_addr], &dns_host, ResolverOpts::default(), &None);
let lookup_future = resolver.resolve(String::from("www.example.com"));
let response = io_loop.block_on(lookup_future).unwrap();
let a = response
.record_iter()
Expand Down

0 comments on commit a4715f4

Please sign in to comment.