From 7ebbdaf3016b404f57843edd44c3cd8c734de60e Mon Sep 17 00:00:00 2001 From: threema-donat <129288638+threema-donat@users.noreply.github.com> Date: Fri, 26 Apr 2024 11:11:11 +0200 Subject: [PATCH] fixup! feat: Add option to set a request timeout --- src/client.rs | 109 +++++++++++++++++++++++++++----------------------- 1 file changed, 58 insertions(+), 51 deletions(-) diff --git a/src/client.rs b/src/client.rs index 33dd873..d07a8e1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -70,6 +70,8 @@ pub struct ClientOptions { pub endpoint: Endpoint, /// See [`crate::signer::Signer`] pub signer: Option, + /// The HTTPS connector used to connect to APNs + pub connector: Option, } impl Default for ClientOptions { @@ -79,6 +81,7 @@ impl Default for ClientOptions { request_timeout_secs: Some(DEFAULT_REQUEST_TIMEOUT_SECS), endpoint: Endpoint::Production, signer: None, + connector: Some(default_connector()), } } } @@ -91,6 +94,11 @@ impl ClientOptions { } } + pub fn with_connector(mut self, connector: HyperConnector) -> Self { + self.connector = Some(connector); + self + } + pub fn with_signer(mut self, signer: Signer) -> Self { self.signer = Some(signer); self @@ -114,14 +122,8 @@ struct ConnectionOptions { signer: Option, } -impl From for ConnectionOptions { - fn from(value: ClientOptions) -> Self { - let ClientOptions { - endpoint, - pool_idle_timeout_secs: _, - signer, - request_timeout_secs, - } = value; +impl ConnectionOptions { + fn new(endpoint: Endpoint, signer: Option, request_timeout_secs: Option) -> Self { let request_timeout = Duration::from_secs(request_timeout_secs.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS)); Self { endpoint, @@ -132,18 +134,23 @@ impl From for ConnectionOptions { } impl Client { - /// If `options` is not set, a default using [`Endpoint::Production`] will - /// be initialized. - fn new(connector: HyperConnector, options: Option) -> Client { - let options = options.unwrap_or_default(); + fn new(options: ClientOptions) -> Self { + let ClientOptions { + request_timeout_secs, + pool_idle_timeout_secs, + endpoint, + signer, + connector, + } = options; let http_client = HttpClient::builder(TokioExecutor::new()) - .pool_idle_timeout(options.pool_idle_timeout_secs.map(Duration::from_secs)) + .pool_idle_timeout(pool_idle_timeout_secs.map(Duration::from_secs)) .http2_only(true) - .build(connector); + .build(connector.unwrap_or_else(default_connector)); - let options = options.into(); - - Client { http_client, options } + Client { + http_client, + options: ConnectionOptions::new(endpoint, signer, request_timeout_secs), + } } /// Create a connection to APNs using the provider client certificate which @@ -165,7 +172,7 @@ impl Client { }; let connector = client_cert_connector(&cert.to_pem()?, &pkey.private_key_to_pem_pkcs8()?)?; - Ok(Self::new(connector, Some(ClientOptions::new(endpoint)))) + Ok(Self::new(ClientOptions::new(endpoint).with_connector(connector))) } /// Create a connection to APNs using the raw PEM-formatted certificate and @@ -174,7 +181,7 @@ impl Client { pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], endpoint: Endpoint) -> Result { let connector = client_cert_connector(cert_pem, key_pem)?; - Ok(Self::new(connector, Some(ClientOptions::new(endpoint)))) + Ok(Self::new(ClientOptions::new(endpoint).with_connector(connector))) } /// Create a connection to APNs using system certificates, signing every @@ -187,18 +194,14 @@ impl Client { T: Into, R: Read, { - let connector = default_connector(); let signature_ttl = Duration::from_secs(60 * 55); let signer = Some(Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?); - Ok(Self::new( - connector, - Some(ClientOptions { - endpoint, - signer, - ..Default::default() - }), - )) + Ok(Self::new(ClientOptions { + endpoint, + signer, + ..Default::default() + })) } /// Send a notification payload. @@ -333,11 +336,18 @@ lCEIvbDqlUhA5FOzcakkG90E8L+hRANCAATKS2ZExEybUvchRDuKBftotMwVEus3 jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ -----END PRIVATE KEY-----"; + impl Client { + fn new_with_defaults() -> Self { + let options = ClientOptions::default(); + Self::new(options) + } + } + #[test] fn test_production_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); let uri = format!("{}", request.uri()); @@ -348,7 +358,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_sandbox_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), Some(ClientOptions::new(Endpoint::Sandbox))); + let client = Client::new(ClientOptions::new(Endpoint::Sandbox)); let request = client.build_request(payload); let uri = format!("{}", request.uri()); @@ -359,7 +369,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_method() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); assert_eq!(&Method::POST, request.method()); @@ -369,7 +379,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_content_type() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap()); @@ -379,7 +389,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_content_length() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload.clone()); let payload_json = payload.to_json_string().unwrap(); let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap(); @@ -391,7 +401,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_authorization_with_no_signer() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); assert_eq!(None, request.headers().get(AUTHORIZATION)); @@ -409,10 +419,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new( - default_connector(), - Some(ClientOptions::new(Endpoint::Production).with_signer(signer)), - ); + let client = Client::new(ClientOptions::new(Endpoint::Production).with_signer(signer)); let request = client.build_request(payload); assert_ne!(None, request.headers().get(AUTHORIZATION)); @@ -426,7 +433,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ..Default::default() }; let payload = builder.build("a_test_id", options); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); let apns_push_type = request.headers().get("apns-push-type").unwrap(); @@ -437,7 +444,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_with_default_priority() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); let apns_priority = request.headers().get("apns-priority"); @@ -456,7 +463,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); let apns_priority = request.headers().get("apns-priority").unwrap(); @@ -475,7 +482,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); let apns_priority = request.headers().get("apns-priority").unwrap(); @@ -488,7 +495,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); let apns_id = request.headers().get("apns-id"); @@ -507,7 +514,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); let apns_id = request.headers().get("apns-id").unwrap(); @@ -520,7 +527,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); let apns_expiration = request.headers().get("apns-expiration"); @@ -539,7 +546,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); let apns_expiration = request.headers().get("apns-expiration").unwrap(); @@ -552,7 +559,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); let apns_collapse_id = request.headers().get("apns-collapse-id"); @@ -571,7 +578,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap(); @@ -584,7 +591,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); let apns_topic = request.headers().get("apns-topic"); @@ -603,7 +610,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload); let apns_topic = request.headers().get("apns-topic").unwrap(); @@ -614,7 +621,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ async fn test_request_body() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload.clone()); let body = request.into_body().collect().await.unwrap().to_bytes();