Skip to content

Commit

Permalink
add pkce to public device grants
Browse files Browse the repository at this point in the history
  • Loading branch information
dsschult committed Aug 8, 2024
1 parent 15f46cb commit e691900
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 74 deletions.
3 changes: 3 additions & 0 deletions examples/get_device_credentials_token.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import logging

from rest_tools.client import SavedDeviceGrantAuth

Expand All @@ -14,6 +15,8 @@ def main():
parser.add_argument('--address', default='https://keycloak.icecube.wisc.edu/auth/realms/IceCube', help='OAuth2 server address')
parser.add_argument('client_id', help='client id')

logging.basicConfig(level=logging.DEBUG)

args = parser.parse_args()
kwargs = vars(args)
print('access token:', get_token(**kwargs))
Expand Down
121 changes: 66 additions & 55 deletions rest_tools/client/device_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .openid_client import OpenIDRestClient
from ..utils.auth import OpenIDAuth
from ..utils.pkce import PKCEMixin


def _print_qrcode(req: Dict[str, str]) -> None:
Expand Down Expand Up @@ -58,53 +59,29 @@ def _print_qrcode(req: Dict[str, str]) -> None:
# fmt:on


def _perform_device_grant(
logger: logging.Logger,
device_url: str,
token_url: str,
client_id: str,
client_secret: Optional[str] = None,
scopes: Optional[List[str]] = None,
) -> str:
args = {
'client_id': client_id,
'scope': 'offline_access ' + (' '.join(scopes) if scopes else ''),
}
if client_secret:
args['client_secret'] = client_secret

try:
r = requests.post(device_url, data=args)
r.raise_for_status()
req = r.json()
except requests.exceptions.HTTPError as exc:
logger.debug('%r', exc.response.text)
try:
req = exc.response.json()
except Exception:
req = {}
error = req.get('error', '')
raise RuntimeError(f'Device authorization failed: {error}') from exc
except Exception as exc:
raise RuntimeError('Device authorization failed') from exc

logger.debug('Device auth in progress')

_print_qrcode(req)

args = {
'grant_type': 'urn:ietf:params:oauth:grant-type:device_code',
'device_code': req['device_code'],
'client_id': client_id,
}
if client_secret:
args['client_secret'] = client_secret

sleep_time = int(req.get('interval', 5))
while True:
time.sleep(sleep_time)
class CommonDeviceGrant(PKCEMixin):
def perform_device_grant(
self,
logger: logging.Logger,
device_url: str,
token_url: str,
client_id: str,
client_secret: Optional[str] = None,
scopes: Optional[List[str]] = None,
) -> str:
args = {
'client_id': client_id,
'scope': 'offline_access ' + (' '.join(scopes) if scopes else ''),
}
if client_secret:
args['client_secret'] = client_secret
else:
code_challenge = self.create_pkce_challenge()
args['code_challenge'] = code_challenge
args['code_challenge_method'] = 'S256'

try:
r = requests.post(token_url, data=args)
r = requests.post(device_url, data=args)
r.raise_for_status()
req = r.json()
except requests.exceptions.HTTPError as exc:
Expand All @@ -114,17 +91,49 @@ def _perform_device_grant(
except Exception:
req = {}
error = req.get('error', '')
if error == 'authorization_pending':
continue
elif error == 'slow_down':
sleep_time += 5
continue
raise RuntimeError(f'Device authorization failed: {error}') from exc
except Exception as exc:
raise RuntimeError('Device authorization failed') from exc
break

return req['refresh_token']
logger.debug('Device auth in progress')

_print_qrcode(req)

args = {
'grant_type': 'urn:ietf:params:oauth:grant-type:device_code',
'device_code': req['device_code'],
'client_id': client_id,
}
if client_secret:
args['client_secret'] = client_secret
else:
args['code_verifier'] = self.get_pkce_verifier(code_challenge)

sleep_time = int(req.get('interval', 5))
while True:
time.sleep(sleep_time)
try:
r = requests.post(token_url, data=args)
r.raise_for_status()
req = r.json()
except requests.exceptions.HTTPError as exc:
logger.debug('%r', exc.response.text)
try:
req = exc.response.json()
except Exception:
req = {}
error = req.get('error', '')
if error == 'authorization_pending':
continue
elif error == 'slow_down':
sleep_time += 5
continue
raise RuntimeError(f'Device authorization failed: {error}') from exc
except Exception as exc:
raise RuntimeError('Device authorization failed') from exc
break

return req['refresh_token']


def DeviceGrantAuth(
Expand Down Expand Up @@ -153,7 +162,8 @@ def DeviceGrantAuth(
raise RuntimeError('Device grant not supported by server')
endpoint: str = auth.provider_info['device_authorization_endpoint'] # type: ignore

refresh_token = _perform_device_grant(
device = CommonDeviceGrant()
refresh_token = device.perform_device_grant(
logger, endpoint, auth.token_url, client_id, client_secret, scopes
)

Expand Down Expand Up @@ -231,7 +241,8 @@ def update_func(access, refresh):
raise RuntimeError('Device grant not supported by server')
endpoint: str = auth.provider_info['device_authorization_endpoint'] # type: ignore

refresh_token = _perform_device_grant(
device = CommonDeviceGrant()
refresh_token = device.perform_device_grant(
logger, endpoint, auth.token_url, client_id, client_secret, scopes
)

Expand Down
22 changes: 3 additions & 19 deletions rest_tools/server/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,26 @@

import base64
import functools
import hashlib
import hmac
import json
import logging
import secrets
import time
import urllib.parse
from collections import defaultdict
from typing import Any, Dict, MutableMapping, Union
from typing import Any, Dict, Union

import rest_tools
import tornado.escape
import tornado.gen
import tornado.httpclient
import tornado.httputil
import tornado.web
from cachetools import TTLCache
from tornado.auth import OAuth2Mixin

from .. import telemetry as wtt
from ..utils.auth import Auth, OpenIDAuth
from ..utils.json_util import json_decode
from ..utils.pkce import PKCEMixin
from .decorators import catch_error
from .stats import RouteStats

Expand Down Expand Up @@ -324,15 +322,14 @@ def clear_tokens(self):
self.clear_cookie('user_info')


class OpenIDLoginHandler(OpenIDCookieHandlerMixin, OAuth2Mixin, RestHandler):
class OpenIDLoginHandler(OpenIDCookieHandlerMixin, OAuth2Mixin, PKCEMixin, RestHandler):
"""Handle OpenID Connect logins.
Should be combined with an appropriate mixin to store the token(s).
`OpenIDCookieHandlerMixin` is used by default, but can be overridden.
Requires the `login_url` application setting to be a full url.
"""
_pkcs_challenges: MutableMapping[str, str] = TTLCache(maxsize=10000, ttl=3600)

def initialize(self, oauth_client_id, oauth_client_secret, oauth_client_scope=None, **kwargs):
super().initialize(**kwargs)
Expand All @@ -354,19 +351,6 @@ def initialize(self, oauth_client_id, oauth_client_secret, oauth_client_scope=No
scopes.add('offline_access')
self.oauth_client_scope = list(scopes)

@classmethod
def create_pkce_challenge(cls) -> str:
code_verifier = secrets.token_urlsafe(64)
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode('utf-8')).digest()).decode('utf-8').split('=')[0]
cls._pkcs_challenges[code_challenge] = code_verifier
return code_challenge

@classmethod
def get_pkce_verifier(cls, challenge: str) -> str:
if challenge in cls._pkcs_challenges:
return cls._pkcs_challenges[challenge]
raise KeyError('invalid pkce challenge')

async def get_authenticated_user(
self, redirect_uri: str, code: str, state: Dict[str, Any]
) -> Dict[str, Any]:
Expand Down
23 changes: 23 additions & 0 deletions rest_tools/utils/pkce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import base64
import hashlib
import secrets
from typing import MutableMapping

from cachetools import TTLCache


class PKCEMixin:
_pkcs_challenges: MutableMapping[str, str] = TTLCache(maxsize=10000, ttl=3600)

@classmethod
def create_pkce_challenge(cls) -> str:
code_verifier = secrets.token_urlsafe(64)
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode('utf-8')).digest()).decode('utf-8').split('=')[0]
cls._pkcs_challenges[code_challenge] = code_verifier
return code_challenge

@classmethod
def get_pkce_verifier(cls, challenge: str) -> str:
if challenge in cls._pkcs_challenges:
return cls._pkcs_challenges[challenge]
raise KeyError('invalid pkce challenge')

0 comments on commit e691900

Please sign in to comment.