Skip to content

Commit

Permalink
stable folder support, optimizations started
Browse files Browse the repository at this point in the history
  • Loading branch information
chanderlud committed Dec 8, 2023
1 parent fefe82b commit 71f6b50
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 90 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ build = "build.rs"

[dependencies]
clap = { version = "4.4", features = ["derive"] }
tokio = { version = "1.34", features = ["macros", "fs", "io-util"] }
tokio = { version = "1.34", default-features = false, features = ["macros", "fs", "io-util"] }
futures = "0.3"
log = { version = "0.4", features = ["std"] }
async-ssh2-tokio = "0.8"
Expand All @@ -20,7 +20,8 @@ rpassword = "7.3"
indicatif = "0.17"
prost = "0.12"
prost-build = "0.12"
async-channel = "2.1"
bytesize = "1.3.0"
kanal = "0.1.0-pre8"

[build-dependencies]
prost-build = "0.12.3"
Expand Down
49 changes: 31 additions & 18 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#![feature(int_roundings)]

use async_channel::Receiver;
use std::error;
use std::fmt::{Display, Formatter};
use std::net::{IpAddr, SocketAddr};
Expand All @@ -13,10 +12,12 @@ use std::sync::Arc;
use std::time::Duration;

use async_ssh2_tokio::{AuthMethod, Client, ServerCheckMethod};
use bytesize::ByteSize;
use clap::Parser;
use futures::stream::iter;
use futures::{StreamExt, TryStreamExt};
use indicatif::{ProgressBar, ProgressStyle};
use kanal::{AsyncReceiver, SendError};
use log::{debug, error, info, warn, LevelFilter};
use prost::Message;
use regex::Regex;
Expand All @@ -35,13 +36,16 @@ mod sender;
type UnlimitedQueue<T> = Arc<deadqueue::unlimited::Queue<T>>;
type Result<T> = std::result::Result<T, Error>;

const READ_BUFFER_SIZE: usize = 10_000_000;
const WRITE_BUFFER_SIZE: usize = 5_000_000;
// read buffer must be a multiple of the transfer buffer to prevent a nasty little bug
const READ_BUFFER_SIZE: usize = TRANSFER_BUFFER_SIZE * 100;
const WRITE_BUFFER_SIZE: usize = TRANSFER_BUFFER_SIZE * 100;
const TRANSFER_BUFFER_SIZE: usize = 1024;
const INDEX_SIZE: usize = std::mem::size_of::<u64>();
const ID_SIZE: usize = std::mem::size_of::<u32>();
const MAX_RETRIES: usize = 10;
const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
// UDP header + ID + INDEX + DATA
const PACKET_SIZE: usize = 8 + ID_SIZE + INDEX_SIZE + TRANSFER_BUFFER_SIZE;

// how long to wait for a job to be confirmed before requeuing it
const REQUEUE_INTERVAL: Duration = Duration::from_millis(1_000);
Expand All @@ -62,6 +66,7 @@ enum ErrorKind {
Parse(std::net::AddrParseError),
Decode(prost::DecodeError),
Join(tokio::task::JoinError),
Send(kanal::SendError),
}

impl From<io::Error> for Error {
Expand Down Expand Up @@ -96,6 +101,14 @@ impl From<tokio::task::JoinError> for Error {
}
}

impl From<SendError> for Error {
fn from(error: SendError) -> Self {
Self {
kind: ErrorKind::Send(error),
}
}
}

impl Termination for Error {
fn report(self) -> ExitCode {
ExitCode::from(match self.kind {
Expand All @@ -106,6 +119,7 @@ impl Termination for Error {
ErrorKind::Parse(_) => 3,
ErrorKind::Decode(_) => 4,
ErrorKind::Join(_) => 5,
ErrorKind::Send(_) => 6,
})
}
}
Expand All @@ -115,7 +129,7 @@ struct Options {
#[clap(
short,
long = "mode",
help = "local or remote",
hide = true, // the user does not need to set this
default_value = "local"
)]
mode: Mode,
Expand Down Expand Up @@ -162,10 +176,10 @@ struct Options {
#[clap(
short,
long = "rate",
help = "the rate to send data at (bytes per second)",
default_value = "1000000"
help = "the rate to send data at [b, kb, mb, gb, tb]",
default_value = "1mb"
)]
rate: u64,
rate: ByteSize,

#[clap(help = "where to get the data from")]
source: FileLocation,
Expand All @@ -179,7 +193,7 @@ impl Options {
let mode = if sender { "rr" } else { "rs" };

format!(
"cccp --mode {} --start-port {} --end-port {} --threads {} --log-level {} --rate {} \"{}\" \"{}\"",
"cccp --mode {} --start-port {} --end-port {} --threads {} --log-level {} --rate \"{}\" \"{}\" \"{}\"",
mode,
self.start_port,
self.end_port,
Expand All @@ -190,6 +204,10 @@ impl Options {
self.destination
)
}

fn pps(&self) -> u64 {
self.rate.0 / PACKET_SIZE as u64
}
}

#[derive(Debug, Clone, PartialEq)]
Expand Down Expand Up @@ -361,14 +379,8 @@ async fn main() -> Result<()> {
}

if options.destination.host.is_none() && options.source.host.is_none() {
panic!("at least one host must be specified")
panic!("either the source or destination must be remote");
}

// UDP header + INDEX + DATA
let packet_size = (8 + INDEX_SIZE + TRANSFER_BUFFER_SIZE) as u64;
let pps_rate = options.rate / packet_size;
debug!("{} byte/s -> {} packet/s", options.rate, pps_rate);
options.rate = pps_rate;
}

let sender = options.source.is_local();
Expand Down Expand Up @@ -451,9 +463,10 @@ async fn main() -> Result<()> {

let main_future = async {
if sender {
sender::main(options, stats, rts_stream, str_stream, remote_addr).await
sender::main(options, stats.clone(), rts_stream, str_stream, remote_addr).await
} else {
receiver::main(options, stats, rts_stream, str_stream, remote_addr).await
receiver::main(options, stats.clone(), rts_stream, str_stream, remote_addr)
.await
}
};

Expand Down Expand Up @@ -640,7 +653,7 @@ async fn read_message<R: AsyncReadExt + Unpin, M: Message + Default>(reader: &mu
/// send messages from a channel to a writer
async fn message_sender<W: AsyncWrite + Unpin, M: Message>(
mut writer: W,
receiver: Receiver<M>,
receiver: AsyncReceiver<M>,
) -> Result<()> {
while let Ok(message) = receiver.recv().await {
write_message(&mut writer, &message).await?;
Expand Down
49 changes: 35 additions & 14 deletions src/receiver/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use async_channel::Sender;
use kanal::AsyncSender;
use std::collections::HashMap;
use std::mem;
use std::net::IpAddr;
Expand Down Expand Up @@ -79,7 +79,7 @@ pub(crate) async fn main(
let confirmation_queue: ConfirmationQueue = Default::default();

// `message_sender` can now be used to send messages to the sender
let (message_sender, message_receiver) = async_channel::unbounded();
let (message_sender, message_receiver) = kanal::unbounded_async();
tokio::spawn(crate::message_sender(rts_stream, message_receiver));

let confirmation_handle = tokio::spawn(send_confirmations(
Expand Down Expand Up @@ -145,7 +145,7 @@ async fn controller(
confirmation_queue: ConfirmationQueue,
confirmed_data: Arc<AtomicUsize>,
file_path: PathBuf,
message_sender: Sender<Message>,
message_sender: AsyncSender<Message>,
) -> Result<()> {
loop {
let message: Message = read_message(&mut str_stream).await?;
Expand All @@ -164,7 +164,14 @@ async fn controller(
file_path.clone()
};

let partial_path = file_path.with_extension("partial");
// append partial extension to the existing extension, if there is one
let partial_extension = if let Some(extension) = file_path.extension() {
extension.to_str().unwrap().to_owned() + ".partial"
} else {
".partial".to_string()
};

let partial_path = file_path.with_extension(partial_extension);

let start_index = if partial_path.exists() {
info!("partial file exists, resuming transfer");
Expand All @@ -187,15 +194,29 @@ async fn controller(
path: file_path,
};

// TODO the result of writer is lost
tokio::spawn(writer(
file,
writer_queue.clone(),
confirmation_queue.clone(),
start_index,
message.id,
message_sender.clone(),
));
tokio::spawn({
let writer_queue = writer_queue.clone();
let confirmation_queue = confirmation_queue.clone();
let message_sender = message_sender.clone();

async move {
let path = file.path.clone();

let result = writer(
file,
writer_queue,
confirmation_queue,
start_index,
message.id,
message_sender,
)
.await;

if let Err(error) = result {
error!("writer for {} failed: {:?}", path.display(), error);
}
}
});

debug!("started file {:?}", details);
}
Expand All @@ -215,7 +236,7 @@ async fn controller(
}

async fn send_confirmations(
sender: Sender<Message>,
sender: AsyncSender<Message>,
queue: UnlimitedQueue<(u32, u64)>,
confirmed_data: Arc<AtomicUsize>,
) -> Result<()> {
Expand Down
21 changes: 10 additions & 11 deletions src/receiver/writer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use async_channel::{SendError, Sender};
use deadqueue::limited::Queue;
use kanal::AsyncSender;
use std::cmp::Ordering;
use std::collections::{BTreeMap, HashMap};
use std::collections::HashMap;
use std::io::SeekFrom;
use std::path::PathBuf;
use std::sync::Arc;
Expand All @@ -13,7 +13,7 @@ use tokio::sync::RwLock;

use crate::items::{message, End, Message};
use crate::receiver::{ConfirmationQueue, Job, WriterQueue};
use crate::{TRANSFER_BUFFER_SIZE, WRITE_BUFFER_SIZE};
use crate::{Result, TRANSFER_BUFFER_SIZE, WRITE_BUFFER_SIZE};

#[derive(Default)]
pub(crate) struct SplitQueue {
Expand Down Expand Up @@ -75,8 +75,8 @@ pub(crate) async fn writer(
confirmation_queue: ConfirmationQueue,
mut position: u64,
id: u32,
message_sender: Sender<Message>,
) -> io::Result<()> {
message_sender: AsyncSender<Message>,
) -> Result<()> {
let file = OpenOptions::new()
.write(true)
.create(true)
Expand All @@ -92,7 +92,7 @@ pub(crate) async fn writer(
position
);

let mut cache: BTreeMap<u64, Job> = BTreeMap::new();
let mut cache: HashMap<u64, Job> = HashMap::new();

while position != file_details.file_size {
let job = writer_queue.pop(&id).await.unwrap();
Expand Down Expand Up @@ -136,9 +136,7 @@ pub(crate) async fn writer(
writer.flush().await?; // flush the writer
file_details.rename().await?; // rename the file
writer_queue.pop_queue(&id).await; // remove the queue
send_end_message(&message_sender, id)
.await
.expect("failed to send end message"); // send end message
send_end_message(&message_sender, id).await?;

Ok(())
}
Expand All @@ -158,10 +156,11 @@ async fn write_data<T: AsyncWrite + Unpin>(
writer.write_all(&buffer[..len as usize]).await // write the data
}

async fn send_end_message(sender: &Sender<Message>, id: u32) -> Result<(), SendError<Message>> {
async fn send_end_message(sender: &AsyncSender<Message>, id: u32) -> Result<()> {
let end_message = Message {
message: Some(message::Message::End(End { id })),
};

sender.send(end_message).await
sender.send(end_message).await?;
Ok(())
}
Loading

0 comments on commit 71f6b50

Please sign in to comment.