Skip to content

Commit

Permalink
Use trust-dns crates
Browse files Browse the repository at this point in the history
- Uses trust dns to create the DNS server, and answer to queries
- Clean code
  • Loading branch information
LIAUD Corentin authored and cocool97 committed Nov 27, 2022
1 parent 2d17c33 commit 4ca543b
Show file tree
Hide file tree
Showing 12 changed files with 955 additions and 179 deletions.
824 changes: 781 additions & 43 deletions Cargo.lock

Large diffs are not rendered by default.

53 changes: 17 additions & 36 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ license = "MIT"
name = "rhole"
readme = "README.md"
repository = "https://github.com/cocool97/rhole"
version = "0.1.1"
version = "0.1.2"

[[bin]]
name = "rhole"
Expand All @@ -19,41 +19,22 @@ opt-level = 'z'
strip = "debuginfo"

[dependencies]
anyhow = { version = "= 1.0.66", default-features = false, features = ["std"] }
bytes = { version = "= 1.2.1", default_features = false }
clap = { version = "= 4.0.22", default-features = false, features = [
"color",
"derive",
"std",
"suggestions",
] }
dns-message-parser = { version = "= 0.7.0", default-features = false }
env_logger = { version = "= 0.9.3", default-features = false, features = [
"atty",
"humantime",
"termcolor",
] }
log = { version = "= 0.4.17", default-features = false }
regex = { version = "= 1.7.0", default-features = false, features = [
"perf",
"std",
"unicode",
] }
reqwest = { version = "= 0.11.12", default-features = false, features = ["json"] }
serde = { version = "= 1.0.147", default-features = false, features = [
"derive",
"std",
] }
serde_yaml = { version = "= 0.9.14", default-features = false }
sled = { version = "= 0.34.7", default-features = false, features = [
"no_metrics",
] }
tokio = { version = "= 1.21.2", default-features = false, features = [
"fs",
"macros",
"net",
"rt-multi-thread",
] }
anyhow = { version = "= 1.0.66" }
async-trait = { version = "= 0.1.58" }
bytes = { version = "= 1.3.0" }
clap = { version = "= 4.0.27", features = ["derive"] }
dns-message-parser = { version = "= 0.7.0" }
env_logger = { version = "= 0.10.0" }
log = { version = "= 0.4.17" }
regex = { version = "= 1.7.0" }
reqwest = { version = "= 0.11.13" }
serde = { version = "= 1.0.147", features = ["derive"] }
serde_yaml = { version = "= 0.9.14" }
sled = { version = "= 0.34.7" }
tokio = { version = "= 1.22.0", features = ["fs", "macros", "rt-multi-thread"] }
trust-dns-client = { version = "= 0.22.0" }
trust-dns-resolver = { version = "= 0.22.0", features = ["tokio-runtime"] }
trust-dns-server = { version = "= 0.22.0", features = ["tokio-rustls"] }

[package.metadata.generate-rpm]
assets = [
Expand Down
17 changes: 17 additions & 0 deletions config-dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
database:
path: "rhole.db"
proxy_server:
addr: "8.8.8.8"
port: 53
sources:
update_interval: 5
entries:
- source_type: !File
location: hosts.txt
comment: Global hosts file
# - source_type: !Network
# location: http://sbc.io/hosts/alternates/fakenews-gambling-porn-social/hosts
# comment: Remote hosts file
net:
listen_addr: "127.0.0.1"
listen_port: 5053
5 changes: 4 additions & 1 deletion src/controllers/blacklist_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use regex::RegexBuilder;
use reqwest::Url;
use sled::Db;

use crate::{models::{DatabaseConfig, SourceEntry, SourceType}, utils};
use crate::{
models::{DatabaseConfig, SourceEntry, SourceType},
utils,
};

use super::NetworkController;

Expand Down
2 changes: 1 addition & 1 deletion src/controllers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ mod network_controller;
mod requests_controller;
pub use blacklist_controller::BlacklistController;
pub use network_controller::NetworkController;
pub use requests_controller::InboundConnectionsController;
pub use requests_controller::RequestsController;
159 changes: 75 additions & 84 deletions src/controllers/requests_controller.rs
Original file line number Diff line number Diff line change
@@ -1,108 +1,99 @@
use std::{net::SocketAddr, ops::Deref, sync::Arc};

use anyhow::{anyhow, Result};
use bytes::Bytes;
use dns_message_parser::rr::{A, RR};
use anyhow::Result;
use sled::Db;
use tokio::net::UdpSocket;
use trust_dns_client::op::{MessageType, ResponseCode};
use trust_dns_resolver::{
config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts},
name_server::{GenericConnection, GenericConnectionProvider, TokioRuntime},
AsyncResolver,
};
use trust_dns_server::{
authority::MessageResponseBuilder,
server::{Request, RequestHandler, ResponseHandler, ResponseInfo},
};

use crate::models::ProxyServer;
use crate::models::{dns_default_response, ProxyServer};

pub struct InboundConnectionsController {
proxy_server: Arc<ProxyServer>,
blacklist: Arc<Db>,
pub struct RequestsController {
pub blacklist: Db,
resolver: AsyncResolver<GenericConnection, GenericConnectionProvider<TokioRuntime>>,
}

impl InboundConnectionsController {
pub fn new(proxy_server: ProxyServer, blacklist: Db) -> Self {
Self {
proxy_server: Arc::new(proxy_server),
blacklist: Arc::new(blacklist),
}
}

pub async fn listen(self, addr: &str, port: u16) -> Result<()> {
let socket = Arc::new(UdpSocket::bind((addr, port)).await?);
loop {
let mut buffer = [0u8; 4096];
impl RequestsController {
pub async fn new(blacklist: Db, proxy: ProxyServer) -> Result<Self> {
let mut resolver_config = ResolverConfig::default();
resolver_config.add_name_server(NameServerConfig::new(proxy.try_into()?, Protocol::Udp));
let resolver = AsyncResolver::tokio(resolver_config, ResolverOpts::default())?;

// TODO: handle errors here
let (number_of_bytes, origin_addr) = socket.recv_from(&mut buffer).await?;
log::debug!("Received request from {}", origin_addr);

let blacklist = self.blacklist.clone();
let proxy = self.proxy_server.clone();
let core_socket = socket.clone();
let bytes = bytes::Bytes::copy_from_slice(&buffer[..number_of_bytes]);
tokio::spawn(async move {
if let Err(e) = Self::handle_packet(
core_socket.deref(),
bytes,
blacklist.deref(),
origin_addr,
proxy.deref(),
)
.await
{
log::error!("Error when handling packet: {e}");
}
});
}
Ok(Self {
blacklist,
resolver,
})
}

async fn handle_packet(
socket: &UdpSocket,
buffer: Bytes,
blacklist: &Db,
origin_addr: SocketAddr,
proxy: &ProxyServer,
) -> Result<()> {
let mut packet = dns_message_parser::Dns::decode(buffer.clone())?;
let question = packet
.questions
.first()
.ok_or_else(|| anyhow!("no questions in request..."))?;
async fn inner_handle_request<R: ResponseHandler>(
&self,
request: &Request,
response_handle: &mut R,
) -> Result<ResponseInfo> {
let query = request.query();
let query_question = query.name().to_string();
let query_type = query.query_type();

// TODO: Treat all questions !
// TODO: Check trailing dots
let mut rev_address = String::new();
for component in question
.domain_name
.to_string()
.trim_end_matches('.')
.split('.')
.rev()
{
for component in query_question.trim_end_matches('.').split('.').rev() {
rev_address.push_str(component);

if let Ok(Some(_)) = blacklist.get(&rev_address) {
log::warn!(
"[{}] Domain {} is blacklisted. Ignoring it.",
origin_addr.ip(),
question.domain_name
);
if let Ok(Some(_)) = self.blacklist.get(&rev_address) {
log::warn!("[{}] Blacklisted domain {}", request.src(), query_question);

packet.answers = vec![RR::A(A {
domain_name: question.domain_name.clone(),
ttl: 86400,
ipv4_addr: "0.0.0.0".parse()?,
})];
let response = dns_default_response(request, ResponseCode::Refused);

socket.send_to(&packet.encode()?, origin_addr).await?;
return Ok(());
return Ok(response_handle.send_response(response).await?);
}

rev_address.push('.');
}

let proxy_socket = UdpSocket::bind("0.0.0.0:0").await?;
proxy_socket.send_to(&buffer, proxy.to_addr()).await?;
let mut res_buffer = [0u8; 4096];
let received_size = proxy_socket.recv(&mut res_buffer).await?;
socket
.send_to(&res_buffer[..received_size], origin_addr)
.await?;
let response = self.resolver.lookup(query.name(), query_type).await?;

let response_builder = MessageResponseBuilder::from_message_request(request);
let mut response = response_builder.build(
*request.header(),
response.records(),
vec![],
vec![],
request.additionals(),
);

Ok(())
response
.header_mut()
.set_message_type(MessageType::Response);

Ok(response_handle.send_response(response).await?)
}
}

#[async_trait::async_trait]
impl RequestHandler for RequestsController {
async fn handle_request<R: ResponseHandler>(
&self,
request: &Request,
mut response_handle: R,
) -> ResponseInfo {
match self
.inner_handle_request(request, &mut response_handle)
.await
{
Ok(r) => r,
Err(_) => {
let response = dns_default_response(request, ResponseCode::ServFail);
response_handle
.send_response(response)
.await
.expect("Cannot send response to client")
}
}
}
}
22 changes: 14 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ mod controllers;
mod models;
mod utils;

use crate::controllers::RequestsController;
pub use crate::models::Config;

use anyhow::Result;
use clap::Parser;
use controllers::{BlacklistController, InboundConnectionsController};
use controllers::BlacklistController;
use models::Opts;
use tokio::net::UdpSocket;
use trust_dns_server::ServerFuture;

#[tokio::main]
async fn main() -> Result<()> {
Expand All @@ -24,14 +27,17 @@ async fn main() -> Result<()> {
let blacklist_controller =
BlacklistController::init_from_sources(config.sources.entries, config.database).await?;

let inbound_connections_controller = InboundConnectionsController::new(
config.proxy_server,
blacklist_controller.get_blacklist(),
);
let socket = UdpSocket::bind((config.net.listen_addr.as_str(), config.net.listen_port)).await?;

inbound_connections_controller
.listen(config.net.listen_addr.as_str(), config.net.listen_port)
.await?;
let mut server = ServerFuture::new(
RequestsController::new(
blacklist_controller.get_blacklist(),
config.proxy_server.clone(),
)
.await?,
);
server.register_socket(socket);
server.block_until_done().await?;

Ok(())
}
20 changes: 15 additions & 5 deletions src/models/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::{fmt::Display, path::Path};
use std::{
fmt::Display,
net::{AddrParseError, Ipv4Addr, SocketAddr, SocketAddrV4},
path::Path,
str::FromStr,
};

use anyhow::Result;
use serde::Deserialize;
Expand All @@ -23,7 +28,7 @@ pub struct NetConfig {
pub listen_port: u16,
}

#[derive(Deserialize)]
#[derive(Clone, Deserialize)]
pub struct ProxyServer {
pub addr: String,
pub port: u16,
Expand Down Expand Up @@ -64,9 +69,14 @@ impl Display for SourceType {
}
}

impl ProxyServer {
pub fn to_addr(&self) -> String {
format!("{}:{}", self.addr, self.port)
impl TryInto<SocketAddr> for ProxyServer {
type Error = AddrParseError;

fn try_into(self) -> Result<SocketAddr, Self::Error> {
Ok(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::from_str(&self.addr)?,
self.port,
)))
}
}

Expand Down
27 changes: 27 additions & 0 deletions src/models/dns_default_response.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use trust_dns_client::{
op::{MessageType, ResponseCode},
rr::Record,
};
use trust_dns_server::{
authority::{MessageResponse, MessageResponseBuilder},
server::Request,
};

pub fn dns_default_response(
request: &Request,
response_code: ResponseCode,
) -> MessageResponse<
impl Iterator<Item = &Record> + Send,
impl Iterator<Item = &Record> + Send,
impl Iterator<Item = &Record> + Send,
impl Iterator<Item = &Record> + Send,
> {
let response_builder = MessageResponseBuilder::from_message_request(request);
let mut response = response_builder.error_msg(request.header(), response_code);

response
.header_mut()
.set_message_type(MessageType::Response);

response
}
Loading

0 comments on commit 4ca543b

Please sign in to comment.