diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index aa29cbb..eb96bb3 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -941,14 +941,31 @@ impl<'a> Handshake<'a> { #[cfg(feature = "std")] #[derive(Debug)] pub enum ProtocolError { - Io(std::io::Error), + /// Wrap all IO errors with a flag which indicates if the error + /// possibly means the remote does not support the V2 protocol + /// and could be retried with the V1 protocol as a fallback. + Io(std::io::Error, bool), + /// Internal protocol specific errors. Internal(Error), } #[cfg(feature = "std")] impl From for ProtocolError { fn from(error: std::io::Error) -> Self { - ProtocolError::Io(error) + // Detect IO errors which possibly mean the remote doesn't understand + // the V2 protocol and immediatly closed the connection. + let retry_with_v1 = matches!( + error.kind(), + // The remote force closed the connection. + std::io::ErrorKind::ConnectionReset + // A more general error than ConnectionReset, but could be caused + // by the remote closing the connection. + | std::io::ErrorKind::ConnectionAborted + // End of file read errors can occur if the remote closes the connection, + // but the local system reads due to timing issues. + | std::io::ErrorKind::UnexpectedEof + ); + ProtocolError::Io(error, retry_with_v1) } } @@ -966,8 +983,13 @@ impl std::error::Error for ProtocolError {} impl fmt::Display for ProtocolError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - ProtocolError::Io(e) => write!(f, "IO error: {:?}", e), - ProtocolError::Internal(e) => write!(f, "Internal error: {:?}", e), + ProtocolError::Io(e, retry_with_v1) => write!( + f, + "IO error: {:?}. Possibly retry on V1 protocol: {}.", + e, + if *retry_with_v1 { "Yes" } else { "No" } + ), + ProtocolError::Internal(e) => write!(f, "Internal error: {:?}.", e), } } } @@ -996,7 +1018,11 @@ where /// A `Result` containing: /// * `Ok(AsyncProtocol)`: An initialized protocol handler. /// * `Err(ProtocolError)`: An error that occurred during the handshake. - pub async fn new( + /// + /// # Errors + /// + /// * `Io` - Includes a flag for if the remote probably only understands the V1 protocol. + pub async fn new<'a>( network: Network, role: Role, garbage: Option<&[u8]>, @@ -1059,7 +1085,7 @@ where std::io::ErrorKind::WouldBlock | std::io::ErrorKind::Interrupted => { continue; } - _ => return Err(ProtocolError::Io(e)), + _ => return Err(ProtocolError::Io(e, false)), }, } }