diff --git a/src/internet_identity/internet_identity.did b/src/internet_identity/internet_identity.did index bdaf682540..978eac4570 100644 --- a/src/internet_identity/internet_identity.did +++ b/src/internet_identity/internet_identity.did @@ -259,6 +259,8 @@ type InternetIdentityInit = record { captcha_config: opt CaptchaConfig; // Configuration for Related Origins Requests related_origins: opt vec text; + // Configuration for OpenID Google client + openid_google_client_id: opt text; }; type ChallengeKey = text; diff --git a/src/internet_identity/src/main.rs b/src/internet_identity/src/main.rs index a18730b2de..edece59822 100644 --- a/src/internet_identity/src/main.rs +++ b/src/internet_identity/src/main.rs @@ -82,9 +82,9 @@ async fn add_tentative_device( tentative_device_registration::add_tentative_device(anchor_number, device_data).await; match result { Ok(TentativeRegistrationInfo { - verification_code, - device_registration_timeout, - }) => AddTentativeDeviceResponse::AddedTentatively { + verification_code, + device_registration_timeout, + }) => AddTentativeDeviceResponse::AddedTentatively { verification_code, device_registration_timeout, }, @@ -154,7 +154,7 @@ fn add(anchor_number: AnchorNumber, device_data: DeviceData) { anchor_operation_with_authz_check(anchor_number, |anchor| { Ok::<_, String>(((), anchor_management::add(anchor, device_data))) }) - .unwrap_or_else(|err| trap(err.as_str())) + .unwrap_or_else(|err| trap(err.as_str())) } #[update] @@ -165,7 +165,7 @@ fn update(anchor_number: AnchorNumber, device_key: DeviceKey, device_data: Devic anchor_management::update(anchor, device_key, device_data), )) }) - .unwrap_or_else(|err| trap(err.as_str())) + .unwrap_or_else(|err| trap(err.as_str())) } #[update] @@ -176,7 +176,7 @@ fn replace(anchor_number: AnchorNumber, device_key: DeviceKey, device_data: Devi anchor_management::replace(anchor_number, anchor, device_key, device_data), )) }) - .unwrap_or_else(|err| trap(err.as_str())) + .unwrap_or_else(|err| trap(err.as_str())) } #[update] @@ -187,7 +187,7 @@ fn remove(anchor_number: AnchorNumber, device_key: DeviceKey) { anchor_management::remove(anchor_number, anchor, device_key), )) }) - .unwrap_or_else(|err| trap(err.as_str())) + .unwrap_or_else(|err| trap(err.as_str())) } /// Returns all devices of the anchor (authentication and recovery) but no information about device registrations. @@ -271,7 +271,7 @@ async fn prepare_delegation( max_time_to_live, &ii_domain, ) - .await + .await } #[query] @@ -388,13 +388,13 @@ fn post_upgrade(maybe_arg: Option) { } fn initialize(maybe_arg: Option) { - let related_origins = maybe_arg.clone().map_or_else( + let related_origins = maybe_arg.as_ref().map_or_else( || persistent_state(|storage| storage.related_origins.clone()), - |arg| arg.related_origins, + |arg| arg.related_origins.clone(), ); - let openid_google_client_id = maybe_arg.clone().map_or_else( + let openid_google_client_id = maybe_arg.as_ref().map_or_else( || persistent_state(|storage| storage.openid_google_client_id.clone()), - |arg| arg.openid_google_client_id, + |arg| arg.openid_google_client_id.clone(), ); init_assets(related_origins); apply_install_arg(maybe_arg); @@ -470,7 +470,7 @@ fn update_root_hash() { /// Calls raw rand to retrieve a random salt (32 bytes). async fn random_salt() -> Salt { let res: Vec = match call(Principal::management_canister(), "raw_rand", ()).await { - Ok((res,)) => res, + Ok((res, )) => res, Err((_, err)) => trap(&format!("failed to get salt: {err}")), }; let salt: Salt = res[..].try_into().unwrap_or_else(|_| { @@ -752,7 +752,7 @@ mod attribute_sharing_mvp { issuer: req.issuer.clone(), }, ) - .await; + .await; Ok(prepared_id_alias) } @@ -791,11 +791,11 @@ mod test { CandidSource::Text(&canister_interface), CandidSource::File(Path::new("internet_identity.did")), ) - .unwrap_or_else(|e| { - panic!( - "the canister code interface is not equal to the did file: {:?}", - e - ) - }); + .unwrap_or_else(|e| { + panic!( + "the canister code interface is not equal to the did file: {:?}", + e + ) + }); } } diff --git a/src/internet_identity/src/openid.rs b/src/internet_identity/src/openid.rs index 37afc160d4..565411706c 100644 --- a/src/internet_identity/src/openid.rs +++ b/src/internet_identity/src/openid.rs @@ -32,11 +32,8 @@ thread_local! { } pub fn setup_google(client_id: String) { - OPEN_ID_PROVIDERS.with(|providers| { - providers - .borrow_mut() - .push(Box::new(google::Provider::create(client_id))); - }); + OPEN_ID_PROVIDERS + .with_borrow_mut(|providers| providers.push(Box::new(google::Provider::create(client_id)))); } #[allow(unused)] @@ -47,9 +44,8 @@ pub fn verify(jwt: &str, salt: &[u8; 32]) -> Result { let claims: PartialClaims = serde_json::from_slice(validation_item.claims()).map_err(|_| "Unable to decode claims")?; - OPEN_ID_PROVIDERS.with(|providers| { + OPEN_ID_PROVIDERS.with_borrow(|providers| { match providers - .borrow() .iter() .find(|provider| provider.issuer() == claims.iss) { @@ -61,6 +57,7 @@ pub fn verify(jwt: &str, salt: &[u8; 32]) -> Result { #[cfg(test)] struct ExampleProvider; + #[cfg(test)] impl OpenIdProvider for ExampleProvider { fn issuer(&self) -> &'static str { @@ -71,6 +68,7 @@ impl OpenIdProvider for ExampleProvider { Ok(self.credential()) } } + #[cfg(test)] impl ExampleProvider { fn credential(&self) -> OpenIdCredential { diff --git a/src/internet_identity/src/openid/google.rs b/src/internet_identity/src/openid/google.rs index 8968b87df2..7fc697108e 100644 --- a/src/internet_identity/src/openid/google.rs +++ b/src/internet_identity/src/openid/google.rs @@ -4,18 +4,8 @@ use base64::prelude::BASE64_URL_SAFE_NO_PAD; use base64::Engine; use candid::Principal; use candid::{Deserialize, Nat}; -#[cfg(not(test))] -use ic_cdk::api::management_canister::http_request::http_request_with_closure; -#[cfg(not(test))] -use ic_cdk::api::management_canister::http_request::CanisterHttpRequestArgument; -#[cfg(not(test))] -use ic_cdk::api::management_canister::http_request::HttpMethod; use ic_cdk::api::management_canister::http_request::{HttpHeader, HttpResponse}; -#[cfg(not(test))] -use ic_cdk::spawn; use ic_cdk::trap; -#[cfg(not(test))] -use ic_cdk_timers::set_timer; use ic_stable_structures::Storable; use identity_jose::jwk::{Jwk, JwkParamsRsa}; use identity_jose::jws::JwsAlgorithm::RS256; @@ -27,14 +17,10 @@ use internet_identity_interface::internet_identity::types::MetadataEntryV2; use rsa::{Pkcs1v15Sign, RsaPublicKey}; use serde::Serialize; use sha2::{Digest, Sha256}; -use std::cell::RefCell; -#[cfg(not(test))] -use std::cmp::min; +use std::cell::{Cell, RefCell}; use std::collections::HashMap; use std::convert::Into; use std::rc::Rc; -#[cfg(not(test))] -use std::time::Duration; const ISSUER: &str = "https://accounts.google.com"; @@ -50,11 +36,13 @@ const HTTP_STATUS_OK: u8 = 200; // Fetch the Google certs every hour, the responses are always // valid for at least 5 hours so that should be enough margin. #[cfg(not(test))] -const FETCH_CERTS_INTERVAL: u64 = 60 * 60; +const FETCH_CERTS_INTERVAL: u64 = 60 * 60; // 1 hour in seconds + +const NANOSECONDS_PER_SECOND: u64 = 1_000_000_000; // A JWT is only valid for a very small window, even if the JWT itself says it's valid for longer, // we only need it right after it's being issued to create a JWT delegation with its own expiry. -const MAX_VALIDITY_WINDOW: u64 = 60_000_000_000; // 5 minutes in nanos, same as ingress expiry +const MAX_VALIDITY_WINDOW: u64 = 5 * 60 * NANOSECONDS_PER_SECOND; // 5 minutes in nanos, same as ingress expiry #[derive(Serialize, Deserialize)] struct Certs { @@ -127,24 +115,27 @@ impl OpenIdProvider for Provider { } impl Provider { - #[cfg(not(test))] pub fn create(client_id: String) -> Provider { + #[cfg(test)] + let certs = Rc::new(RefCell::new(TEST_CERTS.take())); + + #[cfg(not(test))] let certs: Rc>> = Rc::new(RefCell::new(vec![])); + + #[cfg(not(test))] schedule_fetch_certs(Rc::clone(&certs), None); - Provider { client_id, certs } - } - #[cfg(test)] - pub fn create(client_id: String) -> Provider { - let certs = Rc::new(RefCell::new( - TEST_CERTS.with(|certs| certs.borrow().clone()), - )); Provider { client_id, certs } } } #[cfg(not(test))] fn schedule_fetch_certs(certs_reference: Rc>>, delay: Option) { + use ic_cdk::spawn; + use ic_cdk_timers::set_timer; + use std::cmp::min; + use std::time::Duration; + set_timer(Duration::from_secs(delay.unwrap_or(0)), move || { spawn(async move { let new_delay = match fetch_certs().await { @@ -163,6 +154,10 @@ fn schedule_fetch_certs(certs_reference: Rc>>, delay: Option Result, String> { + use ic_cdk::api::management_canister::http_request::{ + http_request_with_closure, CanisterHttpRequestArgument, HttpMethod, + }; + let request = CanisterHttpRequestArgument { url: CERTS_URL.into(), method: HttpMethod::GET, @@ -219,6 +214,30 @@ fn transform_certs(response: HttpResponse) -> HttpResponse { } } +fn create_rsa_public_key(jwk: &Jwk) -> Result { + // Extract the RSA parameters (modulus 'n' and exponent 'e') from the JWK. + let JwkParamsRsa { n, e, .. } = jwk + .try_rsa_params() + .map_err(|_| "Unable to extract modulus and exponent")?; + + // Decode the base64-url encoded modulus 'n' of the RSA public key. + let n = BASE64_URL_SAFE_NO_PAD + .decode(n) + .map_err(|_| "Unable to decode modulus")?; + + // Decode the base64-url encoded public exponent 'e' of the RSA public key. + let e = BASE64_URL_SAFE_NO_PAD + .decode(e) + .map_err(|_| "Unable to decode exponent")?; + + // Construct the RSA public key using the decoded modulus and exponent. + RsaPublicKey::new( + rsa::BigUint::from_bytes_be(&n), + rsa::BigUint::from_bytes_be(&e), + ) + .map_err(|_| "Unable to construct RSA public key".into()) +} + /// Verifier implementation for `identity_jose` that verifies the signature of a JWT. /// /// - `input`: A `VerificationInput` struct containing the JWT's algorithm (`alg`), @@ -241,33 +260,14 @@ fn verify_signature(input: VerificationInput, jwk: &Jwk) -> Result<(), Signature // Define the signature scheme to be used for verification (RSA PKCS#1 v1.5 with SHA-256). let scheme = Pkcs1v15Sign::new::(); - // Extract the RSA parameters (modulus 'n' and exponent 'e') from the JWK. - let JwkParamsRsa { n, e, .. } = jwk.try_rsa_params().map_err(|_| { - SignatureVerificationError::from(SignatureVerificationErrorKind::KeyDecodingFailure) - })?; - - // Decode the base64-url encoded modulus 'n' of the RSA public key. - let n = BASE64_URL_SAFE_NO_PAD.decode(n).map_err(|_| { - SignatureVerificationError::from(SignatureVerificationErrorKind::KeyDecodingFailure) - })?; - - // Decode the base64-url encoded public exponent 'e' of the RSA public key. - let e = BASE64_URL_SAFE_NO_PAD.decode(e).map_err(|_| { - SignatureVerificationError::from(SignatureVerificationErrorKind::KeyDecodingFailure) - })?; - - // Construct the RSA public key using the decoded modulus and exponent. - let rsa_key = RsaPublicKey::new( - rsa::BigUint::from_bytes_be(&n), - rsa::BigUint::from_bytes_be(&e), - ) - .map_err(|_| { - SignatureVerificationError::from(SignatureVerificationErrorKind::KeyDecodingFailure) + // Create RSA public key from JWK + let public_key = create_rsa_public_key(jwk).map_err(|_| { + SignatureVerificationError::new(SignatureVerificationErrorKind::KeyDecodingFailure) })?; // Verify the JWT signature using the RSA public key and the defined signature scheme. // If the signature is invalid, return an InvalidSignature error. - rsa_key + public_key .verify(scheme, &hashed_input, input.decoded_signature.as_ref()) .map_err(|_| SignatureVerificationErrorKind::InvalidSignature.into()) } @@ -289,10 +289,10 @@ fn verify_claims(client_id: &String, claims: &Claims, salt: &[u8; 32]) -> Result if claims.nonce != expected_nonce { return Err(format!("Invalid nonce: {}", claims.nonce)); } - if now > claims.iat * 1_000_000_000 + MAX_VALIDITY_WINDOW { + if now > claims.iat * NANOSECONDS_PER_SECOND + MAX_VALIDITY_WINDOW { return Err("JWT is no longer valid".into()); } - if now < claims.iat * 1_000_000_000 { + if now < claims.iat * NANOSECONDS_PER_SECOND { return Err("JWT is not valid yet".into()); } @@ -301,9 +301,9 @@ fn verify_claims(client_id: &String, claims: &Claims, salt: &[u8; 32]) -> Result #[cfg(test)] thread_local! { - static TEST_CALLER: RefCell = RefCell::new(Principal::from_text("x4gp4-hxabd-5jt4d-wc6uw-qk4qo-5am4u-mncv3-wz3rt-usgjp-od3c2-oae").unwrap()); - static TEST_TIME: RefCell = const { RefCell::new(1_736_794_102_000_000_000) }; - static TEST_CERTS: RefCell> = RefCell::new(serde_json::from_str::(r#"{"keys":[{"n": "jwstqI4w2drqbTTVRDriFqepwVVI1y05D5TZCmGvgMK5hyOsVW0tBRiY9Jk9HKDRue3vdXiMgarwqZEDOyOA0rpWh-M76eauFhRl9lTXd5gkX0opwh2-dU1j6UsdWmMa5OpVmPtqXl4orYr2_3iAxMOhHZ_vuTeD0KGeAgbeab7_4ijyLeJ-a8UmWPVkglnNb5JmG8To77tSXGcPpBcAFpdI_jftCWr65eL1vmAkPNJgUTgI4sGunzaybf98LSv_w4IEBc3-nY5GfL-mjPRqVCRLUtbhHO_5AYDpqGj6zkKreJ9-KsoQUP6RrAVxkNuOHV9g1G-CHihKsyAifxNN2Q","use": "sig","kty": "RSA","alg": "RS256","kid": "dd125d5f462fbc6014aedab81ddf3bcedab70847","e": "AQAB"}]}"#).unwrap().keys); + static TEST_CALLER: Cell = Cell::new(Principal::from_text("x4gp4-hxabd-5jt4d-wc6uw-qk4qo-5am4u-mncv3-wz3rt-usgjp-od3c2-oae").unwrap()); + static TEST_TIME: Cell = const { Cell::new(1_736_794_102 * NANOSECONDS_PER_SECOND) }; + static TEST_CERTS: Cell> = Cell::new(serde_json::from_str::(r#"{"keys":[{"n": "jwstqI4w2drqbTTVRDriFqepwVVI1y05D5TZCmGvgMK5hyOsVW0tBRiY9Jk9HKDRue3vdXiMgarwqZEDOyOA0rpWh-M76eauFhRl9lTXd5gkX0opwh2-dU1j6UsdWmMa5OpVmPtqXl4orYr2_3iAxMOhHZ_vuTeD0KGeAgbeab7_4ijyLeJ-a8UmWPVkglnNb5JmG8To77tSXGcPpBcAFpdI_jftCWr65eL1vmAkPNJgUTgI4sGunzaybf98LSv_w4IEBc3-nY5GfL-mjPRqVCRLUtbhHO_5AYDpqGj6zkKreJ9-KsoQUP6RrAVxkNuOHV9g1G-CHihKsyAifxNN2Q","use": "sig","kty": "RSA","alg": "RS256","kid": "dd125d5f462fbc6014aedab81ddf3bcedab70847","e": "AQAB"}]}"#).unwrap().keys); } #[cfg(not(test))] @@ -313,7 +313,7 @@ fn caller() -> Principal { #[cfg(test)] fn caller() -> Principal { - TEST_CALLER.with(|caller| *caller.borrow()) + TEST_CALLER.get() } #[cfg(not(test))] @@ -322,7 +322,7 @@ fn time() -> u64 { } #[cfg(test)] fn time() -> u64 { - TEST_TIME.with(|time| *time.borrow()) + TEST_TIME.get() } #[test] @@ -494,6 +494,7 @@ fn should_return_error_when_no_longer_valid() { Err("JWT is no longer valid".into()) ); } + #[test] fn should_return_error_when_not_valid_yet() { TEST_TIME.replace(time() - 1); diff --git a/src/internet_identity/tests/integration/config.rs b/src/internet_identity/tests/integration/config.rs index 0d7153eabd..59430eeb01 100644 --- a/src/internet_identity/tests/integration/config.rs +++ b/src/internet_identity/tests/integration/config.rs @@ -44,12 +44,13 @@ fn should_retain_anchor_on_user_range_change() -> Result<(), CallError> { #[test] fn should_retain_config_after_none() -> Result<(), CallError> { let env = env(); - let related_origins: Vec = [ + let related_origins = [ "https://identity.internetcomputer.org".to_string(), "https://identity.ic0.app".to_string(), "https://identity.icp0.io".to_string(), ] .to_vec(); + let openid_google_client_id = "https://example.com".to_string(); let config = InternetIdentityInit { assigned_user_number_range: Some((3456, 798977)), archive_config: Some(ArchiveConfig { @@ -72,7 +73,7 @@ fn should_retain_config_after_none() -> Result<(), CallError> { }, }), related_origins: Some(related_origins), - openid_google_client_id: None, + openid_google_client_id: Some(openid_google_client_id), }; let canister_id = install_ii_canister_with_arg(&env, II_WASM.clone(), Some(config.clone())); @@ -89,12 +90,13 @@ fn should_retain_config_after_none() -> Result<(), CallError> { #[test] fn should_override_partially() -> Result<(), CallError> { let env = env(); - let related_origins: Vec = [ + let related_origins = [ "https://identity.internetcomputer.org".to_string(), "https://identity.ic0.app".to_string(), "https://identity.icp0.io".to_string(), ] .to_vec(); + let openid_google_client_id = "https://example.com".to_string(); let config = InternetIdentityInit { assigned_user_number_range: Some((3456, 798977)), archive_config: Some(ArchiveConfig { @@ -117,7 +119,7 @@ fn should_override_partially() -> Result<(), CallError> { }, }), related_origins: Some(related_origins), - openid_google_client_id: None, + openid_google_client_id: Some(openid_google_client_id), }; let canister_id = install_ii_canister_with_arg(&env, II_WASM.clone(), Some(config.clone())); @@ -157,6 +159,7 @@ fn should_override_partially() -> Result<(), CallError> { "https://identity.ic0.app".to_string(), ] .to_vec(); + let openid_google_client_id_2 = "https://example2.com".to_string(); let config_3 = InternetIdentityInit { assigned_user_number_range: None, archive_config: None, @@ -164,7 +167,7 @@ fn should_override_partially() -> Result<(), CallError> { register_rate_limit: None, captcha_config: None, related_origins: Some(related_origins_2.clone()), - openid_google_client_id: None, + openid_google_client_id: Some(openid_google_client_id_2.clone()), }; let _ = @@ -172,6 +175,7 @@ fn should_override_partially() -> Result<(), CallError> { let expected_config_3 = InternetIdentityInit { related_origins: Some(related_origins_2.clone()), + openid_google_client_id: Some(openid_google_client_id_2.clone()), ..expected_config_2 };