From 7555c2f131d0e0dda6d9326971a592b6ee7e7ea4 Mon Sep 17 00:00:00 2001 From: threema-donat <129288638+threema-donat@users.noreply.github.com> Date: Tue, 30 Apr 2024 08:45:28 +0200 Subject: [PATCH] fixup! feat: Add option to set a request timeout --- src/client.rs | 125 +++++++++++++++++++++++--------------------------- 1 file changed, 58 insertions(+), 67 deletions(-) diff --git a/src/client.rs b/src/client.rs index e407cd6..32224b6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -61,7 +61,7 @@ pub struct Client { /// Uses [`Endpoint::Production`] by default. #[derive(Debug, Clone)] -pub struct ClientOptions { +pub struct ClientBuilder { /// The timeout of the HTTP requests pub request_timeout_secs: Option, /// The timeout for idle sockets being kept alive @@ -74,7 +74,7 @@ pub struct ClientOptions { pub connector: Option, } -impl Default for ClientOptions { +impl Default for ClientBuilder { fn default() -> Self { Self { pool_idle_timeout_secs: Some(600), @@ -86,33 +86,50 @@ impl Default for ClientOptions { } } -impl ClientOptions { - pub fn new(endpoint: Endpoint) -> Self { - Self { - endpoint, - ..Default::default() - } - } - - pub fn with_connector(mut self, connector: HyperConnector) -> Self { +impl ClientBuilder { + pub fn connector(mut self, connector: HyperConnector) -> Self { self.connector = Some(connector); self } - pub fn with_signer(mut self, signer: Signer) -> Self { + pub fn signer(mut self, signer: Signer) -> Self { self.signer = Some(signer); self } - pub fn with_request_timeout(mut self, seconds: u64) -> Self { + pub fn request_timeout(mut self, seconds: u64) -> Self { self.request_timeout_secs = Some(seconds); self } - pub fn with_pool_idle_timeout(mut self, seconds: u64) -> Self { + pub fn pool_idle_timeout(mut self, seconds: u64) -> Self { self.pool_idle_timeout_secs = Some(seconds); self } + + pub fn endpoint(mut self, endpoint: Endpoint) -> Self { + self.endpoint = endpoint; + self + } + + pub fn build(self) -> Client { + let ClientBuilder { + request_timeout_secs, + pool_idle_timeout_secs, + endpoint, + signer, + connector, + } = self; + let http_client = HttpClient::builder(TokioExecutor::new()) + .pool_idle_timeout(pool_idle_timeout_secs.map(Duration::from_secs)) + .http2_only(true) + .build(connector.unwrap_or_else(default_connector)); + + Client { + http_client, + options: ConnectionOptions::new(endpoint, signer, request_timeout_secs), + } + } } #[derive(Debug, Clone)] @@ -134,23 +151,8 @@ impl ConnectionOptions { } impl Client { - fn new(options: ClientOptions) -> Self { - let ClientOptions { - request_timeout_secs, - pool_idle_timeout_secs, - endpoint, - signer, - connector, - } = options; - let http_client = HttpClient::builder(TokioExecutor::new()) - .pool_idle_timeout(pool_idle_timeout_secs.map(Duration::from_secs)) - .http2_only(true) - .build(connector.unwrap_or_else(default_connector)); - - Client { - http_client, - options: ConnectionOptions::new(endpoint, signer, request_timeout_secs), - } + fn builder() -> ClientBuilder { + ClientBuilder::default() } /// Create a connection to APNs using the provider client certificate which @@ -172,7 +174,7 @@ impl Client { }; let connector = client_cert_connector(&cert.to_pem()?, &pkey.private_key_to_pem_pkcs8()?)?; - Ok(Self::new(ClientOptions::new(endpoint).with_connector(connector))) + Ok(Self::builder().connector(connector).endpoint(endpoint).build()) } /// Create a connection to APNs using the raw PEM-formatted certificate and @@ -181,7 +183,7 @@ impl Client { pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], endpoint: Endpoint) -> Result { let connector = client_cert_connector(cert_pem, key_pem)?; - Ok(Self::new(ClientOptions::new(endpoint).with_connector(connector))) + Ok(Self::builder().endpoint(endpoint).connector(connector).build()) } /// Create a connection to APNs using system certificates, signing every @@ -195,13 +197,9 @@ impl Client { R: Read, { let signature_ttl = Duration::from_secs(60 * 55); - let signer = Some(Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?); + let signer = Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?; - Ok(Self::new(ClientOptions { - endpoint, - signer, - ..Default::default() - })) + Ok(Self::builder().endpoint(endpoint).signer(signer).build()) } /// Send a notification payload. @@ -334,18 +332,11 @@ lCEIvbDqlUhA5FOzcakkG90E8L+hRANCAATKS2ZExEybUvchRDuKBftotMwVEus3 jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ -----END PRIVATE KEY-----"; - impl Client { - fn new_with_defaults() -> Self { - let options = ClientOptions::default(); - Self::new(options) - } - } - #[test] fn test_production_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); @@ -356,7 +347,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_sandbox_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(ClientOptions::new(Endpoint::Sandbox)); + let client = Client::builder().endpoint(Endpoint::Sandbox).build(); let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); @@ -367,7 +358,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_method() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); assert_eq!(&Method::POST, request.method()); @@ -377,7 +368,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_invalid() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("\r\n", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().endpoint(Endpoint::Production).build(); let request = client.build_request(payload); assert!(matches!(request, Err(Error::BuildRequestError(_)))); @@ -387,7 +378,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_content_type() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap()); @@ -397,7 +388,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_content_length() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload.clone()).unwrap(); let payload_json = payload.to_json_string().unwrap(); let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap(); @@ -409,7 +400,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_authorization_with_no_signer() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); assert_eq!(None, request.headers().get(AUTHORIZATION)); @@ -427,7 +418,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(ClientOptions::new(Endpoint::Production).with_signer(signer)); + let client = Client::builder().endpoint(Endpoint::Production).signer(signer).build(); let request = client.build_request(payload).unwrap(); assert_ne!(None, request.headers().get(AUTHORIZATION)); @@ -441,7 +432,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ..Default::default() }; let payload = builder.build("a_test_id", options); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_push_type = request.headers().get("apns-push-type").unwrap(); @@ -452,7 +443,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_with_default_priority() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority"); @@ -471,7 +462,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority").unwrap(); @@ -490,7 +481,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority").unwrap(); @@ -503,7 +494,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_id = request.headers().get("apns-id"); @@ -522,7 +513,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_id = request.headers().get("apns-id").unwrap(); @@ -535,7 +526,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_expiration = request.headers().get("apns-expiration"); @@ -554,7 +545,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_expiration = request.headers().get("apns-expiration").unwrap(); @@ -567,7 +558,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_collapse_id = request.headers().get("apns-collapse-id"); @@ -586,7 +577,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap(); @@ -599,7 +590,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_topic = request.headers().get("apns-topic"); @@ -618,7 +609,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_topic = request.headers().get("apns-topic").unwrap(); @@ -629,7 +620,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ async fn test_request_body() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload.clone()).unwrap(); let body = request.into_body().collect().await.unwrap().to_bytes();