From d2a1b933d9f9af0b4538fca607237bfb533b7ac7 Mon Sep 17 00:00:00 2001 From: Ivan Reshetnikov Date: Fri, 16 Feb 2024 14:15:14 +0100 Subject: [PATCH] fix: clean up & tests --- relay_client/src/error.rs | 40 ++++++------ relay_client/src/http.rs | 11 ++-- relay_client/src/websocket/outbound.rs | 2 +- relay_rpc/src/rpc.rs | 30 ++++----- relay_rpc/src/rpc/error.rs | 14 ++--- relay_rpc/src/rpc/tests.rs | 84 +++++++++++++++++++++----- 6 files changed, 117 insertions(+), 64 deletions(-) diff --git a/relay_client/src/error.rs b/relay_client/src/error.rs index 28262e6..b2a5f8a 100644 --- a/relay_client/src/error.rs +++ b/relay_client/src/error.rs @@ -66,23 +66,28 @@ pub enum ClientError { impl From for ClientError { fn from(err: rpc::ErrorData) -> Self { - let rpc::ErrorData { - code, - message, - data, - } = err; - Self::Rpc { - code, - message, - data, + code: err.code, + message: err.message, + data: err.data, } } } -impl ClientError { - pub fn into_service_error(self) -> Error { - match self { +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// Client errors encountered while performing the request. + #[error(transparent)] + Client(ClientError), + + /// Error response received from the relay. + #[error(transparent)] + Response(#[from] rpc::Error), +} + +impl From for Error { + fn from(err: ClientError) -> Self { + match err { ClientError::Rpc { code, message, @@ -101,16 +106,7 @@ impl ClientError { } } - _ => Error::Client(self), + _ => Error::Client(err), } } } - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error(transparent)] - Client(#[from] ClientError), - - #[error(transparent)] - Response(#[from] rpc::Error), -} diff --git a/relay_client/src/http.rs b/relay_client/src/http.rs index 3839eb9..b73fbf7 100644 --- a/relay_client/src/http.rs +++ b/relay_client/src/http.rs @@ -197,7 +197,8 @@ impl Client { let payload = rpc::WatchRegister { register_auth: claims .encode(keypair) - .map_err(|err| HttpClientError::Jwt(err).into()) + .map_err(HttpClientError::Jwt) + .map_err(ClientError::from) .map_err(Error::Client)?, }; @@ -236,7 +237,8 @@ impl Client { let payload = rpc::WatchUnregister { unregister_auth: claims .encode(keypair) - .map_err(|err| HttpClientError::Jwt(err).into()) + .map_err(HttpClientError::Jwt) + .map_err(ClientError::from) .map_err(Error::Client)?, }; @@ -336,7 +338,8 @@ impl Client { .map_err(|_| HttpClientError::InvalidResponse) } .await - .map_err(|err| Error::Client(err.into()))?; + .map_err(ClientError::from) + .map_err(Error::Client)?; match response { rpc::Payload::Response(rpc::Response::Success(response)) => { @@ -345,7 +348,7 @@ impl Client { } rpc::Payload::Response(rpc::Response::Error(response)) => { - Err(ClientError::from(response.error).into_service_error()) + Err(ClientError::from(response.error).into()) } _ => Err(Error::Client(HttpClientError::InvalidResponse.into())), diff --git a/relay_client/src/websocket/outbound.rs b/relay_client/src/websocket/outbound.rs index a0692ab..6a927c6 100644 --- a/relay_client/src/websocket/outbound.rs +++ b/relay_client/src/websocket/outbound.rs @@ -63,7 +63,7 @@ where Err(err) => Err(err), }; - Poll::Ready(result.map_err(ClientError::into_service_error)) + Poll::Ready(result.map_err(Into::into)) } } diff --git a/relay_rpc/src/rpc.rs b/relay_rpc/src/rpc.rs index 26359fa..21f23f8 100644 --- a/relay_rpc/src/rpc.rs +++ b/relay_rpc/src/rpc.rs @@ -153,7 +153,7 @@ impl SuccessfulResponse { /// Validates the parameters. pub fn validate(&self) -> Result<(), PayloadError> { if self.jsonrpc.as_ref() != JSON_RPC_VERSION_STR { - Err(PayloadError::JsonRpcVersion) + Err(PayloadError::InvalidJsonRpcVersion) } else { // We can't really validate `serde_json::Value` without knowing the expected // value type. @@ -188,7 +188,7 @@ impl ErrorResponse { /// Validates the parameters. pub fn validate(&self) -> Result<(), PayloadError> { if self.jsonrpc.as_ref() != JSON_RPC_VERSION_STR { - Err(PayloadError::JsonRpcVersion) + Err(PayloadError::InvalidJsonRpcVersion) } else { Ok(()) } @@ -234,7 +234,7 @@ impl ServiceRequest for Subscribe { fn validate(&self) -> Result<(), PayloadError> { self.topic .decode() - .map_err(|_| PayloadError::TopicDecoding)?; + .map_err(|_| PayloadError::InvalidTopic)?; Ok(()) } @@ -262,7 +262,7 @@ impl ServiceRequest for Unsubscribe { fn validate(&self) -> Result<(), PayloadError> { self.topic .decode() - .map_err(|_| PayloadError::TopicDecoding)?; + .map_err(|_| PayloadError::InvalidTopic)?; // FIXME: Subscription ID validation is currently disabled, since SDKs do not // use the actual IDs generated by the relay, and instead send some randomized @@ -295,7 +295,7 @@ impl ServiceRequest for FetchMessages { fn validate(&self) -> Result<(), PayloadError> { self.topic .decode() - .map_err(|_| PayloadError::TopicDecoding)?; + .map_err(|_| PayloadError::InvalidTopic)?; Ok(()) } @@ -345,7 +345,7 @@ impl ServiceRequest for BatchSubscribe { } for topic in &self.topics { - topic.decode().map_err(|_| PayloadError::TopicDecoding)?; + topic.decode().map_err(|_| PayloadError::InvalidTopic)?; } Ok(()) @@ -413,7 +413,7 @@ impl ServiceRequest for BatchFetchMessages { } for topic in &self.topics { - topic.decode().map_err(|_| PayloadError::TopicDecoding)?; + topic.decode().map_err(|_| PayloadError::InvalidTopic)?; } Ok(()) @@ -460,7 +460,7 @@ impl ServiceRequest for BatchReceiveMessages { receipt .topic .decode() - .map_err(|_| PayloadError::TopicDecoding)?; + .map_err(|_| PayloadError::InvalidTopic)?; } Ok(()) @@ -544,7 +544,7 @@ impl ServiceRequest for Publish { fn validate(&self) -> Result<(), PayloadError> { self.topic .decode() - .map_err(|_| PayloadError::TopicDecoding)?; + .map_err(|_| PayloadError::InvalidTopic)?; Ok(()) } @@ -581,8 +581,8 @@ pub enum WatchError { #[error("Invalid action")] InvalidAction, - #[error("Failed to decode JWT")] - Jwt, + #[error("Invalid JWT")] + InvalidJwt, } #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -651,12 +651,12 @@ impl ServiceRequest for Subscription { fn validate(&self) -> Result<(), PayloadError> { self.id .decode() - .map_err(|_| PayloadError::SubscriptionIdDecoding)?; + .map_err(|_| PayloadError::InvalidSubscriptionId)?; self.data .topic .decode() - .map_err(|_| PayloadError::TopicDecoding)?; + .map_err(|_| PayloadError::InvalidTopic)?; Ok(()) } @@ -767,11 +767,11 @@ impl Request { /// Validates the request payload. pub fn validate(&self) -> Result<(), PayloadError> { if !self.id.validate() { - return Err(PayloadError::RequestId); + return Err(PayloadError::InvalidRequestId); } if self.jsonrpc.as_ref() != JSON_RPC_VERSION_STR { - return Err(PayloadError::JsonRpcVersion); + return Err(PayloadError::InvalidJsonRpcVersion); } match &self.params { diff --git a/relay_rpc/src/rpc/error.rs b/relay_rpc/src/rpc/error.rs index 7aa9db4..aa650db 100644 --- a/relay_rpc/src/rpc/error.rs +++ b/relay_rpc/src/rpc/error.rs @@ -41,10 +41,10 @@ pub enum AuthError { OriginNotAllowed, #[error("Invalid JWT")] - JwtInvalid, + InvalidJwt, #[error("Missing JWT")] - JwtMissing, + MissingJwt, #[error("Country blocked")] CountryBlocked, @@ -65,16 +65,16 @@ pub enum PayloadError { PayloadSizeExceeded, #[error("Topic decoding failed")] - TopicDecoding, + InvalidTopic, #[error("Subscription ID decoding failed")] - SubscriptionIdDecoding, + InvalidSubscriptionId, #[error("Invalid request ID")] - RequestId, + InvalidRequestId, #[error("Invalid JSON RPC version")] - JsonRpcVersion, + InvalidJsonRpcVersion, #[error("The batch contains too many items")] BatchLimitExceeded, @@ -95,7 +95,7 @@ pub enum InternalError { Serialization, #[error("Internal error")] - Other, + Unknown, } /// Errors caught while processing the request. These are meant to be serialized diff --git a/relay_rpc/src/rpc/tests.rs b/relay_rpc/src/rpc/tests.rs index 6cd1305..1fbb0f3 100644 --- a/relay_rpc/src/rpc/tests.rs +++ b/relay_rpc/src/rpc/tests.rs @@ -281,7 +281,7 @@ fn validation() { prompt: false, }), }; - assert_eq!(request.validate(), Err(PayloadError::RequestId)); + assert_eq!(request.validate(), Err(PayloadError::InvalidRequestId)); // Invalid JSONRPC version. let request = Request { @@ -295,7 +295,7 @@ fn validation() { prompt: false, }), }; - assert_eq!(request.validate(), Err(PayloadError::JsonRpcVersion)); + assert_eq!(request.validate(), Err(PayloadError::InvalidJsonRpcVersion)); // Publish: valid. let request = Request { @@ -323,7 +323,7 @@ fn validation() { prompt: false, }), }; - assert_eq!(request.validate(), Err(PayloadError::TopicDecoding)); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Subscribe: valid. let request = Request { @@ -345,7 +345,7 @@ fn validation() { block: false, }), }; - assert_eq!(request.validate(), Err(PayloadError::TopicDecoding)); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Unsubscribe: valid. let request = Request { @@ -367,7 +367,7 @@ fn validation() { subscription_id: subscription_id.clone(), }), }; - assert_eq!(request.validate(), Err(PayloadError::TopicDecoding)); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Fetch: valid. let request = Request { @@ -387,7 +387,7 @@ fn validation() { topic: Topic::from("invalid"), }), }; - assert_eq!(request.validate(), Err(PayloadError::TopicDecoding)); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Subscription: valid. let request = Request { @@ -419,10 +419,7 @@ fn validation() { }, }), }; - assert_eq!( - request.validate(), - Err(PayloadError::SubscriptionIdDecoding) - ); + assert_eq!(request.validate(), Err(PayloadError::InvalidSubscriptionId)); // Subscription: invalid topic. let request = Request { @@ -438,7 +435,7 @@ fn validation() { }, }), }; - assert_eq!(request.validate(), Err(PayloadError::TopicDecoding)); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Batch subscription: valid. let request = Request { @@ -487,7 +484,7 @@ fn validation() { block: false, }), }; - assert_eq!(request.validate(), Err(PayloadError::TopicDecoding)); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Batch unsubscription: valid. let request = Request { @@ -539,7 +536,7 @@ fn validation() { }], }), }; - assert_eq!(request.validate(), Err(PayloadError::TopicDecoding)); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Batch fetch: valid. let request = Request { @@ -580,7 +577,7 @@ fn validation() { )], }), }; - assert_eq!(request.validate(), Err(PayloadError::TopicDecoding)); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); // Batch receive: valid. let request = Request { @@ -630,5 +627,62 @@ fn validation() { }], }), }; - assert_eq!(request.validate(), Err(PayloadError::TopicDecoding)); + assert_eq!(request.validate(), Err(PayloadError::InvalidTopic)); +} + +#[test] +fn error_tags() { + // Validate hardcoded string tags, so that we don't accidentally break + // compatibility with other SDKs as a result of refactoring. + + assert_eq!( + SubscriptionError::SubscriberLimitExceeded.tag(), + "SubscriberLimitExceeded" + ); + + assert_eq!(PublishError::TtlTooShort.tag(), "TtlTooShort"); + assert_eq!(PublishError::TtlTooLong.tag(), "TtlTooLong"); + + assert_eq!(GenericError::Unknown.tag(), "Unknown"); + + assert_eq!(WatchError::InvalidTtl.tag(), "InvalidTtl"); + assert_eq!(WatchError::InvalidServiceUrl.tag(), "InvalidServiceUrl"); + assert_eq!(WatchError::InvalidWebhookUrl.tag(), "InvalidWebhookUrl"); + assert_eq!(WatchError::InvalidAction.tag(), "InvalidAction"); + assert_eq!(WatchError::InvalidJwt.tag(), "InvalidJwt"); + + assert_eq!(AuthError::ProjectNotFound.tag(), "ProjectNotFound"); + assert_eq!( + AuthError::ProjectIdNotSpecified.tag(), + "ProjectIdNotSpecified" + ); + assert_eq!(AuthError::ProjectInactive.tag(), "ProjectInactive"); + assert_eq!(AuthError::OriginNotAllowed.tag(), "OriginNotAllowed"); + assert_eq!(AuthError::InvalidJwt.tag(), "InvalidJwt"); + assert_eq!(AuthError::MissingJwt.tag(), "MissingJwt"); + assert_eq!(AuthError::CountryBlocked.tag(), "CountryBlocked"); + + assert_eq!(PayloadError::InvalidMethod.tag(), "InvalidMethod"); + assert_eq!(PayloadError::InvalidParams.tag(), "InvalidParams"); + assert_eq!( + PayloadError::PayloadSizeExceeded.tag(), + "PayloadSizeExceeded" + ); + assert_eq!(PayloadError::InvalidTopic.tag(), "InvalidTopic"); + assert_eq!( + PayloadError::InvalidSubscriptionId.tag(), + "InvalidSubscriptionId" + ); + assert_eq!(PayloadError::InvalidRequestId.tag(), "InvalidRequestId"); + assert_eq!( + PayloadError::InvalidJsonRpcVersion.tag(), + "InvalidJsonRpcVersion" + ); + assert_eq!(PayloadError::BatchLimitExceeded.tag(), "BatchLimitExceeded"); + assert_eq!(PayloadError::BatchEmpty.tag(), "BatchEmpty"); + assert_eq!(PayloadError::Serialization.tag(), "Serialization"); + + assert_eq!(InternalError::StorageError.tag(), "StorageError"); + assert_eq!(InternalError::Serialization.tag(), "Serialization"); + assert_eq!(InternalError::Unknown.tag(), "Unknown"); }