Skip to content

Commit

Permalink
fixup! feat: Add option to set a request timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
threema-donat committed Apr 26, 2024
1 parent 7186b08 commit 7ebbdaf
Showing 1 changed file with 58 additions and 51 deletions.
109 changes: 58 additions & 51 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ pub struct ClientOptions {
pub endpoint: Endpoint,
/// See [`crate::signer::Signer`]
pub signer: Option<Signer>,
/// The HTTPS connector used to connect to APNs
pub connector: Option<HyperConnector>,
}

impl Default for ClientOptions {
Expand All @@ -79,6 +81,7 @@ impl Default for ClientOptions {
request_timeout_secs: Some(DEFAULT_REQUEST_TIMEOUT_SECS),
endpoint: Endpoint::Production,
signer: None,
connector: Some(default_connector()),
}
}
}
Expand All @@ -91,6 +94,11 @@ impl ClientOptions {
}
}

pub fn with_connector(mut self, connector: HyperConnector) -> Self {
self.connector = Some(connector);
self
}

pub fn with_signer(mut self, signer: Signer) -> Self {
self.signer = Some(signer);
self
Expand All @@ -114,14 +122,8 @@ struct ConnectionOptions {
signer: Option<Signer>,
}

impl From<ClientOptions> for ConnectionOptions {
fn from(value: ClientOptions) -> Self {
let ClientOptions {
endpoint,
pool_idle_timeout_secs: _,
signer,
request_timeout_secs,
} = value;
impl ConnectionOptions {
fn new(endpoint: Endpoint, signer: Option<Signer>, request_timeout_secs: Option<u64>) -> Self {
let request_timeout = Duration::from_secs(request_timeout_secs.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS));
Self {
endpoint,
Expand All @@ -132,18 +134,23 @@ impl From<ClientOptions> for ConnectionOptions {
}

impl Client {
/// If `options` is not set, a default using [`Endpoint::Production`] will
/// be initialized.
fn new(connector: HyperConnector, options: Option<ClientOptions>) -> Client {
let options = options.unwrap_or_default();
fn new(options: ClientOptions) -> Self {
let ClientOptions {
request_timeout_secs,
pool_idle_timeout_secs,
endpoint,
signer,
connector,
} = options;
let http_client = HttpClient::builder(TokioExecutor::new())
.pool_idle_timeout(options.pool_idle_timeout_secs.map(Duration::from_secs))
.pool_idle_timeout(pool_idle_timeout_secs.map(Duration::from_secs))
.http2_only(true)
.build(connector);
.build(connector.unwrap_or_else(default_connector));

let options = options.into();

Client { http_client, options }
Client {
http_client,
options: ConnectionOptions::new(endpoint, signer, request_timeout_secs),
}
}

/// Create a connection to APNs using the provider client certificate which
Expand All @@ -165,7 +172,7 @@ impl Client {
};
let connector = client_cert_connector(&cert.to_pem()?, &pkey.private_key_to_pem_pkcs8()?)?;

Ok(Self::new(connector, Some(ClientOptions::new(endpoint))))
Ok(Self::new(ClientOptions::new(endpoint).with_connector(connector)))
}

/// Create a connection to APNs using the raw PEM-formatted certificate and
Expand All @@ -174,7 +181,7 @@ impl Client {
pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], endpoint: Endpoint) -> Result<Client, Error> {
let connector = client_cert_connector(cert_pem, key_pem)?;

Ok(Self::new(connector, Some(ClientOptions::new(endpoint))))
Ok(Self::new(ClientOptions::new(endpoint).with_connector(connector)))
}

/// Create a connection to APNs using system certificates, signing every
Expand All @@ -187,18 +194,14 @@ impl Client {
T: Into<String>,
R: Read,
{
let connector = default_connector();
let signature_ttl = Duration::from_secs(60 * 55);
let signer = Some(Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?);

Ok(Self::new(
connector,
Some(ClientOptions {
endpoint,
signer,
..Default::default()
}),
))
Ok(Self::new(ClientOptions {
endpoint,
signer,
..Default::default()
}))
}

/// Send a notification payload.
Expand Down Expand Up @@ -333,11 +336,18 @@ lCEIvbDqlUhA5FOzcakkG90E8L+hRANCAATKS2ZExEybUvchRDuKBftotMwVEus3
jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
-----END PRIVATE KEY-----";

impl Client {
fn new_with_defaults() -> Self {
let options = ClientOptions::default();
Self::new(options)
}
}

#[test]
fn test_production_request_uri() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);
let uri = format!("{}", request.uri());

Expand All @@ -348,7 +358,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_sandbox_request_uri() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), Some(ClientOptions::new(Endpoint::Sandbox)));
let client = Client::new(ClientOptions::new(Endpoint::Sandbox));
let request = client.build_request(payload);
let uri = format!("{}", request.uri());

Expand All @@ -359,7 +369,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_method() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);

assert_eq!(&Method::POST, request.method());
Expand All @@ -369,7 +379,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_content_type() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);

assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap());
Expand All @@ -379,7 +389,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_content_length() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload.clone());
let payload_json = payload.to_json_string().unwrap();
let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap();
Expand All @@ -391,7 +401,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_authorization_with_no_signer() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);

assert_eq!(None, request.headers().get(AUTHORIZATION));
Expand All @@ -409,10 +419,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(
default_connector(),
Some(ClientOptions::new(Endpoint::Production).with_signer(signer)),
);
let client = Client::new(ClientOptions::new(Endpoint::Production).with_signer(signer));
let request = client.build_request(payload);

assert_ne!(None, request.headers().get(AUTHORIZATION));
Expand All @@ -426,7 +433,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
..Default::default()
};
let payload = builder.build("a_test_id", options);
let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);
let apns_push_type = request.headers().get("apns-push-type").unwrap();

Expand All @@ -437,7 +444,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_with_default_priority() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);
let apns_priority = request.headers().get("apns-priority");

Expand All @@ -456,7 +463,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);
let apns_priority = request.headers().get("apns-priority").unwrap();

Expand All @@ -475,7 +482,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);
let apns_priority = request.headers().get("apns-priority").unwrap();

Expand All @@ -488,7 +495,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);
let apns_id = request.headers().get("apns-id");

Expand All @@ -507,7 +514,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);
let apns_id = request.headers().get("apns-id").unwrap();

Expand All @@ -520,7 +527,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);
let apns_expiration = request.headers().get("apns-expiration");

Expand All @@ -539,7 +546,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);
let apns_expiration = request.headers().get("apns-expiration").unwrap();

Expand All @@ -552,7 +559,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);
let apns_collapse_id = request.headers().get("apns-collapse-id");

Expand All @@ -571,7 +578,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);
let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap();

Expand All @@ -584,7 +591,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);
let apns_topic = request.headers().get("apns-topic");

Expand All @@ -603,7 +610,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload);
let apns_topic = request.headers().get("apns-topic").unwrap();

Expand All @@ -614,7 +621,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
async fn test_request_body() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None);
let client = Client::new_with_defaults();
let request = client.build_request(payload.clone());

let body = request.into_body().collect().await.unwrap().to_bytes();
Expand Down

0 comments on commit 7ebbdaf

Please sign in to comment.