Skip to content

Commit

Permalink
big update
Browse files Browse the repository at this point in the history
  • Loading branch information
chanderlud committed Dec 4, 2023
1 parent b6db24c commit 0b1a8f2
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 245 deletions.
144 changes: 71 additions & 73 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#![feature(int_roundings)]

use std::error;
use std::fmt::{Display, Formatter};
use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
use std::process::{ExitCode, Termination};
use std::str::FromStr;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::Arc;
use std::time::Duration;
use std::{error, process};

use async_ssh2_tokio::{AuthMethod, Client, ServerCheckMethod};
use clap::Parser;
Expand All @@ -19,7 +22,7 @@ use rpassword::prompt_password;
use simple_logging::{log_to_file, log_to_stderr};
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use tokio::net::UdpSocket;
use tokio::net::{TcpListener, TcpSocket, UdpSocket};
use tokio::time::{interval, sleep};
use tokio::{io, select};

Expand All @@ -31,6 +34,7 @@ type LimitedQueue<T> = Arc<deadqueue::limited::Queue<T>>;
type Result<T> = std::result::Result<T, Error>;

const READ_BUFFER_SIZE: usize = 10_000_000;
const WRITE_BUFFER_SIZE: usize = 5_000_000;
const TRANSFER_BUFFER_SIZE: usize = 1024;
const INDEX_SIZE: usize = std::mem::size_of::<u64>();
const MAX_RETRIES: usize = 10;
Expand Down Expand Up @@ -66,15 +70,15 @@ impl From<std::net::AddrParseError> for Error {
}
}

impl Error {
fn exit_code(self) -> i32 {
match self.kind {
impl Termination for Error {
fn report(self) -> ExitCode {
ExitCode::from(match self.kind {
ErrorKind::IoError(error) => match error.kind() {
io::ErrorKind::NotFound => 1,
_ => 2,
},
ErrorKind::ParseError(_) => 3,
}
})
}
}

Expand Down Expand Up @@ -116,7 +120,7 @@ struct Options {
short,
long = "log-level",
help = "log level [debug, info, warn, error]",
default_value = "info"
default_value = "warn"
)]
log_level: LevelFilter,

Expand All @@ -125,7 +129,7 @@ struct Options {
long = "bind-address",
help = "manually specify the address to listen on"
)]
bind_address: Option<String>,
bind_address: Option<IpAddr>,

#[clap(
short,
Expand Down Expand Up @@ -163,8 +167,7 @@ impl Options {
#[derive(Debug, Clone, PartialEq)]
enum Mode {
Local,
RemoteSender,
RemoteReceiver,
Remote(bool), // Remote(sender)
}

impl FromStr for Mode {
Expand All @@ -173,17 +176,17 @@ impl FromStr for Mode {
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
Ok(match s {
"l" => Self::Local,
"rr" => Self::RemoteReceiver,
"rs" => Self::RemoteSender,
"rr" => Self::Remote(false),
"rs" => Self::Remote(true),
"local" => Self::Local,
"remote-receiver" => Self::RemoteReceiver,
"remote-sender" => Self::RemoteSender,
"remote-receiver" => Self::Remote(false),
"remote-sender" => Self::Remote(true),
_ => return Err(CustomParseErrors::UnknownMode),
})
}
}

// a file located anywhere:tm:
// a file located anywhere
#[derive(Clone, Debug)]
struct FileLocation {
file_path: PathBuf,
Expand Down Expand Up @@ -251,8 +254,8 @@ impl Display for FileLocation {
}

impl FileLocation {
fn is_local(&self, local_address: &str) -> bool {
self.host.is_none() || self.host.as_ref().unwrap() == local_address
fn is_local(&self) -> bool {
self.host.is_none() || (self.host.is_some() && self.file_path.exists())
}

async fn file_size(&self) -> io::Result<u64> {
Expand Down Expand Up @@ -299,12 +302,11 @@ struct TransferStats {
}

#[tokio::main]
async fn main() {
async fn main() -> Result<()> {
let mut options = Options::parse();

match options.mode {
Mode::Local => log_to_stderr(options.log_level),
// TODO choose a better log file location
_ => log_to_file("cccp.log", options.log_level).expect("failed to log"),
}

Expand All @@ -324,6 +326,15 @@ async fn main() {
options.threads = port_count;
} else if port_count < 2 {
panic!("a minimum of two ports are required")
} else if port_count > options.threads {
warn!(
"{} ports > {} threads. changing port range to {}-{}",
port_count,
options.threads,
options.start_port,
options.start_port + options.threads
);
options.end_port = options.start_port + options.threads;
}

if options.destination.host.is_none() && options.source.host.is_none() {
Expand All @@ -333,51 +344,21 @@ async fn main() {
// UDP header + INDEX + DATA
let packet_size = (8 + INDEX_SIZE + TRANSFER_BUFFER_SIZE) as u64;
let pps_rate = options.rate / packet_size;
info!(
"converted {} byte/s rate to {} packet/s rate",
options.rate, pps_rate
);
debug!("{} byte/s -> {} packet/s", options.rate, pps_rate);
options.rate = pps_rate;
}

let public_address = match options.mode {
Mode::Local => {
if let Some(address) = options.bind_address.as_ref() {
address.clone()
} else {
reqwest::get("http://api.ipify.org")
.await
.unwrap()
.text()
.await
.unwrap()
}
}
Mode::RemoteSender => options.source.host.as_ref().unwrap().clone(),
Mode::RemoteReceiver => options.destination.host.as_ref().unwrap().clone(),
};

info!("the bind address is: {}", public_address);

let sender = options.source.is_local(&public_address);
let sender = options.source.is_local();
let stats = TransferStats::default();

match options.mode {
Mode::Local => {
if options.destination.is_local(&public_address) {
debug!("destination is local");
options.destination.host = Some(public_address.clone());
} else {
debug!("source is local");
options.source.host = Some(public_address.clone());
}

let command = options.format_command(sender);

let (local, remote) = if options.destination.is_local(&public_address) {
(&options.destination, &options.source)
} else {
let (local, remote) = if sender {
(&options.source, &options.destination)
} else {
(&options.destination, &options.source)
};

if remote.username.is_none() {
Expand All @@ -397,9 +378,11 @@ async fn main() {
}
};

let remote_addr: IpAddr = remote.host.as_ref().unwrap().parse().unwrap();

let client = loop {
match Client::connect(
(remote.host.as_ref().unwrap().as_str(), 22),
(remote_addr, 22),
remote.username.as_ref().unwrap().as_str(),
auth_method.clone(),
ServerCheckMethod::NoCheck,
Expand Down Expand Up @@ -430,6 +413,26 @@ async fn main() {
client.execute(&command).await
});

let bind = match options.bind_address {
Some(addr) => SocketAddr::new(addr, 0),
None => "0.0.0.0:0".parse().unwrap(),
};

// connect to the remote client on the first port in the range
let control_stream = loop {
let socket = TcpSocket::new_v4().unwrap();
socket.bind(bind).unwrap();

let remote_socket = SocketAddr::new(remote_addr, options.start_port);

if let Ok(stream) = socket.connect(remote_socket).await {
break stream;
} else {
// give the receiver time to start listening
sleep(Duration::from_millis(100)).await;
}
};

let display_handle = tokio::spawn({
let stats = stats.clone();

Expand All @@ -438,9 +441,9 @@ async fn main() {

let main_future = async {
if sender {
sender::main(options, stats).await
sender::main(options, stats, control_stream, remote_addr).await
} else {
receiver::main(options, stats).await
receiver::main(options, stats, control_stream, remote_addr).await
}
};

Expand Down Expand Up @@ -469,28 +472,23 @@ async fn main() {
select! {
_ = command_future => {},
_ = display_handle => {},
result = main_future => {
if let Err(error) = result {
error!("{} failed: {:?}", if sender { "sender" } else { "receiver" }, error);
}
}
}
}
Mode::RemoteSender => {
if let Err(error) = sender::main(options, stats).await {
error!("sender failed: {:?}", error);
process::exit(error.exit_code()); // exit with non 0 status so remote knows it failed
result = main_future => result?
}
}
Mode::RemoteReceiver => {
if let Err(error) = receiver::main(options, stats).await {
error!("receiver failed: {:?}", error);
process::exit(error.exit_code()); // exit with non 0 status so remote knows it failed
}
Mode::Remote(sender) => {
let listener = TcpListener::bind(("0.0.0.0", options.start_port)).await?;
let (control_stream, remote_addr) = listener.accept().await?;

if sender {
sender::main(options, stats, control_stream, remote_addr.ip()).await?;
} else {
receiver::main(options, stats, control_stream, remote_addr.ip()).await?;
};
}
}

info!("exiting");
Ok(())
}

// opens the sockets that will be used to send data
Expand Down
81 changes: 0 additions & 81 deletions src/receiver/metadata.rs

This file was deleted.

Loading

0 comments on commit 0b1a8f2

Please sign in to comment.