Skip to content

Commit

Permalink
Added initial authentication to ping endpoints.
Browse files Browse the repository at this point in the history
  • Loading branch information
JBorrow committed Feb 1, 2024
1 parent 2efcaa2 commit dd38024
Show file tree
Hide file tree
Showing 8 changed files with 325 additions and 3 deletions.
8 changes: 8 additions & 0 deletions alembic/versions/71df5b41ae41_initial_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from hera_librarian.deletion import DeletionPolicy
from hera_librarian.errors import ErrorCategory, ErrorSeverity
from hera_librarian.transfer import TransferStatus
from librarian_server.authlevel import AuthLevel


def upgrade():
Expand Down Expand Up @@ -185,6 +186,13 @@ def upgrade():
Column("caller", String(256)),
)

op.create_table(
"users",
Column("username", String(256), primary_key=True, unique=True),
Column("auth_token", String(256), nullable=False),
Column("auth_level", Enum(AuthLevel), nullable=False),
)


def downgrade():
op.drop_table("incoming_transfers")
Expand Down
138 changes: 138 additions & 0 deletions librarian_server/api/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
Authenticaiton helper functions for the web API.
You should really care about the following dependencies:
- NoneUserDependency
- ReadonlyUserDependency
- ReadappendUserDependency
- ReadwriteUserDependency
- AdminUserDependency
These are used to ensure that the user is authenticated with the correct level
of permissions. If they are not, we raise a HTTPException (see
UnauthorizedError).
"""

from typing import Annotated

from fastapi import Depends, HTTPException
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from pydantic import BaseModel
from sqlalchemy.orm import Session

from ..authlevel import AuthLevel
from ..database import yield_session
from ..orm import User

security = HTTPBasic()

SecurityDepedency = Annotated[HTTPBasicCredentials, Depends(security)]
SessionDependency = Annotated[Session, Depends(yield_session)]

UnauthorizedError = HTTPException(
status_code=401,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)


class UserPermissions(BaseModel):
"""
A simple model to represent a user and their permission.
"""

username: str
"The username of the user."
permission: AuthLevel
"The permission level of the user."


def get_user(
credentials: SecurityDepedency, session: SessionDependency
) -> UserPermissions:
"""
Get the user and their permissions from the database.
"""

return UserPermissions(
username=credentials.username,
permission=User.check_user(
username=credentials.username,
password=credentials.password,
session=session,
),
)


def get_user_with_level(
level: AuthLevel, credentials: SecurityDepedency, session: SessionDependency
) -> UserPermissions:
"""
Get the user and their permissions from the database.
If the user does not have the required level, raise an UnauthorizedError.
"""

user = get_user(credentials, session)

if user.permission.value < level.value:
raise UnauthorizedError

return user


def get_user_with_none(
credentials: SecurityDepedency, session: SessionDependency
) -> UserPermissions:
"""
Ensure user is authenticated with a level of at least NONE.
"""

return get_user_with_level(AuthLevel.NONE, credentials, session)


def get_user_with_readonly(
credentials: SecurityDepedency, session: SessionDependency
) -> UserPermissions:
"""
Ensure user is authenticated with a level of at least READONLY.
"""

return get_user_with_level(AuthLevel.READONLY, credentials, session)


def get_user_with_readappend(
credentials: SecurityDepedency, session: SessionDependency
) -> UserPermissions:
"""
Ensure user is authenticated with a level of at least READAPPEND.
"""

return get_user_with_level(AuthLevel.READAPPEND, credentials, session)


def get_user_with_readwrite(
credentials: SecurityDepedency, session: SessionDependency
) -> UserPermissions:
"""
Ensure user is authenticated with a level of at least READWRITE.
"""

return get_user_with_level(AuthLevel.READWRITE, credentials, session)


def get_user_with_admin(
credentials: SecurityDepedency, session: SessionDependency
) -> UserPermissions:
"""
Ensure user is authenticated with a level of at least ADMIN.
"""

return get_user_with_level(AuthLevel.ADMIN, credentials, session)


NoneUserDependency = Annotated[UserPermissions, Depends(get_user_with_none)]
ReadonlyUserDependency = Annotated[UserPermissions, Depends(get_user_with_readonly)]
ReadappendUserDependency = Annotated[UserPermissions, Depends(get_user_with_readappend)]
ReadwriteUserDependency = Annotated[UserPermissions, Depends(get_user_with_readwrite)]
AdminUserDependency = Annotated[UserPermissions, Depends(get_user_with_admin)]
35 changes: 33 additions & 2 deletions librarian_server/api/ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,49 @@

from ..logger import log
from ..settings import server_settings
from .auth import AdminUserDependency, NoneUserDependency, ReadonlyUserDependency

router = APIRouter(prefix="/api/v2/ping")


@router.post("/", response_model=PingResponse)
def ping(request: PingRequest):
def ping(request: PingRequest, user: NoneUserDependency):
"""
Pings the librarian server. Returns some information about
the server.
"""

log.debug(f"Received ping request: {request}")
log.debug(f"Received ping request: {request} from user {user}")

return PingResponse(
name=server_settings.displayed_site_name,
description=server_settings.displayed_site_description,
)


@router.post("/logged", response_model=PingResponse)
def ping_logged_in(request: PingRequest, user: ReadonlyUserDependency):
"""
Pings the librarian server. Returns some information about
the server.
"""

log.debug(f"Received ping (logged in) request: {request} from user {user}")

return PingResponse(
name=server_settings.displayed_site_name,
description=server_settings.displayed_site_description,
)


@router.post("/admin", response_model=PingResponse)
def ping_admin(request: PingRequest, user: AdminUserDependency):
"""
Pings the librarian server. Returns some information about
the server.
"""

log.debug(f"Received ping (admin) request: {request} from user {user}")

return PingResponse(
name=server_settings.displayed_site_name,
Expand Down
26 changes: 26 additions & 0 deletions librarian_server/authlevel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
Enumeration for authentication levels.
"""

from enum import Enum


class AuthLevel(Enum):
"""
The level of authorization that a given user has.
"""

NONE = 0
"Not used, but in the case where someone is not allowed to do anything."

READONLY = 1
"Can read from the databases and store, but not write."

READAPPEND = 2
"Can read and append to the databases and store."

READWIRTE = 3
"Can read and write to the databases and store."

ADMIN = 100
"Can do anything, including modifying the configuration."
1 change: 1 addition & 0 deletions librarian_server/orm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .librarian import Librarian
from .storemetadata import StoreMetadata
from .transfer import CloneTransfer, IncomingTransfer, OutgoingTransfer, TransferStatus
from .user import User
105 changes: 105 additions & 0 deletions librarian_server/orm/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
ORM model for a user.
"""

import argon2
from sqlalchemy.orm import Session

from .. import database as db
from ..authlevel import AuthLevel


class User(db.Base):
"""
A user in the librarian system, along with authentication functons.
"""

__tablename__ = "users"

username = db.Column(db.String(256), primary_key=True, unique=True)
"The username of the user."
auth_token = db.Column(db.String(256), nullable=False)
"The authentication token for the user (a salted and hashed password with argon2)."
auth_level = db.Column(db.Enum(AuthLevel), nullable=False)
"The authorization level of the user."

@classmethod
def new_user(cls, username: str, password: str, auth_level: int) -> "User":
"""
Create a new user in the database.
Parameters
----------
username : str
The username of the new user.
password : str
The password of the new user.
auth_level : int
The authorization level of the new user.
Returns
-------
User
The new user.
"""
# Create a new user.
ph = argon2.PasswordHasher()

user = cls(
username=username,
auth_token=ph.hash(password),
auth_level=auth_level,
)

return user

@classmethod
def check_user(cls, username: str, password: str, session: Session) -> AuthLevel:
"""
Check if a user exists and the password is correct.
Parameters
----------
username : str
The username to check.
password : str
The password to check.
session : Session
The database session to use.
Returns
-------
AuthLevel
The authorization level of the user.
"""

potential_user = session.get(User, username)

if potential_user is not None:
try:
if potential_user.check_password(password):
return potential_user.auth_level
else:
return AuthLevel.NONE
except argon2.exceptions.VerifyMismatchError:
return AuthLevel.NONE

return AuthLevel.NONE

def check_password(self, password: str) -> bool:
"""
Check if the password is correct for this user.
Parameters
----------
password : str
The password to check.
Returns
-------
bool
True if the password is correct.
"""
ph = argon2.PasswordHasher()

return ph.verify(self.auth_token, password)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
"schedule",
"checksumdir",
"python-dateutil",
"argon2-cffi",
]
authors = [
{name = "HERA Team", email = "hera@lists.berkeley.edu"},
Expand Down Expand Up @@ -90,4 +91,4 @@ profile = "black"
skip = ["docs", ".github", ".vscode", "container", "env", "env311", "build"]

[tool.black]
exclude = "docs|.github|.vscode|container|env|env311|build"
exclude = "docs|.github|.vscode|container|env|env311|build"
12 changes: 12 additions & 0 deletions tests/server_unit_test/test_ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,19 @@ def test_ping(test_client):
"/api/v2/ping",
headers={"Content-Type": "application/json"},
content=request.model_dump_json(),
auth=("test", "test"),
)
assert response.status_code == 200
# Check we can decode the response
response = PingResponse.model_validate_json(response.content)


def test_ping_logged_not_logged(test_client):
request = PingRequest()
response = test_client.post(
"/api/v2/ping/logged",
headers={"Content-Type": "application/json"},
content=request.model_dump_json(),
auth=("test", "test-not-real-password"),
)
assert response.status_code == 401

0 comments on commit dd38024

Please sign in to comment.