diff --git a/socksx/src/common/addresses.rs b/socksx/src/common/addresses.rs index 10f455c..c51ec61 100644 --- a/socksx/src/common/addresses.rs +++ b/socksx/src/common/addresses.rs @@ -1,19 +1,28 @@ -use crate::{constants::*, Credentials}; -use anyhow::Result; use std::convert::{TryFrom, TryInto}; use std::net::{IpAddr, SocketAddr}; + +use anyhow::Result; use tokio::io::{AsyncRead, AsyncReadExt}; use url::Url; -#[derive(Clone, Debug)] +use crate::{constants::*, Credentials}; + +/// Represents a SOCKS proxy address. +#[derive(Clone, Debug, PartialEq)] pub struct ProxyAddress { + /// The version of the SOCKS protocol. pub socks_version: u8, + /// The hostname or IP address of the proxy. pub host: String, + /// The port number of the proxy. pub port: u16, + /// Optional credentials for authentication. pub credentials: Option, } + impl ProxyAddress { + /// Creates a new `ProxyAddress` instance. pub fn new( socks_version: u8, host: String, @@ -28,12 +37,15 @@ impl ProxyAddress { } } + /// Creates a root `ProxyAddress` with predefined settings. pub fn root() -> Self { ProxyAddress::new(6, String::from("root"), 1080, None) } } + impl ToString for ProxyAddress { + // Converts the `ProxyAddress` to a string representation. fn to_string(&self) -> String { format!("socks{}://{}:{}", self.socks_version, self.host, self.port) } @@ -42,6 +54,7 @@ impl ToString for ProxyAddress { impl TryFrom for ProxyAddress { type Error = anyhow::Error; + // Converts a string to a `ProxyAddress`. fn try_from(proxy_addr: String) -> Result { let proxy_addr = Url::parse(&proxy_addr)?; @@ -74,16 +87,18 @@ impl TryFrom for ProxyAddress { } } +/// Represents a network address, which could be either a domain name or an IP address. #[derive(Clone, Debug)] pub enum Address { + /// An address represented by a domain name. Domainname { host: String, port: u16 }, + /// An address represented by an IP address. Ip(SocketAddr), } + impl Address { - /// - /// - /// + /// Creates a new `Address` instance. pub fn new>( host: S, port: u16, @@ -97,9 +112,7 @@ impl Address { } } - /// - /// - /// + /// Converts the `Address` into a byte sequence compatible with the SOCKS protocol. pub fn as_socks_bytes(&self) -> Vec { let mut bytes = vec![]; @@ -134,6 +147,7 @@ impl Address { } impl ToString for Address { + // Converts the `Address` to a string representation. fn to_string(&self) -> String { match self { Address::Domainname { host, port } => format!("{}:{}", host, port), @@ -142,6 +156,7 @@ impl ToString for Address { } } +/// Tries to convert a `SocketAddr` into an `Address`. impl TryFrom for Address { type Error = anyhow::Error; @@ -150,6 +165,7 @@ impl TryFrom for Address { } } +/// Tries to convert a `String` into an `Address`. impl TryFrom for Address { type Error = anyhow::Error; @@ -162,6 +178,7 @@ impl TryFrom for Address { } } +/// Tries to convert a `ProxyAddress` into an `Address`. impl TryFrom<&ProxyAddress> for Address { type Error = anyhow::Error; @@ -170,9 +187,7 @@ impl TryFrom<&ProxyAddress> for Address { } } -/// -/// -/// +/// Reads the destination address from a stream and returns it as an `Address`. pub async fn read_address(stream: &mut S) -> Result
where S: AsyncRead + Unpin, @@ -214,3 +229,122 @@ where Ok(Address::new(dst_addr, dst_port)) } + +#[cfg(test)] +mod tests { + use std::net::SocketAddr; + + use anyhow::Result; + + use super::*; + + #[test] + fn test_proxy_address_new() { + let proxy_address = ProxyAddress::new(5, "localhost".to_string(), 1080, None); + assert_eq!(proxy_address.socks_version, 5); + assert_eq!(proxy_address.host, "localhost"); + assert_eq!(proxy_address.port, 1080); + assert!(proxy_address.credentials.is_none()); + } + + #[test] + fn test_proxy_address_root() { + let root_address = ProxyAddress::root(); + assert_eq!(root_address.socks_version, 6); + assert_eq!(root_address.host, "root"); + assert_eq!(root_address.port, 1080); + assert!(root_address.credentials.is_none()); + } + + #[test] + fn test_address_new_domain() { + let address = Address::new("example.com", 80); + match address { + Address::Domainname { host, port } => { + assert_eq!(host, "example.com"); + assert_eq!(port, 80); + }, + _ => panic!("Expected a domain name address"), + } + } + + #[test] + fn test_address_new_ip() { + let address = Address::new("192.168.1.1", 22); + match address { + Address::Ip(socket_addr) => { + assert_eq!(socket_addr.ip().to_string(), "192.168.1.1"); + assert_eq!(socket_addr.port(), 22); + }, + _ => panic!("Expected an IP address"), + } + } + + #[test] + fn test_proxy_address_try_from_valid_string() -> Result<()> { + let proxy_str = "socks5://localhost:1080".to_string(); + let proxy_address: ProxyAddress = proxy_str.try_into()?; + assert_eq!(proxy_address.socks_version, SOCKS_VER_5); + assert_eq!(proxy_address.host, "localhost"); + assert_eq!(proxy_address.port, 1080); + Ok(()) + } + + #[test] + fn test_proxy_address_try_from_invalid_string() { + let proxy_str = "invalid://localhost:1080".to_string(); + let result: Result = proxy_str.try_into(); + assert!(result.is_err()); + } + + #[test] + fn test_address_try_from_valid_string() -> Result<()> { + let addr_str = "localhost:8000".to_string(); + let address: Address = addr_str.try_into()?; + match address { + Address::Domainname { host, port } => { + assert_eq!(host, "localhost"); + assert_eq!(port, 8000); + }, + _ => panic!("Expected a domain name address"), + } + Ok(()) + } + + #[test] + fn test_address_try_from_invalid_string() { + let addr_str = "localhost&8000".to_string(); + let result: Result
= addr_str.try_into(); + assert!(result.is_err()); + } + + #[test] + fn test_address_try_from_socket_addr() -> Result<()> { + let socket_addr: SocketAddr = "192.168.1.1:22".parse()?; + let address: Address = socket_addr.try_into()?; + match address { + Address::Ip(addr) => { + assert_eq!(addr.ip().to_string(), "192.168.1.1"); + assert_eq!(addr.port(), 22); + }, + _ => panic!("Expected an IP address"), + } + Ok(()) + } + + #[test] + fn test_address_try_from_proxy_address() -> Result<()> { + let proxy_address = ProxyAddress::new(5, "localhost".to_string(), 1080, None); + let address: Address = (&proxy_address).try_into()?; + match address { + Address::Domainname { host, port } => { + assert_eq!(host, "localhost"); + assert_eq!(port, 1080); + }, + _ => panic!("Expected a domain name address"), + } + Ok(()) + } + + // TODO: Add tests for `read_address` function once we have a way to mock the `AsyncRead`. +} diff --git a/socksx/src/common/constants.rs b/socksx/src/common/constants.rs index bbba28c..5807c4d 100644 --- a/socksx/src/common/constants.rs +++ b/socksx/src/common/constants.rs @@ -1,28 +1,50 @@ +/// SOCKS protocol version 5 identifier. pub const SOCKS_VER_5: u8 = 0x05u8; +/// SOCKS protocol version 6 identifier. pub const SOCKS_VER_6: u8 = 0x06u8; +/// Version identifier for SOCKS authentication. pub const SOCKS_AUTH_VER: u8 = 0x01u8; +/// Code for no authentication required. pub const SOCKS_AUTH_NOT_REQUIRED: u8 = 0x00u8; +/// Code for username/password authentication. pub const SOCKS_AUTH_USERNAME_PASSWORD: u8 = 0x02u8; +/// Code for no acceptable authentication methods. pub const SOCKS_AUTH_NO_ACCEPTABLE_METHODS: u8 = 0xFFu8; +/// Code for successful authentication. pub const SOCKS_AUTH_SUCCESS: u8 = 0x00u8; +/// Code for failed authentication. pub const SOCKS_AUTH_FAILED: u8 = 0x01u8; +/// Option kind for stack in SOCKS protocol. pub const SOCKS_OKIND_STACK: u16 = 0x01u16; +/// Option kind for advertising authentication methods. pub const SOCKS_OKIND_AUTH_METH_ADV: u16 = 0x02u16; +/// Option kind for selecting authentication methods. pub const SOCKS_OKIND_AUTH_METH_SEL: u16 = 0x03u16; +/// Option kind for authentication data. pub const SOCKS_OKIND_AUTH_DATA: u16 = 0x04u16; +/// Command code for no operation. pub const SOCKS_CMD_NOOP: u8 = 0x00u8; +/// Command code for establishing a TCP/IP stream connection. pub const SOCKS_CMD_CONNECT: u8 = 0x01u8; +/// Command code for establishing a TCP/IP port binding. pub const SOCKS_CMD_BIND: u8 = 0x02u8; +/// Command code for associating a UDP port. pub const SOCKS_CMD_UDP_ASSOCIATE: u8 = 0x03u8; +/// Padding byte for SOCKS protocol. pub const SOCKS_PADDING: u8 = 0x00u8; +/// Reserved byte for SOCKS protocol. pub const SOCKS_RSV: u8 = 0x00u8; +/// Address type identifier for IPv4 addresses. pub const SOCKS_ATYP_IPV4: u8 = 0x01u8; +/// Address type identifier for domain names. pub const SOCKS_ATYP_DOMAINNAME: u8 = 0x03u8; +/// Address type identifier for IPv6 addresses. pub const SOCKS_ATYP_IPV6: u8 = 0x04u8; +/// Reply code for succeeded operation. pub const SOCKS_REP_SUCCEEDED: u8 = 0x00u8; diff --git a/socksx/src/common/credentials.rs b/socksx/src/common/credentials.rs index 6c1021f..9b84a64 100644 --- a/socksx/src/common/credentials.rs +++ b/socksx/src/common/credentials.rs @@ -1,13 +1,19 @@ -#[derive(Clone, Debug)] +/// Represents the username and password credentials for SOCKS authentication. +#[derive(Clone, Debug, PartialEq)] pub struct Credentials { + /// The username as a byte vector. pub username: Vec, + /// The password as a byte vector. pub password: Vec, } impl Credentials { + /// Creates a new `Credentials` instance. /// + /// # Parameters /// - /// + /// * `username`: The username as a byte vector or convertible to a byte vector. + /// * `password`: The password as a byte vector or convertible to a byte vector. pub fn new>>( username: S, password: S, @@ -18,9 +24,11 @@ impl Credentials { Credentials { username, password } } + /// Converts the `Credentials` into a byte sequence compatible with the SOCKS authentication protocol. /// + /// # Returns /// - /// + /// Returns a vector of bytes containing the username and password in SOCKS-compatible format. pub fn as_socks_bytes(&self) -> Vec { // Append username let mut bytes = vec![self.username.len() as u8]; @@ -33,3 +41,22 @@ impl Credentials { bytes } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_credentials_new() { + let credentials = Credentials::new("username".to_string().into_bytes(), "password".to_string().into_bytes()); + assert_eq!(credentials.username, b"username".to_vec()); + assert_eq!(credentials.password, b"password".to_vec()); + } + + #[test] + fn test_credentials_as_socks_bytes() { + let credentials = Credentials::new("username".to_string().into_bytes(), "password".to_string().into_bytes()); + let socks_bytes = credentials.as_socks_bytes(); + assert_eq!(socks_bytes, vec![8, 117, 115, 101, 114, 110, 97, 109, 101, 8, 112, 97, 115, 115, 119, 111, 114, 100]); + } +} diff --git a/socksx/src/common/interface.rs b/socksx/src/common/interface.rs index fb7bce5..e85efc6 100644 --- a/socksx/src/common/interface.rs +++ b/socksx/src/common/interface.rs @@ -2,18 +2,46 @@ use anyhow::Result; use async_trait::async_trait; use tokio::net::TcpStream; +/// An asynchronous trait defining the core functionalities required for handling SOCKS requests. #[async_trait] pub trait SocksHandler { + /// Accepts a SOCKS request from a client. + /// + /// # Parameters + /// + /// * `source`: A mutable reference to the source `TcpStream` from which the request originates. + /// + /// # Returns + /// + /// Returns `Result<()>` indicating the success or failure of the operation. async fn accept_request( &self, source: &mut TcpStream, ) -> Result<()>; + /// Refuses a SOCKS request from a client. + /// + /// # Parameters + /// + /// * `source`: A reference to the source `TcpStream` from which the request originates. + /// + /// # Returns + /// + /// Returns `Result<()>` indicating the success or failure of the operation. async fn refuse_request( &self, source: &mut TcpStream, ) -> Result<()>; + /// Sets up the SOCKS connection for a given source. + /// + /// # Parameters + /// + /// * `source`: A mutable reference to the source `TcpStream`. + /// + /// # Returns + /// + /// Returns a `Result` containing the prepared `TcpStream` or an error. async fn setup( &self, source: &mut TcpStream, diff --git a/socksx/src/common/util.rs b/socksx/src/common/util.rs index f631223..f112667 100644 --- a/socksx/src/common/util.rs +++ b/socksx/src/common/util.rs @@ -1,13 +1,20 @@ -use anyhow::Result; use std::net::SocketAddr; + +use anyhow::Result; use tokio::net::{self, TcpStream}; +/// Retrieves the original destination address from a socket on a Linux system. /// +/// # Parameters /// +/// * `socket`: A reference to a socket implementing `AsRawFd`. /// +/// # Returns +/// +/// Returns a `Result` containing the original `SocketAddr` or an error. #[cfg(target_os = "linux")] pub fn get_original_dst(socket: &S) -> Result { - use nix::sys::socket::{self, sockopt, InetAddr}; + use nix::sys::socket::{self, InetAddr, sockopt}; let original_dst = socket::getsockopt(socket.as_raw_fd(), sockopt::OriginalDst)?; let original_dst = InetAddr::V4(original_dst).to_std(); @@ -16,11 +23,20 @@ pub fn get_original_dst(socket: &S) -> Result(socket: &S) -> Result { use std::str::FromStr; use windows::core::PSTR; - use windows::Win32::Networking::WinSock::{SO_ORIGINAL_DST, SOCKET, SOL_SOCKET, getsockopt}; + use windows::Win32::Networking::WinSock::{getsockopt, SO_ORIGINAL_DST, SOCKET, SOL_SOCKET}; // Attempt to recover the original destination let original_dst: String = unsafe { @@ -41,13 +57,19 @@ pub fn get_original_dst(socket: &S) -> Res } #[cfg(not(any(target_os = "linux", target_os = "windows")))] -pub fn get_original_dst(socket: S) -> Result { +pub fn get_original_dst(_socket: S) -> Result { todo!(); } +/// Resolves a given address to a `SocketAddr`. /// +/// # Parameters /// +/// * `addr`: The address, either as a domain name or IP address. /// +/// # Returns +/// +/// Returns a `Result` containing the resolved `SocketAddr` or an error. pub async fn resolve_addr>(addr: S) -> Result { let addr: String = addr.into(); @@ -64,9 +86,15 @@ pub async fn resolve_addr>(addr: S) -> Result { } } +/// Attempts to read the initial data from a TCP stream. +/// +/// # Parameters /// +/// * `stream`: A mutable reference to a `TcpStream`. /// +/// # Returns /// +/// Returns a `Result` containing an `Option` with the read data as a `Vec` or an error. pub async fn try_read_initial_data(stream: &mut TcpStream) -> Result>> { let mut initial_data = Vec::with_capacity(2usize.pow(14)); // 16KB is the max