diff --git a/dev-requirements.txt b/dev-requirements.txt index 68b3ecac7..f66a9c6fc 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,68 +1,67 @@ accessible-pygments==0.0.5 -aioca==1.7 +aioca==1.8 aiofiles==24.1.0 -aiohappyeyeballs==2.4.0 -aiohttp==3.10.5 +aiohappyeyeballs==2.4.3 +aiohttp==3.10.10 aiosignal==1.3.1 alabaster==1.0.0 annotated-types==0.7.0 -anyio==4.4.0 +anyio==4.6.2.post1 appdirs==1.4.4 asciitree==0.3.3 asttokens==2.4.1 -async-timeout==4.0.3 attrs==24.2.0 babel==2.16.0 beautifulsoup4==4.12.3 bidict==0.23.1 -bluesky==1.13.0a4 +bluesky==1.13 bluesky-kafka==0.10.0 bluesky-live==0.0.8 bluesky-stomp==0.1.2 boltons==24.0.0 -bump-pydantic==0.8.0 cachetools==5.5.0 caproto==1.1.1 certifi==2024.8.30 +cffi==1.17.1 cfgv==3.4.0 -charset-normalizer==3.3.2 +charset-normalizer==3.4.0 click==8.1.7 -cloudpickle==3.0.0 +cloudpickle==3.1.0 colorama==0.4.6 colorlog==6.8.2 comm==0.2.2 compress-pickle==2.1.0 -confluent-kafka==2.5.3 +confluent-kafka==2.6.0 contourpy==1.3.0 -copier==9.3.1 -coverage==7.6.1 +copier==9.4.0 +coverage==7.6.3 +cryptography==43.0.1 cycler==0.12.1 -dask==2024.9.0 +dask==2024.9.1 databroker==1.2.5 dataclasses-json==0.6.7 decorator==5.1.1 +deepdiff==8.0.1 deepmerge==2.0 -distlib==0.3.8 +Deprecated==1.2.14 +distlib==0.3.9 dls-bluesky-core==0.0.4 -dls-dodal==1.31.1 -dnspython==2.6.1 +dls-dodal==1.33.0 +dnspython==2.7.0 docopt==0.6.2 doct==1.1.0 docutils==0.21.2 dunamai==1.22.0 -email_validator==2.2.0 entrypoints==0.4 -epicscorelibs==7.0.7.99.0.2 +epicscorelibs==7.0.7.99.1.1 event-model==1.21.0 -exceptiongroup==1.2.2 executing==2.1.0 -fastapi==0.114.2 -fastapi-cli==0.0.5 +fastapi==0.115.2 fasteners==0.19 -filelock==3.16.0 +filelock==3.16.1 flexcache==0.3 flexparser==0.3.1 -fonttools==4.53.1 +fonttools==4.54.1 frozenlist==1.4.1 fsspec==2024.9.0 funcy==2.0 @@ -70,38 +69,34 @@ gitdb==4.0.11 GitPython==3.1.43 graypy==2.1.0 h11==0.14.0 -h5py==3.11.0 +h5py==3.12.1 HeapDict==1.0.1 historydict==1.2.6 -httpcore==1.0.5 -httptools==0.6.1 +httpcore==1.0.6 httpx==0.27.2 -humanize==4.10.0 +humanize==4.11.0 identify==2.6.1 idna==3.10 -imageio==2.35.1 +imageio==2.36.0 imagesize==1.4.1 -importlib_metadata==8.5.0 +importlib_metadata==8.4.0 importlib_resources==6.4.5 iniconfig==2.0.0 intake==0.6.4 ipython==8.18.0 ipywidgets==8.1.5 -itsdangerous==2.2.0 jedi==0.19.1 Jinja2==3.1.4 jinja2-ansible-filters==1.3.2 jsonschema==4.23.0 -jsonschema-specifications==2023.12.1 +jsonschema-specifications==2024.10.1 jupyterlab_widgets==3.0.13 kiwisolver==1.4.7 ldap3==2.9.1 -libcst==1.4.0 -livereload==2.7.0 locket==1.0.0 lz4==4.3.3 markdown-it-py==3.0.0 -MarkupSafe==2.1.5 +MarkupSafe==3.0.1 marshmallow==3.22.0 matplotlib==3.9.2 matplotlib-inline==0.1.7 @@ -113,22 +108,24 @@ mongoquery==1.4.2 msgpack==1.1.0 msgpack-numpy==0.4.8 multidict==6.1.0 -mypy==1.11.2 +mypy==1.12.0 mypy-extensions==1.0.0 myst-parser==4.0.0 -networkx==3.3 +networkx==3.4.1 nodeenv==1.9.1 nose2==0.15.1 -nslsii==0.10.3 -numcodecs==0.13.0 +nslsii==0.10.5 +numcodecs==0.13.1 numpy==1.26.4 opencv-python-headless==4.10.0.84 +opentelemetry-api==1.27.0 ophyd==1.9.0 -ophyd-async==0.5.2 +ophyd-async==0.6.0 +orderly-set==5.2.2 orjson==3.10.7 -p4p==4.1.12 +p4p==4.2.0 packaging==24.1 -pandas==2.2.2 +pandas==2.2.3 parso==0.8.4 partd==1.4.2 pathlib2==2.3.7.post1 @@ -136,67 +133,66 @@ pathspec==0.12.1 pexpect==4.9.0 picobox==4.0.0 pika==1.3.2 -pillow==10.4.0 +pillow==11.0.0 PIMS==0.7 Pint==0.24.3 -pipdeptree==2.23.3 -platformdirs==4.3.3 +pipdeptree==2.23.4 +platformdirs==4.3.6 pluggy==1.5.0 -plumbum==1.8.3 +plumbum==1.9.0 ply==3.11 -pre-commit==3.8.0 +pre_commit==4.0.1 prettytable==3.11.0 prompt-toolkit==3.0.36 +propcache==0.2.0 psutil==6.0.0 ptyprocess==0.7.0 pure_eval==0.2.3 -pvxslibs==1.3.1 +pvxslibs==1.3.2 py==1.11.0 pyasn1==0.6.1 -pycryptodome==3.20.0 -pydantic==2.9.1 -pydantic-extra-types==2.9.0 +pycparser==2.22 +pycryptodome==3.21.0 +pydantic==2.9.2 pydantic-settings==2.5.2 -pydantic_core==2.23.3 +pydantic_core==2.23.4 pydantic_numpy==5.0.2 pydata-sphinx-theme==0.15.4 pyepics==3.5.7 Pygments==2.18.0 -pymongo==4.8.0 +PyJWT==2.9.0 +pymongo==4.10.1 pyOlog==4.5.0 -pyparsing==3.1.4 +pyparsing==3.2.0 pytest==8.3.3 pytest-asyncio==0.24.0 pytest-cov==5.0.0 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 -python-multipart==0.0.9 +python-multipart==0.0.12 pytz==2024.2 PyYAML==6.0.2 -pyyaml-include==2.1 questionary==2.0.1 -redis==5.0.8 -redis-json-dict==0.2.0 +redis==5.1.1 +redis-json-dict==0.2.1 referencing==0.35.1 requests==2.32.3 responses==0.25.3 -rich==13.7.1 rpds-py==0.20.0 ruamel.yaml==0.18.6 ruamel.yaml.clib==0.2.8 -ruff==0.6.5 +ruff==0.6.9 scanspec==0.7.2 semver==3.0.2 setuptools-dso==2.11 -shellingham==1.5.4 six==1.16.0 slicerator==1.1.0 smmap==5.0.1 sniffio==1.3.1 snowballstemmer==2.2.0 soupsieve==2.6 -Sphinx==8.0.2 -sphinx-autobuild==2024.9.3 +Sphinx==8.1.3 +sphinx-autobuild==2024.10.3 sphinx-click==6.0.0 sphinx-copybutton==0.5.2 sphinx_design==0.6.1 @@ -210,43 +206,39 @@ sphinxcontrib-openapi==0.8.4 sphinxcontrib-qthelp==2.0.0 sphinxcontrib-serializinghtml==2.0.0 stack-data==0.6.3 -starlette==0.38.5 +starlette==0.40.0 stomp-py==8.1.2 suitcase-mongo==0.6.0 suitcase-msgpack==0.3.0 suitcase-utils==0.5.4 super-state-machine==2.0.2 -tifffile==2024.8.30 -tomli==2.0.1 -toolz==0.12.1 -tornado==6.4.1 +tifffile==2024.9.20 +toolz==1.0.0 tox==3.28.0 tox-direct==0.4 tqdm==4.66.5 traitlets==5.14.3 -typer==0.12.4 types-mock==5.1.0.20240425 types-PyYAML==6.0.12.20240917 -types-requests==2.32.0.20240914 +types-requests==2.32.0.20241016 types-urllib3==1.26.25.14 typing-inspect==0.9.0 typing_extensions==4.12.2 -tzdata==2024.1 +tzdata==2024.2 tzlocal==5.2 -ujson==5.10.0 urllib3==2.2.3 -uvicorn==0.30.6 -uvloop==0.19.0 -virtualenv==20.26.4 +uvicorn==0.32.0 +virtualenv==20.26.6 watchfiles==0.24.0 wcwidth==0.2.13 websocket-client==1.8.0 -websockets==13.0.1 +websockets==13.1 widgetsnbextension==4.0.13 workflows==2.27 +wrapt==1.16.0 xarray==2024.9.0 -yarl==1.11.1 +yarl==1.15.3 zarr==2.18.3 zict==2.2.0 zipp==3.20.2 -zocalo==1.1.0 +zocalo==1.1.1 diff --git a/pyproject.toml b/pyproject.toml index ea16df3a0..a500631ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,9 @@ dependencies = [ "super-state-machine", # See GH issue 553 "GitPython", "bluesky-stomp>=0.1.2", + "pyjwt", + "python-multipart", + "cryptography" ] dynamic = ["version"] license.file = "LICENSE" diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index c0069214d..369ac8883 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -17,8 +17,13 @@ from blueapi.client.client import BlueapiClient from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient from blueapi.client.rest import BlueskyRemoteControlError -from blueapi.config import ApplicationConfig, ConfigLoader +from blueapi.config import ( + ApplicationConfig, + CLIClientConfig, + ConfigLoader, +) from blueapi.core import DataEvent +from blueapi.service.authentication import CliTokenManager, SessionManager from blueapi.service.main import start from blueapi.service.openapi import ( DOCS_SCHEMA_LOCATION, @@ -253,7 +258,6 @@ def pause(obj: dict, defer: bool = False) -> None: @click.pass_obj def resume(obj: dict) -> None: """Resume the execution of the current task""" - client: BlueapiClient = obj["client"] pprint(client.resume()) @@ -329,3 +333,35 @@ def scratch(obj: dict) -> None: setup_scratch(config.scratch) else: raise KeyError("No scratch config supplied") + + +@main.command(name="login") +@click.pass_obj +def login(obj: dict) -> None: + config: ApplicationConfig = obj["config"] + if isinstance(config.oauth_client, CLIClientConfig) and config.oauth_server: + print("Logging in") + auth: SessionManager = SessionManager( + server_config=config.oauth_server, + client_config=config.oauth_client, + token_manager=CliTokenManager(Path(config.oauth_client.token_file_path)), + ) + auth.start_device_flow() + else: + print("Please provide configuration to login!") + + +@main.command(name="logout") +@click.pass_obj +def logout(obj: dict) -> None: + config: ApplicationConfig = obj["config"] + if isinstance(config.oauth_client, CLIClientConfig) and config.oauth_server: + auth: SessionManager = SessionManager( + server_config=config.oauth_server, + client_config=config.oauth_client, + token_manager=CliTokenManager(Path(config.oauth_client.token_file_path)), + ) + auth.logout() + print("Logged out") + else: + print("Please provide configuration to logout!") diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 4468bf10e..f3752cb3c 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -6,6 +6,7 @@ from blueapi.config import ApplicationConfig from blueapi.core.bluesky_types import DataEvent +from blueapi.service.authentication import SessionManager from blueapi.service.model import ( DeviceModel, DeviceResponse, @@ -39,7 +40,10 @@ def __init__( @classmethod def from_config(cls, config: ApplicationConfig) -> "BlueapiClient": - rest = BlueapiRestClient(config.api) + rest: BlueapiRestClient = BlueapiRestClient( + config.api, + SessionManager.from_config(config.oauth_server, config.oauth_client), + ) if config.stomp is not None: template = StompClient.for_broker( broker=Broker( diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 2ec60a1c7..adeb9360d 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -1,10 +1,12 @@ from collections.abc import Callable, Mapping from typing import Any, Literal, TypeVar +import jwt import requests from pydantic import TypeAdapter from blueapi.config import RestConfig +from blueapi.service.authentication import SessionManager from blueapi.service.model import ( DeviceModel, DeviceResponse, @@ -38,8 +40,13 @@ def _exception(response: requests.Response) -> Exception | None: class BlueapiRestClient: _config: RestConfig - def __init__(self, config: RestConfig | None = None) -> None: + def __init__( + self, + config: RestConfig | None = None, + session_manager: SessionManager | None = None, + ) -> None: self._config = config or RestConfig() + self._session_manager: SessionManager | None = session_manager def get_plans(self) -> PlanResponse: return self._request_and_deserialize("/plans", PlanResponse) @@ -127,10 +134,20 @@ def _request_and_deserialize( get_exception: Callable[[requests.Response], Exception | None] = _exception, ) -> T: url = self._url(suffix) + headers: dict[str, str] = { + "content-type": "application/json; charset=UTF-8", + } + if self._session_manager and (token := self._session_manager.get_token()): + try: + self._session_manager.authenticator.decode_jwt(token["access_token"]) + headers["Authorization"] = f"Bearer {token['access_token']}" + except jwt.ExpiredSignatureError: + if token := self._session_manager.refresh_auth_token(): + headers["Authorization"] = f"Bearer {token['access_token']}" if data: - response = requests.request(method, url, json=data) + response = requests.request(method, url, json=data, headers=headers) else: - response = requests.request(method, url) + response = requests.request(method, url, headers=headers) exception = get_exception(response) if exception is not None: raise exception diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 3502590ba..554d819b5 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -3,9 +3,15 @@ from pathlib import Path from typing import Any, Generic, Literal, TypeVar +import requests import yaml from bluesky_stomp.models import BasicAuthentication -from pydantic import BaseModel, Field, TypeAdapter, ValidationError +from pydantic import ( + BaseModel, + Field, + TypeAdapter, + ValidationError, +) from blueapi.utils import BlueapiBaseModel, InvalidConfigError @@ -77,6 +83,63 @@ class ScratchConfig(BlueapiBaseModel): repositories: list[ScratchRepository] = Field(default_factory=list) +class OAuthServerConfig(BlueapiBaseModel): + oidc_config_url: str = Field( + description="URL to fetch OIDC config from the provider" + ) + # Initialized post-init + device_auth_url: str = "" + pkce_auth_url: str = "" + token_url: str = "" + issuer: str = "" + jwks_uri: str = "" + logout_url: str = "" + signing_algos: list[str] = [] + + def model_post_init(self, __context: Any) -> None: + response: requests.Response = requests.get(self.oidc_config_url) + response.raise_for_status() + config_data: dict[str, Any] = response.json() + + device_auth_url: str | None = config_data.get("device_authorization_endpoint") + pkce_auth_url: str | None = config_data.get("authorization_endpoint") + token_url: str | None = config_data.get("token_endpoint") + issuer: str | None = config_data.get("issuer") + jwks_uri: str | None = config_data.get("jwks_uri") + logout_url: str | None = config_data.get("end_session_endpoint") + signing_algos: list[str] | None = config_data.get( + "id_token_signing_alg_values_supported" + ) + # post this we need to check if all the values are present + if ( + device_auth_url + and pkce_auth_url + and token_url + and issuer + and jwks_uri + and logout_url + and signing_algos + ): + self.device_auth_url = device_auth_url + self.pkce_auth_url = pkce_auth_url + self.token_url = token_url + self.issuer = issuer + self.jwks_uri = jwks_uri + self.logout_url = logout_url + self.signing_algos = signing_algos + else: + raise ValueError("OIDC config is missing required fields") + + +class OAuthClientConfig(BlueapiBaseModel): + client_id: str = Field(description="Client ID") + client_audience: str = Field(description="Client Audience") + + +class CLIClientConfig(OAuthClientConfig): + token_file_path: Path = Path("~/token") + + class ApplicationConfig(BlueapiBaseModel): """ Config for the worker application as a whole. Root of @@ -88,6 +151,8 @@ class ApplicationConfig(BlueapiBaseModel): logging: LoggingConfig = Field(default_factory=LoggingConfig) api: RestConfig = Field(default_factory=RestConfig) scratch: ScratchConfig | None = None + oauth_server: OAuthServerConfig | None = None + oauth_client: OAuthClientConfig | CLIClientConfig | None = None def __eq__(self, other: object) -> bool: if isinstance(other, ApplicationConfig): diff --git a/src/blueapi/service/__init__.py b/src/blueapi/service/__init__.py index 7c2fa404c..ae9ffaf79 100644 --- a/src/blueapi/service/__init__.py +++ b/src/blueapi/service/__init__.py @@ -1,3 +1,4 @@ +from .authentication import Authenticator, SessionManager from .model import DeviceModel, PlanModel -__all__ = ["PlanModel", "DeviceModel"] +__all__ = ["PlanModel", "DeviceModel", "Authenticator", "SessionManager"] diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py new file mode 100644 index 000000000..2bb96e8d0 --- /dev/null +++ b/src/blueapi/service/authentication.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import base64 +import json +import os +import time +from abc import ABC, abstractmethod +from enum import Enum +from http import HTTPStatus +from pathlib import Path +from typing import Any + +import jwt +import requests + +from blueapi.config import ( + CLIClientConfig, + OAuthClientConfig, + OAuthServerConfig, +) + + +class AuthenticationType(Enum): + DEVICE = "device" + PKCE = "pkce" + + +class Authenticator: + def __init__( + self, + server_config: OAuthServerConfig, + client_config: OAuthClientConfig, + ): + self._server_config: OAuthServerConfig = server_config + self._client_config: OAuthClientConfig = client_config + + def decode_jwt(self, token: str) -> dict[str, str]: + signing_key = jwt.PyJWKClient( + self._server_config.jwks_uri + ).get_signing_key_from_jwt(token) + decode: dict[str, str] = jwt.decode( + token, + signing_key.key, + algorithms=self._server_config.signing_algos, + verify=True, + audience=self._client_config.client_audience, + issuer=self._server_config.issuer, + ) + return decode + + def print_user_info(self, token: str) -> None: + decode: dict[str, str] = self.decode_jwt(token) + print(f'Logged in as {decode.get("name")} with fed-id {decode.get("fedid")}') + + +class TokenManager(ABC): + @abstractmethod + def save_token(self, token: dict[str, Any]) -> None: ... + @abstractmethod + def load_token(token) -> dict[str, Any] | None: ... + @abstractmethod + def delete_token(self): ... + + +class CliTokenManager(TokenManager): + def __init__(self, token_file_path: Path) -> None: + self._token_file_path: Path = token_file_path + + def _file_path(self) -> str: + return os.path.expanduser(self._token_file_path) + + def save_token(self, token: dict[str, Any]) -> None: + token_json: str = json.dumps(token) + token_bytes: bytes = token_json.encode("utf-8") + token_base64: bytes = base64.b64encode(token_bytes) + with open(self._file_path(), "wb") as token_file: + token_file.write(token_base64) + + def load_token(self) -> dict[str, Any] | None: + file_path = self._file_path() + if not os.path.exists(file_path): + return None + with open(file_path, "rb") as token_file: + token_base64: bytes = token_file.read() + token_bytes: bytes = base64.b64decode(token_base64) + token_json: str = token_bytes.decode("utf-8") + return json.loads(token_json) + + def delete_token(self) -> None: + Path(self._file_path()).unlink(missing_ok=True) + + +class SessionManager: + def __init__( + self, + server_config: OAuthServerConfig, + client_config: OAuthClientConfig, + token_manager: TokenManager, + ) -> None: + self._server_config: OAuthServerConfig = server_config + self._client_config: OAuthClientConfig = client_config + self.authenticator: Authenticator = Authenticator(server_config, client_config) + self._token_manager = token_manager + + @classmethod + def from_config( + cls, + server_config: OAuthServerConfig | None, + client_config: OAuthClientConfig | None, + ) -> SessionManager | None: + if server_config and client_config: + if isinstance(client_config, CLIClientConfig): + return SessionManager( + server_config, + client_config, + CliTokenManager(Path(client_config.token_file_path)), + ) + return None + + def get_token(self) -> dict[str, Any] | None: + return self._token_manager.load_token() + + def logout(self) -> None: + self._token_manager.delete_token() + + def refresh_auth_token(self) -> dict[str, Any] | None: + if token := self._token_manager.load_token(): + response = requests.post( + self._server_config.token_url, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "client_id": self._client_config.client_id, + "grant_type": "refresh_token", + "refresh_token": token["refresh_token"], + }, + ) + if response.status_code == HTTPStatus.OK: + token = response.json() + if token: + self._token_manager.save_token(token) + return token + return None + + def poll_for_token( + self, device_code: str, timeout: float = 30, polling_interval: float = 0.5 + ) -> dict[str, Any]: + too_late: float = time.time() + timeout + while time.time() < too_late: + response = requests.post( + self._server_config.token_url, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + "device_code": device_code, + "client_id": self._client_config.client_id, + }, + ) + if response.status_code == HTTPStatus.OK: + return response.json() + if response.status_code == HTTPStatus.BAD_REQUEST: + polling_interval += 0.5 + time.sleep(polling_interval) + + raise TimeoutError("Polling timed out") + + def start_device_flow(self) -> None: + if token := self._token_manager.load_token(): + try: + access_token_info: dict[str, Any] = self.authenticator.decode_jwt( + token["access_token"] + ) + if access_token_info: + self.authenticator.print_user_info(token["access_token"]) + return + except jwt.ExpiredSignatureError: + if token := self.refresh_auth_token(): + self.authenticator.print_user_info(token["access_token"]) + return + + response: requests.Response = requests.post( + self._server_config.device_auth_url, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "client_id": self._client_config.client_id, + "scope": "openid profile offline_access", + "audience": self._client_config.client_audience, + }, + ) + + if response.status_code == HTTPStatus.OK: + response_json: Any = response.json() + device_code: str = response_json.get("device_code") + print( + "Please login from this URL:- " + f"{response_json['verification_uri_complete']}" + ) + auth_token_json: dict[str, Any] = self.poll_for_token(device_code) + decoded_token: dict[str, Any] = self.authenticator.decode_jwt( + auth_token_json["access_token"] + ) + if decoded_token: + self._token_manager.save_token(auth_token_json) + self.authenticator.print_user_info(auth_token_json["access_token"]) + else: + print("Failed to login") diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 0a5c25070..6ac4ca136 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -1,4 +1,6 @@ +import os from contextlib import asynccontextmanager +from typing import Any from fastapi import ( BackgroundTasks, @@ -10,12 +12,15 @@ Response, status, ) +from fastapi.security import OAuth2AuthorizationCodeBearer from pydantic import ValidationError from starlette.responses import JSONResponse from super_state_machine.errors import TransitionError from blueapi.config import ApplicationConfig from blueapi.service import interface +from blueapi.service.authentication import Authenticator +from blueapi.service.runner import WorkerDispatcher from blueapi.worker import Task, TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum @@ -30,11 +35,18 @@ TasksListResponse, WorkerTask, ) -from .runner import WorkerDispatcher REST_API_VERSION = "0.0.5" RUNNER: WorkerDispatcher | None = None +AUTHENTICATOR: Authenticator | None = None +SWAGGER_CONFIG: dict[str, Any] | None = None +_PKCE_AUTHENTICATION_URL: str = "PKCE_AUTHENTICATION_URL" +_TOKEN_URL: str = "TOKEN_URL" +_PKCE_CLIENT_ID: str = "PKCE_CLIENT_ID" +_PKCE_CLIENT_SECRET: str = "PKCE_CLIENT_SECRET" +AUTH_URL: str = os.getenv(_PKCE_AUTHENTICATION_URL, "") +TOKEN_URL: str = os.getenv(_TOKEN_URL, "") def _runner() -> WorkerDispatcher: @@ -68,11 +80,42 @@ async def lifespan(app: FastAPI): teardown_runner() +oauth_scheme = OAuth2AuthorizationCodeBearer( + authorizationUrl=AUTH_URL, + tokenUrl=TOKEN_URL, + refreshUrl=TOKEN_URL, +) + + +def verify_access_token(access_token: str = Depends(oauth_scheme)): + if AUTHENTICATOR: + try: + decoded_token: dict[str, Any] = AUTHENTICATOR.decode_jwt(access_token) + if not decoded_token: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + except Exception as exception: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + ) from exception + + +if TOKEN_URL and AUTH_URL: + dependencies = [Depends(verify_access_token)] +else: + dependencies = [] + app = FastAPI( docs_url="/docs", title="BlueAPI Control", lifespan=lifespan, version=REST_API_VERSION, + swagger_ui_init_oauth={ + "clientId": os.getenv(_PKCE_CLIENT_ID), + "clientSecret": os.getenv(_PKCE_CLIENT_SECRET), + "usePkceWithAuthorizationCodeGrant": True, + "scopes": "openid profile offline_access", + }, + dependencies=dependencies, ) @@ -98,7 +141,6 @@ async def delete_environment( runner: WorkerDispatcher = Depends(_runner), ) -> EnvironmentResponse: """Delete the current environment, causing internal components to be reloaded.""" - if runner.state.initialized or runner.state.error_message is not None: background_tasks.add_task(runner.reload) return EnvironmentResponse(initialized=False) @@ -333,7 +375,12 @@ def set_state( def start(config: ApplicationConfig): import uvicorn + global AUTHENTICATOR app.state.config = config + if config.oauth_client and config.oauth_server: + AUTHENTICATOR = Authenticator( + server_config=config.oauth_server, client_config=config.oauth_client + ) uvicorn.run(app, host=config.api.host, port=config.api.port) diff --git a/tests/system_tests/Dockerfile b/tests/system_tests/Dockerfile new file mode 100644 index 000000000..74077f407 --- /dev/null +++ b/tests/system_tests/Dockerfile @@ -0,0 +1,2 @@ +FROM docker.io/rabbitmq:management +RUN rabbitmq-plugins enable --offline rabbitmq_stomp \ No newline at end of file diff --git a/tests/system_tests/plans.json b/tests/system_tests/plans.json index b5d0f76e2..1d2c552e6 100644 --- a/tests/system_tests/plans.json +++ b/tests/system_tests/plans.json @@ -3,7 +3,7 @@ { "name": "count", "description": "\n Take `n` readings from a device\n\n Args:\n detectors (Set[Readable]): Readable devices to read\n num (int, optional): Number of readings to take. Defaults to 1.\n delay (Optional[Union[float, List[float]]], optional): Delay between readings.\n Defaults to None.\n metadata (Optional[Mapping[str, Any]], optional): Key-value metadata to include\n in exported data.\n Defaults to None.\n\n Returns:\n MsgGenerator: _description_\n\n Yields:\n Iterator[MsgGenerator]: _description_\n ", - "parameter_schema": { + "schema": { "additionalProperties": false, "properties": { "detectors": { @@ -57,7 +57,7 @@ { "name": "move", "description": "\n Move a device, wrapper for `bp.mv`.\n\n Args:\n moves (Mapping[Movable, Any]): Mapping of Movables to target positions\n group (Optional[Group], optional): The message group to associate with the\n setting, for sequencing. Defaults to None.\n\n Returns:\n MsgGenerator: Plan\n\n Yields:\n Iterator[MsgGenerator]: Bluesky messages\n ", - "parameter_schema": { + "schema": { "additionalProperties": false, "properties": { "moves": { @@ -86,7 +86,7 @@ { "name": "stp_snapshot", "description": "\n Moves devices for pressure and temperature (defaults fetched from the context)\n and captures a single frame from a collection of devices\n\n Args:\n detectors (List[Readable]): A list of devices to read while the sample is at STP\n temperature (Optional[Movable]): A device controlling temperature of the sample,\n defaults to fetching a device name \"sample_temperature\" from the context\n pressure (Optional[Movable]): A device controlling pressure on the sample,\n defaults to fetching a device name \"sample_pressure\" from the context\n Returns:\n MsgGenerator: Plan\n Yields:\n Iterator[MsgGenerator]: Bluesky messages\n ", - "parameter_schema": { + "schema": { "additionalProperties": false, "properties": { "detectors": { @@ -115,7 +115,7 @@ { "name": "scan", "description": "\n Scan wrapping `bp.scan_nd`\n\n Args:\n detectors: Set of readable devices, will take a reading at\n each point\n axes_to_move: All axes involved in this scan, names and\n objects\n spec: ScanSpec modelling the path of the scan\n metadata: Key-value metadata to include\n in exported data, defaults to\n None.\n\n Returns:\n MsgGenerator: Plan\n\n Yields:\n Iterator[MsgGenerator]: Bluesky messages\n ", - "parameter_schema": { + "schema": { "$defs": { "Circle": { "additionalProperties": false, @@ -170,19 +170,11 @@ "description": "Abstract baseclass for a combination of two regions, left and right.", "properties": { "left": { - "allOf": [ - { - "$ref": "#/$defs/Region" - } - ], + "$ref": "#/$defs/Region", "description": "The left-hand Region to combine" }, "right": { - "allOf": [ - { - "$ref": "#/$defs/Region" - } - ], + "$ref": "#/$defs/Region", "description": "The right-hand Region to combine" }, "type": { @@ -207,19 +199,11 @@ "description": "Concatenate two Specs together, running one after the other.\n\nEach Dimension of left and right must contain the same axes. Typically\nformed using `Spec.concat`.\n\n.. example_spec::\n\n from scanspec.specs import Line\n\n spec = Line(\"x\", 1, 3, 3).concat(Line(\"x\", 4, 5, 5))", "properties": { "left": { - "allOf": [ - { - "$ref": "#/$defs/Spec" - } - ], + "$ref": "#/$defs/Spec", "description": "The left-hand Spec to Concat, midpoints will appear earlier" }, "right": { - "allOf": [ - { - "$ref": "#/$defs/Spec" - } - ], + "$ref": "#/$defs/Spec", "description": "The right-hand Spec to Concat, midpoints will appear later" }, "gap": { @@ -256,19 +240,11 @@ "description": "A point is in DifferenceOf(a, b) if in a and not in b.\n\nTypically created with the ``-`` operator.\n\n>>> r = Range(\"x\", 0.5, 2.5) - Range(\"x\", 1.5, 3.5)\n>>> r.mask({\"x\": np.array([0, 1, 2, 3, 4])})\narray([False, True, False, False, False])", "properties": { "left": { - "allOf": [ - { - "$ref": "#/$defs/Region" - } - ], + "$ref": "#/$defs/Region", "description": "The left-hand Region to combine" }, "right": { - "allOf": [ - { - "$ref": "#/$defs/Region" - } - ], + "$ref": "#/$defs/Region", "description": "The right-hand Region to combine" }, "type": { @@ -354,19 +330,11 @@ "description": "A point is in IntersectionOf(a, b) if in both a and b.\n\nTypically created with the ``&`` operator.\n\n>>> r = Range(\"x\", 0.5, 2.5) & Range(\"x\", 1.5, 3.5)\n>>> r.mask({\"x\": np.array([0, 1, 2, 3, 4])})\narray([False, False, True, False, False])", "properties": { "left": { - "allOf": [ - { - "$ref": "#/$defs/Region" - } - ], + "$ref": "#/$defs/Region", "description": "The left-hand Region to combine" }, "right": { - "allOf": [ - { - "$ref": "#/$defs/Region" - } - ], + "$ref": "#/$defs/Region", "description": "The right-hand Region to combine" }, "type": { @@ -434,19 +402,11 @@ "description": "Restrict Spec to only midpoints that fall inside the given Region.\n\nTypically created with the ``&`` operator. It also pushes down the\n``& | ^ -`` operators to its `Region` to avoid the need for brackets on\ncombinations of Regions.\n\nIf a Region spans multiple Frames objects, they will be squashed together.\n\n.. example_spec::\n\n from scanspec.regions import Circle\n from scanspec.specs import Line\n\n spec = Line(\"y\", 1, 3, 3) * Line(\"x\", 3, 5, 5) & Circle(\"x\", \"y\", 4, 2, 1.2)\n\nSee Also: `why-squash-can-change-path`", "properties": { "spec": { - "allOf": [ - { - "$ref": "#/$defs/Spec" - } - ], + "$ref": "#/$defs/Spec", "description": "The Spec containing the source midpoints" }, "region": { - "allOf": [ - { - "$ref": "#/$defs/Region" - } - ], + "$ref": "#/$defs/Region", "description": "The Region that midpoints will be inside" }, "check_path_changes": { @@ -526,19 +486,11 @@ "description": "Outer product of two Specs, nesting inner within outer.\n\nThis means that inner will run in its entirety at each point in outer.\n\n.. example_spec::\n\n from scanspec.specs import Line\n\n spec = Line(\"y\", 1, 2, 3) * Line(\"x\", 3, 4, 12)", "properties": { "outer": { - "allOf": [ - { - "$ref": "#/$defs/Spec" - } - ], + "$ref": "#/$defs/Spec", "description": "Will be executed once" }, "inner": { - "allOf": [ - { - "$ref": "#/$defs/Spec" - } - ], + "$ref": "#/$defs/Spec", "description": "Will be executed len(outer) times" }, "type": { @@ -739,11 +691,7 @@ "description": "Run the Spec in reverse on every other iteration when nested.\n\nTypically created with the ``~`` operator.\n\n.. example_spec::\n\n from scanspec.specs import Line\n\n spec = Line(\"y\", 1, 3, 3) * ~Line(\"x\", 3, 5, 5)", "properties": { "spec": { - "allOf": [ - { - "$ref": "#/$defs/Spec" - } - ], + "$ref": "#/$defs/Spec", "description": "The Spec to run in reverse every other iteration" }, "type": { @@ -882,11 +830,7 @@ "description": "Squash a stack of Frames together into a single expanded Frames object.\n\nSee Also:\n `why-squash-can-change-path`\n\n.. example_spec::\n\n from scanspec.specs import Line, Squash\n\n spec = Squash(Line(\"y\", 1, 2, 3) * Line(\"x\", 0, 1, 4))", "properties": { "spec": { - "allOf": [ - { - "$ref": "#/$defs/Spec" - } - ], + "$ref": "#/$defs/Spec", "description": "The Spec to squash the dimensions of" }, "check_path_changes": { @@ -953,19 +897,11 @@ "description": "A point is in SymmetricDifferenceOf(a, b) if in either a or b, but not both.\n\nTypically created with the ``^`` operator.\n\n>>> r = Range(\"x\", 0.5, 2.5) ^ Range(\"x\", 1.5, 3.5)\n>>> r.mask({\"x\": np.array([0, 1, 2, 3, 4])})\narray([False, True, False, True, False])", "properties": { "left": { - "allOf": [ - { - "$ref": "#/$defs/Region" - } - ], + "$ref": "#/$defs/Region", "description": "The left-hand Region to combine" }, "right": { - "allOf": [ - { - "$ref": "#/$defs/Region" - } - ], + "$ref": "#/$defs/Region", "description": "The right-hand Region to combine" }, "type": { @@ -990,19 +926,11 @@ "description": "A point is in UnionOf(a, b) if in either a or b.\n\nTypically created with the ``|`` operator\n\n>>> r = Range(\"x\", 0.5, 2.5) | Range(\"x\", 1.5, 3.5)\n>>> r.mask({\"x\": np.array([0, 1, 2, 3, 4])})\narray([False, True, True, True, False])", "properties": { "left": { - "allOf": [ - { - "$ref": "#/$defs/Region" - } - ], + "$ref": "#/$defs/Region", "description": "The left-hand Region to combine" }, "right": { - "allOf": [ - { - "$ref": "#/$defs/Region" - } - ], + "$ref": "#/$defs/Region", "description": "The right-hand Region to combine" }, "type": { @@ -1027,19 +955,11 @@ "description": "Run two Specs in parallel, merging their midpoints together.\n\nTypically formed using `Spec.zip`.\n\nStacks of Frames are merged by:\n\n- If right creates a stack of a single Frames object of size 1, expand it to\n the size of the fastest Frames object created by left\n- Merge individual Frames objects together from fastest to slowest\n\nThis means that Zipping a Spec producing stack [l2, l1] with a Spec\nproducing stack [r1] will assert len(l1)==len(r1), and produce\nstack [l2, l1.zip(r1)].\n\n.. example_spec::\n\n from scanspec.specs import Line\n\n spec = Line(\"z\", 1, 2, 3) * Line(\"y\", 3, 4, 5).zip(Line(\"x\", 4, 5, 5))", "properties": { "left": { - "allOf": [ - { - "$ref": "#/$defs/Spec" - } - ], + "$ref": "#/$defs/Spec", "description": "The left-hand Spec to Zip, will appear earlier in axes" }, "right": { - "allOf": [ - { - "$ref": "#/$defs/Spec" - } - ], + "$ref": "#/$defs/Spec", "description": "The right-hand Spec to Zip, will appear later in axes" }, "type": { @@ -1104,7 +1024,7 @@ { "name": "set_absolute", "description": "\n Set a device, wrapper for `bp.abs_set`.\n\n Args:\n movable (Movable): The device to set\n value (T): The new value\n group (Optional[Group], optional): The message group to associate with the\n setting, for sequencing. Defaults to None.\n wait (bool, optional): The group should wait until all setting is complete\n (e.g. a motor has finished moving). Defaults to False.\n\n Returns:\n MsgGenerator: Plan\n\n Yields:\n Iterator[MsgGenerator]: Bluesky messages\n ", - "parameter_schema": { + "schema": { "additionalProperties": false, "properties": { "movable": { @@ -1141,7 +1061,7 @@ { "name": "set_relative", "description": "\n Change a device, wrapper for `bp.rel_set`.\n\n Args:\n movable (Movable): The device to set\n value (T): The new value\n group (Optional[Group], optional): The message group to associate with the\n setting, for sequencing. Defaults to None.\n wait (bool, optional): The group should wait until all setting is complete\n (e.g. a motor has finished moving). Defaults to False.\n\n Returns:\n MsgGenerator: Plan\n\n Yields:\n Iterator[MsgGenerator]: Bluesky messages\n ", - "parameter_schema": { + "schema": { "additionalProperties": false, "properties": { "movable": { @@ -1178,7 +1098,7 @@ { "name": "move_relative", "description": "\n Move a device relative to its current position, wrapper for `bp.mvr`.\n\n Args:\n moves (Mapping[Movable, Any]): Mapping of Movables to target deltas\n group (Optional[Group], optional): The message group to associate with the\n setting, for sequencing. Defaults to None.\n\n Returns:\n MsgGenerator: Plan\n\n Yields:\n Iterator[MsgGenerator]: Bluesky messages\n ", - "parameter_schema": { + "schema": { "additionalProperties": false, "properties": { "moves": { @@ -1207,7 +1127,7 @@ { "name": "sleep", "description": "\n Suspend all action for a given time, wrapper for `bp.sleep`\n\n Args:\n time (float): Time to wait in seconds\n\n Returns:\n MsgGenerator: Plan\n\n Yields:\n Iterator[MsgGenerator]: Bluesky messages\n ", - "parameter_schema": { + "schema": { "additionalProperties": false, "properties": { "time": { @@ -1225,7 +1145,7 @@ { "name": "wait", "description": "\n Wait for a group status to complete, wrapper for `bp.wait`\n\n Args:\n group (Optional[Group], optional): The name of the group to wait for, defaults\n to None.\n\n Returns:\n MsgGenerator: Plan\n\n Yields:\n Iterator[MsgGenerator]: Bluesky messages\n ", - "parameter_schema": { + "schema": { "additionalProperties": false, "properties": { "group": { diff --git a/tests/system_tests/server.yaml b/tests/system_tests/server.yaml new file mode 100644 index 000000000..faa90c6ef --- /dev/null +++ b/tests/system_tests/server.yaml @@ -0,0 +1,13 @@ +oauth_server: + oidc_config_url: https://example.com + +oauth_client: + client_id: example-client + client_audience: example + +stomp: + host: localhost + port: 61613 + auth: + username: guest + password: guest diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index 4297598bc..607392c8a 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -1,9 +1,9 @@ +import inspect import time from pathlib import Path import pytest -import requests -from fastapi import status +from bluesky_stomp.models import BasicAuthentication from pydantic import TypeAdapter from blueapi.client.client import ( @@ -11,16 +11,20 @@ BlueskyRemoteControlError, ) from blueapi.client.event_bus import AnyEvent -from blueapi.config import ApplicationConfig, StompConfig +from blueapi.config import ( + ApplicationConfig, + CLIClientConfig, + OAuthServerConfig, + StompConfig, +) from blueapi.service.model import ( DeviceResponse, EnvironmentResponse, PlanResponse, TaskResponse, - TasksListResponse, WorkerTask, ) -from blueapi.worker.event import TaskStatus, TaskStatusEnum, WorkerEvent, WorkerState +from blueapi.worker.event import TaskStatus, WorkerEvent, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TrackableTask @@ -29,15 +33,70 @@ _DATA_PATH = Path(__file__).parent +# Step 1: Ensure a message bus that supports stomp is running and available: +# podman build --tag 'rabbitmq_stomp' tests/system_tests/ # get the latest rabbitmq +# podman run -d -p 15672:15672 -p 61613:61613 'rabbitmq_stomp' +# +# Step 2: Set the required environment variables: +# export TOKEN_URL="https://example.com/token" +# export PKCE_AUTHENTICATION_URL="https://example.com/auth" +# +# Step 3: Start the BlueAPI server with valid configuration: +# blueapi -c tests/unit_tests/example_yaml/valid_stomp_config.yaml serve +# +# Step 4: Run the system tests using tox: +# tox -e system-test +# +# Step 5: Optionally tear down the message bus: +# podman container stop 'rabbitmq_stomp' +# Note: The system tests will be executed in the CI pipeline after resolving: +# https://github.com/DiamondLightSource/blueapi/issues/630 + + +@pytest.fixture +def oauth_server() -> OAuthServerConfig: + return OAuthServerConfig(oidc_config_url="https://example.com") + + +@pytest.fixture +def oauth_client() -> CLIClientConfig: + return CLIClientConfig( + client_id="example-client", + client_audience="example", + token_file_path=Path("example-token-file"), + ) + @pytest.fixture -def client() -> BlueapiClient: +def client_without_auth() -> BlueapiClient: return BlueapiClient.from_config(config=ApplicationConfig()) @pytest.fixture -def client_with_stomp() -> BlueapiClient: - return BlueapiClient.from_config(config=ApplicationConfig(stomp=StompConfig())) +def client_with_stomp( + oauth_server: OAuthServerConfig, oauth_client: CLIClientConfig +) -> BlueapiClient: + return BlueapiClient.from_config( + config=ApplicationConfig( + stomp=StompConfig( + auth=BasicAuthentication(username="guest", password="guest") # type: ignore + ), + oauth_server=oauth_server, + oauth_client=oauth_client, + ) + ) + + +@pytest.fixture +def client( + oauth_server: OAuthServerConfig, oauth_client: CLIClientConfig +) -> BlueapiClient: + return BlueapiClient.from_config( + config=ApplicationConfig( + oauth_server=oauth_server, + oauth_client=oauth_client, + ) + ) @pytest.fixture @@ -54,6 +113,38 @@ def expected_devices() -> DeviceResponse: ) +@pytest.fixture +def blueapi_client_get_methods() -> list[str]: + # Get a list of methods that take only one argument (self) + # This will currently return + # ['get_plans', 'get_devices', 'get_state', 'resume', 'get_all_tasks', + # 'get_active_task', 'stop', 'get_environment'] + return [ + method + for method in BlueapiClient.__dict__ + if callable(getattr(BlueapiClient, method)) + and not method.startswith("__") + and len(inspect.signature(getattr(BlueapiClient, method)).parameters) == 1 + and "self" in inspect.signature(getattr(BlueapiClient, method)).parameters + ] + + +@pytest.fixture(autouse=True) +def clean_existing_tasks(client: BlueapiClient): + for task in client.get_all_tasks().tasks: + client.clear_task(task.task_id) + yield + + +def test_cannot_access_endpoints( + client_without_auth: BlueapiClient, blueapi_client_get_methods: list[str] +): + for get_method in blueapi_client_get_methods: + with pytest.raises(BlueskyRemoteControlError) as exception: + getattr(client_without_auth, get_method)() + assert str(exception) == "" + + def test_get_plans(client: BlueapiClient, expected_plans: PlanResponse): assert client.get_plans() == expected_plans @@ -178,15 +269,11 @@ def test_set_state_transition_error(client: BlueapiClient): def test_get_task_by_status(client: BlueapiClient): task_1 = client.create_task(_SIMPLE_TASK) task_2 = client.create_task(_SIMPLE_TASK) - task_by_pending_request = requests.get( - client._rest._url("/tasks"), params={"task_status": TaskStatusEnum.PENDING} - ) - assert task_by_pending_request.status_code == status.HTTP_200_OK - task_by_pending = TypeAdapter(TasksListResponse).validate_python( - task_by_pending_request.json() - ) - + task_by_pending = client.get_all_tasks() + # https://github.com/DiamondLightSource/blueapi/issues/680 + # task_by_pending = client.get_tasks_by_status(TaskStatusEnum.PENDING) assert len(task_by_pending.tasks) == 2 + # Check if all the tasks are pending for task in task_by_pending.tasks: trackable_task = TypeAdapter(TrackableTask).validate_python(task) assert trackable_task.is_complete is False and trackable_task.is_pending is True @@ -197,13 +284,11 @@ def test_get_task_by_status(client: BlueapiClient): client.start_task(WorkerTask(task_id=task_2.task_id)) while not client.get_task(task_2.task_id).is_complete: time.sleep(0.1) - task_by_completed_request = requests.get( - client._rest._url("/tasks"), params={"task_status": TaskStatusEnum.COMPLETE} - ) - task_by_completed = TypeAdapter(TasksListResponse).validate_python( - task_by_completed_request.json() - ) + task_by_completed = client.get_all_tasks() + # https://github.com/DiamondLightSource/blueapi/issues/680 + # task_by_pending = client.get_tasks_by_status(TaskStatusEnum.COMPLETE) assert len(task_by_completed.tasks) == 2 + # Check if all the tasks are completed for task in task_by_completed.tasks: trackable_task = TypeAdapter(TrackableTask).validate_python(task) assert trackable_task.is_complete is True and trackable_task.is_pending is False diff --git a/tests/unit_tests/client/test_rest.py b/tests/unit_tests/client/test_rest.py index 925ae9664..54dd13654 100644 --- a/tests/unit_tests/client/test_rest.py +++ b/tests/unit_tests/client/test_rest.py @@ -1,8 +1,17 @@ +import base64 +from pathlib import Path from unittest.mock import Mock, patch +import jwt import pytest +import responses +from pydantic import BaseModel from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError +from blueapi.config import OAuthClientConfig, OAuthServerConfig +from blueapi.core.bluesky_types import Plan +from blueapi.service.authentication import CliTokenManager, SessionManager +from blueapi.service.model import PlanModel, PlanResponse @pytest.fixture @@ -10,6 +19,43 @@ def rest() -> BlueapiRestClient: return BlueapiRestClient() +@pytest.fixture +def cache_token(tmp_path: Path): + with open(tmp_path / "token", "w") as token_file: + # base64 encoded token + token_file.write( + base64.b64encode( + b'{"access_token":"token","refresh_token":"refresh_token"}' + ).decode("utf-8") + ) + + +@pytest.fixture +@responses.activate +def rest_with_auth(tmp_path: Path) -> BlueapiRestClient: + responses.add( + responses.GET, + "http://example.com", + json={ + "device_authorization_endpoint": "https://example.com/device_authorization", + "authorization_endpoint": "https://example.com/authorization", + "token_endpoint": "https://example.com/token", + "issuer": "https://example.com", + "jwks_uri": "https://example.com/realms/master/protocol/openid-connect/certs", + "end_session_endpoint": "https://example.com/logout", + "id_token_signing_alg_values_supported": ["RS256", "RS384", "RS512"], + }, + status=200, + ) + + session_manager = SessionManager( + token_manager=CliTokenManager(tmp_path / "token"), + client_config=OAuthClientConfig(client_id="foo", client_audience="bar"), + server_config=OAuthServerConfig(oidc_config_url="http://example.com"), + ) + return BlueapiRestClient(session_manager=session_manager) + + @pytest.mark.parametrize( "code,expected_exception", [ @@ -30,3 +76,47 @@ def test_rest_error_code( mock_request.return_value = response with pytest.raises(expected_exception): rest.get_plans() + + +class MyModel(BaseModel): + id: str + + +@responses.activate +def test_auth_request_functionality(rest_with_auth: BlueapiRestClient, cache_token): + plan = Plan(name="my-plan", model=MyModel) + responses.add( + responses.GET, + "http://localhost:8000/plans", + json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(), + status=200, + ) + with patch("blueapi.service.Authenticator.decode_jwt") as mock_decode_jwt: + mock_decode_jwt.return_value = {"name": "John Doe", "fedid": "jd1"} + + result = rest_with_auth.get_plans() + mock_decode_jwt.assert_called_once_with("token") + assert result == PlanResponse(plans=[PlanModel.from_plan(plan)]) + + +@responses.activate +def test_refresh_if_signature_expired(rest_with_auth: BlueapiRestClient, cache_token): + plan = Plan(name="my-plan", model=MyModel) + responses.add( + responses.GET, + "http://localhost:8000/plans", + json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(), + status=200, + ) + with ( + patch("blueapi.service.Authenticator.decode_jwt") as mock_decode_token, + patch( + "blueapi.service.SessionManager.refresh_auth_token" + ) as mock_refresh_token, + ): + mock_decode_token.side_effect = jwt.ExpiredSignatureError + mock_refresh_token.return_value = {"access_token": "new_token"} + result = rest_with_auth.get_plans() + mock_decode_token.assert_called_once_with("token") + mock_refresh_token.assert_called_once() + assert result == PlanResponse(plans=[PlanModel.from_plan(plan)]) diff --git a/tests/unit_tests/service/test_authentication.py b/tests/unit_tests/service/test_authentication.py new file mode 100644 index 000000000..a56de89c2 --- /dev/null +++ b/tests/unit_tests/service/test_authentication.py @@ -0,0 +1,155 @@ +import base64 +import os +from http import HTTPStatus +from pathlib import Path +from unittest import mock + +import jwt +import pytest +from fastapi.exceptions import HTTPException + +from blueapi.config import CLIClientConfig, OAuthClientConfig, OAuthServerConfig +from blueapi.service import main +from blueapi.service.authentication import Authenticator, SessionManager + + +@pytest.fixture +def mock_client_config(tmp_path: Path) -> OAuthClientConfig: + return CLIClientConfig( + client_id="client_id", + client_audience="client_audience", + token_file_path=tmp_path / "token", + ) + + +@pytest.fixture +@mock.patch("requests.get") +def mock_server_config(mock_requests_get) -> OAuthServerConfig: + mock_requests_get.return_value.status_code = 200 + mock_requests_get.return_value.json.return_value = { + "device_authorization_endpoint": "https://example.com/device_authorization", + "authorization_endpoint": "https://example.com/authorization", + "token_endpoint": "https://example.com/token", + "issuer": "https://example.com", + "jwks_uri": "https://example.com/.well-known/jwks.json", + "end_session_endpoint": "https://example.com/logout", + "id_token_signing_alg_values_supported": ["RS256", "RS384", "RS512"], + } + return OAuthServerConfig( + oidc_config_url="https://auth.example.com/realms/sample/.well-known/openid-configuration", + ) + + +@pytest.fixture +def mock_session_manager( + mock_client_config: OAuthClientConfig, mock_server_config: OAuthServerConfig +) -> SessionManager | None: + return SessionManager.from_config(mock_server_config, mock_client_config) + + +@pytest.fixture +def mock_connected_client_config(mock_client_config: OAuthClientConfig): + assert isinstance(mock_client_config, CLIClientConfig) + with open(mock_client_config.token_file_path, "w") as token_file: + # base64 encoded token + token_file.write( + base64.b64encode( + b'{"access_token":"token","refresh_token":"refresh_token"}' + ).decode("utf-8") + ) + return mock_client_config + + +@pytest.fixture +def mock_authenticator( + mock_server_config: OAuthServerConfig, mock_client_config: OAuthClientConfig +) -> Authenticator: + return Authenticator(mock_server_config, mock_client_config) + + +def test_logout( + mock_session_manager: SessionManager, mock_connected_client_config: CLIClientConfig +): + assert os.path.exists(mock_connected_client_config.token_file_path) + mock_session_manager.logout() + assert not os.path.exists(mock_connected_client_config.token_file_path) + + +@mock.patch("requests.post") +def test_refresh_auth_token( + mock_post, + mock_session_manager: SessionManager, + mock_connected_client_config: CLIClientConfig, +): + mock_post.return_value.status_code = HTTPStatus.OK + mock_post.return_value.json.return_value = {"access_token": "new_access_token"} + result = mock_session_manager.refresh_auth_token() + assert result == {"access_token": "new_access_token"} + assert os.path.exists(mock_connected_client_config.token_file_path) + with open(mock_connected_client_config.token_file_path) as token_file: + token = token_file.read() + assert token == base64.b64encode( + b'{"access_token": "new_access_token"}' + ).decode("utf-8") + + +@mock.patch("requests.post") +def test_poll_for_token( + mock_post, + mock_session_manager: SessionManager, +): + mock_post.return_value.status_code = HTTPStatus.OK + mock_post.return_value.json.return_value = {"access_token": "access_token"} + device_code = "device_code" + token = mock_session_manager.poll_for_token(device_code) + assert token == {"access_token": "access_token"} + + +@mock.patch("requests.post") +@mock.patch("time.sleep") +def test_poll_for_token_timeout( + mock_sleep, + mock_post, + mock_session_manager: SessionManager, +): + mock_post.return_value.status_code = HTTPStatus.BAD_REQUEST + device_code = "device_code" + with pytest.raises(TimeoutError): + mock_session_manager.poll_for_token( + device_code, timeout=1, polling_interval=0.1 + ) + + +@mock.patch("jwt.decode") +@mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt") +def test_valid_token_access_granted( + mock_get_signing_key, mock_decode, mock_authenticator: Authenticator +): + with mock.patch.object(main, "AUTHENTICATOR", mock_authenticator): + decode_return_value = {"token": "valid_token", "name": "John Doe"} + mock_decode.return_value = decode_return_value + main.verify_access_token("token") + + +@mock.patch("jwt.decode") +@mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt") +def test_invalid_token_no_access( + mock_get_signing_key, mock_decode, mock_authenticator: Authenticator +): + with pytest.raises(HTTPException) as exec: + with mock.patch.object(main, "AUTHENTICATOR", mock_authenticator): + mock_decode.return_value = None + main.verify_access_token("token") + assert exec.value.status_code == HTTPStatus.UNAUTHORIZED + + +@mock.patch("jwt.decode") +@mock.patch("jwt.PyJWKClient.get_signing_key_from_jwt") +def test_verify_access_token_failure( + mock_get_signing_key, mock_decode, mock_authenticator: Authenticator +): + with pytest.raises(HTTPException) as exec: + with mock.patch.object(main, "AUTHENTICATOR", mock_authenticator): + mock_decode.side_effect = jwt.ExpiredSignatureError + main.verify_access_token("token") + assert exec.value.status_code == HTTPStatus.UNAUTHORIZED diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 6bb08cb01..c62975c44 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -25,9 +25,7 @@ @pytest.fixture def client() -> Iterator[TestClient]: - with ( - patch("blueapi.service.interface.worker"), - ): + with patch("blueapi.service.interface.worker"): main.setup_runner(use_subprocess=False) yield TestClient(main.app) main.teardown_runner() diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 370a90e85..a74b742fe 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -1,3 +1,4 @@ +import base64 import json from collections.abc import Mapping from dataclasses import dataclass @@ -7,6 +8,7 @@ from typing import Any from unittest.mock import Mock, patch +import jwt import pytest import responses from bluesky_stomp.messaging import StompClient @@ -43,7 +45,7 @@ def template(mock_connection: Mock) -> StompClient: @pytest.fixture -def runner(): +def runner() -> CliRunner: return CliRunner() @@ -617,3 +619,223 @@ def _assert_matching_formatting(fmt: OutputFormat, obj: Any, expected: str): output = StringIO() fmt.display(obj, output) assert expected == output.getvalue() + + +@responses.activate +def test_login_missing_config(runner: CliRunner): + result = runner.invoke(main, ["login"]) + assert "Please provide configuration to login!" in result.output + assert result.exit_code == 0 + + +@responses.activate +def test_logout_missing_config(runner: CliRunner): + result = runner.invoke(main, ["logout"]) + assert "Please provide configuration to logout!" in result.output + assert result.exit_code == 0 + + +TOKEN_URL: str = "https://example.com/token" +DEVICE_AUTHORIZATION_URL: str = "https://example.com/device_authorization" +OIDC_URL: str = ( + "https://auth.example.com/realms/sample/.well-known/openid-configuration" +) +OAUTH_CONFIGURATION: dict[str, str | list[str]] = { + "device_authorization_endpoint": DEVICE_AUTHORIZATION_URL, + "authorization_endpoint": "https://example.com/authorization", + "token_endpoint": TOKEN_URL, + "issuer": "https://example.com", + "jwks_uri": "https://example.com/realms/master/protocol/openid-connect/certs", + "end_session_endpoint": "https://example.com/logout", + "id_token_signing_alg_values_supported": ["RS256", "RS384", "RS512"], +} +ERROR_RESPONSE: dict[str, str] = { + "details": "not found", +} +PAYLOAD: dict[str, str] = { + "name": "John Doe", + "fedid": "jd1", +} + + +@pytest.fixture +def valid_auth_config(tmp_path: Path) -> str: + config: str = f""" +oauth_server: + oidc_config_url: {OIDC_URL} +oauth_client: + client_id: sample-cli + client_audience: sample-account + token_file_path: {tmp_path}/token +""" + with open(tmp_path / "auth_config.yaml", mode="w") as valid_auth_config_file: + valid_auth_config_file.write(config) + return valid_auth_config_file.name + + +@responses.activate +def test_login_success(runner: CliRunner, valid_auth_config: str): + with responses.RequestsMock(assert_all_requests_are_fired=True) as requests_mock: + requests_mock.add( + requests_mock.GET, + OIDC_URL, + json=OAUTH_CONFIGURATION, + status=200, + ) + requests_mock.add( + requests_mock.POST, + DEVICE_AUTHORIZATION_URL, + json={ + "device_code": "device_code", + "verification_uri_complete": "https://example.com/verify", + }, + status=200, + ) + requests_mock.add( + requests_mock.POST, + TOKEN_URL, + json={ + "access_token": "token", + }, + status=200, + ) + with ( + patch("blueapi.service.Authenticator.decode_jwt") as mock_decode, + ): + mock_decode.return_value = PAYLOAD + result = runner.invoke(main, ["-c", valid_auth_config, "login"]) + assert ( + "Logging in\n" + "Please login from this URL:- https://example.com/verify\n" + f"Logged in as {PAYLOAD['name']} with fed-id {PAYLOAD['fedid']}\n" + == result.output + ) + assert result.exit_code == 0 + + +@responses.activate +def test_token_login_early_exit( + runner: CliRunner, valid_auth_config: str, tmp_path: Path +): + with open(tmp_path / "token", "w") as token_file: + # base64 encoded token + token_file.write( + base64.b64encode( + b'{"access_token":"token","refresh_token":"refresh_token"}' + ).decode("utf-8") + ) + + with responses.RequestsMock(assert_all_requests_are_fired=True) as requests_mock: + requests_mock.add( + requests_mock.GET, + OIDC_URL, + json=OAUTH_CONFIGURATION, + status=200, + ) + with ( + patch("blueapi.service.Authenticator.decode_jwt") as mock_decode, + ): + mock_decode.side_effect = [PAYLOAD, PAYLOAD] + result = runner.invoke(main, ["-c", valid_auth_config, "login"]) + assert ( + "Logging in\n" + f"Logged in as {PAYLOAD['name']} with fed-id {PAYLOAD['fedid']}\n" + == result.output + ) + assert result.exit_code == 0 + + +@responses.activate +def test_login_with_refresh_token( + runner: CliRunner, valid_auth_config: str, tmp_path: Path +): + with open(tmp_path / "token", "w") as token_file: + # base64 encoded token + token_file.write( + base64.b64encode( + b'{"access_token":"token","refresh_token":"refresh_token"}' + ).decode("utf-8") + ) + + with responses.RequestsMock(assert_all_requests_are_fired=True) as requests_mock: + requests_mock.add( + requests_mock.GET, + OIDC_URL, + json=OAUTH_CONFIGURATION, + status=200, + ) + requests_mock.add( + requests_mock.POST, + TOKEN_URL, + json={ + "access_token": "token", + }, + status=200, + ) + with ( + patch("blueapi.service.Authenticator.decode_jwt") as mock_decode, + ): + mock_decode.side_effect = [jwt.ExpiredSignatureError, PAYLOAD] + result = runner.invoke(main, ["-c", valid_auth_config, "login"]) + assert ( + "Logging in\n" + f"Logged in as {PAYLOAD['name']} with fed-id {PAYLOAD['fedid']}\n" + == result.output + ) + assert result.exit_code == 0 + + +@responses.activate +def test_login_edge_cases(runner: CliRunner, valid_auth_config: str, tmp_path: Path): + with open(tmp_path / "token", "w") as token_file: + # base64 encoded token + token_file.write( + base64.b64encode( + b'{"access_token":"token","refresh_token":"refresh_token"}' + ).decode("utf-8") + ) + with responses.RequestsMock(assert_all_requests_are_fired=True) as requests_mock: + requests_mock.add( + requests_mock.GET, + OIDC_URL, + json=OAUTH_CONFIGURATION, + status=200, + ) + requests_mock.add( + requests_mock.POST, + TOKEN_URL, + json=ERROR_RESPONSE, + status=400, + ) + requests_mock.add( + requests_mock.POST, + DEVICE_AUTHORIZATION_URL, + json=ERROR_RESPONSE, + status=400, + ) + with ( + patch("blueapi.service.Authenticator.decode_jwt") as mock_decode, + ): + mock_decode.side_effect = jwt.ExpiredSignatureError + result = runner.invoke(main, ["-c", valid_auth_config, "login"]) + assert "Logging in\nFailed to login\n" == result.output + assert result.exit_code == 0 + + +@responses.activate +def test_logout_success(runner: CliRunner, valid_auth_config: str, tmp_path: Path): + with open(tmp_path / "token", "w") as token_file: + # base64 encoded token + token_file.write(base64.b64encode(b'{"access_token":"token"}').decode("utf-8")) + response = responses.add( + responses.GET, + OIDC_URL, + json=OAUTH_CONFIGURATION, + status=200, + ) + assert tmp_path.joinpath("token").exists() is True + result = runner.invoke(main, ["-c", valid_auth_config, "logout"]) + assert "Logged out" in result.output + assert result.exit_code == 0 + assert response.call_count == 1 + assert tmp_path.joinpath("token").exists() is False diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index d4ffa10c1..2a861dfdd 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -7,7 +7,7 @@ from bluesky_stomp.models import BasicAuthentication from pydantic import BaseModel, Field -from blueapi.config import ConfigLoader +from blueapi.config import ConfigLoader, OAuthServerConfig from blueapi.utils import InvalidConfigError @@ -147,3 +147,49 @@ def test_auth_from_env_throws_when_not_available(): BasicAuthentication(username="${BAZ}", password="baz") with pytest.raises(KeyError): BasicAuthentication(username="${baz}", password="baz") + + +@mock.patch("requests.get") +def test_oauth_config_model_post_init(mock_get): + oidc_config_url = "https://example.com/.well-known/openid-configuration" + mock_response = { + "device_authorization_endpoint": "https://example.com/device_authorization", + "authorization_endpoint": "https://example.com/authorize", + "token_endpoint": "https://example.com/token", + "issuer": "https://example.com/", + "jwks_uri": "https://example.com/jwks", + "end_session_endpoint": "https://example.com/logout", + "id_token_signing_alg_values_supported": ["RS256", "RS384", "RS512"], + } + + mock_get.return_value.json.return_value = mock_response + mock_get.return_value.raise_for_status = lambda: None + + oauth_config = OAuthServerConfig(oidc_config_url=oidc_config_url) + + assert ( + oauth_config.device_auth_url == mock_response["device_authorization_endpoint"] + ) + assert oauth_config.pkce_auth_url == mock_response["authorization_endpoint"] + assert oauth_config.token_url == mock_response["token_endpoint"] + assert oauth_config.issuer == mock_response["issuer"] + assert oauth_config.jwks_uri == mock_response["jwks_uri"] + assert oauth_config.logout_url == mock_response["end_session_endpoint"] + + +@mock.patch("requests.get") +def test_oauth_config_model_post_init_missing_fields(mock_get): + oidc_config_url = "https://example.com/.well-known/openid-configuration" + mock_response = { + "device_authorization_endpoint": "https://example.com/device_authorization", + "authorization_endpoint": "https://example.com/authorize", + "token_endpoint": "https://example.com/token", + "issuer": "https://example.com/", + "jwks_uri": "https://example.com/jwks", + "end_session_endpoint": "", # Missing end_session_endpoint + } + + mock_get.return_value.json.return_value = mock_response + mock_get.return_value.raise_for_status = lambda: None + with pytest.raises(ValueError, match="OIDC config is missing required fields"): + OAuthServerConfig(oidc_config_url=oidc_config_url) diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index 96777db9b..4bb34cc3b 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -307,6 +307,9 @@ def begin_task_and_wait_until_complete( # +@pytest.mark.skip( + "This test is currently waiting for https://github.com/DiamondLightSource/dls-bluesky-core/blob/main/src/dls_bluesky_core/plans/wrapped.py" +) def test_worker_and_data_events_produce_in_order(worker: TaskWorker) -> None: assert_running_count_plan_produces_ordered_worker_and_data_events( [