diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index b838492..a9b5a26 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -372,7 +372,7 @@ impl PacketHandler { } /// Split the handler into separate reader and a writer. - pub fn split(self) -> (PacketReader, PacketWriter) { + pub fn into_split(self) -> (PacketReader, PacketWriter) { (self.packet_reader, self.packet_writer) } diff --git a/proxy/src/bin/async.rs b/proxy/src/bin/async.rs index d1bfabe..fc5e1e9 100644 --- a/proxy/src/bin/async.rs +++ b/proxy/src/bin/async.rs @@ -8,10 +8,9 @@ use bitcoin::Network; use bytes::BytesMut; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; -use tokio::select; /// Validate and bootstrap proxy connection. -async fn proxy_conn(mut client: TcpStream) -> Result<(), bip324_proxy::Error> { +async fn proxy_conn(client: TcpStream) -> Result<(), bip324_proxy::Error> { let remote_ip = bip324_proxy::peek_addr(&client) .await .expect("peek address"); @@ -80,47 +79,63 @@ async fn proxy_conn(mut client: TcpStream) -> Result<(), bip324_proxy::Error> { bip324::Error::MessageLengthTooSmall => continue, e => panic!("unable to authenticate garbage {}", e), }, - _ => break, + _ => { + println!("Channel authenticated."); + break; + } } } } } - println!("Channel authenticated."); - println!("Splitting channels."); let packet_handler = handshake.finalize().expect("finished handshake"); - let (mut client_reader, mut client_writer) = client.split(); - let (mut remote_reader, mut remote_writer) = remote.split(); - let (mut decrypter, mut encrypter) = packet_handler.split(); + let (mut client_reader, mut client_writer) = client.into_split(); + let (mut remote_reader, mut remote_writer) = remote.into_split(); + let (mut decrypter, mut encrypter) = packet_handler.into_split(); - println!("Setting up proxy loop."); - loop { - select! { - res = read_v1(&mut client_reader) => { - match res { - Ok(msg) => { - println!("Read {} message from client, writing to remote.", msg.command()); - write_v2(&mut remote_writer, &mut encrypter, msg).await.expect("write v2 message"); - }, - Err(err) => { - panic!("unable to read v1 from client {}", err); - }, + println!("Setting up proxy loops."); + tokio::spawn(async move { + loop { + let res = read_v1(&mut client_reader).await; + match res { + Ok(msg) => { + println!( + "Read {} message from client, writing to remote.", + msg.command() + ); + write_v2(&mut remote_writer, &mut encrypter, msg) + .await + .expect("write to remote"); } - }, - res = read_v2(&mut remote_reader, &mut decrypter) => { - match res { - Ok(msg) => { - println!("Read {} message from remote, writing to client.", msg.command()); - write_v1(&mut client_writer, msg).await.expect("write v1 message"); - }, - Err(err) => { - panic!("unable to read v2 from client {}", err); - }, + Err(e) => { + panic!("unable to read from client {}", e); + } + } + } + }); + + tokio::spawn(async move { + loop { + let res = read_v2(&mut remote_reader, &mut decrypter).await; + match res { + Ok(msg) => { + println!( + "Read {} message from remote, writing to client.", + msg.command() + ); + write_v1(&mut client_writer, msg) + .await + .expect("write to client"); + } + Err(e) => { + panic!("unable to read from remote {}", e); } - }, + } } - } + }); + + Ok(()) } #[tokio::main] diff --git a/proxy/src/bin/v1.rs b/proxy/src/bin/v1.rs index 54218e7..cdc6bf7 100644 --- a/proxy/src/bin/v1.rs +++ b/proxy/src/bin/v1.rs @@ -3,45 +3,65 @@ use bip324_proxy::{read_v1, write_v1}; use tokio::net::{TcpListener, TcpStream}; -use tokio::select; /// Validate and bootstrap proxy connection. -async fn proxy_conn(mut client: TcpStream) -> Result<(), bip324_proxy::Error> { +async fn proxy_conn(client: TcpStream) -> Result<(), bip324_proxy::Error> { let remote_ip = bip324_proxy::peek_addr(&client).await?; println!("Initialing remote connection {}.", remote_ip); - let mut remote = TcpStream::connect(remote_ip).await?; + let remote = TcpStream::connect(remote_ip).await?; - let (mut client_reader, mut client_writer) = client.split(); - let (mut remote_reader, mut remote_writer) = remote.split(); + let (mut client_reader, mut client_writer) = client.into_split(); + let (mut remote_reader, mut remote_writer) = remote.into_split(); println!("Setting up proxy loop."); - loop { - select! { - res = read_v1(&mut client_reader) => { - match res { - Ok(msg) => { - println!("Read {} message from client, writing to remote.", msg.command()); - write_v1(&mut remote_writer, msg).await?; - }, - Err(e) => { - return Err(e); - }, + + // Spawning two threads instead of selecting on one due + // to the read calls not being cancelable. A select + // drops other futures when one is ready, so it is + // possible that it drops one with half read state. + + tokio::spawn(async move { + loop { + let res = read_v1(&mut client_reader).await; + match res { + Ok(msg) => { + println!( + "Read {} message from client, writing to remote.", + msg.command() + ); + write_v1(&mut remote_writer, msg) + .await + .expect("write to remote"); } - }, - res = read_v1(&mut remote_reader) => { - match res { - Ok(msg) => { - println!("Read {} message from remote, writing to client.", msg.command()); - write_v1(&mut client_writer, msg).await?; - }, - Err(e) => { - return Err(e); - }, + Err(e) => { + panic!("unable to read from client {}", e); } - }, + } } - } + }); + + tokio::spawn(async move { + loop { + let res = read_v1(&mut remote_reader).await; + match res { + Ok(msg) => { + println!( + "Read {} message from remote, writing to client.", + msg.command() + ); + write_v1(&mut client_writer, msg) + .await + .expect("write to client"); + } + Err(e) => { + panic!("unable to read from remote {}", e); + } + } + } + }); + + Ok(()) } #[tokio::main] diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 3d8ebb5..0dba9e7 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -106,6 +106,8 @@ pub async fn peek_addr(client: &TcpStream) -> Result { } /// Read a v1 message off of the input stream. +/// +/// This future is not cancelable since state is read multiple times. pub async fn read_v1(input: &mut T) -> Result { let mut header_bytes = [0; V1_HEADER_BYTES]; input.read_exact(&mut header_bytes).await?; @@ -116,17 +118,20 @@ pub async fn read_v1(input: &mut T) -> Result( input: &mut T, decrypter: &mut PacketReader,