diff --git a/src/main.rs b/src/main.rs index cd97b20..35be8c1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,15 @@ +#![feature(int_roundings)] + +use std::error; use std::fmt::{Display, Formatter}; use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; +use std::process::{ExitCode, Termination}; use std::str::FromStr; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::time::Duration; -use std::{error, process}; use async_ssh2_tokio::{AuthMethod, Client, ServerCheckMethod}; use clap::Parser; @@ -19,7 +22,7 @@ use rpassword::prompt_password; use simple_logging::{log_to_file, log_to_stderr}; use tokio::fs::File; use tokio::io::AsyncReadExt; -use tokio::net::UdpSocket; +use tokio::net::{TcpListener, TcpSocket, UdpSocket}; use tokio::time::{interval, sleep}; use tokio::{io, select}; @@ -31,6 +34,7 @@ type LimitedQueue = Arc>; type Result = std::result::Result; const READ_BUFFER_SIZE: usize = 10_000_000; +const WRITE_BUFFER_SIZE: usize = 5_000_000; const TRANSFER_BUFFER_SIZE: usize = 1024; const INDEX_SIZE: usize = std::mem::size_of::(); const MAX_RETRIES: usize = 10; @@ -66,15 +70,15 @@ impl From for Error { } } -impl Error { - fn exit_code(self) -> i32 { - match self.kind { +impl Termination for Error { + fn report(self) -> ExitCode { + ExitCode::from(match self.kind { ErrorKind::IoError(error) => match error.kind() { io::ErrorKind::NotFound => 1, _ => 2, }, ErrorKind::ParseError(_) => 3, - } + }) } } @@ -116,7 +120,7 @@ struct Options { short, long = "log-level", help = "log level [debug, info, warn, error]", - default_value = "info" + default_value = "warn" )] log_level: LevelFilter, @@ -125,7 +129,7 @@ struct Options { long = "bind-address", help = "manually specify the address to listen on" )] - bind_address: Option, + bind_address: Option, #[clap( short, @@ -163,8 +167,7 @@ impl Options { #[derive(Debug, Clone, PartialEq)] enum Mode { Local, - RemoteSender, - RemoteReceiver, + Remote(bool), // Remote(sender) } impl FromStr for Mode { @@ -173,17 +176,17 @@ impl FromStr for Mode { fn from_str(s: &str) -> std::result::Result { Ok(match s { "l" => Self::Local, - "rr" => Self::RemoteReceiver, - "rs" => Self::RemoteSender, + "rr" => Self::Remote(false), + "rs" => Self::Remote(true), "local" => Self::Local, - "remote-receiver" => Self::RemoteReceiver, - "remote-sender" => Self::RemoteSender, + "remote-receiver" => Self::Remote(false), + "remote-sender" => Self::Remote(true), _ => return Err(CustomParseErrors::UnknownMode), }) } } -// a file located anywhere:tm: +// a file located anywhere #[derive(Clone, Debug)] struct FileLocation { file_path: PathBuf, @@ -251,8 +254,8 @@ impl Display for FileLocation { } impl FileLocation { - fn is_local(&self, local_address: &str) -> bool { - self.host.is_none() || self.host.as_ref().unwrap() == local_address + fn is_local(&self) -> bool { + self.host.is_none() || (self.host.is_some() && self.file_path.exists()) } async fn file_size(&self) -> io::Result { @@ -299,12 +302,11 @@ struct TransferStats { } #[tokio::main] -async fn main() { +async fn main() -> Result<()> { let mut options = Options::parse(); match options.mode { Mode::Local => log_to_stderr(options.log_level), - // TODO choose a better log file location _ => log_to_file("cccp.log", options.log_level).expect("failed to log"), } @@ -324,6 +326,15 @@ async fn main() { options.threads = port_count; } else if port_count < 2 { panic!("a minimum of two ports are required") + } else if port_count > options.threads { + warn!( + "{} ports > {} threads. changing port range to {}-{}", + port_count, + options.threads, + options.start_port, + options.start_port + options.threads + ); + options.end_port = options.start_port + options.threads; } if options.destination.host.is_none() && options.source.host.is_none() { @@ -333,51 +344,21 @@ async fn main() { // UDP header + INDEX + DATA let packet_size = (8 + INDEX_SIZE + TRANSFER_BUFFER_SIZE) as u64; let pps_rate = options.rate / packet_size; - info!( - "converted {} byte/s rate to {} packet/s rate", - options.rate, pps_rate - ); + debug!("{} byte/s -> {} packet/s", options.rate, pps_rate); options.rate = pps_rate; } - let public_address = match options.mode { - Mode::Local => { - if let Some(address) = options.bind_address.as_ref() { - address.clone() - } else { - reqwest::get("http://api.ipify.org") - .await - .unwrap() - .text() - .await - .unwrap() - } - } - Mode::RemoteSender => options.source.host.as_ref().unwrap().clone(), - Mode::RemoteReceiver => options.destination.host.as_ref().unwrap().clone(), - }; - - info!("the bind address is: {}", public_address); - - let sender = options.source.is_local(&public_address); + let sender = options.source.is_local(); let stats = TransferStats::default(); match options.mode { Mode::Local => { - if options.destination.is_local(&public_address) { - debug!("destination is local"); - options.destination.host = Some(public_address.clone()); - } else { - debug!("source is local"); - options.source.host = Some(public_address.clone()); - } - let command = options.format_command(sender); - let (local, remote) = if options.destination.is_local(&public_address) { - (&options.destination, &options.source) - } else { + let (local, remote) = if sender { (&options.source, &options.destination) + } else { + (&options.destination, &options.source) }; if remote.username.is_none() { @@ -397,9 +378,11 @@ async fn main() { } }; + let remote_addr: IpAddr = remote.host.as_ref().unwrap().parse().unwrap(); + let client = loop { match Client::connect( - (remote.host.as_ref().unwrap().as_str(), 22), + (remote_addr, 22), remote.username.as_ref().unwrap().as_str(), auth_method.clone(), ServerCheckMethod::NoCheck, @@ -430,6 +413,26 @@ async fn main() { client.execute(&command).await }); + let bind = match options.bind_address { + Some(addr) => SocketAddr::new(addr, 0), + None => "0.0.0.0:0".parse().unwrap(), + }; + + // connect to the remote client on the first port in the range + let control_stream = loop { + let socket = TcpSocket::new_v4().unwrap(); + socket.bind(bind).unwrap(); + + let remote_socket = SocketAddr::new(remote_addr, options.start_port); + + if let Ok(stream) = socket.connect(remote_socket).await { + break stream; + } else { + // give the receiver time to start listening + sleep(Duration::from_millis(100)).await; + } + }; + let display_handle = tokio::spawn({ let stats = stats.clone(); @@ -438,9 +441,9 @@ async fn main() { let main_future = async { if sender { - sender::main(options, stats).await + sender::main(options, stats, control_stream, remote_addr).await } else { - receiver::main(options, stats).await + receiver::main(options, stats, control_stream, remote_addr).await } }; @@ -469,28 +472,23 @@ async fn main() { select! { _ = command_future => {}, _ = display_handle => {}, - result = main_future => { - if let Err(error) = result { - error!("{} failed: {:?}", if sender { "sender" } else { "receiver" }, error); - } - } - } - } - Mode::RemoteSender => { - if let Err(error) = sender::main(options, stats).await { - error!("sender failed: {:?}", error); - process::exit(error.exit_code()); // exit with non 0 status so remote knows it failed + result = main_future => result? } } - Mode::RemoteReceiver => { - if let Err(error) = receiver::main(options, stats).await { - error!("receiver failed: {:?}", error); - process::exit(error.exit_code()); // exit with non 0 status so remote knows it failed - } + Mode::Remote(sender) => { + let listener = TcpListener::bind(("0.0.0.0", options.start_port)).await?; + let (control_stream, remote_addr) = listener.accept().await?; + + if sender { + sender::main(options, stats, control_stream, remote_addr.ip()).await?; + } else { + receiver::main(options, stats, control_stream, remote_addr.ip()).await?; + }; } } info!("exiting"); + Ok(()) } // opens the sockets that will be used to send data diff --git a/src/receiver/metadata.rs b/src/receiver/metadata.rs deleted file mode 100644 index fc7738a..0000000 --- a/src/receiver/metadata.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::io::SeekFrom; -use std::path::{Path, PathBuf}; - -use tokio::fs::{remove_file, File, OpenOptions}; -use tokio::io; -use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; - -pub(crate) struct Metadata { - pub(crate) initial_index: u64, // the initial index - writer: File, // the metadata file - file_path: PathBuf, // the path to the metadata file -} - -impl Metadata { - /// initialize the metadata object - pub(crate) async fn new(file_path: &Path) -> io::Result { - let file_path = Self::format_path(file_path); - - if file_path.exists() { - if let Ok(metadata) = Self::load(&file_path).await { - return Ok(metadata); - } - // if loading existing metadata fails, create new metadata - } - - let writer = File::create(&file_path).await?; // create metadata file - - Ok(Self { - initial_index: 0, - writer, - file_path, - }) - } - - /// complete an index - pub(crate) async fn complete(&mut self, index: u64) -> io::Result<()> { - self.writer.write_u64(index).await // write index to file - } - - /// remove metadata file - pub(crate) async fn remove(&self) -> io::Result<()> { - remove_file(&self.file_path).await - } - - /// load metadata from file - async fn load(file_path: &Path) -> io::Result { - // format the path to the metadata file - let file_path = Self::format_path(file_path); - - // open the file for reading and writing - let mut file = OpenOptions::new() - .read(true) - .write(true) - .open(&file_path) - .await?; - - // seek to the end of the file - let len = file.seek(SeekFrom::End(0)).await?; - - let mut buf = [0; 8]; // create buffer for data - - // if the file is not empty, seek back 8 bytes from the end and read the last index - if len > 0 { - file.seek(SeekFrom::End(-8)).await?; - file.read_exact(&mut buf).await?; - } - - // return a new instance of the Metadata struct - Ok(Self { - initial_index: u64::from_be_bytes(buf), - writer: file, - file_path, - }) - } - - /// formats the path to the metadata file - #[inline] - fn format_path(path: &Path) -> PathBuf { - path.with_extension("metadata") - } -} diff --git a/src/receiver/mod.rs b/src/receiver/mod.rs index b4a14e9..3db1fec 100644 --- a/src/receiver/mod.rs +++ b/src/receiver/mod.rs @@ -1,4 +1,5 @@ use std::mem; +use std::net::IpAddr; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; @@ -6,22 +7,20 @@ use std::time::Duration; use deadqueue::limited::Queue; use log::{debug, error, info, warn}; -use tokio::fs::rename; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpListener, TcpStream, UdpSocket}; +use tokio::fs::{metadata, rename}; +use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::{TcpStream, UdpSocket}; use tokio::sync::Mutex; use tokio::task::JoinHandle; use tokio::time::{interval, timeout}; use tokio::{io, select}; -use crate::receiver::metadata::Metadata; use crate::receiver::writer::writer; use crate::{ socket_factory, LimitedQueue, Options, Result, TransferStats, UnlimitedQueue, INDEX_SIZE, MAX_RETRIES, RECEIVE_TIMEOUT, TRANSFER_BUFFER_SIZE, }; -mod metadata; mod writer; type WriterQueue = LimitedQueue; @@ -30,10 +29,14 @@ type WriterQueue = LimitedQueue; struct Job { data: [u8; TRANSFER_BUFFER_SIZE], // the file chunk index: u64, // the index of the file chunk - len: usize, // the length of the file chunk } -pub(crate) async fn main(mut options: Options, stats: TransferStats) -> Result<()> { +pub(crate) async fn main( + mut options: Options, + stats: TransferStats, + mut control_stream: TcpStream, + remote_addr: IpAddr, +) -> Result<()> { if options.destination.file_path.is_dir() { info!("destination is a folder, reformatting path with target file"); @@ -45,13 +48,20 @@ pub(crate) async fn main(mut options: Options, stats: TransferStats) -> Result<( info!("receiving {} -> {}", options.source, options.destination); - let meta_data = Metadata::new(&options.destination.file_path).await?; + let partial_path = options.destination.file_path.with_extension("partial"); + + let start_index = if partial_path.exists() { + info!("partial file exists, resuming transfer"); + let metadata = metadata(&partial_path).await?; + // the file is written sequentially, so we can calculate the start index by rounding down to the nearest multiple of the transfer buffer size + metadata.len().div_floor(TRANSFER_BUFFER_SIZE as u64) * TRANSFER_BUFFER_SIZE as u64 + } else { + 0 + }; + stats .confirmed_data - .fetch_add(meta_data.initial_index as usize, Relaxed); - - let listener = TcpListener::bind(("0.0.0.0", options.start_port)).await?; - let (mut control_stream, _remote_addr) = listener.accept().await?; + .fetch_add(start_index as usize, Relaxed); // receive the file size from the remote client let file_size = control_stream.read_u64().await?; @@ -59,13 +69,13 @@ pub(crate) async fn main(mut options: Options, stats: TransferStats) -> Result<( debug!("received file size: {}", file_size); // send the start index to the remote client - debug!("sending start index {}", meta_data.initial_index); - control_stream.write_u64(meta_data.initial_index).await?; + debug!("sending start index {}", start_index); + control_stream.write_u64(start_index).await?; let sockets = socket_factory( options.start_port + 1, // the first port is used for control messages options.end_port, - options.source.host.unwrap().as_str().parse()?, + remote_addr, options.threads, ) .await?; @@ -76,11 +86,11 @@ pub(crate) async fn main(mut options: Options, stats: TransferStats) -> Result<( let confirmation_queue: UnlimitedQueue = Default::default(); let writer_handle = tokio::spawn(writer( - options.destination.file_path.with_extension("partial"), + partial_path.clone(), writer_queue.clone(), file_size, confirmation_queue.clone(), - meta_data, + start_index, )); let confirmation_handle = tokio::spawn(send_confirmations( @@ -107,7 +117,7 @@ pub(crate) async fn main(mut options: Options, stats: TransferStats) -> Result<( // rename the partial file to the original file rename( - &options.destination.file_path.with_extension("partial"), + &partial_path, &options.destination.file_path, ) .await?; @@ -119,20 +129,19 @@ pub(crate) async fn main(mut options: Options, stats: TransferStats) -> Result<( } async fn receiver(queue: WriterQueue, socket: UdpSocket) { - let mut buf = [0; INDEX_SIZE + TRANSFER_BUFFER_SIZE]; - let mut retries = 0; + let mut buf = [0; INDEX_SIZE + TRANSFER_BUFFER_SIZE]; // buffer for receiving data + let mut retries = 0; // counter to keep track of retries while retries < MAX_RETRIES { match timeout(RECEIVE_TIMEOUT, socket.recv(&mut buf)).await { Ok(Ok(read)) if read > 0 => { - retries = 0; + retries = 0; // reset retries let data = buf[INDEX_SIZE..].try_into().unwrap(); let index = u64::from_be_bytes(buf[..INDEX_SIZE].try_into().unwrap()); - let len = read - INDEX_SIZE; - queue.push(Job { data, index, len }).await; + queue.push(Job { data, index }).await; } - Ok(Ok(_)) => warn!("0 byte read?"), - Ok(Err(_)) | Err(_) => retries += 1, + Ok(Ok(_)) => warn!("0 byte read?"), // this should never happen + Ok(Err(_)) | Err(_) => retries += 1, // catch errors and timeouts } } } @@ -159,30 +168,35 @@ async fn send_confirmations( continue; } + // take the data out of the mutex let indexes = mem::take(&mut *data); - drop(data); + drop(data); // release the lock on data - send_indexes(&mut control_stream, &indexes).await?; + write_indexes(&mut control_stream, &indexes).await?; } } }); let future = async { loop { - let index = queue.pop().await; - confirmed_data.fetch_add(TRANSFER_BUFFER_SIZE, Relaxed); - data.lock().await.push(index); + let index = queue.pop().await; // wait for a confirmation + confirmed_data.fetch_add(TRANSFER_BUFFER_SIZE, Relaxed); // increment the confirmed data counter + data.lock().await.push(index); // push the index to the data vector } }; + // propagate errors from the sender thread while executing the future select! { result = sender_handle => result?, _ = future => Ok(()) } } -// sends an array of indexes to the socket -async fn send_indexes(control_stream: &mut TcpStream, data: &[u64]) -> io::Result<()> { +// writes an array of u64 values to the control stream +async fn write_indexes( + control_stream: &mut T, + data: &[u64], +) -> io::Result<()> { let length = data.len() as u64; control_stream.write_u64(length).await?; diff --git a/src/receiver/writer.rs b/src/receiver/writer.rs index 2b7bb2a..6064733 100644 --- a/src/receiver/writer.rs +++ b/src/receiver/writer.rs @@ -5,31 +5,26 @@ use std::path::PathBuf; use log::{debug, info}; use tokio::fs::OpenOptions; -use tokio::io::{self, AsyncSeekExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::{self, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter}; -use crate::receiver::metadata::Metadata; use crate::receiver::{Job, WriterQueue}; -use crate::{UnlimitedQueue, TRANSFER_BUFFER_SIZE}; +use crate::{UnlimitedQueue, TRANSFER_BUFFER_SIZE, WRITE_BUFFER_SIZE}; pub(crate) async fn writer( path: PathBuf, writer_queue: WriterQueue, file_size: u64, confirmation_queue: UnlimitedQueue, - mut metadata: Metadata, + mut position: u64, ) -> io::Result<()> { - let mut writer = OpenOptions::new() + let file = OpenOptions::new() .write(true) .create(true) .open(path) .await?; - let mut position = metadata.initial_index; - - if position > 0 { - position += TRANSFER_BUFFER_SIZE as u64; // the chunk at index has already been written - writer.seek(SeekFrom::Start(position)).await?; // seek to the position - } + let mut writer = BufWriter::with_capacity(WRITE_BUFFER_SIZE, file); + writer.seek(SeekFrom::Start(position)).await?; // seek to the initial position debug!("starting writer at position {}", position); @@ -49,21 +44,19 @@ pub(crate) async fn writer( } // if the chunk is at the current position, write it Ordering::Equal => { - write_data(&mut writer, &job.data, &mut position, job.len).await?; + write_data(&mut writer, &job.data, &mut position, file_size).await?; confirmation_queue.push(job.index); - metadata.complete(job.index).await?; } } - // write all concurrent chunks from `cache` + // write all concurrent chunks from the cache while let Some(job) = cache.remove(&position) { - write_data(&mut writer, &job.data, &mut position, job.len).await?; - metadata.complete(job.index).await?; + write_data(&mut writer, &job.data, &mut position, file_size).await?; } } info!("writer wrote all expected bytes"); - metadata.remove().await?; // remove the metadata file + writer.flush().await?; Ok(()) } @@ -74,8 +67,11 @@ async fn write_data( writer: &mut T, buffer: &[u8], position: &mut u64, - len: usize, + file_size: u64, ) -> io::Result<()> { - *position += len as u64; // advance the position - writer.write_all(&buffer[..len]).await // write the data + // calculate the length of the data to write + let len = (file_size - *position).min(TRANSFER_BUFFER_SIZE as u64); + + *position += len; // advance the position + writer.write_all(&buffer[..len as usize]).await // write the data } diff --git a/src/sender/mod.rs b/src/sender/mod.rs index b4a492d..61bb435 100644 --- a/src/sender/mod.rs +++ b/src/sender/mod.rs @@ -1,4 +1,5 @@ use std::collections::{BTreeMap, HashSet}; +use std::net::IpAddr; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; @@ -9,11 +10,11 @@ use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpStream, UdpSocket}; use tokio::select; use tokio::sync::{Mutex, RwLock, Semaphore}; -use tokio::time::{interval, sleep, Instant}; +use tokio::time::{interval, Instant}; use crate::{ - socket_factory, Options, Result, TransferStats, UnlimitedQueue, MAX_RETRIES, REQUEUE_INTERVAL, - TRANSFER_BUFFER_SIZE, + socket_factory, Options, Result, TransferStats, UnlimitedQueue, INDEX_SIZE, MAX_RETRIES, + REQUEUE_INTERVAL, TRANSFER_BUFFER_SIZE, }; mod reader; @@ -22,23 +23,21 @@ type JobQueue = UnlimitedQueue; type JobCache = Arc>>; struct Job { - data: Vec, // index (8 bytes) + the file chunk - index: u64, // the index of the file chunk + data: [u8; INDEX_SIZE + TRANSFER_BUFFER_SIZE], + index: u64, cached_at: Option, } -pub(crate) async fn main(options: Options, stats: TransferStats) -> Result<()> { - info!("sending {:?} to {:?}", options.source, options.destination); +pub(crate) async fn main( + options: Options, + stats: TransferStats, + mut control_stream: TcpStream, + remote_addr: IpAddr, +) -> Result<()> { + info!("sending {} -> {}", options.source, options.destination); - let remote_address = options.destination.host.unwrap().parse()?; let file_size = options.source.file_size().await?; stats.total_data.store(file_size as usize, Relaxed); - - // give the receiver time to start listening - sleep(Duration::from_millis(1_000)).await; - - // connect to the remote client on the first port in the range - let mut control_stream = TcpStream::connect((remote_address, options.start_port)).await?; // send the file size to the remote client control_stream.write_u64(file_size).await?; @@ -50,7 +49,7 @@ pub(crate) async fn main(options: Options, stats: TransferStats) -> Result<()> { let sockets = socket_factory( options.start_port + 1, // the first port is used for control messages options.end_port, - remote_address, + remote_addr, options.threads, ) .await?; @@ -96,19 +95,7 @@ pub(crate) async fn main(options: Options, stats: TransferStats) -> Result<()> { } }; - // let reader_future = async { - // _ = reader_handle.await; - // info!("reader exited"); - // - // while !queue.is_empty() && !cache.read().await.is_empty() { - // sleep(Duration::from_secs(1)).await; - // } - // - // info!("the queue and cache emptied, so hopefully all the data was sent"); - // }; - select! { - // _ = reader_future => {}, _ = sender_future => error!("senders exited"), result = confirmation_handle => { // the confirmation receiver never exits unless an error occurs diff --git a/src/sender/reader.rs b/src/sender/reader.rs index 6b335dc..fbe5e1b 100644 --- a/src/sender/reader.rs +++ b/src/sender/reader.rs @@ -18,21 +18,16 @@ pub(crate) async fn reader( ) -> Result<()> { let file = File::open(path).await?; let mut reader = BufReader::with_capacity(READ_BUFFER_SIZE, file); - - let mut buffer = vec![0; INDEX_SIZE + TRANSFER_BUFFER_SIZE]; - - if index > 0 { - index += TRANSFER_BUFFER_SIZE as u64; - } - reader.seek(SeekFrom::Start(index)).await?; + let mut buffer = [0; INDEX_SIZE + TRANSFER_BUFFER_SIZE]; + debug!("starting reader at index {}", index); loop { let permit = read.acquire().await.unwrap(); - // write index + // write index to buffer buffer[..INDEX_SIZE].copy_from_slice(&index.to_be_bytes()); // read data into buffer after checksum and index @@ -43,15 +38,15 @@ pub(crate) async fn reader( break; } - // push job with index, checksum, and data + // push job to queue queue.push(Job { - data: buffer[..INDEX_SIZE + read].to_vec(), + data: buffer, index, cached_at: None, }); - index += read as u64; - permit.forget(); + index += read as u64; // increment index by bytes read + permit.forget(); // release permit } Ok(())