Skip to content

Commit

Permalink
feat: key not from file
Browse files Browse the repository at this point in the history
  • Loading branch information
chris13524 committed Apr 18, 2024
1 parent 697e671 commit c0a45bc
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 130 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ring = "0.16.20"
thiserror = "1.0.48"
anyhow = "1.0.40"
futures = { version = "0.3", features = ["executor"], optional = true }
tokio = { version = "1.33.0", optional = true }
tokio = { version = "1.33.0", features = ["test-util", "sync"] }
log = { version = "0.4", optional = true }

[dev-dependencies]
Expand Down
23 changes: 17 additions & 6 deletions src/serv_account/errors.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::result::Result as StdResult;
use reqwest::StatusCode;
use ring::error::{KeyRejected, Unspecified};
use std::{io, path::PathBuf, result::Result as StdResult};
use thiserror::Error;

#[derive(Debug, Error)]
pub enum ServiceAccountError {
#[error("failed to read key file: {0}")]
ReadKey(String),
#[error("failed to read key file: {0}: {1}")]
ReadKey(PathBuf, io::Error),

#[error("failed to de/serialize to json")]
SerdeJson(#[from] serde_json::Error),
Expand All @@ -13,13 +15,22 @@ pub enum ServiceAccountError {
Base64Decode(#[from] base64::DecodeError),

#[error("failed to create rsa key pair: {0}")]
RsaKeyPair(String),
RsaKeyPair(KeyRejected),

#[error("failed to rsa sign: {0}")]
RsaSign(String),
RsaSign(Unspecified),

#[error("failed to send request")]
HttpReqwest(#[from] reqwest::Error),
HttpRequest(reqwest::Error),

#[error("failed to send request")]
HttpRequestUnsuccessful(StatusCode, std::result::Result<String, reqwest::Error>),

#[error("failed to get response JSON")]
HttpJson(reqwest::Error),

#[error("response returned non-Bearer auth access token: {0}")]
AccessTokenNotBeaarer(String),
}

pub type Result<T> = StdResult<T, ServiceAccountError>;
87 changes: 43 additions & 44 deletions src/serv_account/jwt.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use super::errors::{Result, ServiceAccountError};
use base64::{engine::general_purpose, Engine as _};
use ring::{
rand,
signature::{self, RsaKeyPair},
};
use serde_derive::Deserialize;
use serde_derive::Serialize;

#[derive(Clone, Debug, Default, Serialize)]
#[derive(Debug)]
pub struct JwtToken {
private_key: String,
key_pair: RsaKeyPair,
header: JwtHeader,
payload: JwtPayload,
}
Expand All @@ -24,41 +30,36 @@ struct JwtPayload {
iat: u64,
}

use base64::{engine::general_purpose, Engine as _};
use ring::{rand, signature};
use serde_derive::Deserialize;

impl JwtToken {
/// Creates a new JWT token from a service account key file
pub fn from_file(key_path: &str) -> Result<Self> {
let private_key_content = std::fs::read(key_path)
.map_err(|err| ServiceAccountError::ReadKey(format!("{}: {}", err, key_path)))?;

let key_data = serde_json::from_slice::<ServiceAccountKey>(&private_key_content)?;

/// Creates a new JWT token from a service account key
pub fn from_key(key: &ServiceAccountKey) -> Result<Self> {
let iat = chrono::Utc::now().timestamp() as u64;
let exp = iat + 3600;

let private_key = key_data
let private_key = key
.private_key
.replace('\n', "")
.replace("-----BEGIN PRIVATE KEY-----", "")
.replace("-----END PRIVATE KEY-----", "");

let private_key = private_key.as_bytes();
let decoded = general_purpose::STANDARD.decode(private_key)?;
let key_pair = RsaKeyPair::from_pkcs8(&decoded).map_err(ServiceAccountError::RsaKeyPair)?;

Ok(Self {
header: JwtHeader {
alg: String::from("RS256"),
typ: String::from("JWT"),
},
payload: JwtPayload {
iss: key_data.client_email,
iss: key.client_email.clone(),
sub: None,
scope: String::new(),
aud: key_data.token_uri,
aud: key.token_uri.clone(),
exp,
iat,
},
private_key,
key_pair,
})
}

Expand Down Expand Up @@ -100,54 +101,52 @@ impl JwtToken {

/// Signs a message with the private key
fn sign_rsa(&self, message: String) -> Result<Vec<u8>> {
let private_key = self.private_key.as_bytes();
let decoded = general_purpose::STANDARD.decode(private_key)?;

let key_pair = signature::RsaKeyPair::from_pkcs8(&decoded).map_err(|err| {
ServiceAccountError::RsaKeyPair(format!("failed tp create key pair: {}", err))
})?;

// Sign the message, using PKCS#1 v1.5 padding and the SHA256 digest algorithm.
let rng = rand::SystemRandom::new();
let mut signature = vec![0; key_pair.public_modulus_len()];
key_pair
let mut signature = vec![0; self.key_pair.public_modulus_len()];
self.key_pair
.sign(
&signature::RSA_PKCS1_SHA256,
&rng,
message.as_bytes(),
&mut signature,
)
.map_err(|err| ServiceAccountError::RsaSign(format!("{}", err)))?;
.map_err(ServiceAccountError::RsaSign)?;

Ok(signature)
}
}

#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct ServiceAccountKey {
r#type: String,
project_id: String,
private_key_id: String,
private_key: String,
client_email: String,
client_id: String,
auth_uri: String,
token_uri: String,
auth_provider_x509_cert_url: String,
client_x509_cert_url: String,
universe_domain: String,
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ServiceAccountKey {
pub r#type: String,
pub project_id: String,
pub private_key_id: String,
pub private_key: String,
pub client_email: String,
pub client_id: String,
pub auth_uri: String,
pub token_uri: String,
pub auth_provider_x509_cert_url: String,
pub client_x509_cert_url: String,
pub universe_domain: String,
}

#[cfg(test)]
mod tests {
use super::*;

const SERVICE_ACCOUNT_KEY_PATH: &str = "test_fixtures/service-account-key.json";
fn read_key() -> ServiceAccountKey {
serde_json::from_slice(include_bytes!(
"../../test_fixtures/service-account-key.json"
))
.unwrap()
}

#[test]
fn test_jwt_token() {
let mut token = JwtToken::from_file(SERVICE_ACCOUNT_KEY_PATH).unwrap();
let mut token = JwtToken::from_key(&read_key()).unwrap();

assert_eq!(token.header.alg, "RS256");
assert_eq!(token.header.typ, "JWT");
Expand All @@ -170,15 +169,15 @@ mod tests {
fn test_sign_rsa() {
let message = String::from("hello, world");

let token = JwtToken::from_file(SERVICE_ACCOUNT_KEY_PATH).unwrap();
let token = JwtToken::from_key(&read_key()).unwrap();
let signature = token.sign_rsa(message).unwrap();

assert_eq!(signature.len(), 256);
}

#[test]
fn test_token_to_string() {
let token = JwtToken::from_file(SERVICE_ACCOUNT_KEY_PATH)
let token = JwtToken::from_key(&read_key())
.unwrap()
.sub(String::from("some@email.com"))
.scope(String::from("https://www.googleapis.com/auth/pubsub"));
Expand Down
Loading

0 comments on commit c0a45bc

Please sign in to comment.