diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index 381fd9d..d5af57d 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -11,6 +11,8 @@ extern crate std; mod chacha20poly1305; mod fschacha20poly1305; mod hkdf; +#[cfg(feature = "std")] +pub mod serde; use core::fmt; @@ -226,7 +228,7 @@ impl PacketReader { /// # Arguments /// /// - `ciphertext` - The message from the peer. - /// - `aad` - Optional authentication for the peer, currently only used for the first round of messages. + /// - `aad` - Optional authentication for the peer, currently only used for the first round of messages. /// /// # Errors /// @@ -370,7 +372,7 @@ impl PacketHandler { } /// Split the handler into separate reader and a writer. - pub fn split(self) -> (PacketReader, PacketWriter) { + pub fn into_split(self) -> (PacketReader, PacketWriter) { (self.packet_reader, self.packet_writer) } diff --git a/protocol/src/serde.rs b/protocol/src/serde.rs new file mode 100644 index 0000000..9c27012 --- /dev/null +++ b/protocol/src/serde.rs @@ -0,0 +1,315 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 + +//! A subset of commands are represented with a single byte in V2 instead of the 12-byte ASCII encoding like V1. +//! +//! ID mappings defined in [BIP324](https://github.com/bitcoin/bips/blob/master/bip-0324.mediawiki#user-content-v2_Bitcoin_P2P_message_structure). + +use core::fmt; +use std::io; + +use alloc::vec::Vec; +use bitcoin::{ + block, + consensus::{encode, Decodable, Encodable}, + p2p::message::{CommandString, NetworkMessage}, + VarInt, +}; + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum Error { + Serialize, + Deserialize, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::Serialize => write!(f, "Unable to serialize"), + Error::Deserialize => write!(f, "Unable to deserialize"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Error::Serialize => None, + Error::Deserialize => None, + } + } +} + +/// Serialize message in v2 format to buffer. +pub fn serialize(msg: NetworkMessage) -> Result, Error> { + let mut buffer = Vec::new(); + match &msg { + NetworkMessage::Addr(_) => { + buffer.push(1u8); + } + NetworkMessage::Inv(_) => { + buffer.push(14u8); + } + NetworkMessage::GetData(_) => { + buffer.push(11u8); + } + NetworkMessage::NotFound(_) => { + buffer.push(17u8); + } + NetworkMessage::GetBlocks(_) => { + buffer.push(9u8); + } + NetworkMessage::GetHeaders(_) => { + buffer.push(12u8); + } + NetworkMessage::MemPool => { + buffer.push(15u8); + } + NetworkMessage::Tx(_) => { + buffer.push(21u8); + } + NetworkMessage::Block(_) => { + buffer.push(2u8); + } + NetworkMessage::Headers(_) => { + buffer.push(13u8); + } + NetworkMessage::Ping(_) => { + buffer.push(18u8); + } + NetworkMessage::Pong(_) => { + buffer.push(19u8); + } + NetworkMessage::MerkleBlock(_) => { + buffer.push(16u8); + } + NetworkMessage::FilterLoad(_) => { + buffer.push(8u8); + } + NetworkMessage::FilterAdd(_) => { + buffer.push(6u8); + } + NetworkMessage::FilterClear => { + buffer.push(7u8); + } + NetworkMessage::GetCFilters(_) => { + buffer.push(22u8); + } + NetworkMessage::CFilter(_) => { + buffer.push(23u8); + } + NetworkMessage::GetCFHeaders(_) => { + buffer.push(24u8); + } + NetworkMessage::CFHeaders(_) => { + buffer.push(25u8); + } + NetworkMessage::GetCFCheckpt(_) => { + buffer.push(26u8); + } + NetworkMessage::CFCheckpt(_) => { + buffer.push(27u8); + } + NetworkMessage::SendCmpct(_) => { + buffer.push(20u8); + } + NetworkMessage::CmpctBlock(_) => { + buffer.push(4u8); + } + NetworkMessage::GetBlockTxn(_) => { + buffer.push(10u8); + } + NetworkMessage::BlockTxn(_) => { + buffer.push(3u8); + } + NetworkMessage::FeeFilter(_) => { + buffer.push(5u8); + } + NetworkMessage::AddrV2(_) => { + buffer.push(28u8); + } + // Messages which are not optimized and use the zero-byte + 12 following bytes to encode command in ascii. + NetworkMessage::Version(_) + | NetworkMessage::Verack + | NetworkMessage::SendHeaders + | NetworkMessage::GetAddr + | NetworkMessage::WtxidRelay + | NetworkMessage::SendAddrV2 + | NetworkMessage::Alert(_) + | NetworkMessage::Reject(_) => { + buffer.push(0u8); + msg.command() + .consensus_encode(&mut buffer) + .map_err(|_| Error::Serialize)?; + } + NetworkMessage::Unknown { + command, + payload: _, + } => { + buffer.push(0u8); + command + .consensus_encode(&mut buffer) + .map_err(|_| Error::Serialize)?; + } + } + + msg.consensus_encode(&mut buffer) + .map_err(|_| Error::Serialize)?; + + Ok(buffer) +} + +/// Deserialize v2 message into NetworkMessage. +pub fn deserialize(buffer: &[u8]) -> Result { + let short_id = buffer[0]; + let mut payload_buffer = &buffer[1..]; + match short_id { + // Zero-byte means the command is encoded in the next 12 bytes. + 0u8 => { + // Next 12 bytes have encoded command. + let mut command_buffer = &buffer[1..13]; + let command = CommandString::consensus_decode(&mut command_buffer) + .map_err(|_| Error::Deserialize)?; + // Rest of buffer is payload. + payload_buffer = &buffer[13..]; + // There are a handful of "known" messages which don't use a short ID, otherwise Unknown. + match command.as_ref() { + "version" => Ok(NetworkMessage::Version( + Decodable::consensus_decode(&mut payload_buffer) + .map_err(|_| Error::Deserialize)?, + )), + "verack" => Ok(NetworkMessage::Verack), + "sendheaders" => Ok(NetworkMessage::SendHeaders), + "getaddr" => Ok(NetworkMessage::GetAddr), + "wtxidrelay" => Ok(NetworkMessage::WtxidRelay), + "sendaddrv2" => Ok(NetworkMessage::SendAddrV2), + "alert" => Ok(NetworkMessage::Alert( + Decodable::consensus_decode(&mut payload_buffer) + .map_err(|_| Error::Deserialize)?, + )), + "reject" => Ok(NetworkMessage::Reject( + Decodable::consensus_decode(&mut payload_buffer) + .map_err(|_| Error::Deserialize)?, + )), + _ => Ok(NetworkMessage::Unknown { + command, + payload: payload_buffer.to_vec(), + }), + } + } + // The following single byte IDs map to command short IDs. + 1u8 => Ok(NetworkMessage::Addr( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 2u8 => Ok(NetworkMessage::Block( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 3u8 => Ok(NetworkMessage::BlockTxn( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 4u8 => Ok(NetworkMessage::CmpctBlock( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 5u8 => Ok(NetworkMessage::FeeFilter( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 6u8 => Ok(NetworkMessage::FilterAdd( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 7u8 => Ok(NetworkMessage::FilterClear), + 8u8 => Ok(NetworkMessage::FilterLoad( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 9u8 => Ok(NetworkMessage::GetBlocks( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 10u8 => Ok(NetworkMessage::GetBlockTxn( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 11u8 => Ok(NetworkMessage::GetData( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 12u8 => Ok(NetworkMessage::GetHeaders( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + // This one gets a little weird and needs a bit of love in the future. + 13u8 => Ok(NetworkMessage::Headers( + HeaderDeserializationWrapper::consensus_decode(&mut payload_buffer) + .map_err(|_| Error::Deserialize)? + .0, + )), + 14u8 => Ok(NetworkMessage::Inv( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 15u8 => Ok(NetworkMessage::MemPool), + 16u8 => Ok(NetworkMessage::MerkleBlock( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 17u8 => Ok(NetworkMessage::NotFound( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 18u8 => Ok(NetworkMessage::Ping( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 19u8 => Ok(NetworkMessage::Pong( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 20u8 => Ok(NetworkMessage::SendCmpct( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 21u8 => Ok(NetworkMessage::Tx( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 22u8 => Ok(NetworkMessage::GetCFilters( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 23u8 => Ok(NetworkMessage::CFilter( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 24u8 => Ok(NetworkMessage::GetCFHeaders( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 25u8 => Ok(NetworkMessage::CFHeaders( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 26u8 => Ok(NetworkMessage::GetCFCheckpt( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 27u8 => Ok(NetworkMessage::CFCheckpt( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + 28u8 => Ok(NetworkMessage::AddrV2( + Decodable::consensus_decode(&mut payload_buffer).map_err(|_| Error::Deserialize)?, + )), + + // Unsupported short ID. + _ => Err(Error::Deserialize), + } +} + +// Copied from rust-bitcoin internals. +// +// Only the deserialized side needs to be copied over since +// the serialize side is applied at the NetworkMessage level. +struct HeaderDeserializationWrapper(Vec); + +impl Decodable for HeaderDeserializationWrapper { + #[inline] + fn consensus_decode_from_finite_reader( + r: &mut R, + ) -> Result { + let len = VarInt::consensus_decode(r)?.0; + // should be above usual number of items to avoid + // allocation + let mut ret = Vec::with_capacity(core::cmp::min(1024 * 16, len as usize)); + for _ in 0..len { + ret.push(Decodable::consensus_decode(r)?); + if u8::consensus_decode(r)? != 0u8 { + return Err(encode::Error::ParseFailed( + "Headers message should not contain transactions", + )); + } + } + Ok(HeaderDeserializationWrapper(ret)) + } +} diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index a5c9100..e2daa66 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -9,7 +9,7 @@ readme = "README.md" rust-version = "1.56.1" [dependencies] -bitcoin = { version = "0.31.2", default-features = false, features = ["no-std"] } +bitcoin = { version = "0.31.2" } tokio = { version = "1.37.0", features = ["full"] } bytes = "1.6.0" hex = { package = "hex-conservative", version = "0.2.0" } diff --git a/proxy/src/bin/async.rs b/proxy/src/bin/async.rs index 00562ba..fb63b1d 100644 --- a/proxy/src/bin/async.rs +++ b/proxy/src/bin/async.rs @@ -1,19 +1,24 @@ // SPDX-License-Identifier: MIT OR Apache-2.0 +use core::panic; + use bip324::{Handshake, Role}; use bip324_proxy::{read_v1, read_v2, write_v1, write_v2}; use bitcoin::Network; use bytes::BytesMut; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; -use tokio::select; /// Validate and bootstrap proxy connection. -async fn proxy_conn(mut client: TcpStream) -> Result<(), bip324_proxy::Error> { - let remote_ip = bip324_proxy::peek_addr(&client).await?; +async fn proxy_conn(client: TcpStream) -> Result<(), bip324_proxy::Error> { + let remote_ip = bip324_proxy::peek_addr(&client) + .await + .expect("peek address"); println!("Reaching out to {}.", remote_ip); - let mut remote = TcpStream::connect(remote_ip).await?; + let mut remote = TcpStream::connect(remote_ip) + .await + .expect("connect to remote"); println!("Initiating handshake."); let mut local_material_message = vec![0u8; 64]; @@ -23,14 +28,22 @@ async fn proxy_conn(mut client: TcpStream) -> Result<(), bip324_proxy::Error> { None, &mut local_material_message, ) - .unwrap(); - remote.write_all(&local_material_message).await?; + .expect("generate handshake"); + + remote + .write_all(&local_material_message) + .await + .expect("send local materials"); + println!("Sent handshake to remote."); // 64 bytes ES. let mut remote_material_message = [0u8; 64]; println!("Reading handshake response from remote."); - remote.read_exact(&mut remote_material_message).await?; + remote + .read_exact(&mut remote_material_message) + .await + .expect("read remote materials"); println!("Completing materials."); let mut local_garbage_terminator_message = [0u8; 36]; @@ -39,13 +52,16 @@ async fn proxy_conn(mut client: TcpStream) -> Result<(), bip324_proxy::Error> { remote_material_message, &mut local_garbage_terminator_message, ) - .unwrap(); + .expect("complete materials"); println!("Sending garbage terminator and version packet."); - remote.write_all(&local_garbage_terminator_message).await?; + remote + .write_all(&local_garbage_terminator_message) + .await + .expect("send garbage and version"); // Keep pulling bytes from the buffer until the garbage is flushed. - // TODO: Fix arbitrary size. + // Capacity is arbitrary, could use some tuning. let mut remote_garbage_and_version_buffer = BytesMut::with_capacity(4096); loop { println!("Authenticating garbage and version packet..."); @@ -53,7 +69,7 @@ async fn proxy_conn(mut client: TcpStream) -> Result<(), bip324_proxy::Error> { .read_buf(&mut remote_garbage_and_version_buffer) .await; match read { - Err(e) => break Err(bip324_proxy::Error::Network(e)), + Err(e) => panic!("unable to read garbage {}", e), _ => { let auth = handshake.authenticate_garbage_and_version(&remote_garbage_and_version_buffer); @@ -61,49 +77,60 @@ async fn proxy_conn(mut client: TcpStream) -> Result<(), bip324_proxy::Error> { Err(e) => match e { // Read again if too small, other wise surface error. bip324::Error::MessageLengthTooSmall => continue, - e => break Err(bip324_proxy::Error::Cipher(e)), + e => panic!("unable to authenticate garbage {}", e), }, - _ => break Ok(()), + _ => { + println!("Channel authenticated."); + break; + } } } } - }?; + } - println!("Channel authenticated."); + let packet_handler = handshake.finalize().expect("finished handshake"); println!("Splitting channels."); - let packet_handler = handshake.finalize().expect("finished handshake"); - let (mut client_reader, mut client_writer) = client.split(); - let (mut remote_reader, mut remote_writer) = remote.split(); - let (mut decrypter, mut encrypter) = packet_handler.split(); + let (mut client_reader, mut client_writer) = client.into_split(); + let (mut remote_reader, mut remote_writer) = remote.into_split(); + let (mut decrypter, mut encrypter) = packet_handler.into_split(); - println!("Setting up proxy loop."); - loop { - select! { - res = read_v1(&mut client_reader) => { - match res { - Ok(msg) => { - println!("Read {} message from client, writing to remote.", msg.cmd); - write_v2(&mut remote_writer, &mut encrypter, msg).await?; - }, - Err(e) => { - return Err(e); - }, - } - }, - res = read_v2(&mut remote_reader, &mut decrypter) => { - match res { - Ok(msg) => { - println!("Read {} message from remote, writing to client.", msg.cmd); - write_v1(&mut client_writer, msg).await?; - }, - Err(e) => { - return Err(e); - }, - } - }, + println!("Setting up proxy loops."); + + // Spawning two threads instead of selecting on one due + // to the IO calls not being cancellation safe. A select + // drops other futures when one is ready, so it is + // possible that it drops one with half read state. + + tokio::spawn(async move { + loop { + let msg = read_v1(&mut client_reader).await.expect("read from client"); + println!( + "Read {} message from client, writing to remote.", + msg.command() + ); + write_v2(&mut remote_writer, &mut encrypter, msg) + .await + .expect("write to remote"); } - } + }); + + tokio::spawn(async move { + loop { + let msg = read_v2(&mut remote_reader, &mut decrypter) + .await + .expect("read from remote"); + println!( + "Read {} message from remote, writing to client.", + msg.command() + ); + write_v1(&mut client_writer, msg) + .await + .expect("write to client"); + } + }); + + Ok(()) } #[tokio::main] @@ -124,10 +151,10 @@ async fn main() { tokio::spawn(async move { match proxy_conn(stream).await { Ok(_) => { - println!("Ended connection with no errors."); + println!("Proxy establilshed."); } Err(e) => { - println!("Ended connection with error: {e}."); + println!("Connection ended with error: {e}."); } }; }); diff --git a/proxy/src/bin/v1.rs b/proxy/src/bin/v1.rs index 7567f0f..8df7a62 100644 --- a/proxy/src/bin/v1.rs +++ b/proxy/src/bin/v1.rs @@ -3,45 +3,51 @@ use bip324_proxy::{read_v1, write_v1}; use tokio::net::{TcpListener, TcpStream}; -use tokio::select; /// Validate and bootstrap proxy connection. -async fn proxy_conn(mut client: TcpStream) -> Result<(), bip324_proxy::Error> { +async fn proxy_conn(client: TcpStream) -> Result<(), bip324_proxy::Error> { let remote_ip = bip324_proxy::peek_addr(&client).await?; println!("Initialing remote connection {}.", remote_ip); - let mut remote = TcpStream::connect(remote_ip).await?; + let remote = TcpStream::connect(remote_ip).await?; - let (mut client_reader, mut client_writer) = client.split(); - let (mut remote_reader, mut remote_writer) = remote.split(); + let (mut client_reader, mut client_writer) = client.into_split(); + let (mut remote_reader, mut remote_writer) = remote.into_split(); println!("Setting up proxy loop."); - loop { - select! { - res = read_v1(&mut client_reader) => { - match res { - Ok(msg) => { - println!("Read {} message from client, writing to remote.", msg.cmd); - write_v1(&mut remote_writer, msg).await?; - }, - Err(e) => { - return Err(e); - }, - } - }, - res = read_v1(&mut remote_reader) => { - match res { - Ok(msg) => { - println!("Read {} message from remote, writing to client.", msg.cmd); - write_v1(&mut client_writer, msg).await?; - }, - Err(e) => { - return Err(e); - }, - } - }, + + // Spawning two threads instead of selecting on one due + // to the IO calls not being cancellation safe. A select + // drops other futures when one is ready, so it is + // possible that it drops one with half read state. + + tokio::spawn(async move { + loop { + let msg = read_v1(&mut client_reader).await.expect("read from client"); + println!( + "Read {} message from client, writing to remote.", + msg.command() + ); + write_v1(&mut remote_writer, msg) + .await + .expect("write to remote"); } - } + }); + + tokio::spawn(async move { + loop { + let msg = read_v1(&mut remote_reader).await.expect("read from remote"); + println!( + "Read {} message from remote, writing to client.", + msg.command() + ); + write_v1(&mut client_writer, msg) + .await + .expect("write to client"); + } + }); + + Ok(()) } #[tokio::main] @@ -62,7 +68,7 @@ async fn main() { tokio::spawn(async move { match proxy_conn(stream).await { Ok(_) => { - println!("Ended connection with no errors."); + println!("Proxy establilshed."); } Err(e) => { println!("Ended connection with error: {e}."); diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index bd6e65a..342378a 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -3,16 +3,19 @@ //! Helper functions for bitcoin p2p proxies. //! //! The V1 and V2 p2p protocols have different header encodings, so a proxy has to do -//! a little more work than just encrypt/decrypt. +//! a little more work than just encrypt/decrypt. The [NetworkMessage](bitcoin::p2p::message::NetworkMessage) +//! type is the intermediate state for messages. The V1 side can use the RawNetworkMessage wrapper, but the V2 side +//! cannot since things like the checksum are not relevant (those responsibilites are pushed +//! onto the transport in V2). use std::fmt; use std::net::SocketAddr; +use bip324::serde::{deserialize, serialize}; use bip324::ReceivedMessage; use bip324::{PacketReader, PacketWriter}; -use bitcoin::consensus::Decodable; -use bitcoin::hashes::sha256d; -use bitcoin::hashes::Hash; +use bitcoin::consensus::{Decodable, Encodable}; +use bitcoin::p2p::message::{NetworkMessage, RawNetworkMessage}; use bitcoin::p2p::{Address, Magic}; use hex::prelude::*; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -28,47 +31,12 @@ const VERSION_COMMAND: [u8; 12] = [ 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x00, ]; -/// A subset of commands are represented with a single byte -/// in V2 instead of the 12-byte ASCII encoding like V1. The -/// indexes of the commands in the list corresponds to their -/// ID in the protocol, but needs +1 since the zero indexed -/// is reserved to indicated a 12-bytes representation. -const V2_SHORTID_COMMANDS: &[&str] = &[ - "addr", - "block", - "blocktxn", - "cmpctblock", - "feefilter", - "filteradd", - "filterclear", - "filterload", - "getblocks", - "getblocktxn", - "getdata", - "getheaders", - "headers", - "inv", - "mempool", - "merkleblock", - "notfound", - "ping", - "pong", - "sendcmpct", - "tx", - "getcfilters", - "cfilter", - "getcfheaders", - "cfheaders", - "getcfcheckpt", - "cfcheckpt", - "addrv2", -]; - /// An error occured while establishing the proxy connection or during the main loop. #[derive(Debug)] pub enum Error { WrongNetwork, WrongCommand, + Serde, Network(std::io::Error), Cipher(bip324::Error), } @@ -76,10 +44,11 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Error::WrongNetwork => write!(f, "Recieved message on wrong network"), - Error::Network(e) => write!(f, "Network error {}", e), - Error::WrongCommand => write!(f, "Recieved message with wrong command"), - Error::Cipher(e) => write!(f, "Cipher encryption/decrytion error {}", e), + Error::WrongNetwork => write!(f, "recieved message on wrong network"), + Error::Network(e) => write!(f, "network {}", e), + Error::WrongCommand => write!(f, "recieved message with wrong command"), + Error::Cipher(e) => write!(f, "cipher encryption/decrytion error {}", e), + Error::Serde => write!(f, "unable to serialize command"), } } } @@ -91,6 +60,7 @@ impl std::error::Error for Error { Error::WrongNetwork => None, Error::WrongCommand => None, Error::Cipher(e) => Some(e), + Error::Serde => None, } } } @@ -108,12 +78,6 @@ impl From for Error { } } -/// Parsed message. -pub struct Message { - pub cmd: String, - pub payload: Vec, -} - /// Peek the input stream and pluck the remote address based on the version message. pub async fn peek_addr(client: &TcpStream) -> Result { println!("Validating client connection."); @@ -141,28 +105,37 @@ pub async fn peek_addr(client: &TcpStream) -> Result { Ok(socket_addr) } -/// Read a network message off of the input stream. -pub async fn read_v1(input: &mut T) -> Result { - let mut header_bytes = [0; V1_HEADER_BYTES]; +/// Read a v1 message off of the input stream. +/// +/// This future is not cancellation safe since state is read multiple times and depends on read_exact. +pub async fn read_v1(input: &mut T) -> Result { + let mut header_bytes = [0u8; V1_HEADER_BYTES]; input.read_exact(&mut header_bytes).await?; - let cmd = to_ascii(header_bytes[4..16].try_into().expect("12 bytes")); let payload_len = u32::from_le_bytes( header_bytes[16..20] .try_into() .expect("4 header length bytes"), ); - let mut payload = vec![0u8; payload_len as usize]; - input.read_exact(&mut payload).await?; + let mut full_bytes = vec![0u8; V1_HEADER_BYTES + payload_len as usize]; + full_bytes[0..V1_HEADER_BYTES].copy_from_slice(&header_bytes[..]); + let payload_bytes = &mut full_bytes[V1_HEADER_BYTES..]; + input.read_exact(payload_bytes).await?; - Ok(Message { cmd, payload }) + let message = RawNetworkMessage::consensus_decode(&mut &full_bytes[..]).expect("decode v1"); + + // todo: drop this clone? + Ok(message.payload().clone()) } +/// Read a v2 message off the input stream. +/// +/// This future is not cancellation safe since state is read multiple times and depends on read_exact. pub async fn read_v2( input: &mut T, decrypter: &mut PacketReader, -) -> Result { +) -> Result { let mut length_bytes = [0u8; 3]; input.read_exact(&mut length_bytes).await?; let packet_bytes_len = decrypter.decypt_len(length_bytes); @@ -177,76 +150,35 @@ pub async fn read_v2( .message .expect("not a decoy"); - // If packet is using short or full ID. - let (cmd, cmd_index) = if contents.starts_with(&[0u8]) { - (to_ascii(contents[1..13].try_into().expect("12 bytes")), 13) - } else { - ( - V2_SHORTID_COMMANDS[(contents[0] as u8 - 1) as usize].to_string(), - 1, - ) - }; - - let payload = contents[cmd_index..].to_vec(); - Ok(Message { cmd, payload }) + let message = deserialize(&contents).map_err(|_| Error::Serde)?; + Ok(message) } -/// Write the message to the output stream as a v1 packet. -pub async fn write_v1(output: &mut T, msg: Message) -> Result<(), Error> { - let mut write_bytes = vec![]; - // 4 bytes of network magic. - write_bytes.extend_from_slice(DEFAULT_MAGIC.to_bytes().as_slice()); - // 12 bytes for the command as encoded ascii. - write_bytes.extend_from_slice(from_ascii(msg.cmd).as_slice()); - // 4 bytes for length, little endian. - let length_bytes = (msg.payload.len() as u32).to_le_bytes(); - write_bytes.extend_from_slice(length_bytes.as_slice()); - // First 4 bytes of double sha256 digest is checksum. - let checksum: [u8; 4] = sha256d::Hash::hash(msg.payload.as_slice()).as_byte_array()[..4] - .try_into() - .expect("4 byte checksum"); - write_bytes.extend_from_slice(checksum.as_slice()); - // Finally write the payload. - write_bytes.extend_from_slice(msg.payload.as_slice()); - Ok(output.write_all(&write_bytes).await?) +/// Write message to the output stream using v1. +pub async fn write_v1( + output: &mut T, + msg: NetworkMessage, +) -> Result<(), Error> { + let raw = RawNetworkMessage::new(DEFAULT_MAGIC, msg); + let mut buffer = vec![]; + raw.consensus_encode(&mut buffer) + .map_err(|_| Error::Serde)?; + output.write_all(&buffer[..]).await?; + output.flush().await?; + Ok(()) } -/// Write the network message to the output stream. +/// Write the network message to the output stream using v2. pub async fn write_v2( output: &mut T, encrypter: &mut PacketWriter, - msg: Message, + msg: NetworkMessage, ) -> Result<(), Error> { - let mut contents = vec![]; - let shortid_index = V2_SHORTID_COMMANDS.iter().position(|w| w == &&msg.cmd[..]); - match shortid_index { - Some(id) => { - let encoded_id = (id + 1) as u8; - contents.push(encoded_id); - } - None => { - contents.push(0u8); - contents.extend_from_slice(from_ascii(msg.cmd).as_slice()); - } - } - - contents.extend_from_slice(msg.payload.as_slice()); + let payload = serialize(msg).map_err(|_| Error::Serde)?; let write_bytes = encrypter - .prepare_packet_with_alloc(&contents, None, false) + .prepare_packet_with_alloc(&payload, None, false) .expect("encryption"); - Ok(output.write_all(&write_bytes).await?) -} - -fn to_ascii(bytes: [u8; 12]) -> String { - String::from_utf8(bytes.to_vec()) - .expect("ascii") - .trim_end_matches("00") - .to_string() -} - -fn from_ascii(ascii: String) -> [u8; 12] { - let mut output_bytes = [0u8; 12]; - let cmd_bytes = ascii.as_bytes(); - output_bytes[0..cmd_bytes.len()].copy_from_slice(cmd_bytes); - output_bytes + output.write_all(&write_bytes[..]).await?; + output.flush().await?; + Ok(()) }