Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tun): try to match hostnames when clash dns used #609

Merged
merged 8 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ jobs:

- uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.toolchain || 'nightly' }}
toolchain: ${{ matrix.toolchain || 'nightly-2024-09-20' }} # until https://github.com/rust-lang/rust-clippy/issues/13457
target: ${{ matrix.target }}
components: ${{ matrix.components || 'rustfmt, clippy' }}

Expand Down
4 changes: 2 additions & 2 deletions clash/tests/data/config/rules.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mixed-port: 8899
tun:
enable: true
device-id: "dev://utun1989"
route-all: true
route-all: false
gateway: "198.19.0.1/32"
so-mark: 3389
# routes:
Expand Down Expand Up @@ -53,7 +53,7 @@ dns:

allow-lan: true
mode: rule
log-level: debug
log-level: trace
external-controller: :9090
external-ui: "public"
# secret: "clash-rs"
Expand Down
98 changes: 56 additions & 42 deletions clash_lib/src/app/dispatcher/dispatcher_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
internal::proxy::{PROXY_DIRECT, PROXY_GLOBAL},
},
proxy::{datagram::UdpPacket, AnyInboundDatagram},
session::Session,
session::{Session, SocksAddr},
};
use futures::{SinkExt, StreamExt};
use std::{
Expand Down Expand Up @@ -75,40 +75,51 @@ impl Dispatcher {
}

#[instrument(skip(self, sess, lhs))]
pub async fn dispatch_stream<S>(&self, sess: Session, mut lhs: S)
pub async fn dispatch_stream<S>(&self, mut sess: Session, mut lhs: S)
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
let sess = if self.resolver.fake_ip_enabled() {
match sess.destination {
crate::session::SocksAddr::Ip(addr) => {
let ip = addr.ip();
let dest: SocksAddr = match &sess.destination {
crate::session::SocksAddr::Ip(socket_addr) => {
if self.resolver.fake_ip_enabled() {
trace!("looking up fake ip: {}", socket_addr.ip());
let ip = socket_addr.ip();
if self.resolver.is_fake_ip(ip).await {
let host = self.resolver.reverse_lookup(ip).await;
match host {
Some(host) => {
let mut sess = sess;
sess.destination = crate::session::SocksAddr::Domain(
host,
addr.port(),
);
sess
}
Some(host) => (host, socket_addr.port())
.try_into()
.expect("must be valid domain"),
None => {
error!("failed to reverse lookup fake ip: {}", ip);
return;
}
}
} else {
sess
(*socket_addr).into()
}
} else {
trace!("looking up resolve cache ip: {}", socket_addr.ip());
if let Some(resolved) =
self.resolver.cached_for(socket_addr.ip()).await
{
(resolved, socket_addr.port())
.try_into()
.expect("must be valid domain")
} else {
(*socket_addr).into()
}
}
crate::session::SocksAddr::Domain(..) => sess,
}
} else {
sess
crate::session::SocksAddr::Domain(host, port) => {
(host.to_owned(), *port)
.try_into()
.expect("must be valid domain")
}
};

sess.destination = dest.clone();

let mode = *self.mode.lock().unwrap();
let (outbound_name, rule) = match mode {
RunMode::Global => (PROXY_GLOBAL, None),
Expand Down Expand Up @@ -253,28 +264,17 @@ impl Dispatcher {
while let Some(packet) = local_r.next().await {
let mut sess = sess.clone();
sess.source = packet.src_addr.clone().must_into_socket_addr();
sess.destination = packet.dst_addr.clone();

// populate fake ip for route matching
let sess = if resolver.fake_ip_enabled() {
trace!("looking up fake ip for {sess}");
match sess.destination {
crate::session::SocksAddr::Ip(addr) => {
let ip = addr.ip();

let dest: SocksAddr = match &packet.dst_addr {
crate::session::SocksAddr::Ip(socket_addr) => {
if resolver.fake_ip_enabled() {
let ip = socket_addr.ip();
if resolver.is_fake_ip(ip).await {
trace!("fake ip detected");
let host = resolver.reverse_lookup(ip).await;
match host {
Some(host) => {
trace!("fake ip resolved to {}", host);
let mut sess = sess;
sess.destination =
crate::session::SocksAddr::Domain(
host,
addr.port(),
);
sess
}
Some(host) => (host, socket_addr.port())
.try_into()
.expect("must be valid domain"),
None => {
error!(
"failed to reverse lookup fake ip: {}",
Expand All @@ -284,18 +284,32 @@ impl Dispatcher {
}
}
} else {
sess
(*socket_addr).into()
}
} else if let Some(resolved) =
resolver.cached_for(socket_addr.ip()).await
{
(resolved, socket_addr.port())
.try_into()
.expect("must be valid domain")
} else {
(*socket_addr).into()
}
crate::session::SocksAddr::Domain(..) => sess,
}
} else {
sess
crate::session::SocksAddr::Domain(host, port) => {
(host.to_owned(), *port)
.try_into()
.expect("must be valid domain")
}
};
sess.destination = dest.clone();

// mutate packet for fake ip
let mut packet = packet;
packet.dst_addr = sess.destination.clone();
// resolve is done in OutboundDatagramImpl so it's fine to have
// (Domain, port) here. ideally the OutboundDatagramImpl should only
// do Ip though?
packet.dst_addr = dest;

let mode = *mode.lock().unwrap();

Expand Down
6 changes: 3 additions & 3 deletions clash_lib/src/app/dns/dummy_keys.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/// Test certificate and key
/// host: dns.example.com
/// TODO(#51): use real certificate and key
//! Test certificate and key
//! host: dns.example.com
//! TODO(#51): use real certificate and key

pub static TEST_CERT: &str = include_str!("test/test.cert");

Expand Down
3 changes: 3 additions & 0 deletions clash_lib/src/app/dns/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ pub trait ClashResolver: Sync + Send {
enhanced: bool,
) -> anyhow::Result<Option<std::net::Ipv6Addr>>;

async fn cached_for(&self, ip: std::net::IpAddr) -> Option<String>;

/// Used for DNS Server
async fn exchange(&self, message: op::Message) -> anyhow::Result<op::Message>;

/// Only used for look up fake IP
Expand Down
54 changes: 49 additions & 5 deletions clash_lib/src/app/dns/resolver/enhanced.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{
time::Duration,
};
use tokio::sync::RwLock;
use tracing::{debug, error, instrument, warn};
use tracing::{debug, error, instrument, trace, warn};

use hickory_proto::{op, rr};

Expand Down Expand Up @@ -46,6 +46,9 @@ pub struct EnhancedResolver {
policy: Option<trie::StringTrie<Vec<ThreadSafeDNSClient>>>,

fake_dns: Option<ThreadSafeFakeDns>,

reverse_lookup_cache:
Option<Arc<RwLock<lru_time_cache::LruCache<net::IpAddr, String>>>>,
}

impl EnhancedResolver {
Expand Down Expand Up @@ -75,6 +78,8 @@ impl EnhancedResolver {
policy: None,

fake_dns: None,

reverse_lookup_cache: None,
}
}

Expand All @@ -94,6 +99,8 @@ impl EnhancedResolver {
policy: None,

fake_dns: None,

reverse_lookup_cache: None,
});

Self {
Expand Down Expand Up @@ -199,6 +206,17 @@ impl EnhancedResolver {
}
_ => None,
},

reverse_lookup_cache: Some(Arc::new(RwLock::new(
lru_time_cache::LruCache::with_expiry_duration_and_capacity(
Duration::from_secs(3), /* should be shorter than TTL so
* client won't be connecting to a
* different server after the ip is
* reverse mapped to hostname and
* being resolved again */
4096,
),
))),
}
}

Expand Down Expand Up @@ -251,7 +269,7 @@ impl EnhancedResolver {
m.add_query(q);
m.set_recursion_desired(true);

match self.exchange(m).await {
match self.exchange(&m).await {
Ok(result) => {
let ip_list = EnhancedResolver::ip_list_of_message(&result);
if !ip_list.is_empty() {
Expand All @@ -264,14 +282,14 @@ impl EnhancedResolver {
}
}

async fn exchange(&self, message: op::Message) -> anyhow::Result<op::Message> {
async fn exchange(&self, message: &op::Message) -> anyhow::Result<op::Message> {
if let Some(q) = message.query() {
if let Some(lru) = &self.lru_cache {
if let Some(cached) = lru.read().await.peek(q.to_string().as_str()) {
return Ok(cached.clone());
}
}
self.exchange_no_cache(&message).await
self.exchange_no_cache(message).await
} else {
Err(anyhow!("invalid query"))
}
Expand Down Expand Up @@ -436,6 +454,13 @@ impl EnhancedResolver {
})
.collect()
}

async fn save_reverse_lookup(&self, ip: net::IpAddr, domain: String) {
if let Some(lru) = &self.reverse_lookup_cache {
trace!("reverse lookup cache insert: {} -> {}", ip, domain);
lru.write().await.insert(ip, domain);
}
}
}

#[async_trait]
Expand Down Expand Up @@ -545,8 +570,27 @@ impl ClashResolver for EnhancedResolver {
}
}

async fn cached_for(&self, ip: net::IpAddr) -> Option<String> {
if let Some(lru) = &self.reverse_lookup_cache {
if let Some(cached) = lru.read().await.peek(&ip) {
trace!("reverse lookup cache hit: {} -> {}", ip, cached);
return Some(cached.clone());
}
}

None
}

async fn exchange(&self, message: op::Message) -> anyhow::Result<op::Message> {
self.exchange(message).await
let rv = self.exchange(&message).await?;
let hostname = message.query().unwrap().name().to_ascii();
let ip_list = EnhancedResolver::ip_list_of_message(&rv);
if !ip_list.is_empty() {
for ip in ip_list {
self.save_reverse_lookup(ip, hostname.clone()).await;
}
}
Ok(rv)
}

fn ipv6(&self) -> bool {
Expand Down
4 changes: 4 additions & 0 deletions clash_lib/src/app/dns/resolver/system_linux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ impl ClashResolver for SystemResolver {
Ok(response.iter().map(|x| x.0).choose(&mut rand::thread_rng()))
}

async fn cached_for(&self, _: std::net::IpAddr) -> Option<String> {
None
}

async fn exchange(
&self,
_: hickory_proto::op::Message,
Expand Down
4 changes: 4 additions & 0 deletions clash_lib/src/app/dns/resolver/system_non_linux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ impl ClashResolver for SystemResolver {
Ok(response.into_iter().choose(&mut rand::thread_rng()))
}

async fn cached_for(&self, _: std::net::IpAddr) -> Option<String> {
None
}

async fn exchange(
&self,
_: hickory_proto::op::Message,
Expand Down
7 changes: 1 addition & 6 deletions clash_lib/src/app/router/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::app::router::rules::final_::Final;
use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration};

use hyper::Uri;
use tracing::{debug, error, info};
use tracing::{error, info};

use super::{
dns::ThreadSafeDNSResolver,
Expand Down Expand Up @@ -93,10 +93,6 @@ impl Router {
&& r.should_resolve_ip()
&& !sess_resolved
{
debug!(
"rule `{r}` resolving domain {} locally",
sess.destination.domain().unwrap()
);
if let Ok(Some(ip)) = self
.dns_resolver
.resolve(sess.destination.domain().unwrap(), false)
Expand All @@ -114,7 +110,6 @@ impl Router {
r.target(),
r.type_name()
);
debug!("matched rule details: {}", r);
return (r.target(), Some(r));
}
}
Expand Down
Loading