diff --git a/backend/src/api/endpoints/analysis_table.py b/backend/src/api/endpoints/analysis_table.py index 54b85128f..dd5a71ae7 100644 --- a/backend/src/api/endpoints/analysis_table.py +++ b/backend/src/api/endpoints/analysis_table.py @@ -1,7 +1,10 @@ from typing import List, Optional -from api.dependencies import get_current_user, get_db_session +from api.dependencies import get_current_user, get_db_session, is_authorized from app.core.data.crud.analysis_table import crud_analysis_table +from app.core.data.crud.project import crud_project +from app.core.data.crud.user import crud_user +from app.core.data.dto.action import ActionType from app.core.data.dto.analysis_table import ( AnalysisTableCreate, AnalysisTableRead, @@ -22,6 +25,7 @@ response_model=Optional[AnalysisTableRead], summary="Creates an AnalysisTable", description="Creates an AnalysisTable", + dependencies=[is_authorized(ActionType.CREATE, crud_analysis_table)], ) async def create( *, db: Session = Depends(get_db_session), analysis_table: AnalysisTableCreate @@ -36,6 +40,9 @@ async def create( response_model=Optional[AnalysisTableRead], summary="Returns the AnalysisTable", description="Returns the AnalysisTable with the given ID if it exists", + dependencies=[ + is_authorized(ActionType.READ, crud_analysis_table, "analysis_table_id") + ], ) async def get_by_id( *, db: Session = Depends(get_db_session), analysis_table_id: int @@ -49,6 +56,10 @@ async def get_by_id( response_model=List[AnalysisTableRead], summary="Returns AnalysisTables of the Project of the User", description="Returns the AnalysisTable of the Project with the given ID and the User with the given ID if it exists", + dependencies=[ + is_authorized(ActionType.READ, crud_project, "project_id"), + is_authorized(ActionType.READ, crud_user, "user_id"), + ], ) async def get_by_project_and_user( *, db: Session = Depends(get_db_session), project_id: int, user_id: int diff --git a/backend/src/api/endpoints/user.py b/backend/src/api/endpoints/user.py index d7e3f20d7..6f2b60d21 100644 --- a/backend/src/api/endpoints/user.py +++ b/backend/src/api/endpoints/user.py @@ -1,9 +1,15 @@ from typing import Dict, List, Optional -from api.dependencies import get_current_user, get_db_session, skip_limit_params +from api.dependencies import ( + get_current_user, + get_db_session, + is_authorized, + skip_limit_params, +) from app.core.data.crud.annotation_document import crud_adoc from app.core.data.crud.memo import crud_memo from app.core.data.crud.user import crud_user +from app.core.data.dto.action import ActionType from app.core.data.dto.annotation_document import AnnotationDocumentRead from app.core.data.dto.code import CodeRead from app.core.data.dto.memo import MemoRead @@ -32,6 +38,7 @@ async def get_me(*, user: UserRead = Depends(get_current_user)) -> Optional[User response_model=Optional[UserRead], summary="Returns the User", description="Returns the User with the given ID if it exists", + dependencies=[is_authorized(ActionType.READ, crud_user, "user_id")], ) async def get_by_id( *, db: Session = Depends(get_db_session), user_id: int @@ -45,6 +52,7 @@ async def get_by_id( response_model=List[UserRead], summary="Returns all Users", description="Returns all Users that exist in the system", + # TODO do we need some kind of authorization check here? ) async def get_all( *, @@ -60,6 +68,7 @@ async def get_all( response_model=Optional[UserRead], summary="Updates the User", description="Updates the User with the given ID if it exists", + dependencies=[is_authorized(ActionType.UPDATE, crud_user, "user_id")], ) async def update_by_id( *, db: Session = Depends(get_db_session), user_id: int, user: UserUpdate @@ -73,6 +82,7 @@ async def update_by_id( response_model=Optional[UserRead], summary="Removes the User", description="Removes the User with the given ID if it exists", + dependencies=[is_authorized(ActionType.DELETE, crud_user, "user_id")], ) async def delete_by_id( *, db: Session = Depends(get_db_session), user_id: int @@ -86,6 +96,7 @@ async def delete_by_id( response_model=List[ProjectRead], summary="Returns all Projects of the User", description="Returns all Projects of the User with the given ID", + dependencies=[is_authorized(ActionType.READ, crud_user, "user_id")], ) async def get_user_projects( *, user_id: int, db: Session = Depends(get_db_session) @@ -100,6 +111,7 @@ async def get_user_projects( response_model=List[CodeRead], summary="Returns all Codes of the User", description="Returns all Codes of the User with the given ID", + dependencies=[is_authorized(ActionType.READ, crud_user, "user_id")], ) async def get_user_codes( *, user_id: int, db: Session = Depends(get_db_session) @@ -114,6 +126,7 @@ async def get_user_codes( response_model=List[MemoRead], summary="Returns all Memos of the User", description="Returns all Memos of the User with the given ID", + dependencies=[is_authorized(ActionType.READ, crud_user, "user_id")], ) async def get_user_memos( *, user_id: int, db: Session = Depends(get_db_session) @@ -131,6 +144,7 @@ async def get_user_memos( response_model=List[AnnotationDocumentRead], summary="Returns all Adocs of the User", description="Returns all Adocs of the User with the given ID", + dependencies=[is_authorized(ActionType.READ, crud_user, "user_id")], ) async def get_user_adocs( *, user_id: int, db: Session = Depends(get_db_session) @@ -147,6 +161,7 @@ async def get_user_adocs( response_model=List[AnnotationDocumentRead], summary="Returns sdoc ids of sdocs the User recently modified (annotated)", description="Returns the top k sdoc ids that the User recently modified (annotated)", + dependencies=[is_authorized(ActionType.READ, crud_user, "user_id")], ) async def recent_activity( *, user_id: int, k: int, db: Session = Depends(get_db_session) diff --git a/backend/src/app/core/data/orm/util.py b/backend/src/app/core/data/orm/util.py index 04c521f45..3b7542cf7 100644 --- a/backend/src/app/core/data/orm/util.py +++ b/backend/src/app/core/data/orm/util.py @@ -57,7 +57,6 @@ def get_parent_project_id(orm: ORMBase) -> Optional[int]: # TODO missing cases: # - SourceDocumentLinkORM - # - SourceDocumentMetadataORM raise NotImplementedError(f"Unknown ORM: {type(orm)}")