diff --git a/CHANGELOG.md b/CHANGELOG.md index 130b683..0de7aa9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog + ## Unreleased + + - feat: Add support to set the request timeout + ## v0.6.2 - Add support for Safari web push diff --git a/Cargo.toml b/Cargo.toml index e9fcdc8..94eda69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ base64 = "0.20" tracing = { version = "0.1", optional = true } pem = { version = "1.0", optional = true } ring = { version = "0.16", features = ["std"], optional = true } +tokio = { version = "1", features = ["time"] } [dev-dependencies] argparse = "0.2" diff --git a/src/client.rs b/src/client.rs index 2e80818..f5df133 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,6 +4,7 @@ use crate::error::Error; use crate::error::Error::ResponseError; use crate::signer::Signer; use hyper_alpn::AlpnConnector; +use tokio::time::timeout; use crate::request::payload::PayloadLike; use crate::response::Response; @@ -13,6 +14,8 @@ use std::fmt; use std::io::Read; use std::time::Duration; +const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 20; + /// The APNs service endpoint to connect. #[derive(Debug, Clone)] pub enum Endpoint { @@ -43,23 +46,96 @@ impl fmt::Display for Endpoint { /// holds the response for handling. #[derive(Debug, Clone)] pub struct Client { - endpoint: Endpoint, - signer: Option, + options: ConnectionOptions, http_client: HttpClient, } -impl Client { - fn new(connector: AlpnConnector, signer: Option, endpoint: Endpoint) -> Client { - let mut builder = HttpClient::builder(); - builder.pool_idle_timeout(Some(Duration::from_secs(600))); - builder.http2_only(true); +/// Uses [`Endpoint::Production`] by default. +#[derive(Debug, Clone)] +pub struct ClientOptions { + /// The timeout of the HTTP requests + pub request_timeout_secs: Option, + /// The timeout for idle sockets being kept alive + pub pool_idle_timeout_secs: Option, + /// The endpoint where the requests are sent to + pub endpoint: Endpoint, + /// See [`crate::signer::Signer`] + pub signer: Option, +} + +impl Default for ClientOptions { + fn default() -> Self { + Self { + pool_idle_timeout_secs: Some(600), + request_timeout_secs: Some(DEFAULT_REQUEST_TIMEOUT_SECS), + endpoint: Endpoint::Production, + signer: None, + } + } +} - Client { - http_client: builder.build(connector), +impl ClientOptions { + pub fn new(endpoint: Endpoint) -> Self { + Self { + endpoint, + ..Default::default() + } + } + + pub fn with_signer(mut self, signer: Signer) -> Self { + self.signer = Some(signer); + self + } + + pub fn with_request_timeout(mut self, seconds: u64) -> Self { + self.request_timeout_secs = Some(seconds); + self + } + + pub fn with_pool_idle_timeout(mut self, seconds: u64) -> Self { + self.pool_idle_timeout_secs = Some(seconds); + self + } +} + +#[derive(Debug, Clone)] +struct ConnectionOptions { + endpoint: Endpoint, + request_timeout: Duration, + signer: Option, +} + +impl From for ConnectionOptions { + fn from(value: ClientOptions) -> Self { + let ClientOptions { + endpoint, + pool_idle_timeout_secs: _, signer, + request_timeout_secs, + } = value; + let request_timeout = Duration::from_secs(request_timeout_secs.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS)); + Self { endpoint, + request_timeout, + signer, } } +} + +impl Client { + /// If `options` is not set, a default using [`Endpoint::Production`] will + /// be initialized. + fn new(connector: AlpnConnector, options: Option) -> Client { + let options = options.unwrap_or_default(); + let http_client = HttpClient::builder() + .pool_idle_timeout(options.pool_idle_timeout_secs.map(Duration::from_secs)) + .http2_only(true) + .build(connector); + + let options = options.into(); + + Client { http_client, options } + } /// Create a connection to APNs using the provider client certificate which /// you obtain from your [Apple developer @@ -77,7 +153,7 @@ impl Client { let pkcs = openssl::pkcs12::Pkcs12::from_der(&cert_der)?.parse(password)?; let connector = AlpnConnector::with_client_cert(&pkcs.cert.to_pem()?, &pkcs.pkey.private_key_to_pem_pkcs8()?)?; - Ok(Self::new(connector, None, endpoint)) + Ok(Self::new(connector, Some(ClientOptions::new(endpoint)))) } /// Create a connection to APNs using the raw PEM-formatted certificate and @@ -86,7 +162,7 @@ impl Client { pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], endpoint: Endpoint) -> Result { let connector = AlpnConnector::with_client_cert(cert_pem, key_pem)?; - Ok(Self::new(connector, None, endpoint)) + Ok(Self::new(connector, Some(ClientOptions::new(endpoint)))) } /// Create a connection to APNs using system certificates, signing every @@ -101,9 +177,16 @@ impl Client { { let connector = AlpnConnector::new(); let signature_ttl = Duration::from_secs(60 * 55); - let signer = Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?; + let signer = Some(Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?); - Ok(Self::new(connector, Some(signer), endpoint)) + Ok(Self::new( + connector, + Some(ClientOptions { + endpoint, + signer, + ..Default::default() + }), + )) } /// Send a notification payload. @@ -114,7 +197,11 @@ impl Client { let request = self.build_request(payload); let requesting = self.http_client.request(request); - let response = requesting.await?; + let Ok(response_result) = timeout(self.options.request_timeout, requesting).await else { + return Err(Error::RequestTimeout(self.options.request_timeout.as_secs())); + }; + + let response = response_result?; let apns_id = response .headers() @@ -141,7 +228,11 @@ impl Client { } fn build_request(&self, payload: T) -> hyper::Request { - let path = format!("https://{}/3/device/{}", self.endpoint, payload.get_device_token()); + let path = format!( + "https://{}/3/device/{}", + self.options.endpoint, + payload.get_device_token() + ); let mut builder = hyper::Request::builder() .uri(&path) @@ -167,7 +258,7 @@ impl Client { if let Some(apns_topic) = options.apns_topic { builder = builder.header("apns-topic", apns_topic.as_bytes()); } - if let Some(ref signer) = self.signer { + if let Some(ref signer) = self.options.signer { let auth = signer .with_signature(|signature| format!("Bearer {}", signature)) .unwrap(); @@ -205,7 +296,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_production_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); let uri = format!("{}", request.uri()); @@ -216,7 +307,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_sandbox_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Sandbox); + let client = Client::new(AlpnConnector::new(), Some(ClientOptions::new(Endpoint::Sandbox))); let request = client.build_request(payload); let uri = format!("{}", request.uri()); @@ -227,7 +318,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_method() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); assert_eq!(&Method::POST, request.method()); @@ -237,7 +328,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_content_type() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap()); @@ -247,7 +338,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_content_length() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload.clone()); let payload_json = payload.to_json_string().unwrap(); let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap(); @@ -259,7 +350,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_authorization_with_no_signer() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); assert_eq!(None, request.headers().get(AUTHORIZATION)); @@ -277,7 +368,10 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(AlpnConnector::new(), Some(signer), Endpoint::Production); + let client = Client::new( + AlpnConnector::new(), + Some(ClientOptions::new(Endpoint::Production).with_signer(signer)), + ); let request = client.build_request(payload); assert_ne!(None, request.headers().get(AUTHORIZATION)); @@ -291,7 +385,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ..Default::default() }; let payload = builder.build("a_test_id", options); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); let apns_push_type = request.headers().get("apns-push-type").unwrap(); @@ -302,7 +396,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_with_default_priority() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); let apns_priority = request.headers().get("apns-priority"); @@ -321,7 +415,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); let apns_priority = request.headers().get("apns-priority").unwrap(); @@ -340,7 +434,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); let apns_priority = request.headers().get("apns-priority").unwrap(); @@ -353,7 +447,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); let apns_id = request.headers().get("apns-id"); @@ -372,7 +466,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); let apns_id = request.headers().get("apns-id").unwrap(); @@ -385,7 +479,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); let apns_expiration = request.headers().get("apns-expiration"); @@ -404,7 +498,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); let apns_expiration = request.headers().get("apns-expiration").unwrap(); @@ -417,7 +511,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); let apns_collapse_id = request.headers().get("apns-collapse-id"); @@ -436,7 +530,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap(); @@ -449,7 +543,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); let apns_topic = request.headers().get("apns-topic"); @@ -468,7 +562,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload); let apns_topic = request.headers().get("apns-topic").unwrap(); @@ -479,7 +573,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ async fn test_request_body() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(AlpnConnector::new(), None, Endpoint::Production); + let client = Client::new(AlpnConnector::new(), None); let request = client.build_request(payload.clone()); let body = hyper::body::to_bytes(request).await.unwrap(); @@ -497,7 +591,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let cert: Vec = include_str!("../test_cert/test.crt").bytes().collect(); let c = Client::certificate_parts(&cert, &key, Endpoint::Sandbox)?; - assert!(c.signer.is_none()); + assert!(c.options.signer.is_none()); Ok(()) } } diff --git a/src/error.rs b/src/error.rs index af7ffb1..8477d80 100644 --- a/src/error.rs +++ b/src/error.rs @@ -38,6 +38,10 @@ pub enum Error { #[error("Error in reading a certificate file: {0}")] ReadError(#[from] io::Error), + /// No repsonse from APNs after the given amount of time + #[error("The request timed out after {0} s")] + RequestTimeout(u64), + /// Unexpected private key (only EC keys are supported). #[cfg(all(not(feature = "openssl"), feature = "ring"))] #[error("Unexpected private key: {0}")]