diff --git a/Cargo.toml b/Cargo.toml index 65513ff..9fd2be0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,8 @@ name = "cccp" version = "0.6.0" edition = "2021" build = "build.rs" +repository = "https://github.com/chanderlud/cccp" +authors = ["Chander Luderman "] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -20,7 +22,21 @@ indicatif = "0.17" prost = "0.12" bytesize = "1.3.0" kanal = "0.1.0-pre8" -blake3 = "1.5.0" +blake3 = "1.5" +chacha20 = "0.9" +rand = "0.8" +hex = "0.4" +ctr = "0.9" +aes = "0.8" +itertools = "0.12" +libc = "0.2.151" + +[target.'cfg(unix)'.dependencies] +libc = "0.2.151" + +[target.'cfg(windows)'.dependencies] +windows-sys = { version = "0.52.0", features = ["Win32_Storage_FileSystem", "Win32_Foundation"] } +widestring = "1.0.2" [build-dependencies] prost-build = "0.12.3" diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..e3cb9f9 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,157 @@ +use std::array::TryFromSliceError; +use std::process::{ExitCode, Termination}; + +use kanal::{ReceiveError, SendError}; +use tokio::io; +use tokio::sync::AcquireError; + +#[derive(Debug)] +pub(crate) struct Error { + pub(crate) kind: ErrorKind, +} + +#[derive(Debug)] +pub(crate) enum ErrorKind { + Io(io::Error), + Parse(std::net::AddrParseError), + Decode(prost::DecodeError), + Join(tokio::task::JoinError), + Send(SendError), + Receive(ReceiveError), + Acquire(AcquireError), + TryFromSlice(TryFromSliceError), + #[cfg(windows)] + ContainsNull(widestring::error::ContainsNul), + #[cfg(unix)] + Nul(std::ffi::NulError), + MissingQueue, + MaxRetries, + StatusError, + Failure(u32), +} + +impl From for Error { + fn from(error: io::Error) -> Self { + Self { + kind: ErrorKind::Io(error), + } + } +} + +impl From for Error { + fn from(error: std::net::AddrParseError) -> Self { + Self { + kind: ErrorKind::Parse(error), + } + } +} + +impl From for Error { + fn from(error: prost::DecodeError) -> Self { + Self { + kind: ErrorKind::Decode(error), + } + } +} + +impl From for Error { + fn from(error: tokio::task::JoinError) -> Self { + Self { + kind: ErrorKind::Join(error), + } + } +} + +impl From for Error { + fn from(error: SendError) -> Self { + Self { + kind: ErrorKind::Send(error), + } + } +} + +impl From for Error { + fn from(error: ReceiveError) -> Self { + Self { + kind: ErrorKind::Receive(error), + } + } +} + +impl From for Error { + fn from(error: AcquireError) -> Self { + Self { + kind: ErrorKind::Acquire(error), + } + } +} + +impl From for Error { + fn from(error: TryFromSliceError) -> Self { + Self { + kind: ErrorKind::TryFromSlice(error), + } + } +} + +#[cfg(windows)] +impl From> for Error { + fn from(error: widestring::error::ContainsNul) -> Self { + Self { + kind: ErrorKind::ContainsNull(error), + } + } +} + +#[cfg(unix)] +impl From for Error { + fn from(error: std::ffi::NulError) -> Self { + Self { + kind: ErrorKind::Nul(error), + } + } +} + +impl Termination for Error { + fn report(self) -> ExitCode { + ExitCode::from(match self.kind { + ErrorKind::Io(error) => match error.kind() { + io::ErrorKind::NotFound => 1, + _ => 2, + }, + ErrorKind::Parse(_) => 3, + ErrorKind::Decode(_) => 4, + ErrorKind::Join(_) => 5, + ErrorKind::Send(_) => 6, + ErrorKind::Receive(_) => 7, + ErrorKind::Acquire(_) => 8, + _ => 9, + }) + } +} + +impl Error { + pub(crate) fn missing_queue() -> Self { + Self { + kind: ErrorKind::MissingQueue, + } + } + + pub(crate) fn max_retries() -> Self { + Self { + kind: ErrorKind::MaxRetries, + } + } + + pub(crate) fn failure(reason: u32) -> Self { + Self { + kind: ErrorKind::Failure(reason), + } + } + + pub(crate) fn status_error() -> Self { + Self { + kind: ErrorKind::StatusError, + } + } +} diff --git a/src/items.proto b/src/items.proto index 97ffcd1..ce4ca4e 100644 --- a/src/items.proto +++ b/src/items.proto @@ -11,6 +11,7 @@ message Message { End end = 5; Failure failure = 6; Done done = 7; + Completed completed = 8; } } @@ -24,6 +25,24 @@ message FileDetail { string path = 1; // file path relative to the destination directory uint64 size = 2; // file size in bytes optional bytes signature = 3; // blake3 hash of file + optional Crypto crypto = 4; // encryption details +} + +message Crypto { + Cipher cipher = 1; + bytes key = 2; + bytes iv = 3; +} + +enum Cipher { + AES = 0; + CHACHA8 = 1; + CHACHA20 = 2; +} + +// the receiver already had these files +message Completed { + repeated uint32 ids = 1; } // map of transfers and their confirmed indexes diff --git a/src/items.rs b/src/items.rs index 2367d24..2840235 100644 --- a/src/items.rs +++ b/src/items.rs @@ -1,3 +1,6 @@ +use std::collections::HashMap; +use std::fmt::Display; + include!(concat!(env!("OUT_DIR"), "/cccp.items.rs")); impl Message { @@ -7,6 +10,18 @@ impl Message { } } + pub(crate) fn confirmations(indexes: HashMap) -> Self { + Self { + message: Some(message::Message::Confirmations(Confirmations { indexes })), + } + } + + pub(crate) fn completed(ids: Vec) -> Self { + Self { + message: Some(message::Message::Completed(Completed { ids })), + } + } + pub(crate) fn end(id: u32) -> Self { Self { message: Some(message::Message::End(End { id })), @@ -25,3 +40,36 @@ impl Message { } } } + +impl Cipher { + /// the length of the key in bytes + pub(crate) fn key_length(&self) -> usize { + 32 + } + + /// the length of the iv in bytes + pub(crate) fn iv_length(&self) -> usize { + match self { + Self::Chacha20 | Self::Chacha8 => 12, + Self::Aes => 16, + } + } +} + +impl Display for Cipher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let cipher = match self { + Self::Aes => "aes", + Self::Chacha8 => "chacha8", + Self::Chacha20 => "chacha20", + }; + + write!(f, "{}", cipher) + } +} + +impl StartIndex { + pub(crate) fn new(index: u64) -> Self { + Self { index } + } +} diff --git a/src/main.rs b/src/main.rs index b434d0e..3f8ba13 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,169 +1,97 @@ #![feature(int_roundings)] -use std::error; use std::fmt::{Display, Formatter}; use std::net::{IpAddr, SocketAddr}; use std::path::{Path, PathBuf}; -use std::process::{ExitCode, Termination}; use std::str::FromStr; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::time::Duration; +use aes::Aes256; use async_ssh2_tokio::{AuthMethod, Client, ServerCheckMethod}; use blake3::{Hash, Hasher}; use bytesize::ByteSize; +use chacha20::cipher::generic_array::GenericArray; +use chacha20::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; +use chacha20::{ChaCha20, ChaCha8}; use clap::Parser; +use ctr::Ctr128BE; use futures::stream::iter; use futures::{StreamExt, TryStreamExt}; use indicatif::{ProgressBar, ProgressStyle}; -use kanal::{ReceiveError, SendError}; +use itertools::Itertools; use log::{debug, error, info, warn, LevelFilter}; use prost::Message; +use rand::{rngs::OsRng, RngCore}; use regex::Regex; use rpassword::prompt_password; use simple_logging::{log_to_file, log_to_stderr}; use tokio::fs::File; use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}; use tokio::net::{TcpListener, TcpSocket, TcpStream, UdpSocket}; -use tokio::sync::AcquireError; use tokio::time::{interval, sleep}; use tokio::{io, select}; +use crate::items::{Cipher, Crypto}; + +mod error; mod items; mod receiver; mod sender; // result alias used throughout -type Result = std::result::Result; - -// read buffer must be a multiple of the transfer buffer to prevent a nasty little bug -const READ_BUFFER_SIZE: usize = TRANSFER_BUFFER_SIZE * 100; -const WRITE_BUFFER_SIZE: usize = TRANSFER_BUFFER_SIZE * 100; -const TRANSFER_BUFFER_SIZE: usize = 1024; -const INDEX_SIZE: usize = std::mem::size_of::(); -const ID_SIZE: usize = std::mem::size_of::(); -const MAX_RETRIES: usize = 10; -const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5); -// UDP header + ID + INDEX + DATA -const PACKET_SIZE: usize = 8 + ID_SIZE + INDEX_SIZE + TRANSFER_BUFFER_SIZE; - -// how long to wait for a job to be confirmed before requeuing it -const REQUEUE_INTERVAL: Duration = Duration::from_millis(1_000); - -#[derive(Debug)] -struct Error { - kind: ErrorKind, -} - -#[derive(Debug)] -enum ErrorKind { - Io(io::Error), - Parse(std::net::AddrParseError), - Decode(prost::DecodeError), - Join(tokio::task::JoinError), - Send(SendError), - Receive(ReceiveError), - Acquire(AcquireError), - MissingQueue, - MaxRetries, - Failure(u32), -} - -impl From for Error { - fn from(error: io::Error) -> Self { - Self { - kind: ErrorKind::Io(error), - } - } -} +type Result = std::result::Result; -impl From for Error { - fn from(error: std::net::AddrParseError) -> Self { - Self { - kind: ErrorKind::Parse(error), - } - } +pub(crate) trait StreamCipherExt: Send + Sync { + fn seek(&mut self, index: u64); + fn apply_keystream(&mut self, data: &mut [u8]); } -impl From for Error { - fn from(error: prost::DecodeError) -> Self { - Self { - kind: ErrorKind::Decode(error), - } +impl StreamCipherExt for ChaCha20 { + fn seek(&mut self, index: u64) { + StreamCipherSeek::seek(self, index); } -} -impl From for Error { - fn from(error: tokio::task::JoinError) -> Self { - Self { - kind: ErrorKind::Join(error), - } + fn apply_keystream(&mut self, buf: &mut [u8]) { + StreamCipher::apply_keystream(self, buf) } } -impl From for Error { - fn from(error: SendError) -> Self { - Self { - kind: ErrorKind::Send(error), - } +impl StreamCipherExt for ChaCha8 { + fn seek(&mut self, index: u64) { + StreamCipherSeek::seek(self, index); } -} -impl From for Error { - fn from(error: ReceiveError) -> Self { - Self { - kind: ErrorKind::Receive(error), - } + fn apply_keystream(&mut self, buf: &mut [u8]) { + StreamCipher::apply_keystream(self, buf) } } -impl From for Error { - fn from(error: AcquireError) -> Self { - Self { - kind: ErrorKind::Acquire(error), - } +impl StreamCipherExt for Ctr128BE { + fn seek(&mut self, index: u64) { + StreamCipherSeek::seek(self, index); } -} -impl Termination for Error { - fn report(self) -> ExitCode { - ExitCode::from(match self.kind { - ErrorKind::Io(error) => match error.kind() { - io::ErrorKind::NotFound => 1, - _ => 2, - }, - ErrorKind::Parse(_) => 3, - ErrorKind::Decode(_) => 4, - ErrorKind::Join(_) => 5, - ErrorKind::Send(_) => 6, - ErrorKind::Receive(_) => 7, - ErrorKind::Acquire(_) => 8, - _ => 9, - }) + fn apply_keystream(&mut self, buf: &mut [u8]) { + StreamCipher::apply_keystream(self, buf) } } -impl Error { - fn missing_queue() -> Self { - Self { - kind: ErrorKind::MissingQueue, - } - } - - fn max_retries() -> Self { - Self { - kind: ErrorKind::MaxRetries, - } - } +// read buffer must be a multiple of the transfer buffer to prevent a nasty little bug +const READ_BUFFER_SIZE: usize = TRANSFER_BUFFER_SIZE * 100; +const WRITE_BUFFER_SIZE: usize = TRANSFER_BUFFER_SIZE * 100; +const TRANSFER_BUFFER_SIZE: usize = 1024; +const INDEX_SIZE: usize = std::mem::size_of::(); +const ID_SIZE: usize = std::mem::size_of::(); +const MAX_RETRIES: usize = 10; +const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5); +// UDP header + ID + INDEX + DATA +const PACKET_SIZE: usize = 8 + ID_SIZE + INDEX_SIZE + TRANSFER_BUFFER_SIZE; - fn failure(reason: u32) -> Self { - Self { - kind: ErrorKind::Failure(reason), - } - } -} +// how long to wait for a job to be confirmed before requeuing it +const REQUEUE_INTERVAL: Duration = Duration::from_millis(1_000); #[derive(Parser, Clone, Debug)] struct Options { @@ -178,7 +106,7 @@ struct Options { short, long = "start-port", help = "the first port to use", - default_value = "50000" + default_value_t = 50000 )] start_port: u16, @@ -186,7 +114,7 @@ struct Options { short, long = "end-port", help = "the last port to use", - default_value = "50100" + default_value_t = 50099 )] end_port: u16, @@ -194,7 +122,7 @@ struct Options { short, long = "threads", help = "how many threads to use", - default_value = "100" + default_value_t = 98 )] threads: u16, @@ -225,17 +153,30 @@ struct Options { short, long = "max", help = "the maximum number of concurrent transfers", - default_value = "100" + default_value_t = 100 )] max: usize, #[clap( short, long = "verify", - help = "verify integrity of files using blake3" + help = "verify integrity of transfers using blake3" )] verify: bool, + #[clap(short, long = "overwrite", help = "overwrite existing files")] + overwrite: bool, + + #[clap( + long = "control-crypto", + help = "encrypt the control stream", + default_value = "aes" + )] + control_crypto: Crypto, + + #[clap(long = "stream-crypto", help = "encrypt the data stream")] + stream_crypto: Option, + #[clap(help = "where to get the data from")] source: FileLocation, @@ -247,14 +188,23 @@ impl Options { fn format_command(&self, sender: bool) -> String { let mode = if sender { "rr" } else { "rs" }; + let stream_crypto = if let Some(ref crypto) = self.stream_crypto { + format!(" --stream-crypto {}", crypto) + } else { + String::new() + }; + format!( - "cccp --mode {} --start-port {} --end-port {} --threads {} --log-level {} --rate \"{}\" \"{}\" \"{}\"", + "cccp --mode {} -s {} -e {} -t {} -l {} -r \"{}\"{} --control-crypto {}{} \"{}\" \"{}\"", mode, self.start_port, self.end_port, self.threads, self.log_level, self.rate, + stream_crypto, + self.control_crypto, + if self.overwrite { " -o" } else { "" }, self.source, self.destination ) @@ -287,6 +237,84 @@ impl FromStr for Mode { } } +impl FromStr for Crypto { + type Err = CustomParseErrors; + + fn from_str(s: &str) -> std::result::Result { + let s = s.to_uppercase(); + let count = s.matches(':').count(); + + if count == 0 { + let cipher = Cipher::from_str_name(&s).ok_or(CustomParseErrors::InvalidCipher)?; + + let mut key = vec![0; cipher.key_length()]; + OsRng.fill_bytes(&mut key); + + let mut iv = vec![0; cipher.iv_length()]; + OsRng.fill_bytes(&mut iv); + + Ok(Self { + cipher: cipher as i32, + key, + iv, + }) + } else if count == 1 { + let (cipher_str, key_str) = s.split_once(':').unwrap(); + + let cipher = + Cipher::from_str_name(cipher_str).ok_or(CustomParseErrors::InvalidCipher)?; + let key = hex::decode(key_str).map_err(|_| CustomParseErrors::InvalidKey)?; + + let mut iv = vec![0; cipher.iv_length()]; + OsRng.fill_bytes(&mut iv); + + if key.len() == cipher.key_length() { + Ok(Self { + cipher: cipher as i32, + key, + iv, + }) + } else { + Err(CustomParseErrors::InvalidKey) + } + } else if count == 2 { + let (cipher_str, key_str, iv_str) = s.splitn(3, ':').collect_tuple().unwrap(); + + let cipher = + Cipher::from_str_name(cipher_str).ok_or(CustomParseErrors::InvalidCipher)?; + let key = hex::decode(key_str).map_err(|_| CustomParseErrors::InvalidKey)?; + let iv = hex::decode(iv_str).map_err(|_| CustomParseErrors::InvalidKey)?; + + if key.len() == cipher.key_length() && iv.len() == cipher.iv_length() { + Ok(Self { + cipher: cipher as i32, + key, + iv, + }) + } else { + Err(CustomParseErrors::InvalidKey) + } + } else { + Err(CustomParseErrors::InvalidCipherFormat) + } + } +} + +impl Display for Crypto { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let key = hex::encode(&self.key); + let iv = hex::encode(&self.iv); + + write!( + f, + "{}:{}:{}", + Cipher::try_from(self.cipher).unwrap(), + key, + iv + ) + } +} + // a file located anywhere #[derive(Clone, Debug)] struct FileLocation { @@ -361,11 +389,14 @@ impl FileLocation { } #[derive(Clone, Debug)] -enum CustomParseErrors { +pub enum CustomParseErrors { MalformedConnectionString(&'static str), UnknownMode, ParseError, IoError, + InvalidCipher, + InvalidCipherFormat, + InvalidKey, } impl Display for CustomParseErrors { @@ -378,12 +409,15 @@ impl Display for CustomParseErrors { Self::UnknownMode => "The mode can be either sender or receiver", Self::ParseError => "Invalid file path", Self::IoError => "An error occurred while getting the size of a file", + Self::InvalidCipher => "Invalid cipher", + Self::InvalidCipherFormat => "Invalid cipher format", + Self::InvalidKey => "Invalid key", } ) } } -impl error::Error for CustomParseErrors {} +impl std::error::Error for CustomParseErrors {} impl From for CustomParseErrors { fn from(_: io::Error) -> Self { @@ -683,15 +717,17 @@ async fn connect_stream( } /// write a `Message` to a writer -async fn write_message( +async fn write_message( writer: &mut W, message: &M, + cipher: &mut Box, ) -> Result<()> { let len = message.encoded_len(); writer.write_u32(len as u32).await?; let mut buffer = Vec::with_capacity(len); message.encode(&mut buffer).unwrap(); + cipher.apply_keystream(&mut buffer[..]); writer.write_all(&buffer).await?; @@ -699,11 +735,19 @@ async fn write_message( } /// read a `Message` from a reader -async fn read_message(reader: &mut R) -> Result { +async fn read_message< + R: AsyncReadExt + Unpin, + M: Message + Default, + C: StreamCipherExt + ?Sized, +>( + reader: &mut R, + cipher: &mut Box, +) -> Result { let len = reader.read_u32().await? as usize; let mut buffer = vec![0; len]; reader.read_exact(&mut buffer).await?; + cipher.apply_keystream(&mut buffer[..]); let message = M::decode(&buffer[..])?; @@ -729,3 +773,26 @@ async fn hash_file>(path: P) -> io::Result { Ok(hasher.finalize()) } + +fn make_cipher(crypto: &Crypto) -> Box { + let key = GenericArray::from_slice(&crypto.key[..32]); + + match crypto.cipher.try_into() { + Ok(Cipher::Aes) => { + let iv = GenericArray::from_slice(&crypto.key[..16]); + + Box::new(Ctr128BE::::new(key, iv)) + } + Ok(Cipher::Chacha8) => { + let iv = GenericArray::from_slice(&crypto.key[..12]); + + Box::new(ChaCha8::new(key, iv)) + } + Ok(Cipher::Chacha20) => { + let iv = GenericArray::from_slice(&crypto.key[..12]); + + Box::new(ChaCha20::new(key, iv)) + } + _ => unreachable!(), + } +} diff --git a/src/receiver/mod.rs b/src/receiver/mod.rs index 477d3f2..caaa1ec 100644 --- a/src/receiver/mod.rs +++ b/src/receiver/mod.rs @@ -1,13 +1,13 @@ -use kanal::{AsyncReceiver, AsyncSender}; use std::collections::HashMap; use std::mem; use std::net::IpAddr; -use std::path::PathBuf; +use std::path::Path; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::time::Duration; +use kanal::{AsyncReceiver, AsyncSender}; use log::{debug, error, info, warn}; use tokio::fs::{create_dir_all, metadata}; use tokio::io::AsyncWrite; @@ -17,11 +17,12 @@ use tokio::sync::Mutex; use tokio::task::JoinHandle; use tokio::time::{interval, timeout}; -use crate::items::{message, ConfirmationIndexes, Confirmations, Manifest, Message, StartIndex}; -use crate::receiver::writer::{writer, SplitQueue}; +use crate::error::Error; +use crate::items::{message, ConfirmationIndexes, Manifest, Message, StartIndex}; +use crate::receiver::writer::{writer, FileDetails, SplitQueue}; use crate::{ - read_message, socket_factory, write_message, Error, Options, Result, TransferStats, ID_SIZE, - INDEX_SIZE, MAX_RETRIES, RECEIVE_TIMEOUT, TRANSFER_BUFFER_SIZE, + make_cipher, read_message, socket_factory, write_message, Options, Result, StreamCipherExt, + TransferStats, ID_SIZE, INDEX_SIZE, MAX_RETRIES, RECEIVE_TIMEOUT, TRANSFER_BUFFER_SIZE, }; mod writer; @@ -43,11 +44,82 @@ pub(crate) async fn main( ) -> Result<()> { info!("receiving {} -> {}", options.source, options.destination); - let manifest: Manifest = read_message(&mut str_stream).await?; + let mut str_cipher = make_cipher(&options.control_crypto); + let rts_cipher = make_cipher(&options.control_crypto); + + let manifest: Manifest = read_message(&mut str_stream, &mut str_cipher).await?; + let is_dir = manifest.files.len() > 1; // if multiple files are being received, the destination should be a directory debug!("received manifest: {:?}", manifest); - // if multiple files are being received, the destination should be a directory - if manifest.files.len() > 1 { + let mut completed = Vec::new(); + + // TODO get start indexes here and never add them to total_data + let files = manifest + .files + .into_iter() + .filter_map(|(id, details)| { + // formats the path to the file locally + let path = if is_dir { + options.destination.file_path.join(&details.path) + } else { + options.destination.file_path.clone() + }; + + if path.exists() && !options.overwrite { + completed.push(id); + None + } else { + // increment the total data counter + stats.total_data.fetch_add(details.size as usize, Relaxed); + + // append partial extension to the existing extension, if there is one + let partial_extension = if let Some(extension) = path.extension() { + extension.to_str()?.to_owned() + ".partial" + } else { + ".partial".to_string() + }; + + let partial_path = path.with_extension(partial_extension); + + Some(( + id, + FileDetails { + path, + partial_path, + size: details.size, + signature: details.signature, + crypto: details.crypto, + }, + )) + } + }) + .collect(); + + let free_space = free_space(&options.destination.file_path)?; + debug!("free space: {}", free_space); + + if free_space < stats.total_data.load(Relaxed) as u64 { + error!( + "not enough free space {} / {}", + free_space, + stats.total_data.load(Relaxed) + ); + + write_message(&mut str_stream, &Message::failure(0, 1), &mut str_cipher).await?; + + return Err(Error::failure(1)); + } + + debug!("sending completed: {:?}", completed); + // send the completed message to the remote client + write_message( + &mut str_stream, + &Message::completed(completed), + &mut str_cipher, + ) + .await?; + + if is_dir { create_dir_all(&options.destination.file_path).await?; } @@ -58,11 +130,6 @@ pub(crate) async fn main( create_dir_all(local_dir).await?; } - // set the total data to be received - for details in manifest.files.values() { - stats.total_data.fetch_add(details.size as usize, Relaxed); - } - let sockets = socket_factory( options.start_port + 2, // the first two ports are used for control messages and confirmations options.end_port, @@ -78,7 +145,7 @@ pub(crate) async fn main( // `message_sender` can now be used to send messages to the sender let (message_sender, message_receiver) = kanal::unbounded_async(); - tokio::spawn(send_messages(rts_stream, message_receiver)); + tokio::spawn(send_messages(rts_stream, message_receiver, rts_cipher)); let confirmation_handle = tokio::spawn(send_confirmations( message_sender.clone(), @@ -88,12 +155,12 @@ pub(crate) async fn main( let controller_handle = tokio::spawn(controller( str_stream, - manifest.clone(), + files, writer_queue.clone(), confirmation_sender, - stats.confirmed_data.clone(), - options.destination.file_path, + stats.confirmed_data, message_sender, + str_cipher, )); let handles: Vec<_> = sockets @@ -125,10 +192,10 @@ async fn receiver(queue: WriterQueue, socket: UdpSocket) -> Result<()> { Ok(Ok(read)) if read > 0 => { retries = 0; // reset retries - let id = u32::from_be_bytes(buf[..ID_SIZE].try_into().unwrap()); - let index = - u64::from_be_bytes(buf[ID_SIZE..INDEX_SIZE + ID_SIZE].try_into().unwrap()); - let data = buf[INDEX_SIZE + ID_SIZE..].try_into().unwrap(); + // slice the buffer into the id, index, and data + let id = u32::from_be_bytes(buf[..ID_SIZE].try_into()?); + let index = u64::from_be_bytes(buf[ID_SIZE..INDEX_SIZE + ID_SIZE].try_into()?); + let data = buf[INDEX_SIZE + ID_SIZE..].try_into()?; queue.send(Job { data, index }, id).await?; } @@ -144,62 +211,44 @@ async fn receiver(queue: WriterQueue, socket: UdpSocket) -> Result<()> { } } -async fn controller( +async fn controller( mut str_stream: TcpStream, - mut files: Manifest, + mut files: HashMap, writer_queue: WriterQueue, confirmation_sender: AsyncSender<(u32, u64)>, confirmed_data: Arc, - file_path: PathBuf, message_sender: AsyncSender, + mut str_cipher: Box, ) -> Result<()> { loop { - let message: Message = read_message(&mut str_stream).await?; + let message: Message = read_message(&mut str_stream, &mut str_cipher).await?; match message.message { Some(message::Message::Start(message)) => { debug!("received start message: {:?}", message); - let details = files.files.remove(&message.id).unwrap(); - - writer_queue.push_queue(message.id).await; // create a queue for the writer - - let file_path = if file_path.is_dir() { - file_path.join(&details.path) - } else { - file_path.clone() - }; - - // append partial extension to the existing extension, if there is one - let partial_extension = if let Some(extension) = file_path.extension() { - extension.to_str().unwrap().to_owned() + ".partial" - } else { - ".partial".to_string() - }; + let details = files.remove(&message.id).unwrap(); - let partial_path = file_path.with_extension(partial_extension); - - let start_index = if partial_path.exists() { + let start_index = if details.partial_path.exists() { info!("partial file exists, resuming transfer"); - let metadata = metadata(&partial_path).await?; + let metadata = metadata(&details.partial_path).await?; // the file is written sequentially, so we can calculate the start index by rounding down to the nearest multiple of the transfer buffer size - metadata.len().div_floor(TRANSFER_BUFFER_SIZE as u64) - * TRANSFER_BUFFER_SIZE as u64 + let chunks = metadata.len().div_floor(TRANSFER_BUFFER_SIZE as u64); + chunks * TRANSFER_BUFFER_SIZE as u64 } else { 0 }; - confirmed_data.fetch_add(start_index as usize, Relaxed); - // send the start index to the remote client - write_message(&mut str_stream, &StartIndex { index: start_index }).await?; + write_message( + &mut str_stream, + &StartIndex::new(start_index), + &mut str_cipher, + ) + .await?; - let file = writer::FileDetails { - size: details.size, - partial_path, - path: file_path, - signature: details.signature, - }; + writer_queue.push_queue(message.id).await; // create a queue for the writer + confirmed_data.fetch_add(start_index as usize, Relaxed); tokio::spawn({ let writer_queue = writer_queue.clone(); @@ -207,10 +256,8 @@ async fn controller( let message_sender = message_sender.clone(); async move { - let path = file.path.clone(); - - let result = writer( - file, + let result = writer::( + details, writer_queue, confirmation_sender, start_index, @@ -220,7 +267,7 @@ async fn controller( .await; if let Err(error) = result { - error!("writer for {} failed: {:?}", path.display(), error); + error!("writer failed: {:?}", error); } } }); @@ -266,23 +313,16 @@ async fn send_confirmations( let confirmations = mem::take(&mut *data); drop(data); // release the lock on data - let map: HashMap = HashMap::new(); - // group the confirmations by id - let map = confirmations.into_iter().fold(map, |mut map, (id, index)| { - map.entry(id).or_default().inner.push(index); - map - }); - - let message = Message { - message: Some(message::Message::Confirmations(Confirmations { - indexes: map, - })), - }; - sender - .send(message) - .await - .expect("failed to send confirmations"); + let map: HashMap = + confirmations + .into_iter() + .fold(HashMap::new(), |mut map, (id, index)| { + map.entry(id).or_default().inner.push(index); + map + }); + + sender.send(Message::confirmations(map)).await?; } } }); @@ -302,13 +342,61 @@ async fn send_confirmations( } /// send messages from a channel to a writer -async fn send_messages( +async fn send_messages( mut writer: W, receiver: AsyncReceiver, + mut cipher: Box, ) -> Result<()> { while let Ok(message) = receiver.recv().await { - write_message(&mut writer, &message).await?; + write_message(&mut writer, &message, &mut cipher).await?; } Ok(()) } + +// TODO this is buggy af +#[cfg(unix)] +fn free_space(path: &Path) -> Result { + use std::ffi::CString; + use std::os::unix::ffi::OsStrExt; + + let dir = CString::new(path.as_os_str().as_bytes())?; + + unsafe { + let mut buf: mem::MaybeUninit = mem::MaybeUninit::uninit(); + let result = libc::statvfs(dir.as_ptr(), buf.as_mut_ptr()); + + if result == 0 { + let stat = buf.assume_init(); + Ok(stat.f_frsize as u64 * stat.f_bavail as u64) + } else { + Err(Error::status_error()) + } + } +} + +#[cfg(windows)] +fn free_space(path: &Path) -> Result { + use widestring::U16CString; + use windows_sys::Win32::Storage::FileSystem; + + let path = U16CString::from_os_str(path)?; + + let mut free_bytes = 0_u64; + let mut total_bytes = 0_u64; + + let status = unsafe { + FileSystem::GetDiskFreeSpaceExW( + path.as_ptr(), + &mut free_bytes, + &mut total_bytes, + std::ptr::null_mut(), + ) + }; + + if status == 0 { + Err(Error::status_error()) + } else { + Ok(free_bytes) + } +} diff --git a/src/receiver/writer.rs b/src/receiver/writer.rs index f4e1f1e..4443f9f 100644 --- a/src/receiver/writer.rs +++ b/src/receiver/writer.rs @@ -9,9 +9,12 @@ use tokio::fs::{remove_file, rename, OpenOptions}; use tokio::io::{self, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter}; use tokio::sync::Mutex; -use crate::items::Message; +use crate::error::Error; +use crate::items::{Crypto, Message}; use crate::receiver::{Job, WriterQueue}; -use crate::{hash_file, Error, Result, TRANSFER_BUFFER_SIZE, WRITE_BUFFER_SIZE}; +use crate::{ + hash_file, make_cipher, Result, StreamCipherExt, TRANSFER_BUFFER_SIZE, WRITE_BUFFER_SIZE, +}; #[derive(Default)] pub(crate) struct SplitQueue { @@ -37,6 +40,7 @@ impl SplitQueue { receivers.get(id).cloned() } + // TODO benchmark this pub(crate) async fn send(&self, job: Job, id: u32) -> Result<()> { let sender_option = { let senders = self.senders.lock().await; @@ -57,6 +61,7 @@ pub(crate) struct FileDetails { pub(crate) partial_path: PathBuf, pub(crate) size: u64, pub(crate) signature: Option>, + pub(crate) crypto: Option, } impl FileDetails { @@ -66,7 +71,7 @@ impl FileDetails { } } -pub(crate) async fn writer( +pub(crate) async fn writer( details: FileDetails, writer_queue: WriterQueue, confirmation_sender: AsyncSender<(u32, u64)>, @@ -83,6 +88,12 @@ pub(crate) async fn writer( let mut writer = BufWriter::with_capacity(WRITE_BUFFER_SIZE, file); writer.seek(SeekFrom::Start(position)).await?; // seek to the initial position + let mut cipher = details.crypto.as_ref().map(make_cipher); + + if let Some(ref mut cipher) = cipher { + cipher.seek(position); + } + debug!( "writer for {} starting at {}", details.path.display(), @@ -109,14 +120,28 @@ pub(crate) async fn writer( } // if the chunk is at the current position, write it Ordering::Equal => { - write_data(&mut writer, &job.data, &mut position, details.size).await?; + write_data( + &mut writer, + job.data, + &mut position, + details.size, + &mut cipher, + ) + .await?; confirmation_sender.send((id, job.index)).await?; } } // write all concurrent chunks from the cache while let Some(job) = cache.remove(&position) { - write_data(&mut writer, &job.data, &mut position, details.size).await?; + write_data( + &mut writer, + job.data, + &mut position, + details.size, + &mut cipher, + ) + .await?; } } @@ -147,15 +172,20 @@ pub(crate) async fn writer( /// write data and advance position #[inline] -async fn write_data( +async fn write_data( writer: &mut T, - buffer: &[u8], + mut buffer: [u8; TRANSFER_BUFFER_SIZE], position: &mut u64, file_size: u64, + cipher: &mut Option>, ) -> io::Result<()> { // calculate the length of the data to write let len = (file_size - *position).min(TRANSFER_BUFFER_SIZE as u64); + if let Some(ref mut cipher) = cipher { + cipher.apply_keystream(&mut buffer[..len as usize]); + } + *position += len; // advance the position writer.write_all(&buffer[..len as usize]).await // write the data } diff --git a/src/sender/mod.rs b/src/sender/mod.rs index e159c88..6c0008e 100644 --- a/src/sender/mod.rs +++ b/src/sender/mod.rs @@ -1,4 +1,3 @@ -use kanal::{AsyncReceiver, AsyncSender}; use std::collections::{HashMap, HashSet}; use std::net::IpAddr; use std::path::{Path, PathBuf}; @@ -7,18 +6,23 @@ use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::time::Duration; +use aes::cipher::crypto_common::rand_core::OsRng; +use kanal::{AsyncReceiver, AsyncSender}; use log::{debug, error, info, warn}; +use rand::RngCore; use tokio::io::{self, AsyncReadExt}; use tokio::net::{TcpStream, UdpSocket}; use tokio::select; use tokio::sync::{Mutex, RwLock, Semaphore}; use tokio::time::{interval, Instant}; -use crate::items::{message, Confirmations, FileDetail, Manifest, Message, StartIndex}; +use crate::error::Error; +use crate::items::{message, Confirmations, Crypto, FileDetail, Manifest, Message, StartIndex}; use crate::sender::reader::reader; use crate::{ - hash_file, read_message, socket_factory, write_message, Error, Options, Result, TransferStats, - ID_SIZE, INDEX_SIZE, MAX_RETRIES, REQUEUE_INTERVAL, TRANSFER_BUFFER_SIZE, + hash_file, make_cipher, read_message, socket_factory, write_message, Options, Result, + StreamCipherExt, TransferStats, ID_SIZE, INDEX_SIZE, MAX_RETRIES, REQUEUE_INTERVAL, + TRANSFER_BUFFER_SIZE, }; mod reader; @@ -41,15 +45,38 @@ pub(crate) async fn main( ) -> Result<()> { info!("sending {} -> {}", options.source, options.destination); - let manifest = build_manifest( + let mut str_cipher = make_cipher(&options.control_crypto); + let rts_cipher = make_cipher(&options.control_crypto); + + let mut manifest = build_manifest( options.source.file_path.clone(), options.verify, &stats.total_data, + &options.stream_crypto, ) .await?; debug!("sending manifest: {:?}", manifest); - write_message(&mut str_stream, &manifest).await?; + write_message(&mut str_stream, &manifest, &mut str_cipher).await?; + + let message: Message = read_message(&mut str_stream, &mut str_cipher).await?; + + match message.message { + Some(message::Message::Completed(completed)) => { + debug!("received {} completed ids", completed.ids.len()); + + for id in completed.ids { + if let Some(details) = manifest.files.remove(&id) { + stats.total_data.fetch_sub(details.size as usize, Relaxed); + } + } + } + Some(message::Message::Failure(failure)) => { + error!("received failure message {}", failure.reason); + return Err(Error::failure(failure.reason)); + } + _ => unreachable!(), + } let sockets = socket_factory( options.start_port + 2, // the first two ports are used for control messages and confirmations @@ -81,6 +108,7 @@ pub(crate) async fn main( rts_stream, confirmation_sender, controller_sender, + rts_cipher, )); let confirmation_handle = tokio::spawn(receive_confirmations( @@ -95,13 +123,14 @@ pub(crate) async fn main( let controller_handle = tokio::spawn(controller( str_stream, - manifest, + manifest.files, job_sender.clone(), read, stats.confirmed_data, options.source.file_path, controller_receiver, options.max, + str_cipher, )); let handles: Vec<_> = sockets @@ -169,38 +198,43 @@ async fn sender( } } +// TODO there is something wrong with the controller, seems to be starting the same file multiple times or something #[allow(clippy::too_many_arguments)] -async fn controller( +async fn controller( mut control_stream: TcpStream, - mut files: Manifest, + mut files: HashMap, job_sender: AsyncSender, read: Arc, confirmed_data: Arc, base_path: PathBuf, controller_receiver: AsyncReceiver, max: usize, + mut cipher: Box, ) -> Result<()> { let mut id = 0; let mut active = 0; loop { - while active < max { - match files.files.get(&id) { - None => break, - Some(file_details) => { - start_file_transfer( - &mut control_stream, - id, - file_details, - &base_path, - &job_sender, - &read, - &confirmed_data, - ) - .await?; - - id += 1; - active += 1; + if !files.is_empty() { + while active < max { + match files.get(&id) { + None => id += 1, + Some(details) => { + start_file_transfer( + &mut control_stream, + id, + details, + &base_path, + &job_sender, + &read, + &confirmed_data, + &mut cipher, + ) + .await?; + + active += 1; + id += 1 + } } } } @@ -211,63 +245,80 @@ async fn controller( Some(message::Message::End(end)) => { debug!("received end message {} | active {}", end.id, active); - files.files.remove(&end.id); + files.remove(&end.id); active -= 1; } Some(message::Message::Failure(failure)) => { - debug!("received failure message {:?}", failure); - - if let Some(file_details) = files.files.get(&failure.id) { - start_file_transfer( - &mut control_stream, - failure.id, - file_details, - &base_path, - &job_sender, - &read, - &confirmed_data, - ) - .await?; + if failure.reason == 0 { + if let Some(details) = files.get(&failure.id) { + warn!( + "transfer {} failed signature verification, retrying...", + failure.id + ); + + confirmed_data.fetch_sub(details.size as usize, Relaxed); + + start_file_transfer( + &mut control_stream, + failure.id, + details, + &base_path, + &job_sender, + &read, + &confirmed_data, + &mut cipher, + ) + .await?; + } else { + warn!( + "received failure message {} for unknown file {}", + failure.reason, failure.id + ); + } } else { - warn!("received failure message for unknown file {}", failure.id); + warn!( + "received failure message {} for unknown file {}", + failure.reason, failure.id + ); } } _ => unreachable!(), // only end and failure messages are sent to this receiver } - if files.files.is_empty() && active == 0 { + if files.is_empty() && active == 0 { break; } } debug!("all files completed, sending done message"); - write_message(&mut control_stream, &Message::done()).await?; + write_message(&mut control_stream, &Message::done(), &mut cipher).await?; Ok(()) } -async fn start_file_transfer( +#[allow(clippy::too_many_arguments)] +async fn start_file_transfer( mut control_stream: &mut TcpStream, id: u32, - file_details: &FileDetail, + details: &FileDetail, base_path: &Path, job_sender: &AsyncSender, read: &Arc, confirmed_data: &Arc, + cipher: &mut Box, ) -> Result<()> { - write_message(&mut control_stream, &Message::start(id)).await?; + write_message(&mut control_stream, &Message::start(id), cipher).await?; - let file_path = base_path.join(&file_details.path); - - let start_index: StartIndex = read_message(&mut control_stream).await?; + let start_index: StartIndex = read_message(&mut control_stream, cipher).await?; confirmed_data.fetch_add(start_index.index as usize, Relaxed); tokio::spawn(reader( - file_path, + base_path.join(&details.path), job_sender.clone(), read.clone(), start_index.index, id, + details.crypto.as_ref().map(make_cipher), )); Ok(()) @@ -392,13 +443,14 @@ fn files_and_dirs( } /// split the message stream into `Confirmation` and `End + Failure` messages -async fn split_receiver( +async fn split_receiver( mut reader: R, confirmation_sender: AsyncSender, controller_sender: AsyncSender, + mut cipher: Box, ) -> Result<()> { loop { - let message: Message = read_message(&mut reader).await?; + let message: Message = read_message(&mut reader, &mut cipher).await?; match message.message { Some(message::Message::Confirmations(confirmations)) => { @@ -417,6 +469,7 @@ async fn build_manifest( source: PathBuf, verify: bool, total_data: &Arc, + crypto: &Option, ) -> Result { // collect the files and directories to send let mut files = Vec::new(); @@ -425,6 +478,7 @@ async fn build_manifest( let mut file_map: HashMap = HashMap::with_capacity(files.len()); + // TODO add concurrency for (index, mut file) in files.into_iter().enumerate() { let size = tokio::fs::metadata(&file).await?.len(); total_data.fetch_add(size as usize, Relaxed); @@ -442,14 +496,22 @@ async fn build_manifest( file = file.strip_prefix(&source).unwrap().to_path_buf(); } + // TODO windows only let path = file.to_string_lossy().replace('\\', "/"); + let mut crypto = crypto.clone(); + + if let Some(ref mut crypto) = crypto { + OsRng.fill_bytes(&mut crypto.iv); + } + file_map.insert( index as u32, FileDetail { path, size, signature, + crypto, }, ); } @@ -463,7 +525,7 @@ async fn build_manifest( dir.to_string_lossy().to_string() } }) - .map(|dir| dir.replace('\\', "/")) + .map(|dir| dir.replace('\\', "/")) // TODO windows only .collect(); let manifest = Manifest { diff --git a/src/sender/reader.rs b/src/sender/reader.rs index 0b0cc62..788b51a 100644 --- a/src/sender/reader.rs +++ b/src/sender/reader.rs @@ -1,28 +1,34 @@ -use kanal::AsyncSender; use std::io::SeekFrom; use std::path::PathBuf; use std::sync::Arc; +use kanal::AsyncSender; use log::debug; use tokio::fs::File; use tokio::io::{AsyncReadExt, AsyncSeekExt, BufReader}; use tokio::sync::Semaphore; use crate::sender::Job; -use crate::{Result, ID_SIZE, INDEX_SIZE, READ_BUFFER_SIZE, TRANSFER_BUFFER_SIZE}; +use crate::{Result, StreamCipherExt, ID_SIZE, INDEX_SIZE, READ_BUFFER_SIZE, TRANSFER_BUFFER_SIZE}; -pub(crate) async fn reader( +pub(crate) async fn reader( path: PathBuf, queue: AsyncSender, read: Arc, mut index: u64, id: u32, + mut cipher: Option>, ) -> Result<()> { let file = File::open(path).await?; let mut reader = BufReader::with_capacity(READ_BUFFER_SIZE, file); reader.seek(SeekFrom::Start(index)).await?; + if let Some(ref mut cipher) = cipher { + cipher.seek(index); + } + let mut buffer = [0; ID_SIZE + INDEX_SIZE + TRANSFER_BUFFER_SIZE]; + // write id to buffer it is constant for all chunks buffer[..ID_SIZE].copy_from_slice(&id.to_be_bytes()); debug!("starting reader at index {}", index); @@ -33,7 +39,7 @@ pub(crate) async fn reader( // write index to buffer buffer[ID_SIZE..INDEX_SIZE + ID_SIZE].copy_from_slice(&index.to_be_bytes()); - // read data into buffer after checksum and index + // read data into buffer after id and index let read = reader.read(&mut buffer[INDEX_SIZE + ID_SIZE..]).await?; // check if EOF @@ -41,6 +47,10 @@ pub(crate) async fn reader( break; } + if let Some(ref mut cipher) = cipher { + cipher.apply_keystream(&mut buffer[INDEX_SIZE + ID_SIZE..INDEX_SIZE + ID_SIZE + read]); + } + // push job to queue queue .send(Job {