diff --git a/Cargo.toml b/Cargo.toml index 5bc70d8..2beaddd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,8 @@ chacha20 = "0.9" base64 = "0.21" ctr = "0.9" aes = "0.8" +whoami = "1.4" +cipher = "0.4" [target.'cfg(unix)'.dependencies] nix = { version = "0.27", features = ["fs"] } diff --git a/src/cipher.rs b/src/cipher.rs new file mode 100644 index 0000000..00528ef --- /dev/null +++ b/src/cipher.rs @@ -0,0 +1,101 @@ +use aes::{Aes128, Aes256}; +use chacha20::{ChaCha20, ChaCha8}; +use cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; +use ctr::Ctr128BE; +use prost::Message; +use tokio::io; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use crate::items::{Cipher, Crypto}; + +pub(crate) trait StreamCipherWrapper: Send + Sync { + fn seek(&mut self, index: u64); + fn apply_keystream(&mut self, data: &mut [u8]); +} + +impl StreamCipherWrapper for T +where + T: StreamCipherSeek + StreamCipher + Send + Sync, +{ + fn seek(&mut self, index: u64) { + StreamCipherSeek::seek(self, index); + } + + fn apply_keystream(&mut self, buf: &mut [u8]) { + StreamCipher::apply_keystream(self, buf); + } +} + +pub(crate) struct CipherStream { + stream: S, + cipher: Box, +} + +impl CipherStream { + pub(crate) fn new(stream: S, crypto: &Crypto) -> crate::Result { + Ok(Self { + stream, + cipher: make_cipher(crypto)?, + }) + } + + /// write a `Message` to the stream + pub(crate) async fn write_message(&mut self, message: &M) -> crate::Result<()> { + let len = message.encoded_len(); // get the length of the message + self.write_u32(len as u32).await?; // write the length of the message + + let mut buffer = Vec::with_capacity(len); // create a buffer to write the message into + message.encode(&mut buffer).unwrap(); // encode the message into the buffer (infallible) + + self.write_all(&mut buffer).await?; // write the message to the writer + + Ok(()) + } + + /// read a `Message` from the stream + pub(crate) async fn read_message(&mut self) -> crate::Result { + let len = self.read_u32().await? as usize; // read the length of the message + + let mut buffer = vec![0; len]; // create a buffer to read the message into + self.read_exact(&mut buffer).await?; // read the message into the buffer + + let message = M::decode(&buffer[..])?; // decode the message + + Ok(message) + } + + async fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + AsyncReadExt::read_exact(&mut self.stream, buf).await?; + self.cipher.apply_keystream(buf); + Ok(()) + } + + async fn read_u32(&mut self) -> io::Result { + let mut buf = [0; 4]; + self.read_exact(&mut buf).await?; + Ok(u32::from_be_bytes(buf)) + } + + async fn write_all(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.cipher.apply_keystream(buf); + AsyncWriteExt::write_all(&mut self.stream, buf).await + } + + async fn write_u32(&mut self, value: u32) -> io::Result<()> { + let mut buf = value.to_be_bytes(); + self.write_all(&mut buf).await + } +} + +pub(crate) fn make_cipher(crypto: &Crypto) -> crate::Result> { + let cipher: Cipher = crypto.cipher.try_into()?; + let key = &crypto.key[..cipher.key_length()]; + let iv = &crypto.iv[..cipher.iv_length()]; + + Ok(match cipher { + Cipher::Aes128 => Box::new(Ctr128BE::::new(key.into(), iv.into())), + Cipher::Aes256 => Box::new(Ctr128BE::::new(key.into(), iv.into())), + Cipher::Chacha8 => Box::new(ChaCha8::new(key.into(), iv.into())), + Cipher::Chacha20 => Box::new(ChaCha20::new(key.into(), iv.into())), + }) +} diff --git a/src/error.rs b/src/error.rs index a36fb4d..aa1f132 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,8 +1,8 @@ use std::array::TryFromSliceError; use std::fmt::Formatter; -use std::process::{ExitCode, Termination}; use kanal::{ReceiveError, SendError}; +use prost::Message; use tokio::io; use tokio::sync::AcquireError; @@ -34,6 +34,7 @@ pub(crate) enum ErrorKind { Failure(u32), EmptyPath, InvalidExtension, + UnexpectedMessage(Box), } impl From for Error { @@ -134,24 +135,6 @@ impl From for Error { } } -impl Termination for Error { - fn report(self) -> ExitCode { - ExitCode::from(match self.kind { - ErrorKind::Io(error) => match error.kind() { - io::ErrorKind::NotFound => 1, - _ => 2, - }, - ErrorKind::AddrParse(_) => 3, - ErrorKind::Decode(_) => 4, - ErrorKind::Join(_) => 5, - ErrorKind::Send(_) => 6, - ErrorKind::Receive(_) => 7, - ErrorKind::Acquire(_) => 8, - _ => 9, - }) - } -} - impl std::fmt::Display for Error { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self.kind { @@ -176,6 +159,9 @@ impl std::fmt::Display for Error { ErrorKind::Failure(ref reason) => write!(f, "Failure: {}", reason), ErrorKind::EmptyPath => write!(f, "Empty path"), ErrorKind::InvalidExtension => write!(f, "Invalid extension"), + ErrorKind::UnexpectedMessage(ref message) => { + write!(f, "Unexpected message {:?}", message) + } } } } @@ -211,6 +197,12 @@ impl Error { } } + pub(crate) fn unexpected_message(message: Box) -> Self { + Self { + kind: ErrorKind::UnexpectedMessage(message), + } + } + #[cfg(windows)] pub(crate) fn status_error() -> Self { Self { diff --git a/src/items.proto b/src/items.proto index be77be8..801f12c 100644 --- a/src/items.proto +++ b/src/items.proto @@ -35,9 +35,10 @@ message Crypto { } enum Cipher { - AES = 0; - CHACHA8 = 1; + CHACHA8 = 0; + AES128 = 1; CHACHA20 = 2; + AES256 = 3; } // the receiver already had these files @@ -78,4 +79,6 @@ message Failure { } // signals the receiver that the sender won't start new transfers -message Done {} \ No newline at end of file +message Done { + uint32 reason = 1; +} \ No newline at end of file diff --git a/src/items.rs b/src/items.rs index ebeaf85..d1a815f 100644 --- a/src/items.rs +++ b/src/items.rs @@ -38,9 +38,9 @@ impl Message { } } - pub(crate) fn done() -> Self { + pub(crate) fn done(reason: u32) -> Self { Self { - message: Some(message::Message::Done(Done {})), + message: Some(message::Message::Done(Done { reason })), } } } @@ -48,14 +48,17 @@ impl Message { impl Cipher { /// the length of the key in bytes pub(crate) fn key_length(&self) -> usize { - 32 + match self { + Self::Chacha20 | Self::Chacha8 | Self::Aes256 => 32, + Self::Aes128 => 16, + } } /// the length of the iv in bytes pub(crate) fn iv_length(&self) -> usize { match self { Self::Chacha20 | Self::Chacha8 => 12, - Self::Aes => 16, + Self::Aes256 | Self::Aes128 => 16, } } } @@ -63,9 +66,10 @@ impl Cipher { impl Display for Cipher { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let cipher = match self { - Self::Aes => "aes", - Self::Chacha8 => "chacha8", - Self::Chacha20 => "chacha20", + Self::Aes128 => "AES128", + Self::Aes256 => "AES256", + Self::Chacha8 => "CHACHA8", + Self::Chacha20 => "CHACHA20", }; write!(f, "{}", cipher) @@ -77,3 +81,9 @@ impl StartIndex { Self { index } } } + +impl Manifest { + pub(crate) fn is_empty(&self) -> bool { + self.files.is_empty() && self.directories.is_empty() + } +} diff --git a/src/main.rs b/src/main.rs index 9a29b81..b61be9d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,30 +7,26 @@ use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::time::Duration; -use aes::Aes256; use async_ssh2_tokio::{AuthMethod, Client, ServerCheckMethod}; use blake3::{Hash, Hasher}; -use chacha20::cipher::generic_array::GenericArray; -use chacha20::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; -use chacha20::{ChaCha20, ChaCha8}; use clap::{CommandFactory, Parser}; -use ctr::Ctr128BE; use futures::stream::iter; use futures::{StreamExt, TryStreamExt}; use indicatif::{ProgressBar, ProgressStyle}; use log::{debug, error, info, warn}; -use prost::Message; use rpassword::prompt_password; use simple_logging::{log_to_file, log_to_stderr}; use tokio::fs::File; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}; +use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::{TcpListener, TcpSocket, TcpStream, UdpSocket}; use tokio::time::{interval, sleep}; use tokio::{io, select}; -use crate::items::{Cipher, Crypto}; +use crate::cipher::CipherStream; + use crate::options::{Mode, Options}; +mod cipher; mod error; mod items; mod options; @@ -40,11 +36,6 @@ mod sender; // result alias used throughout type Result = std::result::Result; -trait StreamCipherExt: Send + Sync { - fn seek(&mut self, index: u64); - fn apply_keystream(&mut self, data: &mut [u8]); -} - // read buffer must be a multiple of the transfer buffer to prevent a nasty little bug const READ_BUFFER_SIZE: usize = TRANSFER_BUFFER_SIZE * 100; const WRITE_BUFFER_SIZE: usize = TRANSFER_BUFFER_SIZE * 100; @@ -59,36 +50,6 @@ const PACKET_SIZE: usize = 8 + ID_SIZE + INDEX_SIZE + TRANSFER_BUFFER_SIZE; // how long to wait for a job to be confirmed before requeuing it const REQUEUE_INTERVAL: Duration = Duration::from_millis(1_000); -impl StreamCipherExt for ChaCha20 { - fn seek(&mut self, index: u64) { - StreamCipherSeek::seek(self, index); - } - - fn apply_keystream(&mut self, buf: &mut [u8]) { - StreamCipher::apply_keystream(self, buf) - } -} - -impl StreamCipherExt for ChaCha8 { - fn seek(&mut self, index: u64) { - StreamCipherSeek::seek(self, index); - } - - fn apply_keystream(&mut self, buf: &mut [u8]) { - StreamCipher::apply_keystream(self, buf) - } -} - -impl StreamCipherExt for Ctr128BE { - fn seek(&mut self, index: u64) { - StreamCipherSeek::seek(self, index); - } - - fn apply_keystream(&mut self, buf: &mut [u8]) { - StreamCipher::apply_keystream(self, buf) - } -} - #[derive(Clone, Default)] struct TransferStats { confirmed_data: Arc, @@ -101,8 +62,14 @@ async fn main() -> Result<()> { let mut command = Options::command(); match options.mode { - Mode::Local => log_to_stderr(options.log_level), - _ => log_to_file("cccp.log", options.log_level).expect("failed to log"), + Mode::Local => { + if let Some(path) = &options.log_file { + log_to_file(path, options.log_level)? + } else { + log_to_stderr(options.log_level) + } + } + _ => log_to_file("cccp.log", options.log_level)?, } // only the local client needs to handle input validation @@ -147,14 +114,7 @@ async fn main() -> Result<()> { options.end_port = new_end; } - if options.destination.host.is_none() && options.source.host.is_none() { - command - .error( - clap::error::ErrorKind::ValueValidation, - "either the source or destination must be remote", - ) - .exit(); - } else if options.destination.is_local() && options.source.is_local() { + if options.destination.is_local() && options.source.is_local() { command .error( clap::error::ErrorKind::ValueValidation, @@ -169,39 +129,35 @@ async fn main() -> Result<()> { let result = match options.mode { Mode::Local => { - let command_str = options.format_command(sender); - let (local, remote) = if sender { (&options.source, &mut options.destination) } else { (&options.destination, &mut options.source) }; - if remote.username.is_none() { - warn!( - "username not specified for remote host (trying \"\"): {}", - remote - ); - remote.username = Some(String::new()); - } else if remote.host.is_none() { + if remote.host.is_none() { command .error( clap::error::ErrorKind::ValueValidation, - format!("host must be specified for remote host: {}", remote), + format!("host must be specified for remote IoSpec: {}", remote), ) .exit(); + } else if remote.username.is_none() { + remote.username = Some(whoami::username()); } debug!("local {}", local); debug!("remote {}", remote); - let mut auth_method = ssh_key_auth().await.unwrap_or_else(|error| { - warn!("failed to use ssh key auth: {}", error); - password_auth().unwrap() - }); + let mut auth_method = if let Ok(auth) = ssh_key_auth().await { + auth + } else { + password_auth()? + }; // unwrap is safe because we check for a host above let remote_addr = remote.host.unwrap(); + let remote_ip = remote_addr.ip(); let client = loop { match Client::connect( @@ -229,6 +185,8 @@ async fn main() -> Result<()> { info!("connected to the remote host via ssh"); + let command_str = options.format_command(sender); + let command_handle = tokio::spawn(async move { info!("executing command on remote host"); debug!("command: {}", command_str); @@ -237,15 +195,14 @@ async fn main() -> Result<()> { }); // receiver -> sender stream - let rts_stream = - connect_stream(remote_addr.ip(), options.start_port, options.bind_address).await?; + let stream = + connect_stream(remote_ip, options.start_port, options.bind_address).await?; + let rts_stream = CipherStream::new(stream, &options.control_crypto)?; + // sender -> receiver stream - let str_stream = connect_stream( - remote_addr.ip(), - options.start_port + 1, - options.bind_address, - ) - .await?; + let stream = + connect_stream(remote_ip, options.start_port + 1, options.bind_address).await?; + let str_stream = CipherStream::new(stream, &options.control_crypto)?; let display_handle = tokio::spawn({ let stats = stats.clone(); @@ -255,23 +212,9 @@ async fn main() -> Result<()> { let main_future = async { if sender { - sender::main( - options, - stats.clone(), - rts_stream, - str_stream, - remote_addr.ip(), - ) - .await + sender::main(options, stats, rts_stream, str_stream, remote_ip).await } else { - receiver::main( - options, - stats.clone(), - rts_stream, - str_stream, - remote_addr.ip(), - ) - .await + receiver::main(options, stats, rts_stream, str_stream, remote_ip).await } }; @@ -280,18 +223,13 @@ async fn main() -> Result<()> { match result { Ok(Ok(result)) => { - match result.exit_status { - 0 => { - info!("remote client exited successfully"); - // wait forever to allow the other futures to complete - sleep(Duration::from_secs(u64::MAX)).await; - } - 1 => error!("remote client failed, file not found"), - 2 => error!("remote client failed, unknown IO error"), - 3 => error!("remote client failed, parse error"), - 4 => error!("remote client failed, decode error"), - 5 => error!("remote client failed, join error"), - _ => error!("remote client failed, unknown error"), + if result.exit_status != 0 { + // return to terminate execution + error!("remote command failed: {:?}", result); + } else { + info!("remote client exited successfully"); + // wait forever to allow the other futures to complete + sleep(Duration::from_secs(u64::MAX)).await; } } Ok(Err(error)) => error!("remote client failed: {}", error), // return to terminate execution @@ -308,11 +246,15 @@ async fn main() -> Result<()> { Mode::Remote(sender) => { // receiver -> sender stream let listener = TcpListener::bind(("0.0.0.0", options.start_port)).await?; - let (rts_stream, remote_addr) = listener.accept().await?; + let (stream, remote_addr) = listener.accept().await?; + + let rts_stream = CipherStream::new(stream, &options.control_crypto)?; // sender -> receiver stream let listener = TcpListener::bind(("0.0.0.0", options.start_port + 1)).await?; - let (str_stream, _) = listener.accept().await?; + let (stream, _) = listener.accept().await?; + + let str_stream = CipherStream::new(stream, &options.control_crypto)?; let remote_addr = remote_addr.ip(); @@ -431,44 +373,6 @@ async fn connect_stream( } } -/// write a `Message` to a writer -async fn write_message( - writer: &mut W, - message: &M, - cipher: &mut Box, -) -> Result<()> { - let len = message.encoded_len(); // get the length of the message - writer.write_u32(len as u32).await?; // write the length of the message - - let mut buffer = Vec::with_capacity(len); // create a buffer to write the message into - message.encode(&mut buffer).unwrap(); // encode the message into the buffer (infallible) - cipher.apply_keystream(&mut buffer[..]); // encrypt the message - - writer.write_all(&buffer).await?; // write the message to the writer - - Ok(()) -} - -/// read a `Message` from a reader -async fn read_message< - R: AsyncReadExt + Unpin, - M: Message + Default, - C: StreamCipherExt + ?Sized, ->( - reader: &mut R, - cipher: &mut Box, -) -> Result { - let len = reader.read_u32().await? as usize; // read the length of the message - - let mut buffer = vec![0; len]; // create a buffer to read the message into - reader.read_exact(&mut buffer).await?; // read the message into the buffer - cipher.apply_keystream(&mut buffer[..]); // decrypt the message - - let message = M::decode(&buffer[..])?; // decode the message - - Ok(message) -} - async fn hash_file>(path: P) -> io::Result { let file = File::open(path).await?; let mut reader = BufReader::with_capacity(READ_BUFFER_SIZE, file); @@ -488,26 +392,3 @@ async fn hash_file>(path: P) -> io::Result { Ok(hasher.finalize()) } - -fn make_cipher(crypto: &Crypto) -> Box { - let key = GenericArray::from_slice(&crypto.key[..32]); - - match crypto.cipher.try_into() { - Ok(Cipher::Aes) => { - let iv = GenericArray::from_slice(&crypto.key[..16]); - - Box::new(Ctr128BE::::new(key, iv)) - } - Ok(Cipher::Chacha8) => { - let iv = GenericArray::from_slice(&crypto.key[..12]); - - Box::new(ChaCha8::new(key, iv)) - } - Ok(Cipher::Chacha20) => { - let iv = GenericArray::from_slice(&crypto.key[..12]); - - Box::new(ChaCha20::new(key, iv)) - } - _ => unreachable!(), - } -} diff --git a/src/options.rs b/src/options.rs index c1f4bc6..00e088f 100644 --- a/src/options.rs +++ b/src/options.rs @@ -16,13 +16,23 @@ use tokio::io; use crate::items::{Cipher, Crypto}; use crate::PACKET_SIZE; +const HELP_HEADING: &str = "\x1B[1m\x1B[4mAbout\x1B[0m + cccp is a fast, secure, and reliable file transfer utility + +\x1B[1m\x1B[4mIoSpec\x1B[0m + - [user@][host:{port:}]file + - If no user is set for a remote host, the current user is used + - If no port is provided, port 22 is used + - Either the InSpec or OutSpec should be remote, not both or neither + +\x1B[1m\x1B[4mCiphers\x1B[0m + - CHACHA8 + - CHAHA20 + - AES128 + - AES256"; + #[derive(Parser, Clone, Debug)] -#[clap( - name = "cccp", - version, - author, - about = "A fast and secure file transfer utility" -)] +#[clap(version, about = HELP_HEADING)] pub(crate) struct Options { // the user does not need to set this #[clap(long, hide = true, default_value = "local")] @@ -53,7 +63,7 @@ pub(crate) struct Options { pub(crate) max: usize, /// Encrypt the control stream - #[clap(short, long, default_value = "aes")] + #[clap(short, long, default_value = "AES256")] pub(crate) control_crypto: Crypto, /// Verify integrity of transfers using blake3 @@ -76,11 +86,15 @@ pub(crate) struct Options { #[clap(short, long)] pub(crate) bind_address: Option, - /// The source file or directory + /// Log to a file (default: stderr) + #[clap(short = 'L', long)] + pub(crate) log_file: Option, + + /// The source IoSpec (InSpec) #[clap()] pub(crate) source: IoSpec, - /// The destination file or directory + /// The destination IoSpec (OutSpec) #[clap()] pub(crate) destination: IoSpec, } @@ -151,24 +165,21 @@ impl FromStr for Crypto { type Err = OptionParseError; fn from_str(s: &str) -> Result { - let captures = Regex::new("([A-Za-z\\d]+)(?::([A-Za-z0-9+/]+))?(?::([A-Za-z0-9+/]+))?") - .unwrap() // infallible - .captures(s) - .ok_or(Self::Err::invalid_cipher_format())?; + let mut captures = s.split(':'); - // unwrap is safe because the regex requires a cipher name - let cipher_str = captures.get(1).unwrap().as_str().to_uppercase(); + // unwrap is safe because there will always be at least one capture + let cipher_str = captures.next().unwrap().to_uppercase(); let cipher = Cipher::from_str_name(&cipher_str).ok_or(Self::Err::invalid_cipher())?; let key = captures - .get(2) - .map(|m| STANDARD_NO_PAD.decode(m.as_str())) + .next() + .map(|m| STANDARD_NO_PAD.decode(m)) .transpose()? // propagate the decode error .unwrap_or_else(|| random_bytes(cipher.key_length())); let iv = captures - .get(3) - .map(|m| STANDARD_NO_PAD.decode(m.as_str())) + .next() + .map(|m| STANDARD_NO_PAD.decode(m)) .transpose()? // propagate the decode error .unwrap_or_else(|| random_bytes(cipher.iv_length())); @@ -188,6 +199,7 @@ impl FromStr for Crypto { impl Display for Crypto { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + // encode the key & iv to base64 let key = STANDARD_NO_PAD.encode(&self.key); let iv = STANDARD_NO_PAD.encode(&self.iv); @@ -213,7 +225,7 @@ impl FromStr for IoSpec { type Err = OptionParseError; fn from_str(s: &str) -> Result { - let captures = Regex::new("(?:([\\w-]+)@([\\w.-]+)(?::(\\d+))?:)?([ \\w/.-]+)") + let captures = Regex::new("(?:(?:([\\w-]+)@)?([\\w.-]+)(?::(\\d+))?:)?([ \\w/.-]+)") .unwrap() // infallible .captures(s) .ok_or(Self::Err::malformed_io_spec("Invalid IO spec"))?; @@ -290,7 +302,6 @@ enum ErrorKind { MalformedIoSpec(&'static str), UnknownMode, InvalidCipher, - InvalidCipherFormat, InvalidKey, InvalidIv, NoSuchHost, @@ -309,7 +320,6 @@ impl Display for OptionParseError { ErrorKind::MalformedIoSpec(message) => message, ErrorKind::UnknownMode => "The mode can be either sender or receiver", ErrorKind::InvalidCipher => "Invalid cipher", - ErrorKind::InvalidCipherFormat => "Invalid cipher format", ErrorKind::InvalidKey => "Invalid key", ErrorKind::InvalidIv => "Invalid IV", ErrorKind::NoSuchHost => "No such host", @@ -356,12 +366,6 @@ impl OptionParseError { } } - fn invalid_cipher_format() -> Self { - Self { - kind: ErrorKind::InvalidCipherFormat, - } - } - fn malformed_io_spec(message: &'static str) -> Self { Self { kind: ErrorKind::MalformedIoSpec(message), diff --git a/src/receiver/mod.rs b/src/receiver/mod.rs index eab36fa..a938f5c 100644 --- a/src/receiver/mod.rs +++ b/src/receiver/mod.rs @@ -13,7 +13,6 @@ use futures::{StreamExt, TryStreamExt}; use kanal::{AsyncReceiver, AsyncSender}; use log::{debug, error, info, warn}; use tokio::fs::{create_dir, metadata}; -use tokio::io::AsyncWrite; use tokio::net::{TcpStream, UdpSocket}; use tokio::select; use tokio::sync::Mutex; @@ -24,8 +23,8 @@ use crate::error::Error; use crate::items::{message, ConfirmationIndexes, Manifest, Message, StartIndex}; use crate::receiver::writer::{writer, FileDetails, SplitQueue}; use crate::{ - make_cipher, read_message, socket_factory, write_message, Options, Result, StreamCipherExt, - TransferStats, ID_SIZE, INDEX_SIZE, MAX_RETRIES, RECEIVE_TIMEOUT, TRANSFER_BUFFER_SIZE, + socket_factory, CipherStream, Options, Result, TransferStats, ID_SIZE, INDEX_SIZE, MAX_RETRIES, + RECEIVE_TIMEOUT, TRANSFER_BUFFER_SIZE, }; mod writer; @@ -41,16 +40,13 @@ struct Job { pub(crate) async fn main( options: Options, stats: TransferStats, - rts_stream: TcpStream, - mut str_stream: TcpStream, + rts_stream: CipherStream, + mut str_stream: CipherStream, remote_addr: IpAddr, ) -> Result<()> { info!("receiving {} -> {}", options.source, options.destination); - let mut str_cipher = make_cipher(&options.control_crypto); - let rts_cipher = make_cipher(&options.control_crypto); - - let manifest: Manifest = read_message(&mut str_stream, &mut str_cipher).await?; + let manifest: Manifest = str_stream.read_message().await?; let is_dir = manifest.files.len() > 1; // if multiple files are being received, the destination should be a directory debug!( "received manifest | files={} dirs={}", @@ -134,24 +130,18 @@ pub(crate) async fn main( stats.total_data.load(Relaxed) ); - write_message( - &mut str_stream, - &Message::failure(0, 1, None), - &mut str_cipher, - ) - .await?; + str_stream + .write_message(&Message::failure(0, 1, None)) + .await?; return Err(Error::failure(1)); } debug!("sending completed: {:?}", completed); // send the completed message to the remote client - write_message( - &mut str_stream, - &Message::completed(completed), - &mut str_cipher, - ) - .await?; + str_stream + .write_message(&Message::completed(completed)) + .await?; // if the destination is a directory, create it if is_dir { @@ -186,7 +176,7 @@ pub(crate) async fn main( // `message_sender` can now be used to send messages to the sender let (message_sender, message_receiver) = kanal::unbounded_async(); - tokio::spawn(send_messages(rts_stream, message_receiver, rts_cipher)); + let message_sender_handle = tokio::spawn(send_messages(rts_stream, message_receiver)); let confirmation_handle = tokio::spawn(send_confirmations( message_sender.clone(), @@ -200,7 +190,6 @@ pub(crate) async fn main( writer_queue.clone(), confirmation_sender, message_sender, - str_cipher, )); let handles: Vec<_> = sockets @@ -220,6 +209,7 @@ pub(crate) async fn main( result = confirmation_handle => { debug!("confirmation sender exited: {:?}", result); result? }, result = controller_handle => { debug!("controller exited: {:?}", result); result? }, result = receiver_future => { debug!("receivers exited: {:?}", result); result }, + result = message_sender_handle => { debug!("message sender exited: {:?}", result); result? } } } @@ -237,7 +227,10 @@ async fn receiver(queue: WriterQueue, socket: UdpSocket) -> Result<()> { let index = u64::from_be_bytes(buf[ID_SIZE..INDEX_SIZE + ID_SIZE].try_into()?); let data = buf[INDEX_SIZE + ID_SIZE..].try_into()?; - queue.send(Job { data, index }, id).await?; + if queue.send(Job { data, index }, id).await.is_err() { + // a message was received for a file that has already been completed (probably) + debug!("failed to send job for {} to writer", id); + } } Ok(Ok(_)) => warn!("0 byte read?"), // this should never happen Ok(Err(_)) | Err(_) => retries += 1, // catch errors and timeouts @@ -251,17 +244,16 @@ async fn receiver(queue: WriterQueue, socket: UdpSocket) -> Result<()> { } } -async fn controller( - mut str_stream: TcpStream, +async fn controller( + mut control_stream: CipherStream, mut files: HashMap, writer_queue: WriterQueue, confirmation_sender: AsyncSender<(u32, u64)>, message_sender: AsyncSender, - mut str_cipher: Box, ) -> Result<()> { loop { debug!("waiting for message"); - let message: Message = read_message(&mut str_stream, &mut str_cipher).await?; + let message: Message = control_stream.read_message().await?; match message.message { Some(message::Message::Start(message)) => { @@ -281,12 +273,9 @@ async fn controller( } // send the start index to the remote client - write_message( - &mut str_stream, - &StartIndex::new(details.start_index), - &mut str_cipher, - ) - .await?; + control_stream + .write_message(&StartIndex::new(details.start_index)) + .await?; writer_queue.push_queue(message.id).await; // create a queue for the writer @@ -309,12 +298,18 @@ async fn controller( } }); } - Some(message::Message::Done(_)) => { - debug!("received done message"); + Some(message::Message::Done(message)) => { + match message.reason { + 0 => warn!("remote client found no files to send"), + 1 => debug!("all transfers were completed before execution"), + 2 => debug!("remote client completed all transfers"), + _ => break Err(Error::unexpected_message(Box::new(message))), + } + message_sender.close(); break Ok(()); } - _ => unreachable!("controller received unexpected message: {:?}", message), + _ => return Err(Error::unexpected_message(Box::new(message))), } } } @@ -373,14 +368,13 @@ async fn send_confirmations( } } -/// send messages from a channel to a writer -async fn send_messages( - mut writer: W, +/// send messages from a channel to a cipher stream +async fn send_messages( + mut stream: CipherStream, receiver: AsyncReceiver, - mut cipher: Box, ) -> Result<()> { while let Ok(message) = receiver.recv().await { - write_message(&mut writer, &message, &mut cipher).await?; + stream.write_message(&message).await?; } Ok(()) diff --git a/src/receiver/writer.rs b/src/receiver/writer.rs index 4d28777..4d41948 100644 --- a/src/receiver/writer.rs +++ b/src/receiver/writer.rs @@ -9,12 +9,11 @@ use tokio::fs::{remove_file, rename, OpenOptions}; use tokio::io::{self, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter}; use tokio::sync::Mutex; +use crate::cipher::{make_cipher, StreamCipherWrapper}; use crate::error::Error; use crate::items::{Crypto, Message}; use crate::receiver::{Job, WriterQueue}; -use crate::{ - hash_file, make_cipher, Result, StreamCipherExt, TRANSFER_BUFFER_SIZE, WRITE_BUFFER_SIZE, -}; +use crate::{hash_file, Result, TRANSFER_BUFFER_SIZE, WRITE_BUFFER_SIZE}; #[derive(Default)] pub(crate) struct SplitQueue { @@ -87,7 +86,7 @@ pub(crate) async fn writer( let mut writer = BufWriter::with_capacity(WRITE_BUFFER_SIZE, file); writer.seek(SeekFrom::Start(position)).await?; // seek to the initial position - let mut cipher = details.crypto.as_ref().map(make_cipher); + let mut cipher = details.crypto.as_ref().map(make_cipher).transpose()?; if let Some(ref mut cipher) = cipher { cipher.seek(position); @@ -173,7 +172,7 @@ pub(crate) async fn writer( /// write data and advance position #[inline] -async fn write_data( +async fn write_data( writer: &mut T, mut buffer: [u8; TRANSFER_BUFFER_SIZE], position: &mut u64, diff --git a/src/sender/mod.rs b/src/sender/mod.rs index 99c0949..f0770a7 100644 --- a/src/sender/mod.rs +++ b/src/sender/mod.rs @@ -12,19 +12,19 @@ use futures::stream::iter; use futures::{StreamExt, TryStreamExt}; use kanal::{AsyncReceiver, AsyncSender}; use log::{debug, error, info, warn}; -use tokio::io::{self, AsyncReadExt}; +use tokio::io; use tokio::net::{TcpStream, UdpSocket}; use tokio::select; use tokio::sync::{Mutex, RwLock, Semaphore}; use tokio::time::{interval, Instant}; +use crate::cipher::make_cipher; use crate::error::Error; use crate::items::{message, Confirmations, FileDetail, Manifest, Message, StartIndex}; use crate::sender::reader::reader; use crate::{ - hash_file, make_cipher, read_message, socket_factory, write_message, Options, Result, - StreamCipherExt, TransferStats, ID_SIZE, INDEX_SIZE, MAX_RETRIES, REQUEUE_INTERVAL, - TRANSFER_BUFFER_SIZE, + hash_file, socket_factory, CipherStream, Options, Result, TransferStats, ID_SIZE, INDEX_SIZE, + MAX_RETRIES, REQUEUE_INTERVAL, TRANSFER_BUFFER_SIZE, }; mod reader; @@ -41,25 +41,24 @@ struct Job { pub(crate) async fn main( options: Options, stats: TransferStats, - rts_stream: TcpStream, - mut str_stream: TcpStream, + rts_stream: CipherStream, + mut str_stream: CipherStream, remote_addr: IpAddr, ) -> Result<()> { info!("sending {} -> {}", options.source, options.destination); - let mut str_cipher = make_cipher(&options.control_crypto); - let rts_cipher = make_cipher(&options.control_crypto); - let mut manifest = build_manifest(&options, &stats.total_data).await?; - debug!( - "sending manifest | files={} dirs={}", - manifest.files.len(), - manifest.directories.len() - ); - write_message(&mut str_stream, &manifest, &mut str_cipher).await?; + if manifest.is_empty() { + warn!("found no files to send"); + str_stream.write_message(&Message::done(0)).await?; + return Ok(()); + } + + debug!("sending manifest"); + str_stream.write_message(&manifest).await?; - let message: Message = read_message(&mut str_stream, &mut str_cipher).await?; + let message: Message = str_stream.read_message().await?; match message.message { Some(message::Message::Completed(completed)) => { @@ -73,7 +72,7 @@ pub(crate) async fn main( if manifest.files.is_empty() { info!("all files completed"); - write_message(&mut str_stream, &Message::done(), &mut str_cipher).await?; + str_stream.write_message(&Message::done(1)).await?; return Ok(()); } } @@ -81,7 +80,7 @@ pub(crate) async fn main( error!("received failure message {}", failure.reason); return Err(Error::failure(failure.reason)); } - _ => unreachable!("received unexpected message: {:?}", message), + _ => return Err(Error::unexpected_message(Box::new(message))), } let sockets = socket_factory( @@ -114,7 +113,6 @@ pub(crate) async fn main( rts_stream, confirmation_sender, controller_sender, - rts_cipher, )); let confirmation_handle = tokio::spawn(receive_confirmations( @@ -128,15 +126,13 @@ pub(crate) async fn main( tokio::spawn(add_permits_at_rate(send.clone(), options.pps())); let controller_handle = tokio::spawn(controller( + options, str_stream, manifest.files, job_sender.clone(), read, stats.confirmed_data, - options.source.file_path, controller_receiver, - options.max, - str_cipher, )); let handles: Vec<_> = sockets @@ -204,23 +200,20 @@ async fn sender( } } -#[allow(clippy::too_many_arguments)] -async fn controller( - mut control_stream: TcpStream, +async fn controller( + options: Options, + mut control_stream: CipherStream, mut files: HashMap, job_sender: AsyncSender, read: Arc, confirmed_data: Arc, - base_path: PathBuf, controller_receiver: AsyncReceiver, - max: usize, - mut cipher: Box, ) -> Result<()> { let mut id = 0; - let mut active: HashMap = HashMap::with_capacity(max); + let mut active: HashMap = HashMap::with_capacity(options.max); loop { - while active.len() < max && !files.is_empty() { + while active.len() < options.max && !files.is_empty() { match files.remove(&id) { None => id += 1, Some(details) => { @@ -228,11 +221,10 @@ async fn controller( &mut control_stream, id, &details, - &base_path, + &options.source.file_path, &job_sender, &read, &confirmed_data, - &mut cipher, ) .await?; @@ -244,7 +236,9 @@ async fn controller( debug!("waiting for a message"); - match controller_receiver.recv().await?.message { + let message: Message = controller_receiver.recv().await?; + + match message.message { Some(message::Message::End(end)) => { if active.remove(&end.id).is_none() { warn!("received end message for unknown file {}", end.id); @@ -265,40 +259,33 @@ async fn controller( &mut control_stream, failure.id, details, - &base_path, + &options.source.file_path, &job_sender, &read, &confirmed_data, - &mut cipher, ) .await?; } else { - warn!( - "received failure message {} for unknown file {}", - failure.reason, failure.id - ); + warn!("received failure message {:?} for unknown file", failure); } } Some(message::Message::Failure(failure)) if failure.reason == 2 => { if active.remove(&failure.id).is_some() { error!( "remote writer failed {} [TRANSFER WILL NOT BE RETRIED]", - failure.description.unwrap() + failure.description.unwrap() // the description is always present for this failure reason ); } else { warn!( - "received writer failure message {} for unknown file {}", - failure.reason, failure.id + "received writer failure message {:?} for unknown file", + failure ); } } Some(message::Message::Failure(failure)) => { - warn!( - "received unknown failure message {} for file {}", - failure.reason, failure.id - ); + warn!("received unknown failure message {:?}", failure); } - _ => unreachable!(), // only end and failure messages are sent to this receiver + _ => return Err(Error::unexpected_message(Box::new(message))), } if files.is_empty() && active.is_empty() { @@ -307,27 +294,27 @@ async fn controller( } debug!("all files completed, sending done message"); - write_message(&mut control_stream, &Message::done(), &mut cipher).await?; + control_stream.write_message(&Message::done(2)).await?; Ok(()) } -#[allow(clippy::too_many_arguments)] -async fn start_file_transfer( - mut control_stream: &mut TcpStream, +async fn start_file_transfer( + control_stream: &mut CipherStream, id: u32, details: &FileDetail, base_path: &Path, job_sender: &AsyncSender, read: &Arc, confirmed_data: &Arc, - cipher: &mut Box, ) -> Result<()> { - write_message(&mut control_stream, &Message::start(id), cipher).await?; + control_stream.write_message(&Message::start(id)).await?; - let start_index: StartIndex = read_message(&mut control_stream, cipher).await?; + let start_index: StartIndex = control_stream.read_message().await?; confirmed_data.fetch_add(start_index.index as usize, Relaxed); + let cipher = details.crypto.as_ref().map(make_cipher).transpose()?; + tokio::spawn({ let job_sender = job_sender.clone(); let read = read.clone(); @@ -335,21 +322,13 @@ async fn start_file_transfer( let base_path = base_path.to_path_buf(); let path = if base_path.is_dir() { - base_path.join(&details.path) + base_path.join(details.path) } else { base_path }; async move { - let result = reader( - path, - job_sender, - read, - start_index.index, - id, - details.crypto.as_ref().map(make_cipher), - ) - .await; + let result = reader(path, job_sender, read, start_index.index, id, cipher).await; if let Err(error) = result { error!("reader failed: {:?}", error); @@ -460,14 +439,13 @@ async fn add_permits_at_rate(semaphore: Arc, rate: u64) { } /// split the message stream into `Confirmation` and `End + Failure` messages -async fn split_receiver( - mut reader: R, +async fn split_receiver( + mut stream: CipherStream, confirmation_sender: AsyncSender, controller_sender: AsyncSender, - mut cipher: Box, ) -> Result<()> { loop { - let message: Message = read_message(&mut reader, &mut cipher).await?; + let message: Message = stream.read_message().await?; match message.message { Some(message::Message::Confirmations(confirmations)) => { @@ -482,6 +460,7 @@ async fn split_receiver( } } +/// builds a manifest of the files to send and some details about them async fn build_manifest(options: &Options, total_data: &Arc) -> Result { // collect the files and directories to send let mut files = Vec::new(); @@ -493,8 +472,8 @@ async fn build_manifest(options: &Options, total_data: &Arc) -> Res &mut dirs, options.recursive, )?; - let files_len = files.len(); - debug!("found {} files & {} dirs", files_len, dirs.len()); + + debug!("found {} files & {} dirs", files.len(), dirs.len()); let file_map: HashMap = iter(files.into_iter().enumerate()) .map(|(index, mut file)| async move { @@ -553,15 +532,19 @@ async fn build_manifest(options: &Options, total_data: &Arc) -> Res Ok(manifest) } -/// recursively collect all files and directories in a directory +/// collect all files and directories in a directory fn files_and_dirs( - source: &Path, + path: &Path, files: &mut Vec, dirs: &mut Vec, recursive: bool, ) -> io::Result<()> { - if source.is_dir() { - for entry in source.read_dir()?.filter_map(std::result::Result::ok) { + if !path.exists() { + return Ok(()); + } + + if path.is_dir() { + for entry in path.read_dir()?.filter_map(std::result::Result::ok) { let path = entry.path(); if path.is_dir() { @@ -574,7 +557,7 @@ fn files_and_dirs( } } } else { - files.push(source.to_path_buf()); + files.push(path.to_path_buf()); } Ok(()) diff --git a/src/sender/reader.rs b/src/sender/reader.rs index 788b51a..4aa44a7 100644 --- a/src/sender/reader.rs +++ b/src/sender/reader.rs @@ -8,10 +8,11 @@ use tokio::fs::File; use tokio::io::{AsyncReadExt, AsyncSeekExt, BufReader}; use tokio::sync::Semaphore; +use crate::cipher::StreamCipherWrapper; use crate::sender::Job; -use crate::{Result, StreamCipherExt, ID_SIZE, INDEX_SIZE, READ_BUFFER_SIZE, TRANSFER_BUFFER_SIZE}; +use crate::{Result, ID_SIZE, INDEX_SIZE, READ_BUFFER_SIZE, TRANSFER_BUFFER_SIZE}; -pub(crate) async fn reader( +pub(crate) async fn reader( path: PathBuf, queue: AsyncSender, read: Arc,