diff --git a/Cargo.toml b/Cargo.toml index 2beaddd..5a865cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "cccp" -version = "0.8.0" +version = "0.9.0" edition = "2021" build = "build.rs" repository = "https://github.com/chanderlud/cccp" @@ -10,10 +10,11 @@ authors = ["Chander Luderman "] [dependencies] clap = { version = "4.4", features = ["derive"] } -tokio = { version = "1.35", default-features = false, features = ["macros", "fs", "io-util"] } +tokio = { version = "1.35", default-features = false, features = ["macros", "fs", "io-util", "signal"] } futures = "0.3" log = { version = "0.4", features = ["std"] } async-ssh2-tokio = "0.8" +russh = "0.38" simple-logging = "2.0" regex = "1.10" dirs = "5.0" @@ -29,6 +30,7 @@ ctr = "0.9" aes = "0.8" whoami = "1.4" cipher = "0.4" +rand = "0.8" [target.'cfg(unix)'.dependencies] nix = { version = "0.27", features = ["fs"] } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index e69de29..271800c 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly" \ No newline at end of file diff --git a/src/cipher.rs b/src/cipher.rs index 00528ef..71b0461 100644 --- a/src/cipher.rs +++ b/src/cipher.rs @@ -1,12 +1,15 @@ -use aes::{Aes128, Aes256}; +use aes::{Aes128, Aes192, Aes256}; use chacha20::{ChaCha20, ChaCha8}; use cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; use ctr::Ctr128BE; use prost::Message; +use rand::rngs::{OsRng, StdRng}; +use rand::{RngCore, SeedableRng}; use tokio::io; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use crate::items::{Cipher, Crypto}; +use crate::Result; pub(crate) trait StreamCipherWrapper: Send + Sync { fn seek(&mut self, index: u64); @@ -17,10 +20,12 @@ impl StreamCipherWrapper for T where T: StreamCipherSeek + StreamCipher + Send + Sync, { + #[inline(always)] fn seek(&mut self, index: u64) { StreamCipherSeek::seek(self, index); } + #[inline(always)] fn apply_keystream(&mut self, buf: &mut [u8]) { StreamCipher::apply_keystream(self, buf); } @@ -32,15 +37,15 @@ pub(crate) struct CipherStream { } impl CipherStream { - pub(crate) fn new(stream: S, crypto: &Crypto) -> crate::Result { + pub(crate) fn new(stream: S, crypto: &Crypto) -> Result { Ok(Self { stream, - cipher: make_cipher(crypto)?, + cipher: crypto.make_cipher()?, }) } /// write a `Message` to the stream - pub(crate) async fn write_message(&mut self, message: &M) -> crate::Result<()> { + pub(crate) async fn write_message(&mut self, message: &M) -> Result<()> { let len = message.encoded_len(); // get the length of the message self.write_u32(len as u32).await?; // write the length of the message @@ -53,7 +58,7 @@ impl CipherStream { } /// read a `Message` from the stream - pub(crate) async fn read_message(&mut self) -> crate::Result { + pub(crate) async fn read_message(&mut self) -> Result { let len = self.read_u32().await? as usize; // read the length of the message let mut buffer = vec![0; len]; // create a buffer to read the message into @@ -87,15 +92,59 @@ impl CipherStream { } } -pub(crate) fn make_cipher(crypto: &Crypto) -> crate::Result> { - let cipher: Cipher = crypto.cipher.try_into()?; - let key = &crypto.key[..cipher.key_length()]; - let iv = &crypto.iv[..cipher.iv_length()]; - - Ok(match cipher { - Cipher::Aes128 => Box::new(Ctr128BE::::new(key.into(), iv.into())), - Cipher::Aes256 => Box::new(Ctr128BE::::new(key.into(), iv.into())), - Cipher::Chacha8 => Box::new(ChaCha8::new(key.into(), iv.into())), - Cipher::Chacha20 => Box::new(ChaCha20::new(key.into(), iv.into())), - }) +struct NoCipher; + +impl StreamCipherWrapper for NoCipher { + #[inline(always)] + fn seek(&mut self, _index: u64) {} + #[inline(always)] + fn apply_keystream(&mut self, _data: &mut [u8]) {} +} + +impl Crypto { + /// deterministically derive a new iv from the given iv + pub(crate) fn next_iv(&mut self) { + if self.cipher == i32::from(Cipher::None) { + return; + } + + // create a seed from the first 8 bytes of the iv + let seed = u64::from_be_bytes(self.iv[..8].try_into().unwrap()); + // create a random number generator from the seed + let mut rng = StdRng::seed_from_u64(seed); + let mut bytes = vec![0; self.iv.len()]; // buffer for new iv + rng.fill_bytes(&mut bytes); // fill the buffer with random bytes + self.iv = bytes; // set the new iv + } + + /// randomize the iv + pub(crate) fn random_iv(&mut self) { + if self.cipher == i32::from(Cipher::None) { + return; + } + + OsRng.fill_bytes(&mut self.iv); + } + + /// create a new cipher + pub(crate) fn make_cipher(&self) -> Result> { + let cipher: Cipher = self.cipher.try_into()?; + let key = &self.key[..cipher.key_length()]; + let iv = &self.iv[..cipher.iv_length()]; + + Ok(match cipher { + Cipher::None => Box::new(NoCipher), + Cipher::Aes128 => Box::new(Ctr128BE::::new(key.into(), iv.into())), + Cipher::Aes192 => Box::new(Ctr128BE::::new(key.into(), iv.into())), + Cipher::Aes256 => Box::new(Ctr128BE::::new(key.into(), iv.into())), + Cipher::Chacha8 => Box::new(ChaCha8::new(key.into(), iv.into())), + Cipher::Chacha20 => Box::new(ChaCha20::new(key.into(), iv.into())), + }) + } +} + +pub(crate) fn random_bytes(len: usize) -> Vec { + let mut bytes = vec![0; len]; + OsRng.fill_bytes(&mut bytes); + bytes } diff --git a/src/error.rs b/src/error.rs index aa1f132..bc9b56d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -26,7 +26,9 @@ pub(crate) enum ErrorKind { #[cfg(unix)] Nix(nix::Error), StripPrefix(std::path::StripPrefixError), - Ssh(async_ssh2_tokio::Error), + AsyncSsh(async_ssh2_tokio::Error), + RuSsh(russh::Error), + Base64Decode(base64::DecodeError), MissingQueue, MaxRetries, #[cfg(windows)] @@ -130,7 +132,23 @@ impl From for Error { impl From for Error { fn from(error: async_ssh2_tokio::Error) -> Self { Self { - kind: ErrorKind::Ssh(error), + kind: ErrorKind::AsyncSsh(error), + } + } +} + +impl From for Error { + fn from(error: russh::Error) -> Self { + Self { + kind: ErrorKind::RuSsh(error), + } + } +} + +impl From for Error { + fn from(error: base64::DecodeError) -> Self { + Self { + kind: ErrorKind::Base64Decode(error), } } } @@ -151,7 +169,9 @@ impl std::fmt::Display for Error { #[cfg(unix)] ErrorKind::Nix(ref error) => write!(f, "Nix error: {}", error), ErrorKind::StripPrefix(ref error) => write!(f, "StripPrefix error: {}", error), - ErrorKind::Ssh(ref error) => write!(f, "SSH error: {}", error), + ErrorKind::AsyncSsh(ref error) => write!(f, "SSH error: {}", error), + ErrorKind::RuSsh(ref error) => write!(f, "SSH error: {}", error), + ErrorKind::Base64Decode(ref error) => write!(f, "Base64 decode error: {}", error), ErrorKind::MissingQueue => write!(f, "Missing queue"), ErrorKind::MaxRetries => write!(f, "Max retries"), #[cfg(windows)] diff --git a/src/items.proto b/src/items.proto index 801f12c..75b0380 100644 --- a/src/items.proto +++ b/src/items.proto @@ -25,7 +25,7 @@ message FileDetail { string path = 1; // file path relative to the destination directory uint64 size = 2; // file size in bytes optional bytes signature = 3; // blake3 hash of file - optional Crypto crypto = 4; // encryption details + Crypto crypto = 4; // encryption details } message Crypto { @@ -35,10 +35,12 @@ message Crypto { } enum Cipher { - CHACHA8 = 0; - AES128 = 1; - CHACHA20 = 2; - AES256 = 3; + NONE = 0; + CHACHA8 = 1; + AES128 = 2; + AES192 = 3; + CHACHA20 =43; + AES256 = 5; } // the receiver already had these files @@ -81,4 +83,10 @@ message Failure { // signals the receiver that the sender won't start new transfers message Done { uint32 reason = 1; +} + +message Stats { + uint64 confirmed_packets = 1; + uint64 sent_packets = 2; + uint64 total_data = 3; } \ No newline at end of file diff --git a/src/items.rs b/src/items.rs index d1a815f..b533336 100644 --- a/src/items.rs +++ b/src/items.rs @@ -1,9 +1,17 @@ +use crate::TransferStats; use std::collections::HashMap; use std::fmt::Display; +use std::sync::atomic::Ordering::Relaxed; include!(concat!(env!("OUT_DIR"), "/cccp.items.rs")); impl Message { + pub(crate) fn manifest(manifest: &Manifest) -> Self { + Self { + message: Some(message::Message::Manifest(manifest.clone())), + } + } + pub(crate) fn start(id: u32) -> Self { Self { message: Some(message::Message::Start(Start { id })), @@ -49,16 +57,19 @@ impl Cipher { /// the length of the key in bytes pub(crate) fn key_length(&self) -> usize { match self { - Self::Chacha20 | Self::Chacha8 | Self::Aes256 => 32, + Self::None => 0, Self::Aes128 => 16, + Self::Aes192 => 24, + Self::Chacha20 | Self::Chacha8 | Self::Aes256 => 32, } } /// the length of the iv in bytes pub(crate) fn iv_length(&self) -> usize { match self { + Self::None => 0, Self::Chacha20 | Self::Chacha8 => 12, - Self::Aes256 | Self::Aes128 => 16, + Self::Aes256 | Self::Aes128 | Self::Aes192 => 16, } } } @@ -66,7 +77,9 @@ impl Cipher { impl Display for Cipher { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let cipher = match self { + Self::None => "NONE", Self::Aes128 => "AES128", + Self::Aes192 => "AES192", Self::Aes256 => "AES256", Self::Chacha8 => "CHACHA8", Self::Chacha20 => "CHACHA20", @@ -87,3 +100,13 @@ impl Manifest { self.files.is_empty() && self.directories.is_empty() } } + +impl Stats { + pub(crate) fn from(stats: &TransferStats) -> Self { + Self { + confirmed_packets: stats.confirmed_packets.load(Relaxed) as u64, + sent_packets: stats.sent_packets.load(Relaxed) as u64, + total_data: stats.total_data.load(Relaxed) as u64, + } + } +} diff --git a/src/main.rs b/src/main.rs index 083d109..61bb7d6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,30 +1,40 @@ #![feature(int_roundings)] use std::net::{IpAddr, SocketAddr}; +use std::ops::Not; use std::path::Path; -use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; +use std::sync::atomic::{AtomicBool, AtomicUsize}; use std::sync::Arc; use std::time::Duration; use async_ssh2_tokio::{AuthMethod, Client, ServerCheckMethod}; +use base64::engine::general_purpose::STANDARD_NO_PAD; +use base64::prelude::BASE64_STANDARD_NO_PAD; +use base64::Engine; use blake3::{Hash, Hasher}; use clap::{CommandFactory, Parser}; use futures::stream::iter; use futures::{StreamExt, TryStreamExt}; use indicatif::{ProgressBar, ProgressStyle}; use log::{debug, error, info, warn}; +use prost::Message; use rpassword::prompt_password; +use russh::{ChannelMsg, Sig}; use simple_logging::{log_to_file, log_to_stderr}; use tokio::fs::File; use tokio::io::{AsyncReadExt, BufReader}; -use tokio::net::{TcpListener, TcpSocket, TcpStream, UdpSocket}; -use tokio::time::{Instant, interval, sleep}; +use tokio::net::{TcpListener, TcpStream, UdpSocket}; +use tokio::sync::Notify; +use tokio::time::{interval, sleep, Instant, Interval}; use tokio::{io, select}; +use tokio::signal::ctrl_c; use crate::cipher::CipherStream; +use crate::error::Error; +use crate::items::Stats; -use crate::options::{Mode, Options}; +use crate::options::{Mode, Options, SetupMode}; mod cipher; mod error; @@ -34,7 +44,7 @@ mod receiver; mod sender; // result alias used throughout -type Result = std::result::Result; +type Result = std::result::Result; // read buffer must be a multiple of the transfer buffer to prevent a nasty little bug const READ_BUFFER_SIZE: usize = TRANSFER_BUFFER_SIZE * 100; @@ -43,17 +53,61 @@ const TRANSFER_BUFFER_SIZE: usize = 1024; const INDEX_SIZE: usize = std::mem::size_of::(); const ID_SIZE: usize = std::mem::size_of::(); const MAX_RETRIES: usize = 10; -const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5); // UDP header + ID + INDEX + DATA const PACKET_SIZE: usize = 8 + ID_SIZE + INDEX_SIZE + TRANSFER_BUFFER_SIZE; // how long to wait for a job to be confirmed before requeuing it const REQUEUE_INTERVAL: Duration = Duration::from_millis(1_000); -#[derive(Clone, Default)] +#[derive(Clone)] struct TransferStats { - confirmed_data: Arc, + confirmed_packets: Arc, + sent_packets: Arc, total_data: Arc, + start_time: Instant, + complete: Arc, +} + +impl Default for TransferStats { + fn default() -> Self { + Self { + confirmed_packets: Default::default(), + sent_packets: Default::default(), + total_data: Default::default(), + start_time: Instant::now(), + complete: Default::default(), + } + } +} + +impl TransferStats { + fn confirmed(&self) -> usize { + self.confirmed_packets.load(Relaxed) * TRANSFER_BUFFER_SIZE + } + + fn packet_loss(&self) -> f64 { + let sent = self.sent_packets.load(Relaxed); + let confirmed = self.confirmed_packets.load(Relaxed); + + if sent == 0 || sent < confirmed { + return 0_f64; + } + + let lost = sent - confirmed; + lost as f64 / sent as f64 + } + + fn total(&self) -> usize { + self.total_data.load(Relaxed) + } + + fn is_complete(&self) -> bool { + self.complete.load(Relaxed) + } + + fn speed(&self) -> f64 { + self.confirmed() as f64 / self.start_time.elapsed().as_secs_f64() / 1_000_000_f64 + } } #[tokio::main] @@ -61,6 +115,18 @@ async fn main() -> Result<()> { let mut options = Options::parse(); let mut command = Options::command(); + let cancel_signal = Arc::new(Notify::new()); + + tokio::spawn({ + let cancel_signal = cancel_signal.clone(); + + async move { + ctrl_c().await.expect("failed to listen for ctrl-c"); + debug!("ctrl-c received"); + cancel_signal.notify_waiters(); + } + }); + match options.mode { Mode::Local => { if let Some(path) = &options.log_file { @@ -114,164 +180,200 @@ async fn main() -> Result<()> { options.end_port = new_end; } - if options.destination.is_local() && options.source.is_local() { + let source_local = options.source.is_local(); + let destination_local = options.destination.is_local(); + + if source_local && destination_local { command .error( clap::error::ErrorKind::ValueValidation, "both the source and destination cannot be local", ) .exit(); + } else if !source_local && !destination_local { + debug!("switching ton controller mode"); + options.mode = Mode::Controller; } } - let sender = options.source.is_local(); let stats = TransferStats::default(); - let result = match options.mode { + match options.mode { Mode::Local => { + let sender = options.source.is_local(); + let (local, remote) = if sender { - (&options.source, &mut options.destination) + (&options.source, &options.destination) } else { - (&options.destination, &mut options.source) + (&options.destination, &options.source) }; - if remote.host.is_none() { - command - .error( - clap::error::ErrorKind::ValueValidation, - format!("host must be specified for remote IoSpec: {}", remote), - ) - .exit(); - } else if remote.username.is_none() { - remote.username = Some(whoami::username()); - } - debug!("local {}", local); debug!("remote {}", remote); - let mut auth_method = if let Ok(auth) = ssh_key_auth().await { - auth - } else { - password_auth()? - }; - - // unwrap is safe because we check for a host above + // unwrap is safe because host must be specified for remote IoSpec let remote_addr = remote.host.unwrap(); let remote_ip = remote_addr.ip(); - let client = loop { - match Client::connect( - remote_addr, - remote.username.as_ref().unwrap().as_str(), // unwrap is safe because we check for a username above - auth_method.clone(), - ServerCheckMethod::NoCheck, - ) - .await - { - Ok(client) => break client, - Err(error) => { - warn!("failed to connect to remote host: {}", error); - - match error { - async_ssh2_tokio::error::Error::KeyAuthFailed => { - info!("trying password auth"); - auth_method = password_auth().unwrap(); - } - _ => return Err(error.into()), - } - } - } - }; + let client = connect_client(remote_addr, &remote.username).await?; info!("connected to the remote host via ssh"); - let command_str = options.format_command(sender); - - let command_handle = tokio::spawn(async move { - info!("executing command on remote host"); - debug!("command: {}", command_str); - - client.execute(&command_str).await - }); + let command_handle = tokio::spawn(command_runner( + client, + options.format_command(sender, !options.stream_setup_mode), + sender.not().then_some(stats.sent_packets.clone()), + None, + None, + cancel_signal.clone(), + )); // receiver -> sender stream - let stream = - connect_stream(remote_ip, options.start_port, options.bind_address).await?; - let rts_stream = CipherStream::new(stream, &options.control_crypto)?; - + let rts_stream = connect_stream(remote_ip, options.start_port, &mut options).await?; // sender -> receiver stream - let stream = - connect_stream(remote_ip, options.start_port + 1, options.bind_address).await?; - let str_stream = CipherStream::new(stream, &options.control_crypto)?; + let str_stream = + connect_stream(remote_ip, options.start_port + 1, &mut options).await?; - let display_handle = tokio::spawn({ + let stats_handle = tokio::spawn({ let stats = stats.clone(); + let interval = interval(Duration::from_millis(options.progress_interval)); - print_progress(stats) + local_stats_printer(stats, interval) }); - let main_future = async { - if sender { - sender::main(options, stats, rts_stream, str_stream, remote_ip).await - } else { - receiver::main(options, stats, rts_stream, str_stream, remote_ip).await - } - }; + let main_future = run_main( + sender, + options, + stats.clone(), + rts_stream, + str_stream, + remote_ip, + ); let command_future = async { - let result = command_handle.await; - - match result { - Ok(Ok(result)) => { - if result.exit_status != 0 { - // return to terminate execution - error!("remote command failed: {:?}", result); - } else { - info!("remote client exited successfully"); - // wait forever to allow the other futures to complete - sleep(Duration::from_secs(u64::MAX)).await; - } + match command_handle.await?? { + Some(status) if status != 0 => { + error!("remote command failed with status {}", status) } - Ok(Err(error)) => error!("remote client failed: {}", error), // return to terminate execution - Err(error) => error!("failed to join remote command: {}", error), // return to terminate execution + None => error!("remote command failed to exit"), + _ => sleep(Duration::from_secs(u64::MAX)).await, // wait forever to allow the other futures to complete } + + Ok::<(), Error>(()) }; select! { - _ = command_future => Ok(()), - _ = display_handle => Ok(()), - result = main_future => result + result = command_future => result?, + result = main_future => result? } + + stats.complete.store(true, Relaxed); + stats_handle.await?; } Mode::Remote(sender) => { - // receiver -> sender stream - let listener = TcpListener::bind(("0.0.0.0", options.start_port)).await?; - let (stream, remote_addr) = listener.accept().await?; + let (rts_stream, str_stream, remote_addr) = match options.stream_setup_mode { + SetupMode::Listen => { + let (rts_stream, remote_addr) = + listen_stream(options.start_port, &mut options).await?; + let (str_stream, _) = + listen_stream(options.start_port + 1, &mut options).await?; + + (rts_stream, str_stream, remote_addr) + } + SetupMode::Connect => { + // unwrap is safe because host must be specified for remote IoSpec + let addr = options.destination.host.unwrap().ip(); - let rts_stream = CipherStream::new(stream, &options.control_crypto)?; + let rts_stream = connect_stream(addr, options.start_port, &mut options).await?; + let str_stream = + connect_stream(addr, options.start_port + 1, &mut options).await?; - // sender -> receiver stream - let listener = TcpListener::bind(("0.0.0.0", options.start_port + 1)).await?; - let (stream, _) = listener.accept().await?; + (rts_stream, str_stream, addr) + } + }; - let str_stream = CipherStream::new(stream, &options.control_crypto)?; + let stats_handle = tokio::spawn(remote_stats_printer(stats.clone())); - let remote_addr = remote_addr.ip(); + run_main( + sender, + options, + stats.clone(), + rts_stream, + str_stream, + remote_addr, + ) + .await?; - if sender { - sender::main(options, stats, rts_stream, str_stream, remote_addr).await - } else { - receiver::main(options, stats, rts_stream, str_stream, remote_addr).await - } + stats.complete.store(true, Relaxed); + stats_handle.await?; } - }; + Mode::Controller => { + // unwraps are safe because host must be specified for remote IoSpec + let sender_addr = options.source.host.unwrap(); + let receiver_addr = options.destination.host.unwrap(); + + let sender_client = connect_client(sender_addr, &options.source.username).await?; + let receiver_client = + connect_client(receiver_addr, &options.destination.username).await?; + + let sender_handle = tokio::spawn(command_runner( + sender_client, + options.format_command(false, SetupMode::Connect), // sender is inverted somewhat confusingly + Some(stats.sent_packets.clone()), + Some(stats.confirmed_packets.clone()), + Some(stats.total_data.clone()), + cancel_signal.clone(), + )); + + let receiver_handle = tokio::spawn(command_runner( + receiver_client, + options.format_command(true, SetupMode::Listen), + None, + None, + None, + cancel_signal.clone(), + )); + + let stats_handle = tokio::spawn({ + let stats = stats.clone(); + let interval = interval(Duration::from_millis(options.progress_interval)); + + local_stats_printer(stats, interval) + }); + + let sender_status = sender_handle.await??; + let receiver_status = receiver_handle.await??; - if let Err(error) = &result { - error!("{:?}", error); + if sender_status != Some(0) { + error!("sender command failed with status {:?}", sender_status); + } else if receiver_status != Some(0) { + error!("receiver command failed with status {:?}", receiver_status); + } + + stats.complete.store(true, Relaxed); + stats_handle.await?; + } } info!("exiting"); - result + Ok(()) +} + +/// selects the main function to run based on the mode +#[inline] +async fn run_main( + sender: bool, + options: Options, + stats: TransferStats, + rts_stream: CipherStream, + str_stream: CipherStream, + remote_addr: IpAddr, +) -> Result<()> { + if sender { + sender::main(options, stats, rts_stream, str_stream, remote_addr).await + } else { + receiver::main(options, stats, rts_stream, str_stream, remote_addr).await + } } /// opens the sockets that will be used to send data @@ -281,24 +383,58 @@ async fn socket_factory( remote_addr: IpAddr, threads: usize, ) -> io::Result> { - let bind_addr: IpAddr = "0.0.0.0".parse().unwrap(); - iter(start..=end) - .map(|port| { - let local_addr = SocketAddr::new(bind_addr, port); - let remote_addr = SocketAddr::new(remote_addr, port); - - async move { - let socket = UdpSocket::bind(local_addr).await?; - socket.connect(remote_addr).await?; - Ok::(socket) - } + .map(|port| async move { + let socket = UdpSocket::bind(("0.0.0.0", port)).await?; + socket.connect((remote_addr, port)).await?; + Ok::(socket) }) .buffer_unordered(threads) .try_collect() .await } +/// connects to a remote client via ssh +async fn connect_client(remote_addr: SocketAddr, username: &str) -> Result { + let mut auth_method = get_auth(&remote_addr).await?; + + loop { + match Client::connect( + remote_addr, + username, + auth_method, + ServerCheckMethod::NoCheck, + ) + .await + { + Ok(client) => break Ok(client), + Err(error) => match error { + async_ssh2_tokio::error::Error::KeyAuthFailed => { + warn!("ssh key auth failed"); + auth_method = password_auth(&remote_addr)?; + } + async_ssh2_tokio::error::Error::PasswordWrong => { + error!("invalid password"); + auth_method = password_auth(&remote_addr)?; + } + _ => return Err(error.into()), + }, + } + } +} + +/// select an auth method +async fn get_auth(host: &SocketAddr) -> io::Result { + let mut auth = ssh_key_auth().await; + + if auth.is_err() { + warn!("unable to load ssh key"); + auth = password_auth(host); + } + + auth +} + /// try to get an ssh key for authentication async fn ssh_key_auth() -> io::Result { // get the home directory of the current user @@ -321,61 +457,152 @@ async fn ssh_key_auth() -> io::Result { } /// prompt the user for a password -fn password_auth() -> io::Result { - let password = prompt_password("password: ")?; +fn password_auth(host: &SocketAddr) -> io::Result { + let password = prompt_password(format!("{} password: ", host))?; Ok(AuthMethod::with_password(&password)) } -/// print a progress bar to stdout -async fn print_progress(stats: TransferStats) { +/// print a progress bar & stats to stdout +async fn local_stats_printer(stats: TransferStats, mut interval: Interval) { let bar = ProgressBar::new(100); - let mut interval = interval(Duration::from_secs(1)); - let now = Instant::now(); bar.set_style( ProgressStyle::default_bar() - .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos:>7}% {msg}MB/s ({eta})") + .template( + "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}% [{msg}] ({eta})", + ) .unwrap() .progress_chars("=>-"), ); - loop { + while !stats.is_complete() { interval.tick().await; - let progress = - stats.confirmed_data.load(Relaxed) as f64 / stats.total_data.load(Relaxed) as f64; - let speed = stats.confirmed_data.load(Relaxed) as f64 / now.elapsed().as_secs_f64() / 1_000_000_f64; - bar.set_message(format!("{:.2}", speed)); - bar.set_position((progress * 100_f64) as u64); + + let progress = stats.confirmed() as f64 / stats.total() as f64 * 100_f64; + let speed = stats.speed(); + let packet_loss = stats.packet_loss(); + + bar.set_message(format!( + "{:.1}MB/s {:.1}% packet loss", + speed, + packet_loss * 100_f64 + )); + + bar.set_position(progress as u64); } + + bar.finish_with_message("complete"); } -/// connect to a remote client on a given port -async fn connect_stream( - remote_addr: IpAddr, - port: u16, - bind_addr: Option, -) -> Result { - let bind = match bind_addr { - Some(addr) => SocketAddr::new(addr, 0), - None => "0.0.0.0:0".parse()?, - }; +/// prints a base64 encoded stats message to stdout +async fn remote_stats_printer(stats: TransferStats) { + let mut interval = interval(Duration::from_secs(1)); + + while !stats.is_complete() { + interval.tick().await; + + let stats = Stats::from(&stats); // create a Stats message + // allocate a buffer for the message + let mut buf = Vec::with_capacity(stats.encoded_len()); + stats.encode(&mut buf).unwrap(); // infallible + let encoded = BASE64_STANDARD_NO_PAD.encode(&buf); // base64 encode the message + println!("{}", encoded); // print the encoded message + } +} + +/// runs a command on the remote host & handles the output +async fn command_runner( + client: Client, + command: String, + sent_packets: Option>, + confirmed_packets: Option>, + total_data: Option>, + cancel_signal: Arc, +) -> Result> { + debug!("executing command: {}", command); + + let mut channel = client.get_channel().await?; + let mut status: Option = None; + channel.exec(true, command).await?; - // connect to the remote client loop { - let socket = TcpSocket::new_v4()?; - socket.bind(bind)?; + select! { + _ = cancel_signal.notified() => { + debug!("cancel signal received"); + channel.signal(Sig::INT).await?; + debug!("sent INT signal"); + break; + } + message = channel.wait() => { + if let Some(message) = message { + match message { + ChannelMsg::Data { ref data } => { + let message = String::from_utf8_lossy(data).replace('\n', ""); + let buffer = STANDARD_NO_PAD.decode(message)?; + let stats = Stats::decode(&buffer[..])?; + + if let Some(sent_packets) = &sent_packets { + sent_packets.store(stats.sent_packets as usize, Relaxed); + } + + if let Some(confirmed_packets) = &confirmed_packets { + confirmed_packets.store(stats.confirmed_packets as usize, Relaxed); + } + + if let Some(total_data) = &total_data { + total_data.store(stats.total_data as usize, Relaxed); + } + } + ChannelMsg::ExtendedData { ref data, ext: 1 } => { + error!("remote stderr: {}", String::from_utf8_lossy(data)) + } + ChannelMsg::ExitStatus { exit_status } => status = Some(exit_status), + _ => {} + } + } else { + break + } + } + } + } - let remote_socket = SocketAddr::new(remote_addr, port); + debug!("command runner finished with status {:?}", status); + Ok(status) +} - if let Ok(stream) = socket.connect(remote_socket).await { - break Ok(stream); +/// connects to a listening remote client +async fn connect_stream( + remote_addr: IpAddr, + port: u16, + options: &mut Options, +) -> Result> { + let tcp_stream = loop { + if let Ok(stream) = TcpStream::connect((remote_addr, port)).await { + break stream; } else { - // give the receiver time to start listening + // give the listener time to start listening sleep(Duration::from_millis(100)).await; } - } + }; + + let stream = CipherStream::new(tcp_stream, &options.control_crypto)?; + options.control_crypto.next_iv(); + Ok(stream) +} + +/// listens for a remote client to connect +async fn listen_stream( + port: u16, + options: &mut Options, +) -> Result<(CipherStream, IpAddr)> { + let listener = TcpListener::bind(("0.0.0.0", port)).await?; + let (stream, remote_addr) = listener.accept().await?; + let stream = CipherStream::new(stream, &options.control_crypto)?; + options.control_crypto.next_iv(); + Ok((stream, remote_addr.ip())) } +/// calculate the BLAKE3 hash of a file async fn hash_file>(path: P) -> io::Result { let file = File::open(path).await?; let mut reader = BufReader::with_capacity(READ_BUFFER_SIZE, file); diff --git a/src/options.rs b/src/options.rs index fde988d..2f96c34 100644 --- a/src/options.rs +++ b/src/options.rs @@ -1,10 +1,10 @@ use std::error::Error; use std::fmt::{Display, Formatter}; -use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; +use std::net::{SocketAddr, ToSocketAddrs}; +use std::ops::Not; use std::path::PathBuf; use std::str::FromStr; -use aes::cipher::crypto_common::rand_core::{OsRng, RngCore}; use base64::engine::general_purpose::STANDARD_NO_PAD; use base64::Engine; use bytesize::ByteSize; @@ -13,9 +13,11 @@ use log::LevelFilter; use regex::Regex; use tokio::io; +use crate::cipher::random_bytes; use crate::items::{Cipher, Crypto}; use crate::PACKET_SIZE; +// TODO add a help item for firewall stuff const HELP_HEADING: &str = "\x1B[1m\x1B[4mAbout\x1B[0m cccp is a fast, secure, and reliable file transfer utility @@ -23,42 +25,48 @@ const HELP_HEADING: &str = "\x1B[1m\x1B[4mAbout\x1B[0m - [user@][host:{port:}]file - If no user is set for a remote host, the current user is used - If no port is provided, port 22 is used - - Either the InSpec or OutSpec should be remote, not both or neither + - At least one IoSpec should be remote \x1B[1m\x1B[4mCiphers\x1B[0m + - NONE - CHACHA8 - CHAHA20 - AES128 + - AES192 - AES256"; -#[derive(Parser, Clone, Debug)] +#[derive(Parser)] #[clap(version, about = HELP_HEADING)] pub(crate) struct Options { // the user does not need to set this - #[clap(long, hide = true, default_value = "local")] + #[clap(long, hide = true, default_value = "l")] pub(crate) mode: Mode, - /// The first port to use + // set by the controller automatically + #[clap(long, hide = true, default_value = "c")] + pub(crate) stream_setup_mode: SetupMode, + + /// First port to use #[clap(short, long, default_value_t = 50000)] pub(crate) start_port: u16, - /// The last port to use + /// Last port to use #[clap(short, long, default_value_t = 50009)] pub(crate) end_port: u16, - /// The number of threads to use + /// Parallel data streams #[clap(short, long, default_value_t = 8)] pub(crate) threads: usize, - /// The log level [debug, info, warn, error] + /// Log level [debug, info, warn, error] #[clap(short, long, default_value = "warn")] pub(crate) log_level: LevelFilter, - /// The rate to send data at [b, kb, mb, gb, tb] + /// Data send rate [b, kb, mb, gb, tb] #[clap(short, long, default_value = "1mb")] rate: ByteSize, - /// The maximum number of concurrent transfers + /// Maximum concurrent transfers #[clap(short, long, default_value_t = 100)] pub(crate) max: usize, @@ -66,6 +74,26 @@ pub(crate) struct Options { #[clap(short, long, default_value = "AES256")] pub(crate) control_crypto: Crypto, + /// Encrypt the data stream + #[clap(short = 'S', long, default_value = "NONE")] + pub(crate) stream_crypto: Crypto, + + /// Receive timeout in MS + #[clap(short = 'T', long, default_value_t = 5_000)] + pub(crate) receive_timeout: u64, + + /// Limit for concurrent jobs + #[clap(short, long, default_value_t = 1_000)] + pub(crate) job_limit: usize, + + /// Command to execute cccp + #[clap(short = 'E', long, default_value = "cccp")] + command: String, + + /// How often to print progress in MS + #[clap(short, long, default_value_t = 1_000)] + pub(crate) progress_interval: u64, + /// Verify integrity of transfers using blake3 #[clap(short, long)] pub(crate) verify: bool, @@ -74,83 +102,87 @@ pub(crate) struct Options { #[clap(short, long)] pub(crate) overwrite: bool, - /// Include subdirectories and files recursively + /// Include subdirectories recursively #[clap(short = 'R', long)] pub(crate) recursive: bool, - /// Force the transfer even if the there is not enough space + /// Do not check destination's available storage #[clap(short, long)] pub(crate) force: bool, - /// Optionally encrypt the data stream - #[clap(short = 'S', long)] - pub(crate) stream_crypto: Option, - - /// Manually specify the bind address + /// Forces the destination to be a directory #[clap(short, long)] - pub(crate) bind_address: Option, + pub(crate) directory: bool, - /// Log to a file (default: stderr) + /// Log to a file (default: stderr / local only) #[clap(short = 'L', long)] pub(crate) log_file: Option, - /// The source IoSpec (InSpec) + /// Source IoSpec (InSpec) #[clap()] pub(crate) source: IoSpec, - /// The destination IoSpec (OutSpec) + /// Destination IoSpec (OutSpec) #[clap()] pub(crate) destination: IoSpec, } impl Options { - pub(crate) fn format_command(&self, sender: bool) -> String { + /// Returns the command to run on the remote host + pub(crate) fn format_command(&self, sender: bool, mode: SetupMode) -> String { let mut arguments = vec![ - String::from("cccp"), + self.command.clone(), format!("--mode {}", if sender { "rr" } else { "rs" }), + format!("--stream-setup-mode {}", mode), format!("-s {}", self.start_port), format!("-e {}", self.end_port), format!("-t {}", self.threads), format!("-l {}", self.log_level), format!("-r \"{}\"", self.rate), - format!("--control-crypto {}", self.control_crypto), + format!("-m {}", self.max), + format!("-T {}", self.receive_timeout), + format!("-j {}", self.job_limit), + format!("-c {}", self.control_crypto), + format!("-S {}", self.stream_crypto), + format!("\"{}\"", self.source), + format!("\"{}\"", self.destination), ]; - if let Some(ref crypto) = self.stream_crypto { - arguments.push(format!(" --stream-crypto {}", crypto)) - } - + // optional arguments are inserted at index 1 so they are always before the IoSpecs if self.overwrite { - arguments.push(String::from("-o")) + arguments.insert(1, String::from("-o")) } if self.verify { - arguments.push(String::from("-v")) + arguments.insert(1, String::from("-v")) } if self.recursive { - arguments.push(String::from("-R")) + arguments.insert(1, String::from("-R")) } if self.force { - arguments.push(String::from("-f")) + arguments.insert(1, String::from("-f")) } - arguments.push(format!("\"{}\"", self.source)); - arguments.push(format!("\"{}\"", self.destination)); + if self.directory { + arguments.insert(1, String::from("-d")) + } arguments.join(" ") } + /// Calculates the send rate in packets per second pub(crate) fn pps(&self) -> u64 { self.rate.0 / PACKET_SIZE as u64 } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Clone, PartialEq)] pub(crate) enum Mode { Local, Remote(bool), // Remote(sender) + Controller, // two remotes } impl FromStr for Mode { @@ -161,14 +193,54 @@ impl FromStr for Mode { "l" => Self::Local, "rr" => Self::Remote(false), "rs" => Self::Remote(true), - "local" => Self::Local, - "remote-receiver" => Self::Remote(false), - "remote-sender" => Self::Remote(true), + "c" => Self::Controller, _ => return Err(Self::Err::unknown_mode()), }) } } +#[derive(Copy, Clone)] +pub(crate) enum SetupMode { + Listen, + Connect, +} + +impl FromStr for SetupMode { + type Err = OptionParseError; + + fn from_str(s: &str) -> Result { + Ok(match s { + "l" => Self::Listen, + "c" => Self::Connect, + _ => return Err(Self::Err::unknown_mode()), + }) + } +} + +impl Display for SetupMode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::Listen => "l", + Self::Connect => "c", + } + ) + } +} + +impl Not for SetupMode { + type Output = Self; + + fn not(self) -> Self::Output { + match self { + Self::Listen => Self::Connect, + Self::Connect => Self::Listen, + } + } +} + impl FromStr for Crypto { type Err = OptionParseError; @@ -222,11 +294,11 @@ impl Display for Crypto { } /// a file located anywhere -#[derive(Clone, Debug)] +#[derive(Clone)] pub(crate) struct IoSpec { pub(crate) file_path: PathBuf, pub(crate) host: Option, - pub(crate) username: Option, + pub(crate) username: String, } impl FromStr for IoSpec { @@ -236,9 +308,13 @@ impl FromStr for IoSpec { let captures = Regex::new("(?:(?:([\\w-]+)@)?([\\w.-]+)(?::(\\d+))?:)?([ \\w/.-]+)") .unwrap() // infallible .captures(s) - .ok_or(Self::Err::malformed_io_spec("Invalid IO spec"))?; + .ok_or(Self::Err::invalid_io_spec())?; + + let username = captures + .get(1) + .map(|m| m.as_str().to_string()) + .unwrap_or(whoami::username()); - let username = captures.get(1).map(|m| m.as_str().to_string()); let port = captures .get(3) .map(|m| m.as_str().parse::()) @@ -255,9 +331,9 @@ impl FromStr for IoSpec { Ok(ip) => Ok(SocketAddr::new(ip, port)), Err(_) => { // if the host is not an ip address, try to resolve it as a domain - format!("{}:{}", host, port) + (host, port) .to_socket_addrs()? - .next() + .next() // use the first address .ok_or(Self::Err::no_such_host()) } } @@ -283,18 +359,19 @@ impl Display for IoSpec { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let file_path = self.file_path.display(); - match (&self.username, &self.host) { - (None, None) => write!(f, "{}", file_path), - (None, Some(host)) => write!(f, "{}:{}", host, file_path), - (Some(username), Some(host)) => write!(f, "{}@{}:{}", username, host, file_path), - _ => Err(std::fmt::Error), + match self.host { + None => write!(f, "{}@localhost:{}", self.username, file_path), + Some(host) => write!(f, "{}@{}:{}", self.username, host, file_path), } } } impl IoSpec { pub(crate) fn is_local(&self) -> bool { - self.host.is_none() || (self.host.is_some() && self.file_path.exists()) + match self.host { + None => true, + Some(host) => host.ip().is_loopback(), + } } } @@ -307,7 +384,7 @@ pub struct OptionParseError { enum ErrorKind { Io(io::Error), Decode(base64::DecodeError), - MalformedIoSpec(&'static str), + InvalidIoSpec, UnknownMode, InvalidCipher, InvalidKey, @@ -325,13 +402,13 @@ impl Display for OptionParseError { match &self.kind { ErrorKind::Io(error) => error.description(), ErrorKind::Decode(error) => error.description(), - ErrorKind::MalformedIoSpec(message) => message, - ErrorKind::UnknownMode => "The mode can be either sender or receiver", - ErrorKind::InvalidCipher => "Invalid cipher", - ErrorKind::InvalidKey => "Invalid key", - ErrorKind::InvalidIv => "Invalid IV", - ErrorKind::NoSuchHost => "No such host", - ErrorKind::InvalidPort => "Invalid port", + ErrorKind::InvalidIoSpec => "invalid IoSpec, refer to --help for more information", + ErrorKind::UnknownMode => "the mode can be either sender or receiver", + ErrorKind::InvalidCipher => "invalid cipher", + ErrorKind::InvalidKey => "invalid key", + ErrorKind::InvalidIv => "invalid IV", + ErrorKind::NoSuchHost => "no such host", + ErrorKind::InvalidPort => "invalid port", } ) } @@ -374,9 +451,9 @@ impl OptionParseError { } } - fn malformed_io_spec(message: &'static str) -> Self { + fn invalid_io_spec() -> Self { Self { - kind: ErrorKind::MalformedIoSpec(message), + kind: ErrorKind::InvalidIoSpec, } } @@ -398,9 +475,3 @@ impl OptionParseError { } } } - -fn random_bytes(len: usize) -> Vec { - let mut bytes = vec![0; len]; - OsRng.fill_bytes(&mut bytes); - bytes -} diff --git a/src/receiver/mod.rs b/src/receiver/mod.rs index 5db4283..f4fec35 100644 --- a/src/receiver/mod.rs +++ b/src/receiver/mod.rs @@ -20,11 +20,11 @@ use tokio::task::JoinHandle; use tokio::time::{interval, timeout}; use crate::error::Error; -use crate::items::{message, ConfirmationIndexes, Manifest, Message, StartIndex}; +use crate::items::{message, ConfirmationIndexes, Message, StartIndex}; use crate::receiver::writer::{writer, FileDetails, SplitQueue}; use crate::{ socket_factory, CipherStream, Options, Result, TransferStats, ID_SIZE, INDEX_SIZE, MAX_RETRIES, - RECEIVE_TIMEOUT, TRANSFER_BUFFER_SIZE, + TRANSFER_BUFFER_SIZE, }; mod writer; @@ -46,14 +46,27 @@ pub(crate) async fn main( ) -> Result<()> { info!("receiving {} -> {}", options.source, options.destination); - let manifest: Manifest = str_stream.read_message().await?; - let is_dir = manifest.files.len() > 1; // if multiple files are being received, the destination should be a directory - debug!( - "received manifest | files={} dirs={}", - manifest.files.len(), - manifest.directories.len() - ); + let message: Message = str_stream.read_message().await?; + + let manifest = match message.message { + Some(message::Message::Manifest(manifest)) => { + debug!( + "received manifest | files={} dirs={}", + manifest.files.len(), + manifest.directories.len() + ); + manifest + } + Some(message::Message::Done(done)) if done.reason == 0 => { + warn!("remote client found no files to send"); + return Ok(()); + } + _ => return Err(Error::unexpected_message(Box::new(message))), + }; + + // if multiple files are being received, the destination should be a directory + let is_dir = options.directory || manifest.files.len() > 1; let mut completed = Vec::new(); let filtered_files = manifest.files.into_iter().filter_map(|(id, details)| { @@ -105,7 +118,7 @@ pub(crate) async fn main( size: details.size, start_index, signature: details.signature, - crypto: details.crypto, + crypto: details.crypto.unwrap_or_default(), }, )) } @@ -183,7 +196,7 @@ pub(crate) async fn main( let confirmation_handle = tokio::spawn(send_confirmations( message_sender.clone(), confirmation_receiver, - stats.confirmed_data, + stats.confirmed_packets.clone(), )); let controller_handle = tokio::spawn(controller( @@ -194,9 +207,11 @@ pub(crate) async fn main( message_sender, )); + let receive_timeout = Duration::from_millis(options.receive_timeout); + let handles: Vec<_> = sockets .into_iter() - .map(|socket| tokio::spawn(receiver(writer_queue.clone(), socket))) + .map(|socket| tokio::spawn(receiver(writer_queue.clone(), socket, receive_timeout))) .collect(); let receiver_future = async { @@ -211,16 +226,16 @@ pub(crate) async fn main( result = confirmation_handle => { debug!("confirmation sender exited: {:?}", result); result? }, result = controller_handle => { debug!("controller exited: {:?}", result); result? }, result = receiver_future => { debug!("receivers exited: {:?}", result); result }, - result = message_sender_handle => { debug!("message sender exited: {:?}", result); result? } + result = message_sender_handle => { debug!("message sender exited: {:?}", result); result? }, } } -async fn receiver(queue: WriterQueue, socket: UdpSocket) -> Result<()> { +async fn receiver(queue: WriterQueue, socket: UdpSocket, receive_timeout: Duration) -> Result<()> { let mut buf = [0; ID_SIZE + INDEX_SIZE + TRANSFER_BUFFER_SIZE]; // buffer for receiving data let mut retries = 0; // counter to keep track of retries while retries < MAX_RETRIES { - match timeout(RECEIVE_TIMEOUT, socket.recv(&mut buf)).await { + match timeout(receive_timeout, socket.recv(&mut buf)).await { Ok(Ok(read)) if read > 0 => { retries = 0; // reset retries @@ -235,7 +250,14 @@ async fn receiver(queue: WriterQueue, socket: UdpSocket) -> Result<()> { } } Ok(Ok(_)) => warn!("0 byte read?"), // this should never happen - Ok(Err(_)) | Err(_) => retries += 1, // catch errors and timeouts + Ok(Err(error)) => { + retries += 1; + error!("recv error: {:?}", error); + } + Err(timeout) => { + retries += 1; + error!("recv timeout: {:?}", timeout); + } } } @@ -302,7 +324,7 @@ async fn controller( } Some(message::Message::Done(message)) => { match message.reason { - 0 => warn!("remote client found no files to send"), + // 0 => warn!("remote client found no files to send"), 1 => debug!("all transfers were completed before execution"), 2 => debug!("remote client completed all transfers"), _ => break Err(Error::unexpected_message(Box::new(message))), @@ -319,7 +341,7 @@ async fn controller( async fn send_confirmations( sender: AsyncSender, confirmation_receiver: AsyncReceiver<(u32, u64)>, - confirmed_data: Arc, + confirmed_packets: Arc, ) -> Result<()> { let data: Arc>> = Default::default(); @@ -358,7 +380,7 @@ async fn send_confirmations( let future = async { while let Ok(confirmation) = confirmation_receiver.recv().await { - confirmed_data.fetch_add(TRANSFER_BUFFER_SIZE, Relaxed); // increment the confirmed data counter + confirmed_packets.fetch_add(1, Relaxed); // increment the confirmed counter data.lock().await.push(confirmation); // push the index to the data vector } }; @@ -382,6 +404,7 @@ async fn send_messages( Ok(()) } +/// returns the amount of free space in bytes for the given path #[cfg(unix)] fn free_space(path: &Path) -> Result { use nix::sys::statvfs::statvfs; @@ -393,6 +416,7 @@ fn free_space(path: &Path) -> Result { Ok(stat.blocks_available() as u64 * stat.fragment_size()) } +/// returns the amount of free space in bytes for the given path #[cfg(windows)] fn free_space(path: &Path) -> Result { use widestring::U16CString; @@ -436,6 +460,7 @@ fn format_path(path: &Path) -> Result { Ok(path) } +/// returns the start index of the file, if it exists async fn start_index(path: &Path) -> Result { if path.exists() { let metadata = metadata(&path).await?; diff --git a/src/receiver/writer.rs b/src/receiver/writer.rs index 4d41948..6dac8da 100644 --- a/src/receiver/writer.rs +++ b/src/receiver/writer.rs @@ -9,7 +9,7 @@ use tokio::fs::{remove_file, rename, OpenOptions}; use tokio::io::{self, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter}; use tokio::sync::Mutex; -use crate::cipher::{make_cipher, StreamCipherWrapper}; +use crate::cipher::StreamCipherWrapper; use crate::error::Error; use crate::items::{Crypto, Message}; use crate::receiver::{Job, WriterQueue}; @@ -59,7 +59,7 @@ pub(crate) struct FileDetails { pub(crate) size: u64, pub(crate) start_index: u64, pub(crate) signature: Option>, - pub(crate) crypto: Option, + pub(crate) crypto: Crypto, } impl FileDetails { @@ -86,11 +86,8 @@ pub(crate) async fn writer( let mut writer = BufWriter::with_capacity(WRITE_BUFFER_SIZE, file); writer.seek(SeekFrom::Start(position)).await?; // seek to the initial position - let mut cipher = details.crypto.as_ref().map(make_cipher).transpose()?; - - if let Some(ref mut cipher) = cipher { - cipher.seek(position); - } + let mut cipher = details.crypto.make_cipher()?; + cipher.seek(position); debug!( "writer for {} starting at {}", @@ -177,14 +174,12 @@ async fn write_data( mut buffer: [u8; TRANSFER_BUFFER_SIZE], position: &mut u64, file_size: u64, - cipher: &mut Option>, + cipher: &mut Box, ) -> io::Result<()> { // calculate the length of the data to write let len = (file_size - *position).min(TRANSFER_BUFFER_SIZE as u64); - if let Some(ref mut cipher) = cipher { - cipher.apply_keystream(&mut buffer[..len as usize]); - } + cipher.apply_keystream(&mut buffer[..len as usize]); // apply the keystream *position += len; // advance the position writer.write_all(&buffer[..len as usize]).await // write the data diff --git a/src/sender/mod.rs b/src/sender/mod.rs index f0770a7..7ac2198 100644 --- a/src/sender/mod.rs +++ b/src/sender/mod.rs @@ -7,7 +7,6 @@ use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::time::Duration; -use aes::cipher::crypto_common::rand_core::{OsRng, RngCore}; use futures::stream::iter; use futures::{StreamExt, TryStreamExt}; use kanal::{AsyncReceiver, AsyncSender}; @@ -18,7 +17,6 @@ use tokio::select; use tokio::sync::{Mutex, RwLock, Semaphore}; use tokio::time::{interval, Instant}; -use crate::cipher::make_cipher; use crate::error::Error; use crate::items::{message, Confirmations, FileDetail, Manifest, Message, StartIndex}; use crate::sender::reader::reader; @@ -56,7 +54,9 @@ pub(crate) async fn main( } debug!("sending manifest"); - str_stream.write_message(&manifest).await?; + str_stream + .write_message(&Message::manifest(&manifest)) + .await?; let message: Message = str_stream.read_message().await?; @@ -93,7 +93,7 @@ pub(crate) async fn main( info!("opened sockets"); - // the reader fills the queue to 1_000 jobs, the unlimited capacity allows unconfirmed jobs to be added instantly + // the reader fills the queue to `options.sender_limit` jobs, the unlimited capacity allows unconfirmed jobs to be added instantly let (job_sender, job_receiver) = kanal::unbounded_async(); // a cache for the file chunks that have been sent but not confirmed let cache: JobCache = Default::default(); @@ -101,7 +101,7 @@ pub(crate) async fn main( // a semaphore to control the send rate let send = Arc::new(Semaphore::new(0)); // a semaphore to control the readers - let read = Arc::new(Semaphore::new(1_000)); + let read = Arc::new(Semaphore::new(options.job_limit)); // just confirmation messages let (confirmation_sender, confirmation_receiver) = kanal::unbounded_async(); @@ -119,22 +119,12 @@ pub(crate) async fn main( confirmation_receiver, cache.clone(), job_sender.clone(), - stats.confirmed_data.clone(), + stats.confirmed_packets.clone(), read.clone(), )); tokio::spawn(add_permits_at_rate(send.clone(), options.pps())); - let controller_handle = tokio::spawn(controller( - options, - str_stream, - manifest.files, - job_sender.clone(), - read, - stats.confirmed_data, - controller_receiver, - )); - let handles: Vec<_> = sockets .into_iter() .map(|socket| { @@ -144,10 +134,21 @@ pub(crate) async fn main( socket, cache.clone(), send.clone(), + stats.sent_packets.clone(), )) }) .collect(); + let controller_handle = tokio::spawn(controller( + options, + str_stream, + manifest.files, + job_sender.clone(), + read, + stats, + controller_receiver, + )); + let sender_future = async { for handle in handles { handle.await??; // propagate errors @@ -171,6 +172,7 @@ async fn sender( socket: UdpSocket, cache: JobCache, send: Arc, + sent: Arc, ) -> Result<()> { let mut retries = 0; @@ -180,7 +182,7 @@ async fn sender( // send the job data to the socket if let Err(error) = socket.send(&job.data).await { - error!("failed to send data: {}", error); + error!("send error: {}", error); job_sender.send(job).await?; // put the job back in the queue retries += 1; } else { @@ -188,6 +190,7 @@ async fn sender( job.cached_at = Some(Instant::now()); cache.write().await.insert((job.id, job.index), job); retries = 0; + sent.fetch_add(1, Relaxed); } permit.forget(); @@ -206,7 +209,7 @@ async fn controller( mut files: HashMap, job_sender: AsyncSender, read: Arc, - confirmed_data: Arc, + stats: TransferStats, controller_receiver: AsyncReceiver, ) -> Result<()> { let mut id = 0; @@ -224,7 +227,7 @@ async fn controller( &options.source.file_path, &job_sender, &read, - &confirmed_data, + &stats.total_data, ) .await?; @@ -253,7 +256,7 @@ async fn controller( failure.id ); - confirmed_data.fetch_sub(details.size as usize, Relaxed); + stats.total_data.fetch_add(details.size as usize, Relaxed); start_file_transfer( &mut control_stream, @@ -262,7 +265,7 @@ async fn controller( &options.source.file_path, &job_sender, &read, - &confirmed_data, + &stats.total_data, ) .await?; } else { @@ -306,14 +309,14 @@ async fn start_file_transfer( base_path: &Path, job_sender: &AsyncSender, read: &Arc, - confirmed_data: &Arc, + total_data: &Arc, ) -> Result<()> { control_stream.write_message(&Message::start(id)).await?; let start_index: StartIndex = control_stream.read_message().await?; - confirmed_data.fetch_add(start_index.index as usize, Relaxed); + total_data.fetch_sub(start_index.index as usize, Relaxed); - let cipher = details.crypto.as_ref().map(make_cipher).transpose()?; + let cipher = details.crypto.as_ref().unwrap().make_cipher()?; tokio::spawn({ let job_sender = job_sender.clone(); @@ -343,7 +346,7 @@ async fn receive_confirmations( confirmation_receiver: AsyncReceiver, cache: JobCache, job_sender: AsyncSender, - confirmed_data: Arc, + confirmed_packets: Arc, read: Arc, ) -> Result<()> { // this solves a problem where a confirmation is received after a job has already been requeued @@ -353,7 +356,7 @@ async fn receive_confirmations( let requeue_handle = tokio::spawn({ let cache = cache.clone(); let lost_confirmations = lost_confirmations.clone(); - let confirmed_data = confirmed_data.clone(); + let confirmed_packets = confirmed_packets.clone(); let read = read.clone(); let mut interval = interval(Duration::from_millis(100)); @@ -384,7 +387,7 @@ async fn receive_confirmations( lost_confirmations.remove(&key); read.add_permits(1); - confirmed_data.fetch_add(TRANSFER_BUFFER_SIZE, Relaxed); + confirmed_packets.fetch_add(1, Relaxed); } else { unconfirmed.cached_at = None; job_sender.send(unconfirmed).await?; @@ -408,7 +411,7 @@ async fn receive_confirmations( lost_confirmations.insert((id, index)); } else { read.add_permits(1); // add a permit to the reader - confirmed_data.fetch_add(TRANSFER_BUFFER_SIZE, Relaxed); + confirmed_packets.fetch_add(1, Relaxed); } } } @@ -453,9 +456,7 @@ async fn split_receiver( } Some(message::Message::End(_)) => controller_sender.send(message).await?, Some(message::Message::Failure(_)) => controller_sender.send(message).await?, - _ => { - error!("received {:?}", message); - } + _ => return Err(Error::unexpected_message(Box::new(message))), } } } @@ -494,10 +495,7 @@ async fn build_manifest(options: &Options, total_data: &Arc) -> Res } let mut crypto = options.stream_crypto.clone(); - - if let Some(ref mut crypto) = crypto { - OsRng.fill_bytes(&mut crypto.iv); - } + crypto.random_iv(); Ok::<(u32, FileDetail), Error>(( index as u32, @@ -505,7 +503,7 @@ async fn build_manifest(options: &Options, total_data: &Arc) -> Res path: format_dir(file.to_string_lossy()), size, signature, - crypto, + crypto: Some(crypto), }, )) }) diff --git a/src/sender/reader.rs b/src/sender/reader.rs index 4aa44a7..72a5b10 100644 --- a/src/sender/reader.rs +++ b/src/sender/reader.rs @@ -18,15 +18,13 @@ pub(crate) async fn reader( read: Arc, mut index: u64, id: u32, - mut cipher: Option>, + mut cipher: Box, ) -> Result<()> { let file = File::open(path).await?; let mut reader = BufReader::with_capacity(READ_BUFFER_SIZE, file); reader.seek(SeekFrom::Start(index)).await?; - if let Some(ref mut cipher) = cipher { - cipher.seek(index); - } + cipher.seek(index); let mut buffer = [0; ID_SIZE + INDEX_SIZE + TRANSFER_BUFFER_SIZE]; // write id to buffer it is constant for all chunks @@ -48,9 +46,7 @@ pub(crate) async fn reader( break; } - if let Some(ref mut cipher) = cipher { - cipher.apply_keystream(&mut buffer[INDEX_SIZE + ID_SIZE..INDEX_SIZE + ID_SIZE + read]); - } + cipher.apply_keystream(&mut buffer[INDEX_SIZE + ID_SIZE..INDEX_SIZE + ID_SIZE + read]); // push job to queue queue