Skip to content

Commit

Permalink
[skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
zyj committed Sep 28, 2023
1 parent 05138cd commit 2ebcd83
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 29 deletions.
2 changes: 0 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ serde = "1.0.188"
serde_derive = "1.0.188"
regex = "1.9.5"
async-trait = "0.1.73"
tracing = "0.1.37"
thiserror = "1.0.48"
anyhow = "1.0.75"
futures = "0.3.28"
async-recursion = "1.0.5"
Expand Down
31 changes: 11 additions & 20 deletions src/handler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{config::RuleAction, filter, CONFIG, STDERR};
use anyhow::Error;
use async_recursion::async_recursion;
use futures::{
future::{self, MapErr, MapOk},
Expand All @@ -17,25 +18,15 @@ use trust_dns_server::{
server::{Request, RequestHandler, ResponseHandler, ResponseInfo},
};

#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Invalid OpCode {0:}")]
InvalidOpCode(OpCode),
#[error("Invalid MessageType {0:}")]
InvalidMessageType(MessageType),
#[error("IO error: {0:}")]
Io(#[from] std::io::Error),
}

/// DNS Request Handler
#[derive(Clone)]
pub struct Handler {
//pub counter: Arc<AtomicU64>,
}

impl Handler {
/// Create new handler from command-line options.
pub fn new() -> Self {
/// Create default handler.
pub fn default() -> Self {
Handler {
// counter: Arc::new(AtomicU64::new(0)),
}
Expand Down Expand Up @@ -151,22 +142,22 @@ impl Handler {
async fn do_handle_request<R: ResponseHandler>(
&self,
request: &Request,
response: R,
mut response: R,
) -> Result<ResponseInfo, Error> {
debug!(
STDERR,
"DNS requests are forwarded to [{}].",
request.query()
);
// make sure the request is a query
if request.op_code() != OpCode::Query {
return Err(Error::InvalidOpCode(request.op_code()));
// make sure the request is a query and the message type is a query
if request.op_code() != OpCode::Query || request.message_type() != MessageType::Query {
let builder = MessageResponseBuilder::from_message_request(request);
let mut header = Header::response_from_request(request.header());
header.set_response_code(ResponseCode::Refused);
let res = builder.build_no_records(header);
return Ok(response.send_response(res).await?);
}

// make sure the message type is a query
if request.message_type() != MessageType::Query {
return Err(Error::InvalidMessageType(request.message_type()));
}
self.do_handle_request_default(request, response).await
}
}
Expand Down
7 changes: 2 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,18 @@ lazy_static! {
};
}

const TCP_TIMEOUT: Duration = Duration::from_secs(10);

#[tokio::main]
async fn main() -> Result<()> {
let bind_socket = CONFIG.bind;
let request_handler = Handler::new();
let mut server = ServerFuture::new(request_handler);
let mut server = ServerFuture::new(Handler::default());

let bind = UdpSocket::bind(bind_socket);
info!(STDOUT, "Listening on UDP: {}", bind_socket);
server.register_socket(bind.await?);

let bind = TcpListener::bind(bind_socket);
info!(STDOUT, "Listening on TCP: {}", bind_socket);
server.register_listener(bind.await?, TCP_TIMEOUT);
server.register_listener(bind.await?, Duration::from_secs(10));

server.block_until_done().await?;

Expand Down

0 comments on commit 2ebcd83

Please sign in to comment.