From 023f0583e5effa6351031e5add6be5dd3fa976fc Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Mon, 18 Dec 2023 21:43:52 +0100 Subject: [PATCH] In case base_url is not defined use entity_id. When entity_id is defined use it as is. --- src/idpyoidc/claims.py | 17 ++++++++++++----- src/idpyoidc/client/claims/__init__.py | 9 +++++++-- src/idpyoidc/client/service_context.py | 3 ++- src/idpyoidc/metadata.py | 17 ++++++++++------- src/idpyoidc/server/claims/__init__.py | 7 +++++-- 5 files changed, 36 insertions(+), 17 deletions(-) diff --git a/src/idpyoidc/claims.py b/src/idpyoidc/claims.py index f8f2d2b1..e684624f 100644 --- a/src/idpyoidc/claims.py +++ b/src/idpyoidc/claims.py @@ -122,7 +122,7 @@ def _keyjar(self, keyjar=None, conf=None, entity_id=""): return keyjar, _uri_path - def get_base_url(self, configuration: dict): + def get_base_url(self, configuration: dict, entity_id: Optional[str]=""): raise NotImplementedError() def get_id(self, configuration: dict): @@ -134,7 +134,10 @@ def add_extra_keys(self, keyjar, id): def get_jwks(self, keyjar): return keyjar.export_jwks() - def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None): + def handle_keys(self, + configuration: dict, + keyjar: Optional[KeyJar] = None, + entity_id: Optional[str] = ""): _jwks = _jwks_uri = None _id = self.get_id(configuration) keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id) @@ -147,7 +150,7 @@ def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None): if "jwks_uri" in configuration: # simple _jwks_uri = configuration.get("jwks_uri") elif uri_path: - _base_url = self.get_base_url(configuration) + _base_url = self.get_base_url(configuration, entity_id=entity_id) _jwks_uri = add_path(_base_url, uri_path) else: # jwks or nothing _jwks = self.get_jwks(keyjar) @@ -155,7 +158,11 @@ def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None): return {"keyjar": keyjar, "jwks": _jwks, "jwks_uri": _jwks_uri} def load_conf( - self, configuration: dict, supports: dict, keyjar: Optional[KeyJar] = None + self, + configuration: dict, + supports: dict, + keyjar: Optional[KeyJar] = None, + entity_id: Optional[str] = "" ) -> KeyJar: for attr, val in configuration.items(): if attr in ["preference", "capabilities"]: @@ -167,7 +174,7 @@ def load_conf( self.locals(configuration) - for key, val in self.handle_keys(configuration, keyjar=keyjar).items(): + for key, val in self.handle_keys(configuration, keyjar=keyjar, entity_id=entity_id).items(): if key == "keyjar": keyjar = val elif val: diff --git a/src/idpyoidc/client/claims/__init__.py b/src/idpyoidc/client/claims/__init__.py index 12a0a358..59ff0687 100644 --- a/src/idpyoidc/client/claims/__init__.py +++ b/src/idpyoidc/client/claims/__init__.py @@ -1,3 +1,5 @@ +from typing import Optional + from cryptojwt import KeyJar from cryptojwt.exception import IssuerNotFound from cryptojwt.jwk.hmac import SYMKey @@ -11,10 +13,13 @@ def get_client_authn_methods(): class Claims(claims.Claims): - def get_base_url(self, configuration: dict): + def get_base_url(self, configuration: dict, entity_id: Optional[str] = ""): _base = configuration.get("base_url") if not _base: - _base = configuration.get("client_id") + if entity_id: + _base = entity_id + else: + _base = configuration.get("client_id") return _base diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index a919a339..37dfa072 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -173,7 +173,8 @@ def __init__( for key, val in kwargs.items(): setattr(self, key, val) - self.keyjar = self.claims.load_conf(config.conf, supports=self.supports(), keyjar=keyjar) + self.keyjar = self.claims.load_conf(config.conf, supports=self.supports(), keyjar=keyjar, + entity_id=self.entity_id) _jwks_uri = self.provider_info.get("jwks_uri") if _jwks_uri: diff --git a/src/idpyoidc/metadata.py b/src/idpyoidc/metadata.py index 8b332322..7561d483 100644 --- a/src/idpyoidc/metadata.py +++ b/src/idpyoidc/metadata.py @@ -1,5 +1,5 @@ -import logging from functools import cmp_to_key +import logging from typing import Callable from typing import Optional @@ -128,7 +128,7 @@ def _keyjar(self, keyjar=None, conf=None, entity_id=""): _uri_path = conf["key_conf"].get("uri_path") return keyjar, _uri_path - def get_base_url(self, configuration: dict): + def get_base_url(self, configuration: dict, entity_id: Optional[str] = ""): raise NotImplementedError() def get_id(self, configuration: dict): @@ -140,9 +140,11 @@ def add_extra_keys(self, keyjar, id): def get_jwks(self, keyjar): return None - def handle_keys( - self, configuration: dict, keyjar: Optional[KeyJar] = None, base_url: Optional[str] = "" - ): + def handle_keys(self, + configuration: dict, + keyjar: Optional[KeyJar] = None, + base_url: Optional[str] = "", + entity_id: Optional[str] = ""): _jwks = _jwks_uri = None _id = self.get_id(configuration) keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id) @@ -154,7 +156,7 @@ def handle_keys( _jwks_uri = configuration.get("jwks_uri") elif uri_path: if not base_url: - base_url = self.get_base_url(configuration) + base_url = self.get_base_url(configuration, entity_id=entity_id) _jwks_uri = add_path(base_url, uri_path) else: # jwks or nothing _jwks = self.get_jwks(keyjar) @@ -162,7 +164,8 @@ def handle_keys( return {"keyjar": keyjar, "jwks": _jwks, "jwks_uri": _jwks_uri} def load_conf( - self, configuration, supports, keyjar: Optional[KeyJar] = None, base_url: Optional[str] = "" + self, configuration, supports, keyjar: Optional[KeyJar] = None, + base_url: Optional[str] = "" ): for attr, val in configuration.items(): if attr == "preference": diff --git a/src/idpyoidc/server/claims/__init__.py b/src/idpyoidc/server/claims/__init__.py index 6ca13ecc..7f4f3ea1 100644 --- a/src/idpyoidc/server/claims/__init__.py +++ b/src/idpyoidc/server/claims/__init__.py @@ -4,10 +4,13 @@ class Claims(claims.Claims): - def get_base_url(self, configuration: dict): + def get_base_url(self, configuration: dict, entity_id: Optional[str] = ""): _base = configuration.get("base_url") if not _base: - _base = configuration.get("issuer") + if entity_id: + _base = entity_id + else: + _base = configuration.get("issuer") return _base