Skip to content

Commit

Permalink
fixup! feat: Add option to set a request timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
threema-donat committed Apr 30, 2024
1 parent 8f47453 commit 7555c2f
Showing 1 changed file with 58 additions and 67 deletions.
125 changes: 58 additions & 67 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub struct Client {

/// Uses [`Endpoint::Production`] by default.
#[derive(Debug, Clone)]
pub struct ClientOptions {
pub struct ClientBuilder {
/// The timeout of the HTTP requests
pub request_timeout_secs: Option<u64>,
/// The timeout for idle sockets being kept alive
Expand All @@ -74,7 +74,7 @@ pub struct ClientOptions {
pub connector: Option<HyperConnector>,
}

impl Default for ClientOptions {
impl Default for ClientBuilder {
fn default() -> Self {
Self {
pool_idle_timeout_secs: Some(600),
Expand All @@ -86,33 +86,50 @@ impl Default for ClientOptions {
}
}

impl ClientOptions {
pub fn new(endpoint: Endpoint) -> Self {
Self {
endpoint,
..Default::default()
}
}

pub fn with_connector(mut self, connector: HyperConnector) -> Self {
impl ClientBuilder {
pub fn connector(mut self, connector: HyperConnector) -> Self {
self.connector = Some(connector);
self
}

pub fn with_signer(mut self, signer: Signer) -> Self {
pub fn signer(mut self, signer: Signer) -> Self {
self.signer = Some(signer);
self
}

pub fn with_request_timeout(mut self, seconds: u64) -> Self {
pub fn request_timeout(mut self, seconds: u64) -> Self {
self.request_timeout_secs = Some(seconds);
self
}

pub fn with_pool_idle_timeout(mut self, seconds: u64) -> Self {
pub fn pool_idle_timeout(mut self, seconds: u64) -> Self {
self.pool_idle_timeout_secs = Some(seconds);
self
}

pub fn endpoint(mut self, endpoint: Endpoint) -> Self {
self.endpoint = endpoint;
self
}

pub fn build(self) -> Client {
let ClientBuilder {
request_timeout_secs,
pool_idle_timeout_secs,
endpoint,
signer,
connector,
} = self;
let http_client = HttpClient::builder(TokioExecutor::new())
.pool_idle_timeout(pool_idle_timeout_secs.map(Duration::from_secs))
.http2_only(true)
.build(connector.unwrap_or_else(default_connector));

Client {
http_client,
options: ConnectionOptions::new(endpoint, signer, request_timeout_secs),
}
}
}

#[derive(Debug, Clone)]
Expand All @@ -134,23 +151,8 @@ impl ConnectionOptions {
}

impl Client {
fn new(options: ClientOptions) -> Self {
let ClientOptions {
request_timeout_secs,
pool_idle_timeout_secs,
endpoint,
signer,
connector,
} = options;
let http_client = HttpClient::builder(TokioExecutor::new())
.pool_idle_timeout(pool_idle_timeout_secs.map(Duration::from_secs))
.http2_only(true)
.build(connector.unwrap_or_else(default_connector));

Client {
http_client,
options: ConnectionOptions::new(endpoint, signer, request_timeout_secs),
}
fn builder() -> ClientBuilder {
ClientBuilder::default()
}

/// Create a connection to APNs using the provider client certificate which
Expand All @@ -172,7 +174,7 @@ impl Client {
};
let connector = client_cert_connector(&cert.to_pem()?, &pkey.private_key_to_pem_pkcs8()?)?;

Ok(Self::new(ClientOptions::new(endpoint).with_connector(connector)))
Ok(Self::builder().connector(connector).endpoint(endpoint).build())
}

/// Create a connection to APNs using the raw PEM-formatted certificate and
Expand All @@ -181,7 +183,7 @@ impl Client {
pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], endpoint: Endpoint) -> Result<Client, Error> {
let connector = client_cert_connector(cert_pem, key_pem)?;

Ok(Self::new(ClientOptions::new(endpoint).with_connector(connector)))
Ok(Self::builder().endpoint(endpoint).connector(connector).build())
}

/// Create a connection to APNs using system certificates, signing every
Expand All @@ -195,13 +197,9 @@ impl Client {
R: Read,
{
let signature_ttl = Duration::from_secs(60 * 55);
let signer = Some(Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?);
let signer = Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?;

Ok(Self::new(ClientOptions {
endpoint,
signer,
..Default::default()
}))
Ok(Self::builder().endpoint(endpoint).signer(signer).build())
}

/// Send a notification payload.
Expand Down Expand Up @@ -334,18 +332,11 @@ lCEIvbDqlUhA5FOzcakkG90E8L+hRANCAATKS2ZExEybUvchRDuKBftotMwVEus3
jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
-----END PRIVATE KEY-----";

impl Client {
fn new_with_defaults() -> Self {
let options = ClientOptions::default();
Self::new(options)
}
}

#[test]
fn test_production_request_uri() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let uri = format!("{}", request.uri());

Expand All @@ -356,7 +347,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_sandbox_request_uri() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(ClientOptions::new(Endpoint::Sandbox));
let client = Client::builder().endpoint(Endpoint::Sandbox).build();
let request = client.build_request(payload).unwrap();
let uri = format!("{}", request.uri());

Expand All @@ -367,7 +358,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_method() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();

assert_eq!(&Method::POST, request.method());
Expand All @@ -377,7 +368,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_invalid() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("\r\n", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().endpoint(Endpoint::Production).build();
let request = client.build_request(payload);

assert!(matches!(request, Err(Error::BuildRequestError(_))));
Expand All @@ -387,7 +378,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_content_type() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();

assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap());
Expand All @@ -397,7 +388,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_content_length() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload.clone()).unwrap();
let payload_json = payload.to_json_string().unwrap();
let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap();
Expand All @@ -409,7 +400,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_authorization_with_no_signer() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();

assert_eq!(None, request.headers().get(AUTHORIZATION));
Expand All @@ -427,7 +418,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(ClientOptions::new(Endpoint::Production).with_signer(signer));
let client = Client::builder().endpoint(Endpoint::Production).signer(signer).build();
let request = client.build_request(payload).unwrap();

assert_ne!(None, request.headers().get(AUTHORIZATION));
Expand All @@ -441,7 +432,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
..Default::default()
};
let payload = builder.build("a_test_id", options);
let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_push_type = request.headers().get("apns-push-type").unwrap();

Expand All @@ -452,7 +443,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_with_default_priority() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority");

Expand All @@ -471,7 +462,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority").unwrap();

Expand All @@ -490,7 +481,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority").unwrap();

Expand All @@ -503,7 +494,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_id = request.headers().get("apns-id");

Expand All @@ -522,7 +513,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_id = request.headers().get("apns-id").unwrap();

Expand All @@ -535,7 +526,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_expiration = request.headers().get("apns-expiration");

Expand All @@ -554,7 +545,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_expiration = request.headers().get("apns-expiration").unwrap();

Expand All @@ -567,7 +558,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_collapse_id = request.headers().get("apns-collapse-id");

Expand All @@ -586,7 +577,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap();

Expand All @@ -599,7 +590,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_topic = request.headers().get("apns-topic");

Expand All @@ -618,7 +609,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_topic = request.headers().get("apns-topic").unwrap();

Expand All @@ -629,7 +620,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
async fn test_request_body() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new_with_defaults();
let client = Client::builder().build();
let request = client.build_request(payload.clone()).unwrap();

let body = request.into_body().collect().await.unwrap().to_bytes();
Expand Down

0 comments on commit 7555c2f

Please sign in to comment.