Skip to content

Commit

Permalink
improvement to exit statuses
Browse files Browse the repository at this point in the history
  • Loading branch information
chanderlud committed Dec 4, 2023
1 parent 0a9a5ed commit b6db24c
Showing 1 changed file with 104 additions and 105 deletions.
209 changes: 104 additions & 105 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use indicatif::{ProgressBar, ProgressStyle};
use log::{debug, error, info, warn, LevelFilter};
use regex::Regex;
use rpassword::prompt_password;
use simple_logging::{log_to_file, log_to_stderr};
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use tokio::net::UdpSocket;
Expand All @@ -40,7 +41,7 @@ const REQUEUE_INTERVAL: Duration = Duration::from_millis(1_000);

#[derive(Debug)]
struct Error {
_kind: ErrorKind,
kind: ErrorKind,
}

#[derive(Debug)]
Expand All @@ -52,15 +53,27 @@ enum ErrorKind {
impl From<io::Error> for Error {
fn from(error: io::Error) -> Self {
Self {
_kind: ErrorKind::IoError(error),
kind: ErrorKind::IoError(error),
}
}
}

impl From<std::net::AddrParseError> for Error {
fn from(error: std::net::AddrParseError) -> Self {
Self {
_kind: ErrorKind::ParseError(error),
kind: ErrorKind::ParseError(error),
}
}
}

impl Error {
fn exit_code(self) -> i32 {
match self.kind {
ErrorKind::IoError(error) => match error.kind() {
io::ErrorKind::NotFound => 1,
_ => 2,
},
ErrorKind::ParseError(_) => 3,
}
}
}
Expand Down Expand Up @@ -147,7 +160,7 @@ impl Options {
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
enum Mode {
Local,
RemoteSender,
Expand Down Expand Up @@ -182,87 +195,57 @@ impl FromStr for FileLocation {
type Err = CustomParseErrors;

fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
if s.contains('@') {
let regex = Regex::new("([^:]+)@([^:]+):(.+)").unwrap();

let captures =
regex
.captures(s)
.ok_or(CustomParseErrors::MalformedConnectionString(
"input does not match format",
))?;

let username = captures
.get(1)
.ok_or(CustomParseErrors::MalformedConnectionString(
"missing username",
))?;
let host = captures
.get(2)
.ok_or(CustomParseErrors::MalformedConnectionString("missing host"))?;
let file_path = captures
.get(3)
let (username, host, file_path_str) = if s.contains('@') {
let captures = Regex::new("([^:]+)@([^:]+):(.+)")
.unwrap()
.captures(s)
.ok_or(CustomParseErrors::MalformedConnectionString(
"missing file path",
"input does not match format",
))?;

let file_path = file_path
.as_str()
.parse()
.map_err(|_| CustomParseErrors::ParseError)?;

Ok(Self {
file_path,
host: Some(host.as_str().to_string()),
username: Some(username.as_str().to_string()),
})
(
captures.get(1).map(|m| m.as_str().to_string()),
captures.get(2).map(|m| m.as_str().to_string()),
captures.get(3).map(|m| m.as_str().to_string()),
)
} else if s.contains(':') {
let (host, file_path_str) =
s.split_once(':')
.ok_or(CustomParseErrors::MalformedConnectionString(
"input does not match known format",
))?;

let file_path = file_path_str
.parse()
.map_err(|_| CustomParseErrors::ParseError)?;

Ok(Self {
file_path,
host: Some(host.to_string()),
username: None,
})
(
None,
Some(host.to_string()),
Some(file_path_str.to_string()),
)
} else {
let file_path = s.parse().map_err(|_| CustomParseErrors::ParseError)?;

Ok(Self {
file_path,
host: None,
username: None,
})
}
(None, None, Some(s.to_string()))
};

let file_path = file_path_str
.ok_or(CustomParseErrors::MalformedConnectionString(
"missing file path",
))?
.parse()
.map_err(|_| CustomParseErrors::ParseError)?;

Ok(Self {
file_path,
host,
username,
})
}
}

impl Display for FileLocation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if self.username.is_none() && self.host.is_none() {
write!(f, "{}", self.file_path.display())
} else if self.username.is_none() {
write!(
f,
"{}:{}",
self.host.as_ref().unwrap(),
self.file_path.display()
)
} else {
write!(
f,
"{}@{}:{}",
self.username.as_ref().unwrap(),
self.host.as_ref().unwrap(),
self.file_path.display()
)
let file_path = self.file_path.display();

match (&self.username, &self.host) {
(None, None) => write!(f, "{}", file_path),
(None, Some(host)) => write!(f, "{}:{}", host, file_path),
(Some(username), Some(host)) => write!(f, "{}@{}:{}", username, host, file_path),
_ => Err(std::fmt::Error),
}
}
}
Expand Down Expand Up @@ -319,42 +302,46 @@ struct TransferStats {
async fn main() {
let mut options = Options::parse();

// TODO choose a better log file location
// local client should log in execution location
simple_logging::log_to_file("cccp.log", options.log_level).expect("failed to log");

if options.start_port > options.end_port {
panic!("end port must be greater than start port")
match options.mode {
Mode::Local => log_to_stderr(options.log_level),
// TODO choose a better log file location
_ => log_to_file("cccp.log", options.log_level).expect("failed to log"),
}

let port_count = options.end_port - options.start_port;
// only the local client needs to handle input validation
if options.mode == Mode::Local {
if options.start_port > options.end_port {
panic!("end port must be greater than start port")
}

if port_count < options.threads {
warn!(
"{} ports < {} threads. decreasing threads to {}",
port_count, options.threads, port_count
);
options.threads = port_count;
} else if port_count < 2 {
panic!("a minimum of two ports are required")
}
let port_count = options.end_port - options.start_port;

if options.destination.host.is_none() && options.source.host.is_none() {
panic!("at least one host must be specified")
if port_count < options.threads {
warn!(
"{} ports < {} threads. decreasing threads to {}",
port_count, options.threads, port_count
);
options.threads = port_count;
} else if port_count < 2 {
panic!("a minimum of two ports are required")
}

if options.destination.host.is_none() && options.source.host.is_none() {
panic!("at least one host must be specified")
}

// UDP header + INDEX + DATA
let packet_size = (8 + INDEX_SIZE + TRANSFER_BUFFER_SIZE) as u64;
let pps_rate = options.rate / packet_size;
info!(
"converted {} byte/s rate to {} packet/s rate",
options.rate, pps_rate
);
options.rate = pps_rate;
}

let public_address = match options.mode {
Mode::Local => {
// only the local client needs to handle the rate conversion
// UDP header + INDEX + DATA
let packet_size = (8 + INDEX_SIZE + TRANSFER_BUFFER_SIZE) as u64;
let pps_rate = options.rate / packet_size;
info!(
"converted {} byte/s rate to {} packet/s rate",
options.rate, pps_rate
);
options.rate = pps_rate;

if let Some(address) = options.bind_address.as_ref() {
address.clone()
} else {
Expand Down Expand Up @@ -393,6 +380,12 @@ async fn main() {
(&options.source, &options.destination)
};

if remote.username.is_none() {
panic!("username must be specified for remote host: {}", remote);
} else if remote.host.is_none() {
panic!("host must be specified for remote host: {}", remote);
}

debug!("local {}", local);
debug!("remote {}", remote);

Expand Down Expand Up @@ -456,10 +449,16 @@ async fn main() {

match result {
Ok(Ok(result)) => {
if result.exit_status != 0 {
error!("remote client failed: {:?}", result);
} else {
sleep(Duration::from_secs(u64::MAX)).await; // wait forever
match result.exit_status {
0 => {
info!("remote client exited successfully");
// wait forever to allow the other futures to complete
sleep(Duration::from_secs(u64::MAX)).await;
}
1 => error!("remote client failed, file not found"),
2 => error!("remote client failed, unknown IO error"),
3 => error!("remote client failed, parse error"),
_ => error!("remote client failed, unknown error"),
}
}
Ok(Err(error)) => error!("remote client failed: {}", error), // return to terminate execution
Expand All @@ -480,13 +479,13 @@ async fn main() {
Mode::RemoteSender => {
if let Err(error) = sender::main(options, stats).await {
error!("sender failed: {:?}", error);
process::exit(1); // exit with non 0 status so remote knows it failed
process::exit(error.exit_code()); // exit with non 0 status so remote knows it failed
}
}
Mode::RemoteReceiver => {
if let Err(error) = receiver::main(options, stats).await {
error!("receiver failed: {:?}", error);
process::exit(1); // exit with non 0 status so remote knows it failed
process::exit(error.exit_code()); // exit with non 0 status so remote knows it failed
}
}
}
Expand Down

0 comments on commit b6db24c

Please sign in to comment.