diff --git a/Cargo.lock b/Cargo.lock index 09488a2..6e0f5b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -63,7 +63,7 @@ dependencies = [ "flate2", "futures-core", "h2", - "http", + "http 0.2.12", "httparse", "httpdate", "itoa", @@ -98,7 +98,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d22475596539443685426b6bdadb926ad0ecaefdfc5fb05e5e3441f15463c511" dependencies = [ "bytestring", - "http", + "http 0.2.12", "regex", "serde", "tracing", @@ -683,6 +683,12 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "data-encoding" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" + [[package]] name = "der" version = "0.6.1" @@ -870,6 +876,7 @@ checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-core", "futures-macro", + "futures-sink", "futures-task", "pin-project-lite", "pin-utils", @@ -930,7 +937,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.12", "indexmap 2.2.6", "slab", "tokio", @@ -986,6 +993,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-range" version = "0.1.5" @@ -2154,6 +2172,18 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "tokio-tungstenite" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6989540ced10490aaf14e6bad2e3d33728a2813310a0c71d1574304c49631cd" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.10" @@ -2198,6 +2228,24 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "859eb650cfee7434994602c3a68b25d77ad9e68c8a6cd491616ef86661382eb3" +[[package]] +name = "tungstenite" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e2ce1e47ed2994fd43b04c8f618008d4cabdd5ee34027cf14f9d918edd9c8" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "typemap-ors" version = "1.0.0" @@ -2285,6 +2333,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "uuid" version = "1.8.0" @@ -2308,12 +2362,13 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "vnts" -version = "1.2.9" +version = "1.2.11" dependencies = [ "actix-files", "actix-web", "actix-web-static-files", "aes-gcm", + "anyhow", "async-trait", "chrono", "clap", @@ -2342,6 +2397,7 @@ dependencies = [ "static-files", "thiserror", "tokio", + "tokio-tungstenite", "tokio-util", "uuid", ] diff --git a/Cargo.toml b/Cargo.toml index e94a28e..f35d20a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vnts" -version = "1.2.9" +version = "1.2.11" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -22,6 +22,7 @@ ring = { version = "0.17", optional = true } rand = "0.8" sha2 = { version = "0.10", features = ["oid"] } colored = "2.1" +anyhow = "1.0.82" thiserror = "1" chrono = "0.4" @@ -36,6 +37,7 @@ socket2 = { version = "0.5", features = ["all"] } actix-web = { version = "4.5", optional = true } actix-files = { version = "0.6", optional = true } actix-web-static-files = { version = "4.0.1", optional = true } +tokio-tungstenite = "0.23.1" serde = { version = "1", features = ["derive"] } crossbeam-utils = "0.8" diff --git a/src/core/server/mod.rs b/src/core/server/mod.rs index 8ef90fe..2c05d49 100644 --- a/src/core/server/mod.rs +++ b/src/core/server/mod.rs @@ -12,6 +12,7 @@ mod tcp; mod udp; #[cfg(feature = "web")] mod web; +mod websocket; pub async fn start( udp: std::net::UdpSocket, diff --git a/src/core/server/tcp.rs b/src/core/server/tcp.rs index 7f8c610..ed644e5 100644 --- a/src/core/server/tcp.rs +++ b/src/core/server/tcp.rs @@ -7,6 +7,8 @@ use tokio::net::tcp::OwnedReadHalf; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::{channel, Sender}; +const TCP_MAX_PACKET_SIZE: usize = (1 << 24) - 1; + pub async fn start(tcp: TcpListener, handler: PacketHandler) { if let Err(e) = accept(tcp, handler).await { log::error!("accept {:?}", e); @@ -17,24 +19,47 @@ async fn accept(tcp: TcpListener, handler: PacketHandler) -> io::Result<()> { loop { let (stream, addr) = tcp.accept().await?; let _ = stream.set_nodelay(true); - stream_handle(stream, addr, handler.clone()).await; + tokio::spawn(stream_handle(stream, addr, handler.clone())); } } async fn stream_handle(stream: TcpStream, addr: SocketAddr, handler: PacketHandler) { + { + let mut buf = [0u8; 1]; + match stream.peek(&mut buf).await { + Ok(len) => { + if len == 0 { + log::warn!("数据流读取失败 {}", addr); + return; + } + if buf[0] != 0 { + //可能是ws协议 + crate::core::server::websocket::handle_websocket_connection( + stream, addr, handler, + ) + .await; + return; + } + } + Err(e) => { + log::warn!("数据流读取失败 {:?} {}", e, addr); + return; + } + } + } + let (r, mut w) = stream.into_split(); let (sender, mut receiver) = channel::>(100); tokio::spawn(async move { while let Some(data) = receiver.recv().await { let len = data.len(); + if len > TCP_MAX_PACKET_SIZE { + log::warn!("超过了tcp的最大长度传输 地址{}", addr); + return; + } if let Err(e) = w - .write_all(&[ - (len >> 24) as u8, - (len >> 16) as u8, - (len >> 8) as u8, - len as u8, - ]) + .write_all(&[0, (len >> 16) as u8, (len >> 8) as u8, len as u8]) .await { log::info!("发送失败,链接终止:{:?},{:?}", addr, e); @@ -65,10 +90,11 @@ async fn tcp_read( let sender = Some(sender); loop { read.read_exact(&mut head).await?; - let len = ((head[0] as usize) << 24) - | ((head[1] as usize) << 16) - | ((head[2] as usize) << 8) - | head[3] as usize; + if head[0] != 0 { + log::warn!("tcp数据流错误 来源地址 {}", addr); + return Ok(()); + } + let len = ((head[1] as usize) << 16) | ((head[2] as usize) << 8) | head[3] as usize; if len < 12 || len > buf.len() { return Err(io::Error::new( io::ErrorKind::InvalidData, diff --git a/src/core/server/websocket/mod.rs b/src/core/server/websocket/mod.rs new file mode 100644 index 0000000..cc8b48f --- /dev/null +++ b/src/core/server/websocket/mod.rs @@ -0,0 +1,69 @@ +use crate::core::service::PacketHandler; +use crate::protocol::NetPacket; +use anyhow::Context; +use futures_util::{SinkExt, StreamExt}; +use std::net::SocketAddr; +use tokio::net::TcpStream; +use tokio::sync::mpsc::channel; +use tokio_tungstenite::accept_async; +use tokio_tungstenite::tungstenite::Message; + +pub async fn handle_websocket_connection( + stream: TcpStream, + addr: SocketAddr, + handler: PacketHandler, +) { + tokio::spawn(async move { + if let Err(e) = handle_websocket_connection0(stream, addr, handler).await { + log::warn!("websocket err {:?} {}", e, addr); + } + }); +} + +async fn handle_websocket_connection0( + stream: TcpStream, + addr: SocketAddr, + handler: PacketHandler, +) -> anyhow::Result<()> { + let ws_stream = accept_async(stream) + .await + .with_context(|| format!("Error during WebSocket handshake {}", addr))?; + + let (mut ws_write, mut ws_read) = ws_stream.split(); + + let (sender, mut receiver) = channel::>(100); + tokio::spawn(async move { + while let Some(data) = receiver.recv().await { + if let Err(e) = ws_write.send(Message::Binary(data)).await { + log::warn!("websocket err {:?} {}", e, addr); + break; + } + } + let _ = ws_write.close().await; + }); + let sender = Some(sender); + while let Some(msg) = ws_read.next().await { + let msg = msg.with_context(|| format!("Error during WebSocket read {}", addr))?; + match msg { + Message::Text(txt) => log::info!("Received text message: {} {}", txt, addr), + Message::Binary(mut data) => { + let packet = NetPacket::new0(data.len(), &mut data)?; + if let Some(rs) = handler.handle(packet, addr, &sender).await { + if sender + .as_ref() + .unwrap() + .send(rs.buffer().to_vec()) + .await + .is_err() + { + break; + } + } + } + Message::Ping(_) | Message::Pong(_) => (), + Message::Close(_) => break, + _ => {} + } + } + return Ok(()); +} diff --git a/src/main.rs b/src/main.rs index e3fa18f..21b81cc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -258,8 +258,8 @@ async fn main() { log::info!("监听udp端口: {:?}", port); println!("监听udp端口: {:?}", port); let tcp = create_tcp(port).unwrap(); - log::info!("监听tcp端口: {:?}", port); - println!("监听tcp端口: {:?}", port); + log::info!("监听tcp/ws端口: {:?}", port); + println!("监听tcp/ws端口: {:?}", port); #[cfg(feature = "web")] let http = if web_port != 0 { let http = create_tcp(web_port).unwrap(); diff --git a/src/protocol/body.rs b/src/protocol/body.rs index 7395be5..07d39c0 100644 --- a/src/protocol/body.rs +++ b/src/protocol/body.rs @@ -1,10 +1,264 @@ -#![allow(dead_code)] use std::{fmt, io}; pub const ENCRYPTION_RESERVED: usize = 16 + 32 + 12; pub const AES_GCM_ENCRYPTION_RESERVED: usize = 32; pub const RSA_ENCRYPTION_RESERVED: usize = 32; +pub const RANDOM_RESERVED: usize = 4; +pub const FINGER_RESERVED: usize = 12; +pub const TAG_RESERVED: usize = 16; + +/* ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| random(32) | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| finger(32) | +| finger(32) | +| finger(32) | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ +pub trait SecretTail { + fn buffer(&self) -> &[u8]; + fn exist_finger(&self) -> bool; + fn random_buf(&self) -> &[u8] { + let buf = self.buffer(); + let mut end = buf.len(); + if self.exist_finger() { + end -= FINGER_RESERVED; + } + &buf[end - RANDOM_RESERVED..end] + } + fn finger(&self) -> &[u8] { + if self.exist_finger() { + let buf = self.buffer(); + let end = buf.len(); + &buf[end - FINGER_RESERVED..end] + } else { + &[] + } + } +} + +pub trait SecretTailMut: SecretTail { + fn buffer_mut(&mut self) -> &mut [u8]; + fn set_random(&mut self, random: &[u8]) { + let f = self.exist_finger(); + let buf = self.buffer_mut(); + let mut end = buf.len(); + if f { + end -= FINGER_RESERVED; + } + buf[end - RANDOM_RESERVED..end].copy_from_slice(random); + } + fn set_finger(&mut self, finger: &[u8]) -> io::Result<()> { + if self.exist_finger() { + if finger.len() != FINGER_RESERVED { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "finger.len != 12", + )); + } + let buf = self.buffer_mut(); + let end = buf.len(); + buf[end - FINGER_RESERVED..end].copy_from_slice(finger); + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "not exist finger", + )) + } + } +} + +/* aead加密数据体 + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 数据体 | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | tag(32) | + | tag(32) | + | tag(32) | + | tag(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | random(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | finger(32) | + | finger(32) | + | finger(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + 注:finger用于快速校验数据是否被修改,上层可使用token、协议头参与计算finger, + 确保服务端和客户端都能感知修改(服务端不能解密也能校验指纹) +*/ +pub struct AEADSecretBody { + buffer: B, + exist_finger: bool, +} + +impl> AEADSecretBody { + pub fn new(buffer: B, exist_finger: bool) -> io::Result> { + let len = buffer.as_ref().len(); + let min_len = if exist_finger { + TAG_RESERVED + RANDOM_RESERVED + FINGER_RESERVED + } else { + TAG_RESERVED + RANDOM_RESERVED + }; + // 不能大于udp最大载荷长度 + if len < min_len || len > 65535 - 20 - 8 - 12 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("AEADSecretBody length overflow {}", len), + )); + } + Ok(AEADSecretBody { + buffer, + exist_finger, + }) + } + pub fn data(&self) -> &[u8] { + let mut end = self.buffer.as_ref().len() - TAG_RESERVED - RANDOM_RESERVED; + if self.exist_finger { + end -= FINGER_RESERVED; + } + &self.buffer.as_ref()[..end] + } + pub fn tag(&self) -> &[u8] { + let mut end = self.buffer.as_ref().len() - RANDOM_RESERVED; + if self.exist_finger { + end -= FINGER_RESERVED; + } + &self.buffer.as_ref()[end - TAG_RESERVED..end] + } +} + +impl> SecretTail for AEADSecretBody { + #[inline] + fn buffer(&self) -> &[u8] { + self.buffer.as_ref() + } + #[inline] + fn exist_finger(&self) -> bool { + self.exist_finger + } +} + +impl + AsMut<[u8]>> SecretTailMut for AEADSecretBody { + #[inline] + fn buffer_mut(&mut self) -> &mut [u8] { + self.buffer.as_mut() + } +} + +impl + AsMut<[u8]>> AEADSecretBody { + /// 数据部分 + pub fn data_mut(&mut self) -> &mut [u8] { + let mut end = self.buffer.as_ref().len() - RANDOM_RESERVED - TAG_RESERVED; + if self.exist_finger { + end -= FINGER_RESERVED; + } + &mut self.buffer.as_mut()[..end] + } + /// 数据和tag部分 + pub fn data_tag_mut(&mut self) -> &mut [u8] { + let mut end = self.buffer.as_ref().len() - RANDOM_RESERVED; + if self.exist_finger { + end -= FINGER_RESERVED; + } + &mut self.buffer.as_mut()[..end] + } + pub fn set_tag(&mut self, tag: &[u8]) -> io::Result<()> { + if tag.len() != 16 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "tag.len != 16")); + } + let mut end = self.buffer.as_ref().len() - RANDOM_RESERVED; + if self.exist_finger { + end -= FINGER_RESERVED; + } + self.buffer.as_mut()[end - TAG_RESERVED..end].copy_from_slice(tag); + Ok(()) + } +} + +/* 带随机数的加密数据体 + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 数据体 | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | random(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | finger(32) | + | finger(32) | + | finger(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + 注:finger用于快速校验数据是否被修改,上层可使用token、协议头参与计算finger, + 确保服务端和客户端都能感知修改(服务端不能解密也能校验指纹) +*/ +pub struct IVSecretBody { + buffer: B, + exist_finger: bool, +} + +impl> IVSecretBody { + pub fn new(buffer: B, exist_finger: bool) -> io::Result> { + let len = buffer.as_ref().len(); + let min_len = if exist_finger { + FINGER_RESERVED + RANDOM_RESERVED + } else { + RANDOM_RESERVED + }; + // 不能大于udp最大载荷长度 + if len < min_len || len > 65535 - 20 - 8 - 12 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("IVSecretBody length overflow {}", len), + )); + } + Ok(IVSecretBody { + buffer, + exist_finger, + }) + } + pub fn data(&self) -> &[u8] { + let mut end = self.buffer.as_ref().len() - RANDOM_RESERVED; + if self.exist_finger { + end -= FINGER_RESERVED; + } + &self.buffer.as_ref()[..end] + } +} + +impl + AsMut<[u8]>> IVSecretBody { + pub fn data_mut(&mut self) -> &mut [u8] { + let mut end = self.buffer.as_ref().len() - RANDOM_RESERVED; + if self.exist_finger { + end -= FINGER_RESERVED; + } + &mut self.buffer.as_mut()[..end] + } +} + +impl> SecretTail for IVSecretBody { + #[inline] + fn buffer(&self) -> &[u8] { + self.buffer.as_ref() + } + #[inline] + fn exist_finger(&self) -> bool { + self.exist_finger + } +} + +impl + AsMut<[u8]>> SecretTailMut for IVSecretBody { + #[inline] + fn buffer_mut(&mut self) -> &mut [u8] { + self.buffer.as_mut() + } +} + /* aes_gcm加密数据体 0 15 31 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 @@ -280,7 +534,7 @@ impl> RsaSecretBody { pub fn new(buffer: B) -> io::Result> { let len = buffer.as_ref().len(); // 不能大于udp最大载荷长度 - if !(32..=65535 - 20 - 8 - 12).contains(&len) { + if len < 32 || len > 65535 - 20 - 8 - 12 { return Err(io::Error::new( io::ErrorKind::InvalidData, "length overflow", @@ -305,7 +559,7 @@ impl> RsaSecretBody { &self.buffer.as_ref()[end..] } pub fn buffer(&self) -> &[u8] { - self.buffer.as_ref() + &self.buffer.as_ref() } } diff --git a/src/protocol/control_packet.rs b/src/protocol/control_packet.rs index 3ae37b5..dbd1bd6 100644 --- a/src/protocol/control_packet.rs +++ b/src/protocol/control_packet.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] use std::net::Ipv4Addr; use std::{fmt, io}; @@ -6,14 +5,20 @@ use std::{fmt, io}; pub enum Protocol { /// ping请求 /* - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | time | echo | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | time | echo | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ Ping, - /// 维持连接,内容同ping + /* + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | time | echo | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ Pong, /// 打洞请求 PunchRequest, @@ -39,9 +44,9 @@ impl From for Protocol { } } -impl From for u8 { - fn from(val: Protocol) -> Self { - match val { +impl Into for Protocol { + fn into(self) -> u8 { + match self { Protocol::Ping => 1, Protocol::Pong => 2, Protocol::PunchRequest => 3, @@ -86,8 +91,8 @@ pub type PongPacket = PingPacket; impl> PingPacket { pub fn new(buffer: B) -> io::Result> { let len = buffer.as_ref().len(); - if len != 4 { - return Err(io::Error::new(io::ErrorKind::InvalidData, "len != 4")); + if len < 4 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "len < 4")); } Ok(PingPacket { buffer }) } @@ -127,8 +132,8 @@ pub struct AddrPacket { impl> AddrPacket { pub fn new(buffer: B) -> io::Result> { let len = buffer.as_ref().len(); - if len != 6 { - return Err(io::Error::new(io::ErrorKind::InvalidData, "len != 6")); + if len < 6 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "len < 6")); } Ok(AddrPacket { buffer }) } diff --git a/src/protocol/error_packet.rs b/src/protocol/error_packet.rs index 018ca14..8014f44 100644 --- a/src/protocol/error_packet.rs +++ b/src/protocol/error_packet.rs @@ -1,6 +1,4 @@ -#![allow(dead_code)] - -use tokio::io; +use std::io; #[derive(Eq, PartialEq, Copy, Clone, Debug)] pub enum Protocol { diff --git a/src/protocol/extension.rs b/src/protocol/extension.rs new file mode 100644 index 0000000..1e43e06 --- /dev/null +++ b/src/protocol/extension.rs @@ -0,0 +1,141 @@ +/* 扩展协议 + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 扩展数据(n) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 扩展数据(n) | type(8) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + 注:扩展数据的长度由type决定 +*/ + +use anyhow::anyhow; +use std::io; + +use crate::protocol::NetPacket; + +#[derive(Eq, PartialEq, Copy, Clone, Debug)] +pub enum ExtensionTailType { + Compression, + Unknown(u8), +} + +impl From for ExtensionTailType { + fn from(value: u8) -> Self { + if value == 0 { + ExtensionTailType::Compression + } else { + ExtensionTailType::Unknown(value) + } + } +} + +pub enum ExtensionTailPacket { + Compression(CompressionExtensionTail), + Unknown, +} + +impl + AsMut<[u8]>> NetPacket { + /// 分离尾部数据 + pub fn split_tail_packet(&mut self) -> anyhow::Result> { + if self.is_extension() { + let payload = self.payload(); + if let Some(v) = payload.last() { + return match ExtensionTailType::from(*v) { + ExtensionTailType::Compression => { + let data_len = self.data_len - 4; + self.set_data_len(data_len)?; + self.set_extension_flag(false); + Ok(ExtensionTailPacket::Compression( + CompressionExtensionTail::new( + &self.raw_buffer()[data_len..data_len + 4], + ), + )) + } + ExtensionTailType::Unknown(e) => Err(anyhow!("unknown extension {}", e)), + }; + } + } + Err(anyhow!("not extension")) + } + /// 追加压缩扩展 + pub fn append_compression_extension_tail( + &mut self, + ) -> io::Result> { + let len = self.data_len; + //增加数据长度 + self.set_data_len(self.data_len + 4)?; + self.set_extension_flag(true); + let mut tail = CompressionExtensionTail::new(&mut self.buffer_mut()[len..]); + tail.init(); + return Ok(tail); + } +} + +/* 扩展协议 + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | algorithm(8) | | type(8) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + 注:扩展数据的长度由type决定 +*/ +/// 压缩扩展 +pub struct CompressionExtensionTail { + buffer: B, +} + +impl> CompressionExtensionTail { + pub fn new(buffer: B) -> CompressionExtensionTail { + assert_eq!(buffer.as_ref().len(), 4); + CompressionExtensionTail { buffer } + } +} + +impl> CompressionExtensionTail { + pub fn algorithm(&self) -> CompressionAlgorithm { + self.buffer.as_ref()[0].into() + } +} + +impl + AsMut<[u8]>> CompressionExtensionTail { + pub fn init(&mut self) { + self.buffer.as_mut().fill(0); + } + pub fn set_algorithm(&mut self, algorithm: CompressionAlgorithm) { + self.buffer.as_mut()[0] = algorithm.into() + } +} + +#[derive(Eq, PartialEq, Copy, Clone, Debug)] +pub enum CompressionAlgorithm { + #[cfg(feature = "lz4_compress")] + Lz4, + #[cfg(feature = "zstd_compress")] + Zstd, + Unknown(u8), +} + +impl From for CompressionAlgorithm { + fn from(value: u8) -> Self { + match value { + #[cfg(feature = "lz4_compress")] + 1 => CompressionAlgorithm::Lz4, + #[cfg(feature = "zstd_compress")] + 2 => CompressionAlgorithm::Zstd, + v => CompressionAlgorithm::Unknown(v), + } + } +} + +impl From for u8 { + fn from(value: CompressionAlgorithm) -> Self { + match value { + #[cfg(feature = "lz4_compress")] + CompressionAlgorithm::Lz4 => 1, + #[cfg(feature = "zstd_compress")] + CompressionAlgorithm::Zstd => 2, + CompressionAlgorithm::Unknown(val) => val, + } + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 4edc6d3..4fbc95a 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,7 +1,6 @@ #![allow(dead_code)] use crate::protocol::body::ENCRYPTION_RESERVED; -use std::fmt::Formatter; use std::net::Ipv4Addr; use std::{fmt, io}; @@ -9,7 +8,7 @@ use std::{fmt, io}; 0 15 31 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - |e |s |u |u| 版本(4) | 协议(8) | 上层协议(8) | 初始ttl(4) | 生存时间(4) | + |e |s |x |u| 版本(4) | 协议(8) | 上层协议(8) | 初始ttl(4) | 生存时间(4) | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | 源ip地址(32) | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ @@ -17,13 +16,14 @@ use std::{fmt, io}; +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | 数据体 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - 注:e为是否加密标志,s为服务端通信包标志,u未使用 + 注:e为是否加密标志,s为服务端通信包标志,x扩展标志,u未使用 */ pub const HEAD_LEN: usize = 12; pub mod body; pub mod control_packet; pub mod error_packet; +pub mod extension; pub mod ip_turn_packet; pub mod other_turn_packet; pub mod service_packet; @@ -44,9 +44,9 @@ impl From for Version { } } -impl From for u8 { - fn from(val: Version) -> Self { - match val { +impl Into for Version { + fn into(self) -> u8 { + match self { Version::V2 => 2, Version::Unknown(val) => val, } @@ -81,9 +81,9 @@ impl From for Protocol { } } -impl From for u8 { - fn from(val: Protocol) -> Self { - match val { +impl Into for Protocol { + fn into(self) -> u8 { + match self { Protocol::Service => 1, Protocol::Error => 2, Protocol::Control => 3, @@ -104,6 +104,10 @@ pub struct NetPacket { } impl> NetPacket { + pub fn unchecked(buffer: B) -> Self { + let data_len = buffer.as_ref().len(); + Self { data_len, buffer } + } pub fn new(buffer: B) -> io::Result> { let data_len = buffer.as_ref().len(); Self::new0(data_len, buffer) @@ -126,14 +130,15 @@ impl> NetPacket { "length overflow", )); } - if HEAD_LEN > data_len { + if data_len < 12 { return Err(io::Error::new( io::ErrorKind::InvalidData, - "length overflow", + "data_len too short", )); } Ok(NetPacket { data_len, buffer }) } + #[inline] pub fn buffer(&self) -> &[u8] { &self.buffer.as_ref()[..self.data_len] } @@ -160,6 +165,10 @@ impl> NetPacket { pub fn is_gateway(&self) -> bool { self.buffer.as_ref()[0] & 0x40 == 0x40 } + /// 扩展协议 + pub fn is_extension(&self) -> bool { + self.buffer.as_ref()[0] & 0x20 == 0x20 + } pub fn version(&self) -> Version { Version::from(self.buffer.as_ref()[0] & 0x0F) } @@ -192,6 +201,9 @@ impl> NetPacket { } impl + AsMut<[u8]>> NetPacket { + pub fn head_mut(&mut self) -> &mut [u8] { + &mut self.buffer.as_mut()[..12] + } pub fn buffer_mut(&mut self) -> &mut [u8] { &mut self.buffer.as_mut()[..self.data_len] } @@ -204,12 +216,18 @@ impl + AsMut<[u8]>> NetPacket { } pub fn set_gateway_flag(&mut self, is_gateway: bool) { if is_gateway { - // 后面的版本再改为0x40,改了之后不兼容1.2.5之前的版本 - self.buffer.as_mut()[0] = self.buffer.as_ref()[0] | 0x50 + self.buffer.as_mut()[0] = self.buffer.as_ref()[0] | 0x40 } else { self.buffer.as_mut()[0] = self.buffer.as_ref()[0] & 0xBF }; } + pub fn set_extension_flag(&mut self, is_extension: bool) { + if is_extension { + self.buffer.as_mut()[0] = self.buffer.as_ref()[0] | 0x20 + } else { + self.buffer.as_mut()[0] = self.buffer.as_ref()[0] & 0xDF + }; + } pub fn set_default_version(&mut self) { let v: u8 = Version::V2.into(); self.buffer.as_mut()[0] = (self.buffer.as_ref()[0] & 0xF0) | (0x0F & v); @@ -266,6 +284,10 @@ impl + AsMut<[u8]>> NetPacket { self.data_len = data_len; Ok(()) } + pub fn set_payload_len(&mut self, payload_len: usize) -> io::Result<()> { + let data_len = HEAD_LEN + payload_len; + self.set_data_len(data_len) + } pub fn set_data_len_max(&mut self) { self.data_len = self.buffer.as_ref().len(); } @@ -287,19 +309,3 @@ impl> fmt::Debug for NetPacket { .finish() } } - -impl> fmt::Display for NetPacket { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("NetPacket") - .field("version", &self.version()) - .field("gateway", &self.is_gateway()) - .field("encrypt", &self.is_encrypt()) - .field("protocol", &self.protocol()) - .field("transport_protocol", &self.transport_protocol()) - .field("ttl", &self.ttl()) - .field("source_ttl", &self.source_ttl()) - .field("source", &self.source()) - .field("destination", &self.destination()) - .finish() - } -}