From c0a45bc815abb67bb77b432026779f8d77c9f895 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Thu, 18 Apr 2024 14:38:04 -0400 Subject: [PATCH 1/6] feat: key not from file --- Cargo.toml | 2 +- src/serv_account/errors.rs | 23 +++- src/serv_account/jwt.rs | 87 ++++++++------- src/serv_account/mod.rs | 220 ++++++++++++++++++++++++------------- 4 files changed, 202 insertions(+), 130 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 02e74fc..444939e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ ring = "0.16.20" thiserror = "1.0.48" anyhow = "1.0.40" futures = { version = "0.3", features = ["executor"], optional = true } -tokio = { version = "1.33.0", optional = true } +tokio = { version = "1.33.0", features = ["test-util", "sync"] } log = { version = "0.4", optional = true } [dev-dependencies] diff --git a/src/serv_account/errors.rs b/src/serv_account/errors.rs index cd0ea9a..b6ee7b8 100644 --- a/src/serv_account/errors.rs +++ b/src/serv_account/errors.rs @@ -1,10 +1,12 @@ -use std::result::Result as StdResult; +use reqwest::StatusCode; +use ring::error::{KeyRejected, Unspecified}; +use std::{io, path::PathBuf, result::Result as StdResult}; use thiserror::Error; #[derive(Debug, Error)] pub enum ServiceAccountError { - #[error("failed to read key file: {0}")] - ReadKey(String), + #[error("failed to read key file: {0}: {1}")] + ReadKey(PathBuf, io::Error), #[error("failed to de/serialize to json")] SerdeJson(#[from] serde_json::Error), @@ -13,13 +15,22 @@ pub enum ServiceAccountError { Base64Decode(#[from] base64::DecodeError), #[error("failed to create rsa key pair: {0}")] - RsaKeyPair(String), + RsaKeyPair(KeyRejected), #[error("failed to rsa sign: {0}")] - RsaSign(String), + RsaSign(Unspecified), #[error("failed to send request")] - HttpReqwest(#[from] reqwest::Error), + HttpRequest(reqwest::Error), + + #[error("failed to send request")] + HttpRequestUnsuccessful(StatusCode, std::result::Result), + + #[error("failed to get response JSON")] + HttpJson(reqwest::Error), + + #[error("response returned non-Bearer auth access token: {0}")] + AccessTokenNotBeaarer(String), } pub type Result = StdResult; diff --git a/src/serv_account/jwt.rs b/src/serv_account/jwt.rs index d61afb6..d60244d 100644 --- a/src/serv_account/jwt.rs +++ b/src/serv_account/jwt.rs @@ -1,9 +1,15 @@ use super::errors::{Result, ServiceAccountError}; +use base64::{engine::general_purpose, Engine as _}; +use ring::{ + rand, + signature::{self, RsaKeyPair}, +}; +use serde_derive::Deserialize; use serde_derive::Serialize; -#[derive(Clone, Debug, Default, Serialize)] +#[derive(Debug)] pub struct JwtToken { - private_key: String, + key_pair: RsaKeyPair, header: JwtHeader, payload: JwtPayload, } @@ -24,41 +30,36 @@ struct JwtPayload { iat: u64, } -use base64::{engine::general_purpose, Engine as _}; -use ring::{rand, signature}; -use serde_derive::Deserialize; - impl JwtToken { - /// Creates a new JWT token from a service account key file - pub fn from_file(key_path: &str) -> Result { - let private_key_content = std::fs::read(key_path) - .map_err(|err| ServiceAccountError::ReadKey(format!("{}: {}", err, key_path)))?; - - let key_data = serde_json::from_slice::(&private_key_content)?; - + /// Creates a new JWT token from a service account key + pub fn from_key(key: &ServiceAccountKey) -> Result { let iat = chrono::Utc::now().timestamp() as u64; let exp = iat + 3600; - let private_key = key_data + let private_key = key .private_key .replace('\n', "") .replace("-----BEGIN PRIVATE KEY-----", "") .replace("-----END PRIVATE KEY-----", ""); + let private_key = private_key.as_bytes(); + let decoded = general_purpose::STANDARD.decode(private_key)?; + let key_pair = RsaKeyPair::from_pkcs8(&decoded).map_err(ServiceAccountError::RsaKeyPair)?; + Ok(Self { header: JwtHeader { alg: String::from("RS256"), typ: String::from("JWT"), }, payload: JwtPayload { - iss: key_data.client_email, + iss: key.client_email.clone(), sub: None, scope: String::new(), - aud: key_data.token_uri, + aud: key.token_uri.clone(), exp, iat, }, - private_key, + key_pair, }) } @@ -100,54 +101,52 @@ impl JwtToken { /// Signs a message with the private key fn sign_rsa(&self, message: String) -> Result> { - let private_key = self.private_key.as_bytes(); - let decoded = general_purpose::STANDARD.decode(private_key)?; - - let key_pair = signature::RsaKeyPair::from_pkcs8(&decoded).map_err(|err| { - ServiceAccountError::RsaKeyPair(format!("failed tp create key pair: {}", err)) - })?; - // Sign the message, using PKCS#1 v1.5 padding and the SHA256 digest algorithm. let rng = rand::SystemRandom::new(); - let mut signature = vec![0; key_pair.public_modulus_len()]; - key_pair + let mut signature = vec![0; self.key_pair.public_modulus_len()]; + self.key_pair .sign( &signature::RSA_PKCS1_SHA256, &rng, message.as_bytes(), &mut signature, ) - .map_err(|err| ServiceAccountError::RsaSign(format!("{}", err)))?; + .map_err(ServiceAccountError::RsaSign)?; Ok(signature) } } #[allow(dead_code)] -#[derive(Debug, Deserialize)] -struct ServiceAccountKey { - r#type: String, - project_id: String, - private_key_id: String, - private_key: String, - client_email: String, - client_id: String, - auth_uri: String, - token_uri: String, - auth_provider_x509_cert_url: String, - client_x509_cert_url: String, - universe_domain: String, +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ServiceAccountKey { + pub r#type: String, + pub project_id: String, + pub private_key_id: String, + pub private_key: String, + pub client_email: String, + pub client_id: String, + pub auth_uri: String, + pub token_uri: String, + pub auth_provider_x509_cert_url: String, + pub client_x509_cert_url: String, + pub universe_domain: String, } #[cfg(test)] mod tests { use super::*; - const SERVICE_ACCOUNT_KEY_PATH: &str = "test_fixtures/service-account-key.json"; + fn read_key() -> ServiceAccountKey { + serde_json::from_slice(include_bytes!( + "../../test_fixtures/service-account-key.json" + )) + .unwrap() + } #[test] fn test_jwt_token() { - let mut token = JwtToken::from_file(SERVICE_ACCOUNT_KEY_PATH).unwrap(); + let mut token = JwtToken::from_key(&read_key()).unwrap(); assert_eq!(token.header.alg, "RS256"); assert_eq!(token.header.typ, "JWT"); @@ -170,7 +169,7 @@ mod tests { fn test_sign_rsa() { let message = String::from("hello, world"); - let token = JwtToken::from_file(SERVICE_ACCOUNT_KEY_PATH).unwrap(); + let token = JwtToken::from_key(&read_key()).unwrap(); let signature = token.sign_rsa(message).unwrap(); assert_eq!(signature.len(), 256); @@ -178,7 +177,7 @@ mod tests { #[test] fn test_token_to_string() { - let token = JwtToken::from_file(SERVICE_ACCOUNT_KEY_PATH) + let token = JwtToken::from_key(&read_key()) .unwrap() .sub(String::from("some@email.com")) .scope(String::from("https://www.googleapis.com/auth/pubsub")); diff --git a/src/serv_account/mod.rs b/src/serv_account/mod.rs index 2863b12..39e5cc4 100644 --- a/src/serv_account/mod.rs +++ b/src/serv_account/mod.rs @@ -1,50 +1,43 @@ -use chrono::Utc; +use self::errors::ServiceAccountError; +use chrono::{DateTime, Duration, Utc}; use errors::Result; use reqwest::Client as HttpClient; +use serde_derive::Deserialize; +use std::{path::Path, sync::Arc}; +use tokio::sync::RwLock; -use self::errors::ServiceAccountError; +pub use self::jwt::ServiceAccountKey; -pub(crate) mod errors; +pub mod errors; mod jwt; #[derive(Debug, Clone)] pub struct ServiceAccount { + http_client: HttpClient, + key: ServiceAccountKey, scopes: String, - key_path: String, user_email: Option, - - access_token: Option, - expires_at: Option, - - http_client: HttpClient, + access_token: Arc>>, } -#[derive(Debug, serde_derive::Deserialize)] -struct Token { - access_token: String, - expires_in: u64, - token_type: String, +#[derive(Debug, Clone)] +pub struct AccessToken { + pub bearer_token: String, + pub expires_at: DateTime, } -impl Token { - fn bearer_token(&self) -> String { - format!("{} {}", self.token_type, self.access_token) +impl ServiceAccount { + pub fn builder() -> ServiceAccountBuilder { + ServiceAccountBuilder::new() } -} -impl ServiceAccount { /// Creates a new service account from a key file and scopes - pub fn from_file(key_path: &str, scopes: Vec<&str>) -> Self { - Self { - scopes: scopes.join(" "), - key_path: key_path.to_string(), - user_email: None, - - access_token: None, - expires_at: None, - - http_client: HttpClient::new(), - } + pub fn from_file>(key_path: P, scopes: Vec<&str>) -> Result { + let bytes = std::fs::read(&key_path) + .map_err(|e| ServiceAccountError::ReadKey(key_path.as_ref().to_path_buf(), e))?; + let key = serde_json::from_slice::(&bytes) + .map_err(ServiceAccountError::SerdeJson)?; + Ok(Self::builder().key(key).scopes(scopes).build()) } /// Sets the user email @@ -56,60 +49,125 @@ impl ServiceAccount { /// Returns an access token /// If the access token is not expired, it will return the cached access token /// Otherwise, it will exchange the JWT token for an access token - pub async fn access_token(&mut self) -> Result { - match (self.access_token.as_ref(), self.expires_at) { - (Some(access_token), Some(expires_at)) - if expires_at > Utc::now().timestamp() as u64 => - { - Ok(access_token.to_string()) - } + pub async fn access_token(&self) -> Result { + let access_token = self.access_token.read().await.clone(); + match access_token { + Some(access_token) if access_token.expires_at > Utc::now() => Ok(access_token), _ => { - let jwt_token = self.jwt_token()?; - let token = match self.exchange_jwt_token_for_access_token(jwt_token).await { - Ok(token) => token, - Err(err) => return Err(err), - }; - - let expires_at = Utc::now().timestamp() as u64 + token.expires_in - 30; - - self.access_token = Some(token.bearer_token()); - self.expires_at = Some(expires_at); - - Ok(token.bearer_token()) + let new_token = self.get_fresh_access_token().await?; + *self.access_token.write().await = Some(new_token.clone()); + Ok(new_token) } } } - async fn exchange_jwt_token_for_access_token( - &mut self, - jwt_token: jwt::JwtToken, - ) -> Result { - let req_builder = self.http_client.post(jwt_token.token_uri()).form(&[ - ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), - ("assertion", &jwt_token.to_string()?), - ]); - - let res = match req_builder.send().await { - Ok(resp) => resp, - Err(err) => return Err(ServiceAccountError::HttpReqwest(err)), + async fn get_fresh_access_token(&self) -> Result { + let jwt_token = { + let mut token = jwt::JwtToken::from_key(&self.key)?.scope(self.scopes.clone()); + if let Some(user_email) = &self.user_email { + token = token.sub(user_email.clone()); + }; + token }; - let token = match res.json::().await { - Ok(token) => token, - Err(err) => return Err(ServiceAccountError::HttpReqwest(err)), - }; + #[derive(Debug, Deserialize)] + pub struct TokenResponse { + token_type: String, + access_token: String, + expires_in: i64, + } + + let response = self + .http_client + .post(jwt_token.token_uri()) + .form(&[ + ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), + ("assertion", &jwt_token.to_string()?), + ]) + .send() + .await + .map_err(ServiceAccountError::HttpRequest)?; + + if !response.status().is_success() { + return Err(ServiceAccountError::HttpRequestUnsuccessful( + response.status(), + response.text().await, + )); + } + + let json = response + .json::() + .await + .map_err(ServiceAccountError::HttpJson)?; + + if json.token_type != "Bearer" { + return Err(ServiceAccountError::AccessTokenNotBeaarer(json.token_type)); + } + + // Account for clock skew or time to receive or process the response + const LEEWAY: Duration = Duration::seconds(30); - Ok(token) + let expires_at = Utc::now() + Duration::seconds(json.expires_in) - LEEWAY; + + Ok(AccessToken { + bearer_token: json.access_token, + expires_at, + }) } +} - fn jwt_token(&self) -> Result { - let token = jwt::JwtToken::from_file(&self.key_path)?; +pub struct ServiceAccountBuilder { + http_client: Option, + key: Option, + scopes: Option, + user_email: Option, +} + +impl ServiceAccountBuilder { + pub fn new() -> Self { + Self { + http_client: None, + key: None, + scopes: None, + user_email: None, + } + } - Ok(match self.user_email { - Some(ref user_email) => token.sub(user_email.to_string()), - None => token, + /// Panics if key is not provided + pub fn build(self) -> ServiceAccount { + ServiceAccount { + http_client: self.http_client.unwrap_or_default(), + key: self.key.expect("Key required"), + scopes: self.scopes.unwrap_or_default(), + user_email: self.user_email, + access_token: Arc::new(RwLock::new(None)), } - .scope(self.scopes.clone())) + } + + pub fn http_client(mut self, http_client: HttpClient) -> Self { + self.http_client = Some(http_client); + self + } + + pub fn key(mut self, key: ServiceAccountKey) -> Self { + self.key = Some(key); + self + } + + pub fn scopes(mut self, scopes: Vec<&str>) -> Self { + self.scopes = Some(scopes.join(" ")); + self + } + + pub fn user_email>(mut self, user_email: S) -> Self { + self.user_email = Some(user_email.into()); + self + } +} + +impl Default for ServiceAccountBuilder { + fn default() -> Self { + Self::new() } } @@ -121,22 +179,26 @@ mod tests { async fn test_access_token() { let scopes = vec!["https://www.googleapis.com/auth/drive"]; let key_path = "test_fixtures/service-account-key.json"; - let mut service_account = ServiceAccount::from_file(key_path, scopes); + let service_account = ServiceAccount::from_file(key_path, scopes).unwrap(); // TODO: fix this test - make sure we can run an integration test // let access_token = service_account.access_token(); // assert!(access_token.is_ok()); // assert!(!access_token.unwrap().is_empty()); - service_account.access_token = Some("test_access_token".to_string()); - - let expires_at = Utc::now().timestamp() as u64 + 3600; - service_account.expires_at = Some(expires_at); + let expires_at = Utc::now() + Duration::seconds(3600); + *service_account.access_token.write().await = Some(AccessToken { + bearer_token: "test_access_token".to_string(), + expires_at, + }); assert_eq!( - service_account.access_token().await.unwrap(), + service_account.access_token().await.unwrap().bearer_token, "test_access_token" ); - assert_eq!(service_account.expires_at.unwrap(), expires_at); + assert_eq!( + service_account.access_token().await.unwrap().expires_at, + expires_at + ); } } From b422162a1c07e276a7ebcd5def28f4c5ba20e1d0 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Thu, 18 Apr 2024 16:15:04 -0400 Subject: [PATCH 2/6] fix: Cargo.toml --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 444939e..1d6e5c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,4 +34,4 @@ env_logger = "0.10.0" [features] app-blocking = ["dep:futures"] -token-watcher = ["dep:tokio", "dep:async-trait", "dep:log"] +token-watcher = ["dep:async-trait", "dep:log"] From 023be000298c7f1c8be8dc44d3c0f62ebd575ccf Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Mon, 22 Apr 2024 21:21:32 -0400 Subject: [PATCH 3/6] fix: typo --- src/serv_account/errors.rs | 2 +- src/serv_account/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serv_account/errors.rs b/src/serv_account/errors.rs index b6ee7b8..8e5da89 100644 --- a/src/serv_account/errors.rs +++ b/src/serv_account/errors.rs @@ -30,7 +30,7 @@ pub enum ServiceAccountError { HttpJson(reqwest::Error), #[error("response returned non-Bearer auth access token: {0}")] - AccessTokenNotBeaarer(String), + AccessTokenNotBearer(String), } pub type Result = StdResult; diff --git a/src/serv_account/mod.rs b/src/serv_account/mod.rs index 39e5cc4..5c02ec2 100644 --- a/src/serv_account/mod.rs +++ b/src/serv_account/mod.rs @@ -101,7 +101,7 @@ impl ServiceAccount { .map_err(ServiceAccountError::HttpJson)?; if json.token_type != "Bearer" { - return Err(ServiceAccountError::AccessTokenNotBeaarer(json.token_type)); + return Err(ServiceAccountError::AccessTokenNotBearer(json.token_type)); } // Account for clock skew or time to receive or process the response From 10c114c4e01318a1f2d3ca969a677e8d59b9b6ae Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Mon, 13 May 2024 11:17:33 -0400 Subject: [PATCH 4/6] fix: error handling improvements & refactor, add CI --- .github/workflows/ci.yaml | 13 ++ Cargo.toml | 4 +- README.md | 6 +- examples/async_token_provider.rs | 7 +- src/app/mod.rs | 2 +- src/serv_account/errors.rs | 43 +++-- src/serv_account/jwt.rs | 266 ++++++++++++++++--------------- src/serv_account/mod.rs | 84 +++++----- src/token_provider/errors.rs | 4 +- 9 files changed, 237 insertions(+), 192 deletions(-) create mode 100644 .github/workflows/ci.yaml diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..76f5e97 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,13 @@ +name: Cargo checks +on: + push: + pull_request: +jobs: + clippy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - run: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh + - run: cargo clippy --workspace --all-features --all-targets -- -D warnings + - run: cargo test --workspace --all-features --all-targets + - run: cargo fmt -- --check diff --git a/Cargo.toml b/Cargo.toml index 1d6e5c3..5cedfca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] description = "HTTP Client for Google OAuth2" name = "gauth" -version = "0.8.0" +version = "0.9.0" authors = ["Simon Makarski "] edition = "2021" license = "MIT OR Apache-2.0" @@ -29,7 +29,7 @@ log = { version = "0.4", optional = true } [dev-dependencies] mockito = "1.2.0" -tokio = { version = "1.33.0", features = ["test-util"] } +tokio = { version = "1.33.0", features = ["test-util", "rt", "macros", "rt-multi-thread"] } env_logger = "0.10.0" [features] diff --git a/README.md b/README.md index 70943b7..0f18c4f 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ The library supports the following Google Auth flows: ```toml [dependencies] -gauth = "0.8" +gauth = "0.9" ``` #### OAuth2 @@ -45,7 +45,7 @@ It is also possible to make a **blocking call** to retrieve an access token. Thi ``` [dependencies] -gauth = { version = "0.8", features = ["app-blocking"] } +gauth = { version = "0.9", features = ["app-blocking"] } ``` ```rust,no_run @@ -123,7 +123,7 @@ To resolve this, we adopted an experimental approach by developing a `token_prov ``` [dependencies] -gauth = { version = "0.8", features = ["token-watcher"] } +gauth = { version = "0.9", features = ["token-watcher"] } ``` ```rust,no_run diff --git a/examples/async_token_provider.rs b/examples/async_token_provider.rs index f9cbbb1..46bbd5f 100644 --- a/examples/async_token_provider.rs +++ b/examples/async_token_provider.rs @@ -12,8 +12,11 @@ async fn main() -> Result<(), Box> { .nth(1) .expect("Provide a path to the service account key file"); - let service_account = - ServiceAccount::from_file(&keypath, vec!["https://www.googleapis.com/auth/pubsub"]); + let service_account = ServiceAccount::from_file(&keypath) + .unwrap() + .scopes(vec!["https://www.googleapis.com/auth/pubsub"]) + .build() + .unwrap(); let tp = AsyncTokenProvider::new(service_account).with_interval(5); diff --git a/src/app/mod.rs b/src/app/mod.rs index 8e2a80a..7403efa 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -260,7 +260,7 @@ mod tests { #[tokio::test] async fn test_access_token_success() { - let mut google = mockito::Server::new(); + let mut google = mockito::Server::new_async().await; let google_host = google.url(); google diff --git a/src/serv_account/errors.rs b/src/serv_account/errors.rs index 8e5da89..fd31bb9 100644 --- a/src/serv_account/errors.rs +++ b/src/serv_account/errors.rs @@ -1,23 +1,44 @@ use reqwest::StatusCode; use ring::error::{KeyRejected, Unspecified}; -use std::{io, path::PathBuf, result::Result as StdResult}; +use std::{io, path::PathBuf}; use thiserror::Error; #[derive(Debug, Error)] -pub enum ServiceAccountError { +pub enum ServiceAccountFromFileError { #[error("failed to read key file: {0}: {1}")] - ReadKey(PathBuf, io::Error), + ReadFile(PathBuf, io::Error), #[error("failed to de/serialize to json")] - SerdeJson(#[from] serde_json::Error), + DeserializeFile(#[from] serde_json::Error), - #[error("failed to decode base64")] - Base64Decode(#[from] base64::DecodeError), + #[error("Failed to initialize service account: {0}")] + ServiceAccountInitialization(ServiceAccountBuildError), - #[error("failed to create rsa key pair: {0}")] - RsaKeyPair(KeyRejected), + #[error("Failed to get access token: {0}")] + GetAccessToken(GetAccessTokenError), +} + +#[derive(Debug, Error)] +pub enum ServiceAccountBuildError { + #[error("RSA private key didn't start with PEM prefix: -----BEGIN PRIVATE KEY-----")] + RsaPrivateKeyNoPrefix, - #[error("failed to rsa sign: {0}")] + #[error("RSA private key didn't end with PEM suffix: -----END PRIVATE KEY-----")] + RsaPrivateKeyNoSuffix, + + #[error("RSA private key could not be decoded as base64: {0}")] + RsaPrivateKeyDecode(base64::DecodeError), + + #[error("RSA private key could not be parsed: {0}")] + RsaPrivateKeyParse(KeyRejected), +} + +#[derive(Debug, Error)] +pub enum GetAccessTokenError { + #[error("failed to serialize JSON: {0}")] + JsonSerialization(serde_json::Error), + + #[error("failed to RSA sign: {0}")] RsaSign(Unspecified), #[error("failed to send request")] @@ -31,6 +52,6 @@ pub enum ServiceAccountError { #[error("response returned non-Bearer auth access token: {0}")] AccessTokenNotBearer(String), -} -pub type Result = StdResult; + // TODO error variant for invalid authentication +} diff --git a/src/serv_account/jwt.rs b/src/serv_account/jwt.rs index d60244d..741d426 100644 --- a/src/serv_account/jwt.rs +++ b/src/serv_account/jwt.rs @@ -1,4 +1,4 @@ -use super::errors::{Result, ServiceAccountError}; +use super::errors::{GetAccessTokenError, ServiceAccountBuildError}; use base64::{engine::general_purpose, Engine as _}; use ring::{ rand, @@ -6,115 +6,115 @@ use ring::{ }; use serde_derive::Deserialize; use serde_derive::Serialize; +use std::sync::Arc; -#[derive(Debug)] -pub struct JwtToken { - key_pair: RsaKeyPair, - header: JwtHeader, - payload: JwtPayload, -} - -#[derive(Clone, Debug, Default, Serialize)] -struct JwtHeader { - alg: String, - typ: String, -} - -#[derive(Clone, Debug, Default, Serialize)] -struct JwtPayload { +#[derive(Debug, Clone)] +pub struct JwtTokenSigner { + key_pair: Arc, + rng: rand::SystemRandom, iss: String, sub: Option, scope: String, aud: String, - exp: u64, - iat: u64, } -impl JwtToken { +impl JwtTokenSigner { /// Creates a new JWT token from a service account key - pub fn from_key(key: &ServiceAccountKey) -> Result { - let iat = chrono::Utc::now().timestamp() as u64; - let exp = iat + 3600; - - let private_key = key - .private_key - .replace('\n', "") - .replace("-----BEGIN PRIVATE KEY-----", "") - .replace("-----END PRIVATE KEY-----", ""); - - let private_key = private_key.as_bytes(); - let decoded = general_purpose::STANDARD.decode(private_key)?; - let key_pair = RsaKeyPair::from_pkcs8(&decoded).map_err(ServiceAccountError::RsaKeyPair)?; + pub fn from_key( + key: ServiceAccountKey, + scope: String, + sub: Option, + ) -> Result { + let no_whitespace = key.private_key.replace('\n', ""); + let private_key = no_whitespace + .strip_prefix("-----BEGIN PRIVATE KEY-----") + .ok_or(ServiceAccountBuildError::RsaPrivateKeyNoPrefix)? + .strip_suffix("-----END PRIVATE KEY-----") + .ok_or(ServiceAccountBuildError::RsaPrivateKeyNoSuffix)?; + println!("private_key: {:?}", private_key); + + let decoded = general_purpose::STANDARD + .decode(private_key.as_bytes()) + .map_err(ServiceAccountBuildError::RsaPrivateKeyDecode)?; + let key_pair = RsaKeyPair::from_pkcs8(&decoded) + .map_err(ServiceAccountBuildError::RsaPrivateKeyParse)?; Ok(Self { - header: JwtHeader { - alg: String::from("RS256"), - typ: String::from("JWT"), - }, - payload: JwtPayload { - iss: key.client_email.clone(), - sub: None, - scope: String::new(), - aud: key.token_uri.clone(), - exp, - iat, - }, - key_pair, + iss: key.client_email, + rng: rand::SystemRandom::new(), + sub, + scope, + aud: key.token_uri, + key_pair: Arc::new(key_pair), }) } - /// Returns a JWT token string - pub fn to_string(&self) -> Result { - let header = serde_json::to_vec(&self.header)?; - let payload = serde_json::to_vec(&self.payload)?; - - let base64_header = general_purpose::STANDARD.encode(header); - let base64_payload = general_purpose::STANDARD.encode(payload); - - let raw_signature = format!("{}.{}", base64_header, base64_payload); - let signature = self.sign_rsa(raw_signature)?; + /// Returns a signed JWT token string + pub fn sign(&self) -> Result { + #[derive(Clone, Debug, Default, Serialize)] + struct JwtHeader<'a> { + alg: &'a str, + typ: &'a str, + } + let header = serde_json::to_vec(&JwtHeader { + alg: "RS256", + typ: "JWT", + }) + .map_err(GetAccessTokenError::JsonSerialization)?; + let header = general_purpose::STANDARD.encode(header); + + #[derive(Clone, Debug, Default, Serialize)] + struct JwtPayload<'a> { + iss: &'a str, + sub: Option<&'a str>, + scope: &'a str, + aud: &'a str, + exp: u64, + iat: u64, + } + let iat = chrono::Utc::now().timestamp() as u64; + let exp = iat + 3600; + let payload = serde_json::to_vec(&JwtPayload { + iss: &self.iss, + sub: self.sub.as_deref(), + scope: &self.scope, + aud: &self.aud, + exp, + iat, + }) + .map_err(GetAccessTokenError::JsonSerialization)?; + let payload = general_purpose::STANDARD.encode(payload); - let base64_signature = general_purpose::STANDARD.encode(signature); + let to_sign = format!("{header}.{payload}"); + let signature = + sign_rsa(&self.key_pair, &self.rng, &to_sign).map_err(GetAccessTokenError::RsaSign)?; + let signature = general_purpose::STANDARD.encode(signature); - Ok(format!( - "{}.{}.{}", - base64_header, base64_payload, base64_signature - )) + Ok(format!("{to_sign}.{signature}")) } /// Returns the token uri pub fn token_uri(&self) -> &str { - &self.payload.aud - } - - /// Sets the sub field in the payload - pub fn sub(mut self, sub: String) -> Self { - self.payload.sub = Some(sub); - self - } - - /// Sets the scope field in the payload - pub fn scope(mut self, scope: String) -> Self { - self.payload.scope = scope; - self + &self.aud } +} - /// Signs a message with the private key - fn sign_rsa(&self, message: String) -> Result> { - // Sign the message, using PKCS#1 v1.5 padding and the SHA256 digest algorithm. - let rng = rand::SystemRandom::new(); - let mut signature = vec![0; self.key_pair.public_modulus_len()]; - self.key_pair - .sign( - &signature::RSA_PKCS1_SHA256, - &rng, - message.as_bytes(), - &mut signature, - ) - .map_err(ServiceAccountError::RsaSign)?; - - Ok(signature) - } +/// Signs a message with the private key +fn sign_rsa( + key_pair: &RsaKeyPair, + rng: &dyn rand::SecureRandom, + message: &str, +) -> Result, ring::error::Unspecified> { + // Sign the message, using PKCS#1 v1.5 padding and the SHA256 digest algorithm. + let mut signature = vec![0; key_pair.public_modulus_len()]; + key_pair.sign( + &signature::RSA_PKCS1_SHA256, + rng, + message.as_bytes(), + &mut signature, + )?; + + Ok(signature) } #[allow(dead_code)] @@ -136,6 +136,7 @@ pub struct ServiceAccountKey { #[cfg(test)] mod tests { use super::*; + use serde_json::Value; fn read_key() -> ServiceAccountKey { serde_json::from_slice(include_bytes!( @@ -145,49 +146,64 @@ mod tests { } #[test] - fn test_jwt_token() { - let mut token = JwtToken::from_key(&read_key()).unwrap(); - - assert_eq!(token.header.alg, "RS256"); - assert_eq!(token.header.typ, "JWT"); - assert!(token.payload.iss.contains("iam.gserviceaccount.com")); - assert_eq!(token.payload.sub, None); - assert_eq!(token.payload.scope, ""); - assert_eq!(token.payload.aud, "https://oauth2.googleapis.com/token"); - assert!(token.payload.exp > 0); - assert_eq!(token.payload.iat, token.payload.exp - 3600); - - token = token - .sub(String::from("some@email.domain")) - .scope(String::from("test_scope1 test_scope2 test_scope3")); - - assert_eq!(token.payload.sub, Some(String::from("some@email.domain"))); - assert_eq!(token.payload.scope, "test_scope1 test_scope2 test_scope3"); + fn test_rsa_sign() { + let key = "MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCr/KzFiWfiw5vd8KrFPmsktUfmba4x8r0uPDxxdeI/zrENHPkef3Zd3Tt4bvdG4VRWAQ/zuomHcksTW1AYaaS/TfoiH5c/xivWptKHGS/eh91SgPunmoK9wbvdNW8C4goVdw57JUz6IG1vZpenHjI7ofHMfg+2cBiTsTSWFDnd1EoNkK2lmdP1R5lzxNSRce9HgugKvHAcvDtB2goL9coo8y+3kyBTiS5qCgpWplGwIMBACGW6U4a//GajvmvvZyfym7OXJeqjXznjNH32ghhjcP2DUuGf36wika1rOpmZKCJDKBoMPQERUDa1ydYLfY3v1g/8xFTL4ezuyYEkGuu5AgMBAAECggEAP3Meglno+53SuRR6y/31JTvD5Nz98Otuo8oROoKVD5k/dGkF9xxrHMHrmMjHbVzf8kK+Edr1tgSScfe0Gu2OnA02hLRG5n5D2hL9hF3kbSKOokt3jCPSrBL3Leryo4uk0Lp1mzTtqzGfbgPZWwwm2B0syZaQUWwVhRdRITUhDBcUW8cuxGXzNeDTJMUjij0li61H62rJFjE5nyxCpwlukqR96uVWN6wXhM4xhzwhaHt6oGVUAENG3Er+ZjYCgBISQkEuiaFUgB3Zkv3qYWhaWNhwhO6MDsT33xex4Ecw4epCrAfEirkP1AIYmVWFw3uxODOJ/u8mb6IQIobnxwRiIQKBgQDihX+XxV8tSvHxgHTN5vzp4oOgnKhmiClm7/MSbjwHjLcffWh6gqBLbPAvcrfA0aewIT29xgIO0CpygJcg/4RND30YKTilYo7/ieTkdwRYsCbt9zM/WBop1snZja4Zox/SK23u4OJ4uUw0e4onXOOzAogCtiEKMx+U6+JmsyhNFQKBgQDCXmAhdrinbfXtsC5J+HwC81XaFujE2l4EiLqVaHH6DIrVTNSucf6O/nsCHWhttb3U7xT7CIHCe1om8peKZsjuiQqmlKjeqPRhDNlLXV5TadIKUs8svPM+MUXArhTc3vAv1pArhi7RpQ5F1AeTJGkOvxcY6vmMjXIb/dSiZMp1FQKBgDIii+fidjtHEB98Z92+lxGI4cslgRwYXNl8mBbnMQAWw90DW6Fp0eJ/vPUzdboGbQ/Ne6XJ8mCm8A4hqdFS3ExV9kDntrLcCnxCX9e1A9BBRIx8nuoRLNE/ybMN6Y+hDATvOciaG2XO1S/0e9JUe8z97W50MwHX6NCEGLrUQkI1AoGADD4lj/YKa4FhnDccs0wTg5wQLEyFHOEkSuTR29dYVoeztvu/6b0Ea71bwiZYDZEFBASLLcS7Z6SdaRaetPkEbwHyyctTV7MMsZA9n6Gh718a+8t7gTXlnGU+H4TXi5H/TwQU0KkDCfF7lKpmT75bX7Jpoggq7895AIpcel4e4oECgYAbddARaP5mH2KAiSoBUlvh4P2beCv5HmWjIhS2nA7KaGOtGfOk9/VGTRLZXtPed70cGD5SrgMze3umI37nAtcVv+MHcZSXhjoSQZ6M3GChaDUwJNC+f6GVjfadn7LOsY5L1+0cu1pe6r4uXBOwmvv1tynpY6sGOE+tPJibK5Pm8Q=="; + let key_pair = RsaKeyPair::from_pkcs8(&general_purpose::STANDARD.decode(key).unwrap()) + .expect("Failed to parse key"); + let rng = rand::SystemRandom::new(); + let message = "hello world"; + let signature = sign_rsa(&key_pair, &rng, message).unwrap(); + assert_eq!(signature.len(), 256); } #[test] - fn test_sign_rsa() { - let message = String::from("hello, world"); - - let token = JwtToken::from_key(&read_key()).unwrap(); - let signature = token.sign_rsa(message).unwrap(); + fn test_sign() { + let scope = "test_scope1 test_scope2 test_scope3"; + let signer = JwtTokenSigner::from_key(read_key(), scope.to_owned(), None).unwrap(); + let token = signer.sign().unwrap(); + println!("token: {:?}", token); + let parts = token.split('.').collect::>(); + assert_eq!(parts.len(), 3); + let mut parts = parts.into_iter(); + + let header = parts.next().unwrap(); + let header = general_purpose::STANDARD.decode(header).unwrap(); + let header = serde_json::from_slice::(&header).unwrap(); + assert_eq!(header["alg"], "RS256"); + assert_eq!(header["typ"], "JWT"); + + let payload = parts.next().unwrap(); + let payload = general_purpose::STANDARD.decode(payload).unwrap(); + let payload = serde_json::from_slice::(&payload).unwrap(); + assert_eq!(payload["scope"], Value::String(scope.to_owned())); + assert_eq!(payload["sub"], Value::Null); + assert_eq!(payload["aud"], "https://oauth2.googleapis.com/token"); + assert!(payload["exp"].as_i64().unwrap() > 0); + assert_eq!( + payload["iat"].as_i64().unwrap(), + payload["exp"].as_i64().unwrap() - 3600 + ); + let signature = parts.next().unwrap(); + let signature = general_purpose::STANDARD.decode(signature).unwrap(); assert_eq!(signature.len(), 256); } #[test] - fn test_token_to_string() { - let token = JwtToken::from_key(&read_key()) - .unwrap() - .sub(String::from("some@email.com")) - .scope(String::from("https://www.googleapis.com/auth/pubsub")); - - let token_string = token.to_string(); - - assert!(token_string.is_ok(), "token string successfully created"); - assert!( - !token_string.unwrap().is_empty(), - "token string is not empty" - ); + fn test_sign_email() { + let sub = "some@email.domain"; + let signer = + JwtTokenSigner::from_key(read_key(), "".to_owned(), Some(sub.to_owned())).unwrap(); + let token = signer.sign().unwrap(); + let parts = token.split('.').collect::>(); + assert_eq!(parts.len(), 3); + let mut parts = parts.into_iter(); + + let _header = parts.next().unwrap(); + + let payload = parts.next().unwrap(); + let payload = general_purpose::STANDARD.decode(payload).unwrap(); + let payload = serde_json::from_slice::(&payload).unwrap(); + assert_eq!(payload["sub"], Value::String(sub.to_owned())); } } diff --git a/src/serv_account/mod.rs b/src/serv_account/mod.rs index 5c02ec2..63053c3 100644 --- a/src/serv_account/mod.rs +++ b/src/serv_account/mod.rs @@ -1,6 +1,11 @@ -use self::errors::ServiceAccountError; +use self::{ + errors::{ + GetAccessTokenError, ServiceAccountBuildError as ServiceAccountBuilderError, + ServiceAccountFromFileError, + }, + jwt::JwtTokenSigner, +}; use chrono::{DateTime, Duration, Utc}; -use errors::Result; use reqwest::Client as HttpClient; use serde_derive::Deserialize; use std::{path::Path, sync::Arc}; @@ -14,9 +19,7 @@ mod jwt; #[derive(Debug, Clone)] pub struct ServiceAccount { http_client: HttpClient, - key: ServiceAccountKey, - scopes: String, - user_email: Option, + jwt_token: JwtTokenSigner, access_token: Arc>>, } @@ -31,25 +34,22 @@ impl ServiceAccount { ServiceAccountBuilder::new() } - /// Creates a new service account from a key file and scopes - pub fn from_file>(key_path: P, scopes: Vec<&str>) -> Result { - let bytes = std::fs::read(&key_path) - .map_err(|e| ServiceAccountError::ReadKey(key_path.as_ref().to_path_buf(), e))?; + /// Creates a new `ServiceAccountBuilder` from a key file + pub fn from_file>( + key_path: P, + ) -> Result { + let bytes = std::fs::read(&key_path).map_err(|e| { + ServiceAccountFromFileError::ReadFile(key_path.as_ref().to_path_buf(), e) + })?; let key = serde_json::from_slice::(&bytes) - .map_err(ServiceAccountError::SerdeJson)?; - Ok(Self::builder().key(key).scopes(scopes).build()) - } - - /// Sets the user email - pub fn user_email(mut self, user_email: &str) -> Self { - self.user_email = Some(user_email.to_string()); - self + .map_err(ServiceAccountFromFileError::DeserializeFile)?; + Ok(Self::builder().key(key)) } /// Returns an access token /// If the access token is not expired, it will return the cached access token /// Otherwise, it will exchange the JWT token for an access token - pub async fn access_token(&self) -> Result { + pub async fn access_token(&self) -> Result { let access_token = self.access_token.read().await.clone(); match access_token { Some(access_token) if access_token.expires_at > Utc::now() => Ok(access_token), @@ -61,15 +61,7 @@ impl ServiceAccount { } } - async fn get_fresh_access_token(&self) -> Result { - let jwt_token = { - let mut token = jwt::JwtToken::from_key(&self.key)?.scope(self.scopes.clone()); - if let Some(user_email) = &self.user_email { - token = token.sub(user_email.clone()); - }; - token - }; - + async fn get_fresh_access_token(&self) -> Result { #[derive(Debug, Deserialize)] pub struct TokenResponse { token_type: String, @@ -79,17 +71,17 @@ impl ServiceAccount { let response = self .http_client - .post(jwt_token.token_uri()) + .post(self.jwt_token.token_uri()) .form(&[ ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), - ("assertion", &jwt_token.to_string()?), + ("assertion", &self.jwt_token.sign()?), ]) .send() .await - .map_err(ServiceAccountError::HttpRequest)?; + .map_err(GetAccessTokenError::HttpRequest)?; if !response.status().is_success() { - return Err(ServiceAccountError::HttpRequestUnsuccessful( + return Err(GetAccessTokenError::HttpRequestUnsuccessful( response.status(), response.text().await, )); @@ -98,10 +90,10 @@ impl ServiceAccount { let json = response .json::() .await - .map_err(ServiceAccountError::HttpJson)?; + .map_err(GetAccessTokenError::HttpJson)?; if json.token_type != "Bearer" { - return Err(ServiceAccountError::AccessTokenNotBearer(json.token_type)); + return Err(GetAccessTokenError::AccessTokenNotBearer(json.token_type)); } // Account for clock skew or time to receive or process the response @@ -134,14 +126,15 @@ impl ServiceAccountBuilder { } /// Panics if key is not provided - pub fn build(self) -> ServiceAccount { - ServiceAccount { + pub fn build(self) -> Result { + let key = self.key.expect("Key required"); + let jwt_token = + jwt::JwtTokenSigner::from_key(key, self.scopes.unwrap_or_default(), self.user_email)?; + Ok(ServiceAccount { http_client: self.http_client.unwrap_or_default(), - key: self.key.expect("Key required"), - scopes: self.scopes.unwrap_or_default(), - user_email: self.user_email, + jwt_token, access_token: Arc::new(RwLock::new(None)), - } + }) } pub fn http_client(mut self, http_client: HttpClient) -> Self { @@ -176,15 +169,14 @@ mod tests { use super::*; #[tokio::test] - async fn test_access_token() { + async fn test_access_token_cache() { let scopes = vec!["https://www.googleapis.com/auth/drive"]; let key_path = "test_fixtures/service-account-key.json"; - let service_account = ServiceAccount::from_file(key_path, scopes).unwrap(); - - // TODO: fix this test - make sure we can run an integration test - // let access_token = service_account.access_token(); - // assert!(access_token.is_ok()); - // assert!(!access_token.unwrap().is_empty()); + let service_account = ServiceAccount::from_file(key_path) + .unwrap() + .scopes(scopes) + .build() + .unwrap(); let expires_at = Utc::now() + Duration::seconds(3600); *service_account.access_token.write().await = Some(AccessToken { diff --git a/src/token_provider/errors.rs b/src/token_provider/errors.rs index cf1fab8..11cba56 100644 --- a/src/token_provider/errors.rs +++ b/src/token_provider/errors.rs @@ -3,7 +3,7 @@ use thiserror::Error; use tokio::sync::mpsc::error::SendError; use tokio::sync::TryLockError; -use crate::serv_account::errors::ServiceAccountError; +use crate::serv_account::errors::GetAccessTokenError; #[derive(Debug, Error)] pub enum TokenProviderError { @@ -11,7 +11,7 @@ pub enum TokenProviderError { AccessToken(#[from] TryLockError), #[error("service account error: {0}")] - ServiceAccountError(#[from] ServiceAccountError), + GetAccessTokenError(#[from] GetAccessTokenError), #[error("failed to send token: {0}")] SendError(#[from] SendError), From 9cad0483c0b629fd89aa31e02d82daa5f2fc8e83 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Mon, 13 May 2024 12:22:22 -0400 Subject: [PATCH 5/6] fix: CI --- .github/workflows/ci.yaml | 2 +- src/app/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 76f5e97..c21d493 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -7,7 +7,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - run: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh + - run: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y - run: cargo clippy --workspace --all-features --all-targets -- -D warnings - run: cargo test --workspace --all-features --all-targets - run: cargo fmt -- --check diff --git a/src/app/mod.rs b/src/app/mod.rs index 7403efa..1e43db4 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -90,7 +90,7 @@ impl Auth { /// App_name can be used to override the default app name pub fn app_name(mut self, app_name: &str) -> Self { - self.app_name = app_name.to_owned(); + app_name.clone_into(&mut self.app_name); self } From 126269c6ed7a53784dd6e14ea8530fab9e14941f Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Mon, 13 May 2024 12:24:56 -0400 Subject: [PATCH 6/6] chore: bump reqwest --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 5cedfca..0029c75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ serde = "1" serde_json = "1" serde_derive = "1" dirs = "5.0.1" -reqwest = { version = "0.11", features = ["json"] } +reqwest = { version = "0.12.4", features = ["json"] } chrono = "0.4.31" base64 = "0.21.3" ring = "0.16.20"