Skip to content

Commit

Permalink
All tests green
Browse files Browse the repository at this point in the history
  • Loading branch information
rohe committed Oct 14, 2023
1 parent 640f837 commit 2529b77
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 148 deletions.
79 changes: 43 additions & 36 deletions satosa_oidcop/core/persistence.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import logging
from typing import Optional
from typing import Union
Expand All @@ -6,11 +7,12 @@
from cryptojwt.exception import BadSignature
from cryptojwt.exception import Invalid
from cryptojwt.exception import MissingKey
from cryptojwt.utils import as_bytes
from idpyoidc.message import Message
from idpyoidc.message.oidc import JsonWebToken
from idpyoidc.server.client_authn import basic_authn
from idpyoidc.server.exception import ClientAuthenticationError
from idpyoidc.util import instantiate
from idpyoidc.server.token import UnknownToken
from idpyoidc.util import sanitize

from . import ExtendedContext
Expand Down Expand Up @@ -47,13 +49,15 @@ def _deal_with_client_assertion(self, session_manager, token):
return ca_jwt["iss"]

def _get_client_id(self,
session_manager,
endpoint_context,
request: Union[Message, dict],
http_info: dict) -> Optional[str]:
# Figure out which client is concerned
if "client_id" in request:
return request["client_id"]

session_manager = endpoint_context.session_manager

for param in ["code", "access_token", "refresh_token", "registration_access_token"]:
if param in request:
_token_info = session_manager.token_handler.info(request[param])
Expand All @@ -66,63 +70,66 @@ def _get_client_id(self,

authz = http_info.get("headers", {}).get("authorization", "")
if authz:

if "Basic " in authz:
token = authz.split(" ", 1)[1]
_info = basic_authn(token)
_info = basic_authn(authz)
return _info["id"]
else:
token = authz.split(" ", 1)[1]
_token_info = session_manager.token_handler.info(token)
try:
_token_info = session_manager.token_handler.info(token)
except UnknownToken:
_msg = ""
logger.error("Someone is using a token I can not parse")
raise
sid = _token_info["sid"]
_path = session_manager.decrypt_branch_id(sid)
return _path[1]

return None

# def get_client_info(self, client_id, context):
# _cinfo = context.cdb.get(client_id)
# if _cinfo:
# return _cinfo
# else:
# _cinfo = self.app.storage.fetch(informantion_type="client_info",
# key=client_id)
# if _cinfo:
# context.cdb = {client_id: _cinfo}
#
# return _cinfo

def update_state(self,
request: Union[Message, dict],
http_info: Optional[dict]):
sman = self.app.server.context.session_manager
_session_info = self.app.storage.fetch(information_type="session_info", key="")

self._flush_endpoint_context_memory(sman)
sman.load(_session_info)

http_info: Optional[dict]) -> str:
endpoint_context = self.app.server.context
sman = endpoint_context.session_manager
# Find the client_id
client_id = self._get_client_id(session_manager=sman,
client_id = self._get_client_id(endpoint_context=endpoint_context,
request=request,
http_info=http_info)
# Update session
_client_session_info = self.app.storage.fetch(information_type="client_session_info",
key=client_id)
_session_info["db"] = _client_session_info

self._flush_endpoint_context_memory(sman)
sman.load(_session_info)
_session_info = self.app.storage.fetch(information_type="session_info", key="")
if _session_info:
self._flush_endpoint_context_memory(sman)
sman.load(_session_info)

# Update session
_client_session_info = self.app.storage.fetch(information_type="client_session_info",
key=client_id)
_session_info["db"] = _client_session_info

self._flush_endpoint_context_memory(sman)
sman.load(_session_info)

# Update client database
client_info = self.app.storage.fetch(information_type="client_info", key=client_id)
self.app.server.context.cdb = {client_id: client_info}
return client_id

def _hash_session_id(self, session_id):
return hashlib.sha256(as_bytes(session_id)).hexdigest()

def load_claims(self, client_id):
return self.app.storage.fetch(information_type="claims", key=client_id)
def load_claims(self, session_id):
# session IDs can be quite large, so I just use the hash
sid_hash = self._hash_session_id(session_id)
return self.app.storage.fetch(information_type="claims", key=sid_hash)

# Now for the store part

def store_claims(self, claims: dict, client_id: str):
self.app.storage.store(information_type="claims", value=claims, key=client_id)
def store_claims(self, claims: dict, session_id: str):
# session IDs can be quite large, so I just use the hash
sid_hash = self._hash_session_id(session_id)
self.app.storage.store(information_type="claims", value=claims, key=sid_hash)

def get_client_session_info(self, client_id, db, session_manager):
res = {}
Expand All @@ -147,7 +154,7 @@ def store_state(self, client_id):
_session_state["db"] = {}
self.app.storage.store(information_type="session_info", value=_session_state)

def _get_http_headers(self, context: ExtendedContext):
def _get_http_info(self, context: ExtendedContext):
"""
aligns parameters for oidcop interoperability needs
"""
Expand Down
4 changes: 4 additions & 0 deletions satosa_oidcop/core/storage/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ def store(self, information_type: str, value, key: Optional[str] = ""):
self[":".join([information_type, key])] = value
else:
self[information_type] = value

def information_type_keys(self, information_type: str):
return [k[len(information_type) + 1:] for k in (self.keys()) if
k.startswith(information_type)]
Loading

0 comments on commit 2529b77

Please sign in to comment.