Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve unwrap in build body and when acquiring RwLock #79

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ ring = { version = "0.17", features = ["std"], optional = true }
hyper-rustls = { version = "0.26.0", default-features = false, features = ["http2", "webpki-roots", "ring"] }
rustls-pemfile = "2.1.1"
rustls = "0.22.0"
parking_lot = "0.12"

[dev-dependencies]
argparse = "0.2"
Expand Down
62 changes: 35 additions & 27 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl Client {
/// See [ErrorReason](enum.ErrorReason.html) for possible errors.
#[cfg_attr(feature = "tracing", ::tracing::instrument)]
pub async fn send<T: PayloadLike>(&self, payload: T) -> Result<Response, Error> {
let request = self.build_request(payload);
let request = self.build_request(payload)?;
let requesting = self.http_client.request(request);

let response = requesting.await?;
Expand Down Expand Up @@ -152,7 +152,7 @@ impl Client {
}
}

fn build_request<T: PayloadLike>(&self, payload: T) -> hyper::Request<BoxBody<Bytes, Infallible>> {
fn build_request<T: PayloadLike>(&self, payload: T) -> Result<hyper::Request<BoxBody<Bytes, Infallible>>, Error> {
let path = format!("https://{}/3/device/{}", self.endpoint, payload.get_device_token());

let mut builder = hyper::Request::builder()
Expand Down Expand Up @@ -180,18 +180,16 @@ impl Client {
builder = builder.header("apns-topic", apns_topic.as_bytes());
}
if let Some(ref signer) = self.signer {
let auth = signer
.with_signature(|signature| format!("Bearer {}", signature))
.unwrap();
let auth = signer.with_signature(|signature| format!("Bearer {}", signature))?;

builder = builder.header(AUTHORIZATION, auth.as_bytes());
}

let payload_json = payload.to_json_string().unwrap();
let payload_json = payload.to_json_string()?;
builder = builder.header(CONTENT_LENGTH, format!("{}", payload_json.len()).as_bytes());

let request_body = Full::from(payload_json.into_bytes()).boxed();
builder.body(request_body).unwrap()
builder.body(request_body).map_err(Error::BuildRequestError)
}
}

Expand Down Expand Up @@ -247,7 +245,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let uri = format!("{}", request.uri());

assert_eq!("https://api.push.apple.com/3/device/a_test_id", &uri);
Expand All @@ -258,7 +256,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Sandbox);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let uri = format!("{}", request.uri());

assert_eq!("https://api.development.push.apple.com/3/device/a_test_id", &uri);
Expand All @@ -269,17 +267,27 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_eq!(&Method::POST, request.method());
}

#[test]
fn test_request_invalid() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("\r\n", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);

assert!(matches!(request, Err(Error::BuildRequestError(_))));
}

#[test]
fn test_request_content_type() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap());
}
Expand All @@ -289,7 +297,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload.clone());
let request = client.build_request(payload.clone()).unwrap();
let payload_json = payload.to_json_string().unwrap();
let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap();

Expand All @@ -301,7 +309,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_eq!(None, request.headers().get(AUTHORIZATION));
}
Expand All @@ -319,7 +327,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), Some(signer), Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();

assert_ne!(None, request.headers().get(AUTHORIZATION));
}
Expand All @@ -333,7 +341,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
};
let payload = builder.build("a_test_id", options);
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_push_type = request.headers().get("apns-push-type").unwrap();

assert_eq!("background", apns_push_type);
Expand All @@ -344,7 +352,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority");

assert_eq!(None, apns_priority);
Expand All @@ -363,7 +371,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority").unwrap();

assert_eq!("5", apns_priority);
Expand All @@ -382,7 +390,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority").unwrap();

assert_eq!("10", apns_priority);
Expand All @@ -395,7 +403,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_id = request.headers().get("apns-id");

assert_eq!(None, apns_id);
Expand All @@ -414,7 +422,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_id = request.headers().get("apns-id").unwrap();

assert_eq!("a-test-apns-id", apns_id);
Expand All @@ -427,7 +435,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_expiration = request.headers().get("apns-expiration");

assert_eq!(None, apns_expiration);
Expand All @@ -446,7 +454,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_expiration = request.headers().get("apns-expiration").unwrap();

assert_eq!("420", apns_expiration);
Expand All @@ -459,7 +467,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_collapse_id = request.headers().get("apns-collapse-id");

assert_eq!(None, apns_collapse_id);
Expand All @@ -478,7 +486,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap();

assert_eq!("a_collapse_id", apns_collapse_id);
Expand All @@ -491,7 +499,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_topic = request.headers().get("apns-topic");

assert_eq!(None, apns_topic);
Expand All @@ -510,7 +518,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
);

let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload);
let request = client.build_request(payload).unwrap();
let apns_topic = request.headers().get("apns-topic").unwrap();

assert_eq!("a_topic", apns_topic);
Expand All @@ -521,7 +529,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let request = client.build_request(payload.clone());
let request = client.build_request(payload.clone()).unwrap();

let body = request.into_body().collect().await.unwrap().to_bytes();
let body_str = String::from_utf8(body.to_vec()).unwrap();
Expand Down
4 changes: 4 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ pub enum Error {
#[error("Error building TLS config: {0}")]
Tls(#[from] rustls::Error),

/// Error while creating the HTTP request
#[error("Failed to construct HTTP request: {0}")]
BuildRequestError(#[source] http::Error),

/// Unexpected private key (only EC keys are supported).
#[cfg(all(not(feature = "openssl"), feature = "ring"))]
#[error("Unexpected private key: {0}")]
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@
//! }
//! # }
//! ```
#![warn(clippy::unwrap_used)]

#[cfg(not(any(feature = "openssl", feature = "ring")))]
compile_error!("either feature \"openssl\" or feature \"ring\" has to be enabled");

Expand Down
12 changes: 5 additions & 7 deletions src/signer.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use crate::error::Error;
use parking_lot::RwLock;
use std::io::Read;
use std::sync::Arc;
use std::{
sync::RwLock,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use std::time::{Duration, SystemTime, UNIX_EPOCH};

use base64::prelude::*;
#[cfg(feature = "openssl")]
Expand Down Expand Up @@ -138,7 +136,7 @@ impl Signer {
self.renew()?;
}

let signature = self.signature.read().unwrap();
let signature = self.signature.read();

#[cfg(feature = "tracing")]
{
Expand Down Expand Up @@ -191,7 +189,7 @@ impl Signer {
);
}

let mut signature = self.signature.write().unwrap();
let mut signature = self.signature.write();

*signature = Signature {
key: Self::create_signature(&self.secret, &self.key_id, &self.team_id, issued_at)?,
Expand All @@ -202,7 +200,7 @@ impl Signer {
}

fn is_expired(&self) -> bool {
let sig = self.signature.read().unwrap();
let sig = self.signature.read();
let expiry = get_time() - sig.issued_at;
expiry >= self.expire_after_s.as_secs() as i64
}
Expand Down
Loading