Skip to content

Commit

Permalink
feat(query): add config jwks_refresh_interval & jwks_refresh_timeout (d…
Browse files Browse the repository at this point in the history
…atabendlabs#17087)

* feat(query): add config jwks_refresh_interval & jwks_refresh_timeout

* fix: remove force reload when key not found

* z

* z

* z

* z

* z

* z

* z

* z

* z
  • Loading branch information
everpcpc authored Dec 19, 2024
1 parent 85c4caf commit c2294d8
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 73 deletions.
1 change: 1 addition & 0 deletions .github/actions/setup_build_tool/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ runs:
EOF
RUNNER_PROVIDER="${RUNNER_PROVIDER:-github}"
export SCCACHE_IDLE_TIMEOUT=0
case ${RUNNER_PROVIDER} in
aws)
echo "setting up sccache for AWS S3..."
Expand Down
13 changes: 12 additions & 1 deletion src/query/config/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1528,7 +1528,14 @@ pub struct QueryConfig {
#[clap(long, value_name = "VALUE", default_value_t)]
pub jwt_key_file: String,

/// If there are multiple trusted jwt provider put it into additional_jwt_key_files configuration
/// Interval in seconds to refresh jwks
#[clap(long, value_name = "VALUE", default_value = "600")]
pub jwks_refresh_interval: u64,

/// Timeout in seconds to refresh jwks
#[clap(long, value_name = "VALUE", default_value = "10")]
pub jwks_refresh_timeout: u64,

#[clap(skip)]
pub jwt_key_files: Vec<String>,

Expand Down Expand Up @@ -1754,6 +1761,8 @@ impl TryInto<InnerQueryConfig> for QueryConfig {
max_storage_io_requests: self.max_storage_io_requests,
jwt_key_file: self.jwt_key_file,
jwt_key_files: self.jwt_key_files,
jwks_refresh_interval: self.jwks_refresh_interval,
jwks_refresh_timeout: self.jwks_refresh_timeout,
default_storage_format: self.default_storage_format,
default_compression: self.default_compression,
builtin: BuiltInConfig {
Expand Down Expand Up @@ -1845,6 +1854,8 @@ impl From<InnerQueryConfig> for QueryConfig {
max_storage_io_requests: inner.max_storage_io_requests,
jwt_key_file: inner.jwt_key_file,
jwt_key_files: inner.jwt_key_files,
jwks_refresh_interval: inner.jwks_refresh_interval,
jwks_refresh_timeout: inner.jwks_refresh_timeout,
default_storage_format: inner.default_storage_format,
default_compression: inner.default_compression,
users: inner.builtin.users,
Expand Down
4 changes: 4 additions & 0 deletions src/query/config/src/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ pub struct QueryConfig {

pub jwt_key_file: String,
pub jwt_key_files: Vec<String>,
pub jwks_refresh_interval: u64,
pub jwks_refresh_timeout: u64,
pub default_storage_format: String,
pub default_compression: String,
pub builtin: BuiltInConfig,
Expand Down Expand Up @@ -301,6 +303,8 @@ impl Default for QueryConfig {
max_storage_io_requests: None,
jwt_key_file: "".to_string(),
jwt_key_files: Vec::new(),
jwks_refresh_interval: 600,
jwks_refresh_timeout: 10,
default_storage_format: "auto".to_string(),
default_compression: "auto".to_string(),
builtin: BuiltInConfig::default(),
Expand Down
2 changes: 2 additions & 0 deletions src/query/service/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ impl AuthMgr {
jwt_auth: JwtAuthenticator::create(
cfg.query.jwt_key_file.clone(),
cfg.query.jwt_key_files.clone(),
cfg.query.jwks_refresh_interval,
cfg.query.jwks_refresh_timeout,
),
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ DB.Table: 'system'.'configs', Table: configs-table_id:1, ver:0, Engine: SystemCo
| 'query' | 'http_handler_tls_server_root_ca_cert' | '' | '' |
| 'query' | 'internal_enable_sandbox_tenant' | 'false' | '' |
| 'query' | 'internal_merge_on_read_mutation' | 'false' | '' |
| 'query' | 'jwks_refresh_interval' | '600' | '' |
| 'query' | 'jwks_refresh_timeout' | '10' | '' |
| 'query' | 'jwt_key_file' | '' | '' |
| 'query' | 'jwt_key_files' | '' | '' |
| 'query' | 'management_mode' | 'false' | '' |
Expand Down
22 changes: 19 additions & 3 deletions src/query/users/src/jwt/authenticator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,30 @@ impl CustomClaims {
}

impl JwtAuthenticator {
pub fn create(jwt_key_file: String, jwt_key_files: Vec<String>) -> Option<Self> {
pub fn create(
jwt_key_file: String,
jwt_key_files: Vec<String>,
jwks_refresh_interval: u64,
jwks_refresh_timeout: u64,
) -> Option<Self> {
if jwt_key_file.is_empty() && jwt_key_files.is_empty() {
return None;
}
// init a vec of key store
let mut key_stores = vec![jwk::JwkKeyStore::new(jwt_key_file)];
let mut key_stores = vec![];
if !jwt_key_file.is_empty() {
key_stores.push(
jwk::JwkKeyStore::new(jwt_key_file)
.with_refresh_interval(jwks_refresh_interval)
.with_refresh_timeout(jwks_refresh_timeout),
);
}
for u in jwt_key_files {
key_stores.push(jwk::JwkKeyStore::new(u))
key_stores.push(
jwk::JwkKeyStore::new(u)
.with_refresh_interval(jwks_refresh_interval)
.with_refresh_timeout(jwks_refresh_timeout),
);
}
Some(JwtAuthenticator { key_stores })
}
Expand Down
72 changes: 36 additions & 36 deletions src/query/users/src/jwt/jwk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ use serde::Serialize;

use super::PubKey;

const JWK_REFRESH_INTERVAL: u64 = 15;
const JWKS_REFRESH_TIMEOUT: u64 = 10;
const JWKS_REFRESH_INTERVAL: u64 = 600;

#[derive(Debug, Serialize, Deserialize)]
pub struct JwkKey {
Expand Down Expand Up @@ -99,17 +100,17 @@ pub struct JwkKeyStore {
cached_keys: Arc<RwLock<HashMap<String, PubKey>>>,
pub(crate) last_refreshed_at: RwLock<Option<Instant>>,
pub(crate) refresh_interval: Duration,
pub(crate) refresh_timeout: Duration,
pub(crate) load_keys_func: Option<Arc<dyn Fn() -> HashMap<String, PubKey> + Send + Sync>>,
}

impl JwkKeyStore {
pub fn new(url: String) -> Self {
let refresh_interval = Duration::from_secs(JWK_REFRESH_INTERVAL * 60);
let keys = Arc::new(RwLock::new(HashMap::new()));
Self {
url,
cached_keys: keys,
refresh_interval,
cached_keys: Arc::new(RwLock::new(HashMap::new())),
refresh_interval: Duration::from_secs(JWKS_REFRESH_INTERVAL),
refresh_timeout: Duration::from_secs(JWKS_REFRESH_TIMEOUT),
last_refreshed_at: RwLock::new(None),
load_keys_func: None,
}
Expand All @@ -124,6 +125,16 @@ impl JwkKeyStore {
self
}

pub fn with_refresh_interval(mut self, interval: u64) -> Self {
self.refresh_interval = Duration::from_secs(interval);
self
}

pub fn with_refresh_timeout(mut self, timeout: u64) -> Self {
self.refresh_timeout = Duration::from_secs(timeout);
self
}

pub fn url(&self) -> String {
self.url.clone()
}
Expand All @@ -136,12 +147,19 @@ impl JwkKeyStore {
return Ok(load_keys_func());
}

let response = reqwest::get(&self.url).await.map_err(|e| {
let client = reqwest::Client::builder()
.timeout(self.refresh_timeout)
.build()
.map_err(|e| {
ErrorCode::InvalidConfig(format!("Failed to create jwks client: {}", e))
})?;
let response = client.get(&self.url).send().await.map_err(|e| {
ErrorCode::AuthenticateFailure(format!("Could not download JWKS: {}", e))
})?;
let body = response.text().await.unwrap();
let jwk_keys = serde_json::from_str::<JwkKeys>(&body)
.map_err(|e| ErrorCode::InvalidConfig(format!("Failed to parse keys: {}", e)))?;
let jwk_keys: JwkKeys = response
.json()
.await
.map_err(|e| ErrorCode::InvalidConfig(format!("Failed to parse JWKS: {}", e)))?;
let mut new_keys: HashMap<String, PubKey> = HashMap::new();
for k in &jwk_keys.keys {
new_keys.insert(k.kid.to_string(), k.get_public_key()?);
Expand All @@ -166,6 +184,7 @@ impl JwkKeyStore {
let new_keys = match self.load_keys().await {
Ok(new_keys) => new_keys,
Err(err) => {
warn!("Failed to load JWKS: {}", err);
if !old_keys.is_empty() {
return Ok(old_keys);
}
Expand All @@ -177,9 +196,9 @@ impl JwkKeyStore {
if !new_keys.keys().eq(old_keys.keys()) {
info!("JWKS keys changed.");
}
*self.cached_keys.write() = new_keys;
*self.cached_keys.write() = new_keys.clone();
self.last_refreshed_at.write().replace(Instant::now());
Ok(old_keys)
Ok(new_keys)
}

#[async_backtrace::framed]
Expand All @@ -200,31 +219,12 @@ impl JwkKeyStore {
}
};

// happy path: the key_id is found in the store
if let Some(key) = keys.get(&key_id) {
return Ok(key.clone());
match keys.get(&key_id) {
None => Err(ErrorCode::AuthenticateFailure(format!(
"key id {} not found in jwk store",
key_id
))),
Some(key) => Ok(key.clone()),
}

// if the key_id is not set here, it might because the JWKS has been rotated, we need to refresh it.
warn!(
"key_id {} not found in jwks store, try to reload keys",
key_id
);
let keys = self
.load_keys_with_cache(true)
.await
.map_err(|e| e.add_message("failed to reload JWKS keys"))?;

let key = match keys.get(&key_id) {
None => {
return Err(ErrorCode::AuthenticateFailure(format!(
"key id {} not found in jwk store",
key_id
)));
}
Some(key) => key.clone(),
};

Ok(key)
}
}
34 changes: 1 addition & 33 deletions src/query/users/tests/it/jwt/authenticator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;

use base64::engine::general_purpose;
use base64::prelude::*;
use databend_common_base::base::tokio;
use databend_common_exception::Result;
use databend_common_users::JwkKeyStore;
use databend_common_users::JwtAuthenticator;
use databend_common_users::PubKey;
use jwt_simple::prelude::*;
use wiremock::matchers::method;
use wiremock::matchers::path;
Expand Down Expand Up @@ -60,7 +53,7 @@ async fn test_parse_non_custom_claim() -> Result<()> {
.mount(&server)
.await;
let first_url = format!("http://{}{}", server.address(), json_path);
let auth = JwtAuthenticator::create(first_url, vec![]).unwrap();
let auth = JwtAuthenticator::create(first_url, vec![], 600, 10).unwrap();
let user_name = "test-user2";
let my_additional_data = MyAdditionalData {
user_is_admin: false,
Expand All @@ -74,28 +67,3 @@ async fn test_parse_non_custom_claim() -> Result<()> {
assert_eq!(res.custom.role, None);
Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_jwk_key_store_retry_on_key_not_found() -> Result<()> {
let func_calls = Arc::new(AtomicUsize::new(0));
let func_calls_cloned = func_calls.clone();

let mock_load_keys = Arc::new(move || -> HashMap<String, PubKey> {
let mut keys_map = HashMap::new();
keys_map.insert(
"key1".to_string(),
PubKey::RSA256(RS256KeyPair::generate(2048).unwrap().public_key().into()),
);
func_calls_cloned.fetch_add(1, Ordering::SeqCst);
keys_map
});
let store = JwkKeyStore::new("".to_string()).with_load_keys_func(mock_load_keys);

let r = store.get_key(Some("key2".to_string())).await;
assert_eq!(
r.unwrap_err().message(),
"key id key2 not found in jwk store"
);
assert_eq!(func_calls.load(Ordering::SeqCst), 2);
Ok(())
}

0 comments on commit c2294d8

Please sign in to comment.