Skip to content

Commit

Permalink
Add OpenID to config integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sea-snake committed Jan 15, 2025
1 parent e52ce03 commit 94526a3
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 87 deletions.
2 changes: 2 additions & 0 deletions src/internet_identity/internet_identity.did
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ type InternetIdentityInit = record {
captcha_config: opt CaptchaConfig;
// Configuration for Related Origins Requests
related_origins: opt vec text;
// Configuration for OpenID Google client
openid_google_client_id: opt text;
};

type ChallengeKey = text;
Expand Down
40 changes: 20 additions & 20 deletions src/internet_identity/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ async fn add_tentative_device(
tentative_device_registration::add_tentative_device(anchor_number, device_data).await;
match result {
Ok(TentativeRegistrationInfo {
verification_code,
device_registration_timeout,
}) => AddTentativeDeviceResponse::AddedTentatively {
verification_code,
device_registration_timeout,
}) => AddTentativeDeviceResponse::AddedTentatively {
verification_code,
device_registration_timeout,
},
Expand Down Expand Up @@ -154,7 +154,7 @@ fn add(anchor_number: AnchorNumber, device_data: DeviceData) {
anchor_operation_with_authz_check(anchor_number, |anchor| {
Ok::<_, String>(((), anchor_management::add(anchor, device_data)))
})
.unwrap_or_else(|err| trap(err.as_str()))
.unwrap_or_else(|err| trap(err.as_str()))
}

#[update]
Expand All @@ -165,7 +165,7 @@ fn update(anchor_number: AnchorNumber, device_key: DeviceKey, device_data: Devic
anchor_management::update(anchor, device_key, device_data),
))
})
.unwrap_or_else(|err| trap(err.as_str()))
.unwrap_or_else(|err| trap(err.as_str()))
}

#[update]
Expand All @@ -176,7 +176,7 @@ fn replace(anchor_number: AnchorNumber, device_key: DeviceKey, device_data: Devi
anchor_management::replace(anchor_number, anchor, device_key, device_data),
))
})
.unwrap_or_else(|err| trap(err.as_str()))
.unwrap_or_else(|err| trap(err.as_str()))
}

#[update]
Expand All @@ -187,7 +187,7 @@ fn remove(anchor_number: AnchorNumber, device_key: DeviceKey) {
anchor_management::remove(anchor_number, anchor, device_key),
))
})
.unwrap_or_else(|err| trap(err.as_str()))
.unwrap_or_else(|err| trap(err.as_str()))
}

/// Returns all devices of the anchor (authentication and recovery) but no information about device registrations.
Expand Down Expand Up @@ -271,7 +271,7 @@ async fn prepare_delegation(
max_time_to_live,
&ii_domain,
)
.await
.await
}

#[query]
Expand Down Expand Up @@ -388,13 +388,13 @@ fn post_upgrade(maybe_arg: Option<InternetIdentityInit>) {
}

fn initialize(maybe_arg: Option<InternetIdentityInit>) {
let related_origins = maybe_arg.clone().map_or_else(
let related_origins = maybe_arg.as_ref().map_or_else(
|| persistent_state(|storage| storage.related_origins.clone()),
|arg| arg.related_origins,
|arg| arg.related_origins.clone(),
);
let openid_google_client_id = maybe_arg.clone().map_or_else(
let openid_google_client_id = maybe_arg.as_ref().map_or_else(
|| persistent_state(|storage| storage.openid_google_client_id.clone()),
|arg| arg.openid_google_client_id,
|arg| arg.openid_google_client_id.clone(),
);
init_assets(related_origins);
apply_install_arg(maybe_arg);
Expand Down Expand Up @@ -470,7 +470,7 @@ fn update_root_hash() {
/// Calls raw rand to retrieve a random salt (32 bytes).
async fn random_salt() -> Salt {
let res: Vec<u8> = match call(Principal::management_canister(), "raw_rand", ()).await {
Ok((res,)) => res,
Ok((res, )) => res,
Err((_, err)) => trap(&format!("failed to get salt: {err}")),
};
let salt: Salt = res[..].try_into().unwrap_or_else(|_| {
Expand Down Expand Up @@ -752,7 +752,7 @@ mod attribute_sharing_mvp {
issuer: req.issuer.clone(),
},
)
.await;
.await;
Ok(prepared_id_alias)
}

Expand Down Expand Up @@ -791,11 +791,11 @@ mod test {
CandidSource::Text(&canister_interface),
CandidSource::File(Path::new("internet_identity.did")),
)
.unwrap_or_else(|e| {
panic!(
"the canister code interface is not equal to the did file: {:?}",
e
)
});
.unwrap_or_else(|e| {
panic!(
"the canister code interface is not equal to the did file: {:?}",
e
)
});
}
}
12 changes: 5 additions & 7 deletions src/internet_identity/src/openid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,8 @@ thread_local! {
}

pub fn setup_google(client_id: String) {
OPEN_ID_PROVIDERS.with(|providers| {
providers
.borrow_mut()
.push(Box::new(google::Provider::create(client_id)));
});
OPEN_ID_PROVIDERS
.with_borrow_mut(|providers| providers.push(Box::new(google::Provider::create(client_id))));
}

#[allow(unused)]
Expand All @@ -47,9 +44,8 @@ pub fn verify(jwt: &str, salt: &[u8; 32]) -> Result<OpenIdCredential, String> {
let claims: PartialClaims =
serde_json::from_slice(validation_item.claims()).map_err(|_| "Unable to decode claims")?;

OPEN_ID_PROVIDERS.with(|providers| {
OPEN_ID_PROVIDERS.with_borrow(|providers| {
match providers
.borrow()
.iter()
.find(|provider| provider.issuer() == claims.iss)
{
Expand All @@ -61,6 +57,7 @@ pub fn verify(jwt: &str, salt: &[u8; 32]) -> Result<OpenIdCredential, String> {

#[cfg(test)]
struct ExampleProvider;

#[cfg(test)]
impl OpenIdProvider for ExampleProvider {
fn issuer(&self) -> &'static str {
Expand All @@ -71,6 +68,7 @@ impl OpenIdProvider for ExampleProvider {
Ok(self.credential())
}
}

#[cfg(test)]
impl ExampleProvider {
fn credential(&self) -> OpenIdCredential {
Expand Down
111 changes: 56 additions & 55 deletions src/internet_identity/src/openid/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,8 @@ use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use base64::Engine;
use candid::Principal;
use candid::{Deserialize, Nat};
#[cfg(not(test))]
use ic_cdk::api::management_canister::http_request::http_request_with_closure;
#[cfg(not(test))]
use ic_cdk::api::management_canister::http_request::CanisterHttpRequestArgument;
#[cfg(not(test))]
use ic_cdk::api::management_canister::http_request::HttpMethod;
use ic_cdk::api::management_canister::http_request::{HttpHeader, HttpResponse};
#[cfg(not(test))]
use ic_cdk::spawn;
use ic_cdk::trap;
#[cfg(not(test))]
use ic_cdk_timers::set_timer;
use ic_stable_structures::Storable;
use identity_jose::jwk::{Jwk, JwkParamsRsa};
use identity_jose::jws::JwsAlgorithm::RS256;
Expand All @@ -27,14 +17,10 @@ use internet_identity_interface::internet_identity::types::MetadataEntryV2;
use rsa::{Pkcs1v15Sign, RsaPublicKey};
use serde::Serialize;
use sha2::{Digest, Sha256};
use std::cell::RefCell;
#[cfg(not(test))]
use std::cmp::min;
use std::cell::{Cell, RefCell};
use std::collections::HashMap;
use std::convert::Into;
use std::rc::Rc;
#[cfg(not(test))]
use std::time::Duration;

const ISSUER: &str = "https://accounts.google.com";

Expand All @@ -50,11 +36,13 @@ const HTTP_STATUS_OK: u8 = 200;
// Fetch the Google certs every hour, the responses are always
// valid for at least 5 hours so that should be enough margin.
#[cfg(not(test))]
const FETCH_CERTS_INTERVAL: u64 = 60 * 60;
const FETCH_CERTS_INTERVAL: u64 = 60 * 60; // 1 hour in seconds

const NANOSECONDS_PER_SECOND: u64 = 1_000_000_000;

// A JWT is only valid for a very small window, even if the JWT itself says it's valid for longer,
// we only need it right after it's being issued to create a JWT delegation with its own expiry.
const MAX_VALIDITY_WINDOW: u64 = 60_000_000_000; // 5 minutes in nanos, same as ingress expiry
const MAX_VALIDITY_WINDOW: u64 = 5 * 60 * NANOSECONDS_PER_SECOND; // 5 minutes in nanos, same as ingress expiry

#[derive(Serialize, Deserialize)]
struct Certs {
Expand Down Expand Up @@ -127,24 +115,27 @@ impl OpenIdProvider for Provider {
}

impl Provider {
#[cfg(not(test))]
pub fn create(client_id: String) -> Provider {
#[cfg(test)]
let certs = Rc::new(RefCell::new(TEST_CERTS.take()));

#[cfg(not(test))]
let certs: Rc<RefCell<Vec<Jwk>>> = Rc::new(RefCell::new(vec![]));

#[cfg(not(test))]
schedule_fetch_certs(Rc::clone(&certs), None);
Provider { client_id, certs }
}

#[cfg(test)]
pub fn create(client_id: String) -> Provider {
let certs = Rc::new(RefCell::new(
TEST_CERTS.with(|certs| certs.borrow().clone()),
));
Provider { client_id, certs }
}
}

#[cfg(not(test))]
fn schedule_fetch_certs(certs_reference: Rc<RefCell<Vec<Jwk>>>, delay: Option<u64>) {
use ic_cdk::spawn;
use ic_cdk_timers::set_timer;
use std::cmp::min;
use std::time::Duration;

set_timer(Duration::from_secs(delay.unwrap_or(0)), move || {
spawn(async move {
let new_delay = match fetch_certs().await {
Expand All @@ -163,6 +154,10 @@ fn schedule_fetch_certs(certs_reference: Rc<RefCell<Vec<Jwk>>>, delay: Option<u6

#[cfg(not(test))]
async fn fetch_certs() -> Result<Vec<Jwk>, String> {
use ic_cdk::api::management_canister::http_request::{
http_request_with_closure, CanisterHttpRequestArgument, HttpMethod,
};

let request = CanisterHttpRequestArgument {
url: CERTS_URL.into(),
method: HttpMethod::GET,
Expand Down Expand Up @@ -219,6 +214,30 @@ fn transform_certs(response: HttpResponse) -> HttpResponse {
}
}

fn create_rsa_public_key(jwk: &Jwk) -> Result<RsaPublicKey, String> {
// Extract the RSA parameters (modulus 'n' and exponent 'e') from the JWK.
let JwkParamsRsa { n, e, .. } = jwk
.try_rsa_params()
.map_err(|_| "Unable to extract modulus and exponent")?;

// Decode the base64-url encoded modulus 'n' of the RSA public key.
let n = BASE64_URL_SAFE_NO_PAD
.decode(n)
.map_err(|_| "Unable to decode modulus")?;

// Decode the base64-url encoded public exponent 'e' of the RSA public key.
let e = BASE64_URL_SAFE_NO_PAD
.decode(e)
.map_err(|_| "Unable to decode exponent")?;

// Construct the RSA public key using the decoded modulus and exponent.
RsaPublicKey::new(
rsa::BigUint::from_bytes_be(&n),
rsa::BigUint::from_bytes_be(&e),
)
.map_err(|_| "Unable to construct RSA public key".into())
}

/// Verifier implementation for `identity_jose` that verifies the signature of a JWT.
///
/// - `input`: A `VerificationInput` struct containing the JWT's algorithm (`alg`),
Expand All @@ -241,33 +260,14 @@ fn verify_signature(input: VerificationInput, jwk: &Jwk) -> Result<(), Signature
// Define the signature scheme to be used for verification (RSA PKCS#1 v1.5 with SHA-256).
let scheme = Pkcs1v15Sign::new::<Sha256>();

// Extract the RSA parameters (modulus 'n' and exponent 'e') from the JWK.
let JwkParamsRsa { n, e, .. } = jwk.try_rsa_params().map_err(|_| {
SignatureVerificationError::from(SignatureVerificationErrorKind::KeyDecodingFailure)
})?;

// Decode the base64-url encoded modulus 'n' of the RSA public key.
let n = BASE64_URL_SAFE_NO_PAD.decode(n).map_err(|_| {
SignatureVerificationError::from(SignatureVerificationErrorKind::KeyDecodingFailure)
})?;

// Decode the base64-url encoded public exponent 'e' of the RSA public key.
let e = BASE64_URL_SAFE_NO_PAD.decode(e).map_err(|_| {
SignatureVerificationError::from(SignatureVerificationErrorKind::KeyDecodingFailure)
})?;

// Construct the RSA public key using the decoded modulus and exponent.
let rsa_key = RsaPublicKey::new(
rsa::BigUint::from_bytes_be(&n),
rsa::BigUint::from_bytes_be(&e),
)
.map_err(|_| {
SignatureVerificationError::from(SignatureVerificationErrorKind::KeyDecodingFailure)
// Create RSA public key from JWK
let public_key = create_rsa_public_key(jwk).map_err(|_| {
SignatureVerificationError::new(SignatureVerificationErrorKind::KeyDecodingFailure)
})?;

// Verify the JWT signature using the RSA public key and the defined signature scheme.
// If the signature is invalid, return an InvalidSignature error.
rsa_key
public_key
.verify(scheme, &hashed_input, input.decoded_signature.as_ref())
.map_err(|_| SignatureVerificationErrorKind::InvalidSignature.into())
}
Expand All @@ -289,10 +289,10 @@ fn verify_claims(client_id: &String, claims: &Claims, salt: &[u8; 32]) -> Result
if claims.nonce != expected_nonce {
return Err(format!("Invalid nonce: {}", claims.nonce));
}
if now > claims.iat * 1_000_000_000 + MAX_VALIDITY_WINDOW {
if now > claims.iat * NANOSECONDS_PER_SECOND + MAX_VALIDITY_WINDOW {
return Err("JWT is no longer valid".into());
}
if now < claims.iat * 1_000_000_000 {
if now < claims.iat * NANOSECONDS_PER_SECOND {
return Err("JWT is not valid yet".into());
}

Expand All @@ -301,9 +301,9 @@ fn verify_claims(client_id: &String, claims: &Claims, salt: &[u8; 32]) -> Result

#[cfg(test)]
thread_local! {
static TEST_CALLER: RefCell<Principal> = RefCell::new(Principal::from_text("x4gp4-hxabd-5jt4d-wc6uw-qk4qo-5am4u-mncv3-wz3rt-usgjp-od3c2-oae").unwrap());
static TEST_TIME: RefCell<u64> = const { RefCell::new(1_736_794_102_000_000_000) };
static TEST_CERTS: RefCell<Vec<Jwk>> = RefCell::new(serde_json::from_str::<Certs>(r#"{"keys":[{"n": "jwstqI4w2drqbTTVRDriFqepwVVI1y05D5TZCmGvgMK5hyOsVW0tBRiY9Jk9HKDRue3vdXiMgarwqZEDOyOA0rpWh-M76eauFhRl9lTXd5gkX0opwh2-dU1j6UsdWmMa5OpVmPtqXl4orYr2_3iAxMOhHZ_vuTeD0KGeAgbeab7_4ijyLeJ-a8UmWPVkglnNb5JmG8To77tSXGcPpBcAFpdI_jftCWr65eL1vmAkPNJgUTgI4sGunzaybf98LSv_w4IEBc3-nY5GfL-mjPRqVCRLUtbhHO_5AYDpqGj6zkKreJ9-KsoQUP6RrAVxkNuOHV9g1G-CHihKsyAifxNN2Q","use": "sig","kty": "RSA","alg": "RS256","kid": "dd125d5f462fbc6014aedab81ddf3bcedab70847","e": "AQAB"}]}"#).unwrap().keys);
static TEST_CALLER: Cell<Principal> = Cell::new(Principal::from_text("x4gp4-hxabd-5jt4d-wc6uw-qk4qo-5am4u-mncv3-wz3rt-usgjp-od3c2-oae").unwrap());
static TEST_TIME: Cell<u64> = const { Cell::new(1_736_794_102 * NANOSECONDS_PER_SECOND) };
static TEST_CERTS: Cell<Vec<Jwk>> = Cell::new(serde_json::from_str::<Certs>(r#"{"keys":[{"n": "jwstqI4w2drqbTTVRDriFqepwVVI1y05D5TZCmGvgMK5hyOsVW0tBRiY9Jk9HKDRue3vdXiMgarwqZEDOyOA0rpWh-M76eauFhRl9lTXd5gkX0opwh2-dU1j6UsdWmMa5OpVmPtqXl4orYr2_3iAxMOhHZ_vuTeD0KGeAgbeab7_4ijyLeJ-a8UmWPVkglnNb5JmG8To77tSXGcPpBcAFpdI_jftCWr65eL1vmAkPNJgUTgI4sGunzaybf98LSv_w4IEBc3-nY5GfL-mjPRqVCRLUtbhHO_5AYDpqGj6zkKreJ9-KsoQUP6RrAVxkNuOHV9g1G-CHihKsyAifxNN2Q","use": "sig","kty": "RSA","alg": "RS256","kid": "dd125d5f462fbc6014aedab81ddf3bcedab70847","e": "AQAB"}]}"#).unwrap().keys);
}

#[cfg(not(test))]
Expand All @@ -313,7 +313,7 @@ fn caller() -> Principal {

#[cfg(test)]
fn caller() -> Principal {
TEST_CALLER.with(|caller| *caller.borrow())
TEST_CALLER.get()
}

#[cfg(not(test))]
Expand All @@ -322,7 +322,7 @@ fn time() -> u64 {
}
#[cfg(test)]
fn time() -> u64 {
TEST_TIME.with(|time| *time.borrow())
TEST_TIME.get()
}

#[test]
Expand Down Expand Up @@ -494,6 +494,7 @@ fn should_return_error_when_no_longer_valid() {
Err("JWT is no longer valid".into())
);
}

#[test]
fn should_return_error_when_not_valid_yet() {
TEST_TIME.replace(time() - 1);
Expand Down
Loading

0 comments on commit 94526a3

Please sign in to comment.