diff --git a/alembic/versions/71df5b41ae41_initial_schema.py b/alembic/versions/71df5b41ae41_initial_schema.py index 2e5ab6a..0a7926e 100644 --- a/alembic/versions/71df5b41ae41_initial_schema.py +++ b/alembic/versions/71df5b41ae41_initial_schema.py @@ -113,6 +113,7 @@ def upgrade(): ), Column("status", Enum(TransferStatus), nullable=False), Column("uploader", String(256), nullable=False), + Column("upload_name", String(256), nullable=False), Column("source", String(256), nullable=False), Column("transfer_size", BigInteger, nullable=False), Column("transfer_checksum", String(256), nullable=False), diff --git a/librarian_background/check_integrity.py b/librarian_background/check_integrity.py index 7b006ea..5179d12 100644 --- a/librarian_background/check_integrity.py +++ b/librarian_background/check_integrity.py @@ -9,8 +9,10 @@ from schedule import CancelJob -from librarian_server.database import session, query from librarian_server.orm import StoreMetadata, Instance +from librarian_server.database import get_session + +from sqlalchemy.orm import Session logger = logging.getLogger("schedule") @@ -26,9 +28,9 @@ class CheckIntegrity(Task): age_in_days: int "Age in days of the files to check. I.e. only check files younger than this (we assume older files are fine as they've been checked before)" - def get_store(self) -> StoreMetadata: + def get_store(self, session: Session) -> StoreMetadata: possible_metadata = ( - query(StoreMetadata).filter(StoreMetadata.name == self.store_name).first() + session.query(StoreMetadata).filter_by(name=self.store_name).first() ) if not possible_metadata: @@ -37,8 +39,15 @@ def get_store(self) -> StoreMetadata: return possible_metadata def on_call(self): + with get_session() as session: + return self.core(session=session) + + def core(self, session: Session): + """ + Frame this out with the session so that it is automatically closed. + """ try: - store = self.get_store() + store = self.get_store(session=session) except ValueError: # Store doesn't exist. Cancel this job. logger.error( @@ -51,9 +60,8 @@ def on_call(self): # Now we can query the database for all files that were uploaded in the past age_in_days days. files = ( - query(Instance) - .filter(Instance.store == store) - .filter(Instance.created_time > start_time) + session.query(Instance) + .filter(Instance.store == store and Instance.created_time > start_time) .all() ) diff --git a/librarian_background/create_clone.py b/librarian_background/create_clone.py index 2642040..7dc5eae 100644 --- a/librarian_background/create_clone.py +++ b/librarian_background/create_clone.py @@ -12,8 +12,10 @@ from schedule import CancelJob from pathlib import Path -from librarian_server.database import session, query from librarian_server.orm import StoreMetadata, Instance, CloneTransfer, TransferStatus +from librarian_server.database import get_session + +from sqlalchemy.orm import Session logger = logging.getLogger("schedule") @@ -33,19 +35,23 @@ class CreateLocalClone(Task): # TODO: In the future, we could implement a _rolling_ n day clone here, i.e. only keep the last n days of files on the clone_to store. - def get_store(self, name: str) -> StoreMetadata: + def get_store(self, name: str, session: Session) -> StoreMetadata: possible_metadata = ( - query(StoreMetadata).filter(StoreMetadata.name == name).first() + session.query(StoreMetadata).filter_by(name=name).first() ) if not possible_metadata: raise ValueError(f"Store {name} does not exist.") return possible_metadata - + def on_call(self): + with get_session() as session: + return self.core(session=session) + + def core(self, session: Session): try: - store_from = self.get_store(self.clone_from) + store_from = self.get_store(self.clone_from, session) except ValueError: # Store doesn't exist. Cancel this job. logger.error( @@ -54,7 +60,7 @@ def on_call(self): return CancelJob try: - store_to = self.get_store(self.clone_to) + store_to = self.get_store(self.clone_to, session) except ValueError: # Store doesn't exist. Cancel this job. logger.error( @@ -67,7 +73,7 @@ def on_call(self): # Now we can query the database for all files that were uploaded in the past age_in_days days. instances: list[Instance] = ( - query(Instance) + session.query(Instance) .filter(Instance.store == store_from) .filter(Instance.created_time > start_time) .all() @@ -81,7 +87,7 @@ def on_call(self): # Check if there is a matching instance already on our clone_to store. # If there is, we don't need to clone it. if ( - query(Instance) + session.query(Instance) .filter(Instance.store == store_to) .filter(Instance.file == instance.file) .first() @@ -113,7 +119,7 @@ def on_call(self): f"File {instance.file.name} is too large to fit on store {store_to}. Skipping." ) - transfer.fail_transfer() + transfer.fail_transfer(session=session) all_transfers_successful = False @@ -136,7 +142,7 @@ def on_call(self): f"Failed to transfer file {instance.path} to store {store_to} using transfer manager {transfer_manager}." ) - transfer.fail_transfer() + transfer.fail_transfer(session=session) continue except FileNotFoundError as e: @@ -144,7 +150,7 @@ def on_call(self): f"File {instance.path} does not exist on store {store_from}. Skipping." ) - transfer.fail_transfer() + transfer.fail_transfer(session=session) all_transfers_successful = False @@ -155,7 +161,7 @@ def on_call(self): f"Failed to transfer file {instance.path} to store {store_to}. Skipping." ) - transfer.fail_transfer() + transfer.fail_transfer(session=session) all_transfers_successful = False @@ -176,7 +182,7 @@ def on_call(self): f"Expected {instance.file.checksum}, got {path_info.md5}." ) - transfer.fail_transfer() + transfer.fail_transfer(session=session) store_to.store_manager.unstage(staged_path) @@ -193,7 +199,7 @@ def on_call(self): ) store_to.store_manager.unstage(staging_name) - transfer.fail_transfer() + transfer.fail_transfer(session=session) all_transfers_successful = False diff --git a/librarian_background/recieve_clone.py b/librarian_background/recieve_clone.py index 6247589..b4d0751 100644 --- a/librarian_background/recieve_clone.py +++ b/librarian_background/recieve_clone.py @@ -12,7 +12,7 @@ from .task import Task -from librarian_server.database import session, query +from librarian_server.database import get_session from librarian_server.orm import ( File, Instance, @@ -28,6 +28,10 @@ CloneCompleteResponse, ) +from typing import TYPE_CHECKING + +from sqlalchemy.orm import Session + logger = logging.getLogger("schedule") @@ -39,15 +43,21 @@ class RecieveClone(Task): deletion_policy: DeletionPolicy = DeletionPolicy.DISALLOWED def on_call(self): + with get_session() as session: + return self.core(session=session) + + def core(self, session: Session): """ Checks for incoming transfers and processes them. """ # Find incoming transfers that are ONGOING - ongoing_transfers: list[IncomingTransfer] = query( - IncomingTransfer, status=TransferStatus.ONGOING - ).all() - + ongoing_transfers: list[IncomingTransfer] = ( + session.query(IncomingTransfer) + .filter_by(status=TransferStatus.ONGOING) + .all() + ) + all_transfers_succeeded = True if len(ongoing_transfers) == 0: @@ -118,7 +128,7 @@ def on_call(self): path=path_info.path, file=file, store=store, - deletion_policy=self.deletion_policy + deletion_policy=self.deletion_policy, ) session.add(file) @@ -132,9 +142,9 @@ def on_call(self): session.commit() # Callback to the source librarian. - librarian: Optional[Librarian] = query( - Librarian, name=transfer.source - ).first() + librarian: Optional[Librarian] = ( + session.query(Librarian).filter_by(name=transfer.source).first() + ) if librarian: # Need to call back @@ -148,10 +158,12 @@ def on_call(self): ) try: - response: CloneCompleteResponse = librarian.client.do_pydantic_http_post( - endpoint="/api/v2/clone/complete", - request_model=request, - response_model=CloneCompleteResponse, + response: CloneCompleteResponse = ( + librarian.client.do_pydantic_http_post( + endpoint="/api/v2/clone/complete", + request_model=request, + response_model=CloneCompleteResponse, + ) ) except Exception as e: logger.error( @@ -168,5 +180,4 @@ def on_call(self): logger.info(f"Transfer {transfer.id} has not yet completed. Skipping.") continue - return all_transfers_succeeded diff --git a/librarian_background/send_clone.py b/librarian_background/send_clone.py index bb443af..906f788 100644 --- a/librarian_background/send_clone.py +++ b/librarian_background/send_clone.py @@ -12,7 +12,7 @@ from schedule import CancelJob from pathlib import Path -from librarian_server.database import session, query +from librarian_server.database import get_session() from librarian_server.orm import ( StoreMetadata, Instance, @@ -30,6 +30,13 @@ CloneOngoingResponse ) +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from hera_librarian import LibrarianClient + +from sqlalchemy.orm import Session + logger = logging.getLogger("schedule") @@ -49,13 +56,17 @@ class SendClone(Task): "Name of the store to prefer when sending files. If None, we will use whatever store is available for sending that file." def on_call(self): + with get_session() as session: + return self.core(session=session) + + def core(self, session: Session): """ Creates uploads to the remote librarian as specified. """ # Before even attempting to do anything, get the information about the librarian and create # a client connection to it. - librarian: Optional[Librarian] = query( - Librarian, Librarian.name == self.destination_librarian + librarian: Optional[Librarian] = session.query( + Librarian).filter_by(name=self.destination_librarian ).first() if librarian is None: @@ -79,10 +90,10 @@ def on_call(self): age_in_days = datetime.timedelta(days=self.age_in_days) oldest_file_age = current_time - age_in_days - files_without_remote_instances: list[File] = query( - File, - File.create_time > oldest_file_age, - File.remote_instances.any(librarian_name=self.destination_librarian), + files_without_remote_instances: list[File] = session.query( + File).filter( + File.create_time > oldest_file_age and + File.remote_instances.any(librarian_name=self.destination_librarian) ).all() logger.info( @@ -90,8 +101,8 @@ def on_call(self): ) if self.store_preference is not None: - use_store: StoreMetadata = query( - StoreMetadata, StoreMetadata.name == self.store_preference + use_store: StoreMetadata = session.query( + StoreMetadata).filter_by(name = self.store_preference ).first() if use_store is None: @@ -159,7 +170,7 @@ def on_call(self): ) # Mark the transfer as failed. - transfer.fail_transfer() + transfer.fail_transfer(session=session) continue @@ -192,7 +203,7 @@ def on_call(self): f"Failed to transfer file {instance.path} to remote store. Skipping." ) - transfer.fail_transfer() + transfer.fail_transfer(session=session) continue # Great! We can now mark the transfer as ONGOING in the background. diff --git a/librarian_server/api/clone.py b/librarian_server/api/clone.py index b8bdf45..7add28f 100644 --- a/librarian_server/api/clone.py +++ b/librarian_server/api/clone.py @@ -25,7 +25,7 @@ from ..orm.storemetadata import StoreMetadata from ..orm.transfer import TransferStatus, IncomingTransfer, OutgoingTransfer from ..orm.file import File -from ..database import session, query +from ..database import yield_session from ..logger import log from hera_librarian.models.clone import ( @@ -40,12 +40,13 @@ CloneFailRequest, ) -from fastapi import APIRouter, Response, status +from fastapi import APIRouter, Response, status, Depends +from sqlalchemy.orm import Session router = APIRouter(prefix="/api/v2/clone") @router.post("/stage", response_model=CloneInitiationResponse | CloneFailedResponse) -def stage(request: CloneInitiationRequest, response: Response): +def stage(request: CloneInitiationRequest, response: Response, session: Session = Depends(yield_session)): """ Recieved from a remote librarian to initiate a clone. @@ -104,16 +105,12 @@ def stage(request: CloneInitiationRequest, response: Response): # again a logic error. They should not have tried to send us that again! It's already # on its way. - existing_transfer = ( - query(IncomingTransfer) - .filter( - (IncomingTransfer.transfer_checksum == request.upload_checksum) - & (IncomingTransfer.status != TransferStatus.FAILED) - & (IncomingTransfer.status != TransferStatus.COMPLETED) - & (IncomingTransfer.status != TransferStatus.CANCELLED) - ) - .all() - ) + existing_transfer = session.query(IncomingTransfer).filter( + (IncomingTransfer.transfer_checksum == request.upload_checksum) + & (IncomingTransfer.status != TransferStatus.FAILED) + & (IncomingTransfer.status != TransferStatus.COMPLETED) + & (IncomingTransfer.status != TransferStatus.CANCELLED) + ).all() if len(existing_transfer) != 0: log.info( @@ -149,7 +146,7 @@ def stage(request: CloneInitiationRequest, response: Response): # Unstage the files. try: - store = StoreMetadata.from_id(transfer.store_id) + store = session.get(StoreMetadata, transfer.store_id) store.store_manager.unstage(Path(transfer.staging_path)) except ServerError: # No store was yet assigned, do not need to delete. @@ -175,6 +172,7 @@ def stage(request: CloneInitiationRequest, response: Response): transfer = IncomingTransfer.new_transfer( source=request.source, uploader=request.uploader, + upload_name=str(request.upload_name), transfer_size=request.upload_size, transfer_checksum=request.upload_checksum, ) @@ -184,7 +182,7 @@ def stage(request: CloneInitiationRequest, response: Response): use_store: Optional[StoreMetadata] = None - for store in query(StoreMetadata, ingestable=True).all(): + for store in session.query(StoreMetadata).filter_by(ingestable=True).all(): if not store.store_manager.available: continue @@ -242,7 +240,7 @@ def stage(request: CloneInitiationRequest, response: Response): @router.post("/ongoing", response_model=CloneOngoingResponse | CloneFailedResponse) -def ongoing(request: CloneOngoingRequest, response: Response): +def ongoing(request: CloneOngoingRequest, response: Response, session: Session = Depends(yield_session)): """ Called when the remote librarian has started the transfer. We should update the status of the transfer to ONGOING. @@ -256,7 +254,7 @@ def ongoing(request: CloneOngoingRequest, response: Response): log.debug(f"Received clone ongoing request: {request}") - transfer = query(IncomingTransfer, id=request.destination_transfer_id).first() + transfer = session.query(IncomingTransfer).filter_by(id=request.destination_transfer_id).first() if transfer is None: log.debug( @@ -302,7 +300,7 @@ def ongoing(request: CloneOngoingRequest, response: Response): @router.post("/complete", response_model=CloneCompleteResponse | CloneFailedResponse) -def complete(request: CloneCompleteRequest, response: Response): +def complete(request: CloneCompleteRequest, response: Response, session: Session = Depends(yield_session)): """ The callback from librarian B to librarian A that it has completed the transfer. Used to update anything in our OutgiongTransfers that needs it. @@ -316,7 +314,7 @@ def complete(request: CloneCompleteRequest, response: Response): log.debug(f"Received clone complete request: {request}") - transfer = query(OutgoingTransfer, id=request.source_transfer_id).first() + transfer = session.query(OutgoingTransfer).filter_by(id=request.source_transfer_id).first() if transfer is None: log.debug( @@ -362,14 +360,14 @@ def complete(request: CloneCompleteRequest, response: Response): @router.post("/fail", response_model=CloneFailResponse | CloneFailedResponse) -def fail(request: CloneFailRequest, response: Response): +def fail(request: CloneFailRequest, response: Response, session: Session = Depends(yield_session)): """ Endpoint to send to if you would like to fail a specific IncomingTransfer. """ log.debug(f"Received clone fail request: {request}") - transfer = query(IncomingTransfer, id=request.destination_transfer_id).first() + transfer = session.query(IncomingTransfer).filter_by(id=request.destination_transfer_id).first() if transfer is None: log.debug( diff --git a/librarian_server/api/ping.py b/librarian_server/api/ping.py index 000fe6f..69c8bc9 100644 --- a/librarian_server/api/ping.py +++ b/librarian_server/api/ping.py @@ -2,8 +2,6 @@ Contains endpoints for pinging and requesting a ping back. """ -from ..webutil import ServerError -from ..database import session, query from ..logger import log from ..settings import server_settings diff --git a/librarian_server/api/upload.py b/librarian_server/api/upload.py index 945f7e8..95a77a3 100644 --- a/librarian_server/api/upload.py +++ b/librarian_server/api/upload.py @@ -7,7 +7,7 @@ from ..orm.storemetadata import StoreMetadata from ..orm.transfer import TransferStatus, IncomingTransfer from ..orm.file import File -from ..database import session, query +from ..database import yield_session from ..logger import log from hera_librarian.models.uploads import ( @@ -20,13 +20,15 @@ from pathlib import Path from typing import Optional -from fastapi import APIRouter, Response, status +from fastapi import APIRouter, Response, status, Depends +from sqlalchemy.orm import Session +from sqlalchemy import select router = APIRouter(prefix="/api/v2/upload") @router.post("/stage", response_model=UploadInitiationResponse | UploadFailedResponse) -def stage(request: UploadInitiationRequest, response: Response): +def stage(request: UploadInitiationRequest, response: Response, session: Session = Depends(yield_session)): """ Initiates an upload to a store. @@ -67,16 +69,12 @@ def stage(request: UploadInitiationRequest, response: Response): ) # First, try to see if this is someone trying to re-start an existing transfer! - existing_transfer = ( - session.query(IncomingTransfer) - .filter( - (IncomingTransfer.transfer_checksum == request.upload_checksum) - & (IncomingTransfer.status != TransferStatus.FAILED) - & (IncomingTransfer.status != TransferStatus.COMPLETED) - & (IncomingTransfer.status != TransferStatus.CANCELLED) - ) - .all() - ) + existing_transfer = session.query(IncomingTransfer).filter( + (IncomingTransfer.transfer_checksum == request.upload_checksum) + & (IncomingTransfer.status != TransferStatus.FAILED) + & (IncomingTransfer.status != TransferStatus.COMPLETED) + & (IncomingTransfer.status != TransferStatus.CANCELLED) + ).all() if len(existing_transfer) != 0: log.info( @@ -86,7 +84,7 @@ def stage(request: UploadInitiationRequest, response: Response): for transfer in existing_transfer: # Unstage the files. try: - store = StoreMetadata.from_id(transfer.store_id) + store = session.get(StoreMetadata, transfer.store_id) store.store_manager.unstage(Path(transfer.staging_path)) except ServerError: # Store with ID does not exist (usually store_id is None as transfer never got there.) @@ -101,6 +99,7 @@ def stage(request: UploadInitiationRequest, response: Response): transfer = IncomingTransfer.new_transfer( source=request.uploader, uploader=request.uploader, + upload_name=str(request.upload_name), transfer_size=request.upload_size, transfer_checksum=request.upload_checksum, ) @@ -110,7 +109,7 @@ def stage(request: UploadInitiationRequest, response: Response): use_store: Optional[StoreMetadata] = None - for store in query(StoreMetadata, ingestable=True).all(): + for store in session.query(StoreMetadata).filter_by(ingestable=True).all(): if not store.store_manager.available: continue @@ -165,7 +164,7 @@ def stage(request: UploadInitiationRequest, response: Response): @router.post("/commit") -def commit(request: UploadCompletionRequest, response: Response): +def commit(request: UploadCompletionRequest, response: Response, session: Session = Depends(yield_session)): """ Commits a file to a store, called once it has been uploaded. @@ -178,10 +177,10 @@ def commit(request: UploadCompletionRequest, response: Response): log.debug(f"Received upload completion request: {request}") - store: StoreMetadata = StoreMetadata.from_name(request.store_name) + store: StoreMetadata = session.query(StoreMetadata).filter_by(name=request.store_name).first() # Go grab the transfer from the database. - transfer = query(IncomingTransfer, id=request.transfer_id).first() + transfer = session.get(IncomingTransfer, request.transfer_id) transfer.status = TransferStatus.STAGED transfer.transfer_manager_name = request.transfer_provider_name # DB cannot handle path objects; serialize to string. @@ -196,6 +195,7 @@ def commit(request: UploadCompletionRequest, response: Response): store.ingest_staged_file( request=request, transfer=transfer, + session=session, ) except FileNotFoundError: log.debug( diff --git a/librarian_server/api/util.py b/librarian_server/api/util.py deleted file mode 100644 index 3bf754c..0000000 --- a/librarian_server/api/util.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -Web utils for v2 of the API that uses pydantic models. -""" - -from pydantic import BaseModel - -from flask import request, jsonify, Response -from typing import Optional -from functools import wraps - -# TODO: Authentication - -def pydantic_api(recieve_model: Optional[BaseModel] = None): - """ - This decorator wraps API functions and serializes and deserializes - them based upon the expected response types. - - Crucially, if you provide a 'recieve_model' argument, a keyword - argument of 'request' is provided to the function that is the - deserialized request body. - """ - def decorator(f): - def wrapped(*args, **kwargs): - # If we have a recieve model, we need to deserialize the - # request body into it. - if recieve_model is not None: - try: - request_data = request.get_json() - except: - return jsonify({ - "error": "Invalid JSON." - }), 400 - - try: - request_model = recieve_model.model_validate_json(request_data) - except: - return jsonify({ - "error": "Invalid request body." - }), 400 - - kwargs["request"] = request_model - - # Now run the function. - try: - result = f(*args, **kwargs) - except Exception as e: - return jsonify({ - "Internal server error": str(e) - }), 500 - - # If the result is a Response, just return it. - if isinstance(result, Response): - return result - - # If the result is a tuple, assume it is (data, status). - if isinstance(result, tuple): - data, status = result - else: - data, status = result, 200 - - # If the data is a pydantic model, serialize it. - if isinstance(data, BaseModel): - data = data.model_dump_json() - else: - # Just try to jsonify our stuff - data = jsonify(data) - - - # Return the data. - return data, status - return wrapped - return decorator \ No newline at end of file diff --git a/librarian_server/database.py b/librarian_server/database.py index c0548fd..6042004 100644 --- a/librarian_server/database.py +++ b/librarian_server/database.py @@ -22,25 +22,26 @@ log.info("Creating database session.") SessionMaker = sessionmaker(bind=engine, autocommit=False, autoflush=False) -session = SessionMaker() -Base = declarative_base() +def yield_session() -> SessionMaker: + """ + Yields a new databse session. + """ + + session = SessionMaker() + try: + yield session + finally: + session.close() -def query(model: Base, **kwargs): + +def get_session() -> SessionMaker: """ - Query the database for a model. - - Parameters - ---------- - model : Base - The model to query. - kwargs - The query parameters. - - Returns - ------- - Base - The query result. + Returns a new database session. Unlike yield_session, it is + your responsibility to close the session. """ - return session.query(model).filter_by(**kwargs) + return SessionMaker() + +Base = declarative_base() + diff --git a/librarian_server/orm/file.py b/librarian_server/orm/file.py index e84390b..610f9f4 100644 --- a/librarian_server/orm/file.py +++ b/librarian_server/orm/file.py @@ -64,7 +64,11 @@ def file_exists(self, filename: Path) -> bool: True if it exists already. """ - existing_file = db.query(File, name=str(filename)).first() + session = db.get_session() + + existing_file = session.get(File, str(filename)) + + session.close() return existing_file is not None diff --git a/librarian_server/orm/storemetadata.py b/librarian_server/orm/storemetadata.py index 8e6e5d3..424ecd9 100644 --- a/librarian_server/orm/storemetadata.py +++ b/librarian_server/orm/storemetadata.py @@ -18,7 +18,7 @@ from pathlib import Path from typing import Optional -from sqlalchemy.orm import reconstructor +from sqlalchemy.orm import reconstructor, Session from sqlalchemy.exc import SQLAlchemyError from .file import File @@ -96,6 +96,7 @@ def ingest_staged_file( self, request: UploadCompletionRequest, transfer: IncomingTransfer, + session: "Session", ) -> Instance: """ Ingests a file into the store. Creates a new File and associated file Instance. @@ -133,7 +134,7 @@ def ingest_staged_file( info = self.store_manager.path_info(staged_path) except FileNotFoundError: transfer.status = TransferStatus.FAILED - db.session.commit() + session.commit() raise FileNotFoundError( f"File {staged_path} not found in staging area. " @@ -148,7 +149,7 @@ def ingest_staged_file( self.store_manager.unstage(staged_path) transfer.status = TransferStatus.FAILED - db.session.commit() + session.commit() raise ValueError( f"File {staged_path} does not match expected size/checksum; " @@ -165,7 +166,7 @@ def ingest_staged_file( self.store_manager.unstage(staged_path) transfer.status = TransferStatus.FAILED - db.session.commit() + session.commit() raise FileExistsError(f"File {store_path} already exists on store.") @@ -191,13 +192,13 @@ def ingest_staged_file( deletion_policy=deletion_policy, ) - db.session.add(file) - db.session.add(instance) + session.add(file) + session.add(instance) # Commit our change to the transfer, file, and instance simultaneously. try: - db.session.commit() + session.commit() # We're good to go and move the file to where it needs to be. self.store_manager.commit( @@ -208,11 +209,11 @@ def ingest_staged_file( # Need to rollback everything. The upload failed... self.store_manager.unstage(request.staging_name) - db.session.rollback() + session.rollback() try: transfer.status = TransferStatus.FAILED - db.session.commit() + session.commit() except SQLAlchemyError as e: # We can't even set the transfer status... We are in big trouble! raise ServerError( @@ -221,28 +222,6 @@ def ingest_staged_file( return instance - @classmethod - def from_name(cls, name) -> "StoreMetadata": - stores = db.query(cls, name=name).all() - - if len(stores) == 0: - raise ServerError(f"Store {name} does not exist") - elif len(stores) > 1: - raise ServerError(f"Multiple stores with name {name} exist") - - return stores[0] - - @classmethod - def from_id(cls, id) -> "StoreMetadata": - stores = db.query(cls, id=id).all() - - if len(stores) == 0: - raise ServerError(f"Store with ID {id} does not exist") - elif len(stores) > 1: - raise ServerError(f"Multiple stores with ID {id} exist") - - return stores[0] - def __repr__(self) -> str: return ( f" "IncomingTransfer": """ Create a new transfer! @@ -94,6 +104,7 @@ def new_transfer( return IncomingTransfer( status=TransferStatus.INITIATED, uploader=uploader, + upload_name=upload_name, source=source, transfer_size=transfer_size, transfer_checksum=transfer_checksum, @@ -161,14 +172,14 @@ def new_transfer( ) - def fail_transfer(self): + def fail_transfer(self, session: "Session"): """ Fail the transfer and commit to the database. """ self.status = TransferStatus.FAILED self.end_time = datetime.datetime.utcnow() - db.session.commit() + session.commit() if self.remote_transfer_id is None: # No remote transfer ID, so we can't do anything. @@ -177,7 +188,7 @@ def fail_transfer(self): # Now here's the interesting part - we need to communicate to the # remote librarian that the transfer failed! - librarian: Librarian = db.query(Librarian, name=self.destination).first() + librarian: Librarian = session.query(Librarian).filter_by(name=self.destination).first() if not librarian: # Librarian doesn't exist. We can't do anything. @@ -273,7 +284,7 @@ def new_transfer( start_time=datetime.datetime.utcnow(), ) - def fail_transfer(self): + def fail_transfer(self, session: "Session"): """ Fail the transfer and commit to the database. """ @@ -281,6 +292,6 @@ def fail_transfer(self): self.status = TransferStatus.FAILED self.end_time = datetime.datetime.utcnow() - db.session.commit() + session.commit() return \ No newline at end of file diff --git a/librarian_server_scripts/librarian_server_start.py b/librarian_server_scripts/librarian_server_start.py index ae2bf50..0a1230c 100755 --- a/librarian_server_scripts/librarian_server_start.py +++ b/librarian_server_scripts/librarian_server_start.py @@ -9,7 +9,7 @@ from librarian_server.settings import server_settings -from librarian_server.database import session, engine +from librarian_server.database import get_session from librarian_server.orm import StoreMetadata from librarian_server.logger import log @@ -22,11 +22,15 @@ # Do this in if __name__ == "__main__" so we can spawn threads on MacOS... + def main(): log.info("Librarian-server-start settings: " + str(server_settings)) # Perform pre-startup tasks! log.debug("Creating the database.") - return_value = subprocess.call(f"cd {server_settings.alembic_config_path}; {server_settings.alembic_path} upgrade head", shell=True) + return_value = subprocess.call( + f"cd {server_settings.alembic_config_path}; {server_settings.alembic_path} upgrade head", + shell=True, + ) if return_value != 0: log.debug("Error creating or updating the database. Exiting.") exit(0) @@ -37,29 +41,36 @@ def main(): stores_added = 0 - for store_config in server_settings.add_stores: - if session.query(StoreMetadata).filter(StoreMetadata.name == store_config.store_name).first(): - log.debug(f"Store {store_config.store_name} already exists in database.") - continue + with get_session() as session: + for store_config in server_settings.add_stores: + if ( + session.query(StoreMetadata) + .filter(StoreMetadata.name == store_config.store_name) + .first() + ): + log.debug( + f"Store {store_config.store_name} already exists in database." + ) + continue - log.debug(f"Adding store {store_config.store_name} to database.") + log.debug(f"Adding store {store_config.store_name} to database.") - store = StoreMetadata( - name=store_config.store_name, - store_type=store_config.store_type, - ingestable=store_config.ingestable, - store_data={**store_config.store_data, "name": store_config.store_name}, - transfer_manager_data=store_config.transfer_manager_data, - ) + store = StoreMetadata( + name=store_config.store_name, + store_type=store_config.store_type, + ingestable=store_config.ingestable, + store_data={**store_config.store_data, "name": store_config.store_name}, + transfer_manager_data=store_config.transfer_manager_data, + ) - session.add(store) + session.add(store) - stores_added += 1 + stores_added += 1 - log.debug(f"Added {stores_added} store to the database. Committing.") + log.debug(f"Added {stores_added} store to the database. Committing.") - if stores_added > 0: - session.commit() + if stores_added > 0: + session.commit() # Now we can start the background process thread. log.info("Starting background process.") @@ -84,4 +95,4 @@ def main(): log.info("Waiting for background process to finish.") background_process.terminate() - log.info("Background process finished.") \ No newline at end of file + log.info("Background process finished.") diff --git a/tests/background_unit_test/test_check_integrity.py b/tests/background_unit_test/test_check_integrity.py index d50d934..f248693 100644 --- a/tests/background_unit_test/test_check_integrity.py +++ b/tests/background_unit_test/test_check_integrity.py @@ -11,12 +11,15 @@ def test_check_integrity(test_client, test_server_with_valid_file, test_orm): from librarian_background.check_integrity import CheckIntegrity # Get a store to check - _, session, _ = test_server_with_valid_file - store = session.query(test_orm.StoreMetadata).first() + _, get_session, _ = test_server_with_valid_file + + with get_session() as session: + store = session.query(test_orm.StoreMetadata).first().name integrity_task = CheckIntegrity( - name="Integrity check", store_name=store.name, age_in_days=1 + name="Integrity check", store_name=store, age_in_days=1 ) + assert integrity_task() @@ -28,11 +31,13 @@ def test_check_integrity_failure(test_client, test_server_with_invalid_file, tes from librarian_background.check_integrity import CheckIntegrity # Get a store to check - _, session, _ = test_server_with_invalid_file - store = session.query(test_orm.StoreMetadata).first() + _, get_session, _ = test_server_with_invalid_file + + with get_session() as session: + store = session.query(test_orm.StoreMetadata).first().name integrity_task = CheckIntegrity( - name="Integrity check", store_name=store.name, age_in_days=1 + name="Integrity check", store_name=store, age_in_days=1 ) assert integrity_task() == False @@ -62,10 +67,12 @@ def test_check_integrity_missing_store( from librarian_background.check_integrity import CheckIntegrity # Get a store to check - _, session, _ = test_server_with_missing_file - store = session.query(test_orm.StoreMetadata).first() + _, get_session, _ = test_server_with_missing_file + + with get_session() as session: + store = session.query(test_orm.StoreMetadata).first().name integrity_task = CheckIntegrity( - name="Integrity check", store_name=store.name, age_in_days=1 + name="Integrity check", store_name=store, age_in_days=1 ) assert integrity_task() == False diff --git a/tests/background_unit_test/test_create_clone.py b/tests/background_unit_test/test_create_clone.py index ff68087..a978878 100644 --- a/tests/background_unit_test/test_create_clone.py +++ b/tests/background_unit_test/test_create_clone.py @@ -13,16 +13,18 @@ def test_create_local_clone_with_valid( from librarian_background.create_clone import CreateLocalClone # Get a store to check - _, session, _ = test_server_with_valid_file - stores = session.query(test_orm.StoreMetadata).all() + _, get_session, _ = test_server_with_valid_file - from_store = [store for store in stores if store.ingestable][0] - to_store = [store for store in stores if not store.ingestable][0] + with get_session() as session: + stores = session.query(test_orm.StoreMetadata).all() + + from_store = [store.name for store in stores if store.ingestable][0] + to_store = [store.name for store in stores if not store.ingestable][0] clone_task = CreateLocalClone( name="Local clone", - clone_from=from_store.name, - clone_to=to_store.name, + clone_from=from_store, + clone_to=to_store, age_in_days=1, ) @@ -39,16 +41,18 @@ def test_create_local_clone_with_invalid( from librarian_background.create_clone import CreateLocalClone # Get a store to check - _, session, _ = test_server_with_invalid_file - stores = session.query(test_orm.StoreMetadata).all() + _, get_session, _ = test_server_with_invalid_file + + with get_session() as session: + stores = session.query(test_orm.StoreMetadata).all() - from_store = [store for store in stores if store.ingestable][0] - to_store = [store for store in stores if not store.ingestable][0] + from_store = [store.name for store in stores if store.ingestable][0] + to_store = [store.name for store in stores if not store.ingestable][0] clone_task = CreateLocalClone( name="Local clone", - clone_from=from_store.name, - clone_to=to_store.name, + clone_from=from_store, + clone_to=to_store, age_in_days=1, ) @@ -65,16 +69,18 @@ def test_create_local_clone_with_missing( from librarian_background.create_clone import CreateLocalClone # Get a store to check - _, session, _ = test_server_with_missing_file - stores = session.query(test_orm.StoreMetadata).all() + _, get_session, _ = test_server_with_missing_file + + with get_session() as session: + stores = session.query(test_orm.StoreMetadata).all() - from_store = [store for store in stores if store.ingestable][0] - to_store = [store for store in stores if not store.ingestable][0] + from_store = [store.name for store in stores if store.ingestable][0] + to_store = [store.name for store in stores if not store.ingestable][0] clone_task = CreateLocalClone( name="Local clone", - clone_from=from_store.name, - clone_to=to_store.name, + clone_from=from_store, + clone_to=to_store, age_in_days=1, ) diff --git a/tests/background_unit_test/test_recieve_clone.py b/tests/background_unit_test/test_recieve_clone.py index 6acef9a..999345c 100644 --- a/tests/background_unit_test/test_recieve_clone.py +++ b/tests/background_unit_test/test_recieve_clone.py @@ -26,7 +26,10 @@ def test_recieve_clone_with_valid(test_client, test_server, test_orm, garbage_fi from librarian_background.recieve_clone import RecieveClone # Get a store to use - _, session, _ = test_server + _, get_session, _ = test_server + + session = get_session() + store = session.query(test_orm.StoreMetadata).filter_by(ingestable=True).first() # Create the fake incoming transfer @@ -38,6 +41,7 @@ def test_recieve_clone_with_valid(test_client, test_server, test_orm, garbage_fi incoming_transfer = test_orm.IncomingTransfer.new_transfer( uploader="test_fake_librarian", source="test_user", + upload_name=garbage_file.name, transfer_size=info.size, transfer_checksum=info.md5, ) @@ -51,6 +55,10 @@ def test_recieve_clone_with_valid(test_client, test_server, test_orm, garbage_fi session.add(incoming_transfer) session.commit() + incoming_transfer_id = incoming_transfer.id + + session.close() + clone_task = RecieveClone( name="Recieve clone", ) @@ -60,11 +68,9 @@ def test_recieve_clone_with_valid(test_client, test_server, test_orm, garbage_fi # Now check in the DB to see if we marked the status as correct and moved the file to # the right place. - incoming_transfer = ( - session.query(test_orm.IncomingTransfer) - .filter_by(id=incoming_transfer.id) - .first() - ) + session = get_session() + + incoming_transfer = session.get(test_orm.IncomingTransfer, incoming_transfer_id) assert incoming_transfer.status == test_orm.TransferStatus.COMPLETED @@ -75,3 +81,5 @@ def test_recieve_clone_with_valid(test_client, test_server, test_orm, garbage_fi # Check the file is in the store. assert store.store_manager.path_info(Path(incoming_transfer.store_path)).md5 == info.md5 + + session.close() diff --git a/tests/conftest.py b/tests/conftest.py index 3bb8311..9d6289c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -79,28 +79,29 @@ def test_server(tmp_path_factory): import librarian_server.database app = librarian_server.main() - session = librarian_server.database.session + get_session = librarian_server.database.get_session # Need to add our stores... from librarian_server.orm import StoreMetadata from librarian_server.settings import StoreSettings - for store_config in json.loads(setup.ADD_STORES): - store_config = StoreSettings(**store_config) + with get_session() as session: + for store_config in json.loads(setup.ADD_STORES): + store_config = StoreSettings(**store_config) - store = StoreMetadata( - name=store_config.store_name, - store_type=store_config.store_type, - ingestable=store_config.ingestable, - store_data={**store_config.store_data, "name": store_config.store_name}, - transfer_manager_data=store_config.transfer_manager_data, - ) + store = StoreMetadata( + name=store_config.store_name, + store_type=store_config.store_type, + ingestable=store_config.ingestable, + store_data={**store_config.store_data, "name": store_config.store_name}, + transfer_manager_data=store_config.transfer_manager_data, + ) - session.add(store) + session.add(store) - session.commit() + session.commit() - yield app, session, setup + yield app, get_session, setup for env_var in list(env_vars.keys()): if env_vars[env_var] is None: @@ -108,8 +109,6 @@ def test_server(tmp_path_factory): else: os.environ[env_var] = env_vars[env_var] - session.close() - @pytest.fixture(scope="package") def test_client(test_server): @@ -154,7 +153,9 @@ def test_server_with_valid_file(test_server, test_orm): Test server with a valid file and instance in the store. """ - store = test_server[1].query(test_orm.StoreMetadata).first() + session = test_server[1]() + + store = session.query(test_orm.StoreMetadata).first() data = random.randbytes(1024) @@ -179,18 +180,27 @@ def test_server_with_valid_file(test_server, test_orm): deletion_policy="ALLOWED", ) - test_server[1].add_all([file, instance]) + session.add_all([file, instance]) + session.commit() + + instance_id = instance.id - test_server[1].commit() + session.close() yield test_server # Now delete those items from the database. - test_server[1].delete(instance) - test_server[1].delete(file) + session = test_server[1]() + + instance = session.get(test_orm.Instance, instance_id) + file = session.get(test_orm.File, "example_file.txt") - test_server[1].commit() + session.delete(instance) + session.delete(file) + + session.commit() + session.close() path.unlink() @@ -201,7 +211,9 @@ def test_server_with_invalid_file(test_server, test_orm): Test server with a invalid file and instance in the store. """ - store = test_server[1].query(test_orm.StoreMetadata).first() + session = test_server[1]() + + store = session.query(test_orm.StoreMetadata).first() data = random.randbytes(1024) @@ -226,18 +238,28 @@ def test_server_with_invalid_file(test_server, test_orm): deletion_policy="ALLOWED", ) - test_server[1].add_all([file, instance]) + session.add_all([file, instance]) - test_server[1].commit() + session.commit() + + instance_id = instance.id + + session.close() yield test_server # Now delete those items from the database. - test_server[1].delete(instance) - test_server[1].delete(file) + session = test_server[1]() + + instance = session.get(test_orm.Instance, instance_id) + file = session.get(test_orm.File, "example_file.txt") - test_server[1].commit() + session.delete(instance) + session.delete(file) + + session.commit() + session.close() path.unlink() @@ -248,7 +270,9 @@ def test_server_with_missing_file(test_server, test_orm): Test server with a missing file and instance in the store. """ - store = test_server[1].query(test_orm.StoreMetadata).first() + session = test_server[1]() + + store = session.query(test_orm.StoreMetadata).first() data = random.randbytes(1024) @@ -274,15 +298,25 @@ def test_server_with_missing_file(test_server, test_orm): deletion_policy="ALLOWED", ) - test_server[1].add_all([file, instance]) + session.add_all([file, instance]) + + session.commit() + + instance_id = instance.id - test_server[1].commit() + session.close() yield test_server # Now delete those items from the database. - test_server[1].delete(instance) - test_server[1].delete(file) + session = test_server[1]() + + instance = session.get(test_orm.Instance, instance_id) + file = session.get(test_orm.File, "example_file.txt") - test_server[1].commit() + session.delete(instance) + session.delete(file) + + session.commit() + session.close() diff --git a/tests/server_unit_test/test_clone.py b/tests/server_unit_test/test_clone.py index d4d22c2..db60884 100644 --- a/tests/server_unit_test/test_clone.py +++ b/tests/server_unit_test/test_clone.py @@ -85,15 +85,16 @@ def test_valid_stage_and_fail(test_client, test_server, test_orm): # Check we got this thing in the database. - _, session, _ = test_server - - assert ( - session.query(test_orm.IncomingTransfer) - .filter_by(id=decoded_response.destination_transfer_id) - .first() - .status - == test_orm.TransferStatus.INITIATED - ) + _, get_session, _ = test_server + + with get_session() as session: + assert ( + session.query(test_orm.IncomingTransfer) + .filter_by(id=decoded_response.destination_transfer_id) + .first() + .status + == test_orm.TransferStatus.INITIATED + ) # Now see what happens if we try to clone again. @@ -119,13 +120,14 @@ def test_valid_stage_and_fail(test_client, test_server, test_orm): decoded_response = CloneFailResponse.model_validate_json(response.content) - assert ( - session.query(test_orm.IncomingTransfer) - .filter_by(id=decoded_response.destination_transfer_id) - .first() - .status - == test_orm.TransferStatus.FAILED - ) + with get_session() as session: + assert ( + session.query(test_orm.IncomingTransfer) + .filter_by(id=decoded_response.destination_transfer_id) + .first() + .status + == test_orm.TransferStatus.FAILED + ) def test_try_to_fail_non_existent_transfer(test_client, test_server, test_orm): @@ -217,15 +219,16 @@ def test_ongoing_transfer( # Check it's in the database with correct status - _, session, _ = test_server + _, get_session, _ = test_server - assert ( - session.query(test_orm.IncomingTransfer) - .filter_by(id=decoded_response.destination_transfer_id) - .first() - .status - == test_orm.TransferStatus.ONGOING - ) + with get_session() as session: + assert ( + session.query(test_orm.IncomingTransfer) + .filter_by(id=decoded_response.destination_transfer_id) + .first() + .status + == test_orm.TransferStatus.ONGOING + ) # If we try to upload again with the same source and destination, it should fail. @@ -264,45 +267,49 @@ def test_incoming_transfer_endpoints( that is having stuff sent to it, not the client that is sending) """ - _, session, _ = test_server + _, get_session, _ = test_server # First we need to create fake files and instances. - file = test_orm.File.new_file( - filename=garbage_filename, - size=100, - checksum="abcd", - uploader="test", - source="test", - ) + with get_session() as session: + file = test_orm.File.new_file( + filename=garbage_filename, + size=100, + checksum="abcd", + uploader="test", + source="test", + ) - store = session.query(test_orm.StoreMetadata).first() + store = session.query(test_orm.StoreMetadata).first() - instance = test_orm.Instance.new_instance( - path=garbage_file, - file=file, - store=store, - deletion_policy="DISALLOWED", - ) + instance = test_orm.Instance.new_instance( + path=garbage_file, + file=file, + store=store, + deletion_policy="DISALLOWED", + ) - # Add first to get IDs - session.add_all([file, instance]) - session.commit() + # Add first to get IDs + session.add_all([file, instance]) + session.commit() - transfer = test_orm.OutgoingTransfer.new_transfer( - destination="test2", instance=instance, file=file - ) + transfer = test_orm.OutgoingTransfer.new_transfer( + destination="test2", instance=instance, file=file + ) + + session.add(transfer) + session.commit() - session.add(transfer) - session.commit() + transfer_id = transfer.id + instance_id = instance.id # We will first test the failure case where we have not set the transfer to be ongoing # Now call the endpoint request = CloneCompleteRequest( - source_transfer_id=transfer.id, - destination_transfer_id=transfer.id, + source_transfer_id=transfer_id, + destination_transfer_id=transfer_id, ) response = test_client.post( @@ -314,8 +321,11 @@ def test_incoming_transfer_endpoints( decoded_response = CloneFailedResponse.model_validate_json(response.content) # Now try again but set the transfer to be ongoing - transfer.status = test_orm.TransferStatus.ONGOING - session.commit() + with get_session() as session: + transfer = session.get(test_orm.OutgoingTransfer, transfer_id) + + transfer.status = test_orm.TransferStatus.ONGOING + session.commit() response = test_client.post( "/api/v2/clone/complete", content=request.model_dump_json() @@ -327,12 +337,19 @@ def test_incoming_transfer_endpoints( # Check it's in the database with correct status - assert transfer.status == test_orm.TransferStatus.COMPLETED # Clean up that garbage - session.delete(instance) - session.delete(file) - session.commit() + with get_session() as session: + transfer = session.get(test_orm.OutgoingTransfer, transfer_id) + + assert transfer.status == test_orm.TransferStatus.COMPLETED + + instance = session.get(test_orm.Instance, instance_id) + file = session.get(test_orm.File, str(garbage_filename)) + + session.delete(instance) + session.delete(file) + session.commit() def test_complete_no_transfer(test_client, test_server, test_orm): @@ -360,23 +377,27 @@ def test_set_ongoing_with_different_status(test_client, test_server, test_orm): completed (or has some other status). """ - _, session, _ = test_server + _, get_session, _ = test_server - transfer = test_orm.IncomingTransfer.new_transfer( - uploader="test", - source="test", - transfer_size=100, - transfer_checksum="", - ) + with get_session() as session: + transfer = test_orm.IncomingTransfer.new_transfer( + uploader="test", + source="test", + upload_name="test", + transfer_size=100, + transfer_checksum="", + ) - transfer.status = test_orm.TransferStatus.COMPLETED + transfer.status = test_orm.TransferStatus.COMPLETED - session.add(transfer) - session.commit() + session.add(transfer) + session.commit() + + transfer_id = transfer.id request = CloneOngoingRequest( - source_transfer_id=transfer.id, - destination_transfer_id=transfer.id, + source_transfer_id=transfer_id, + destination_transfer_id=transfer_id, ) response = test_client.post( @@ -401,9 +422,11 @@ def test_clone_file_exists(test_client, test_server, test_orm, garbage_filename) source="test", ) - _, session, _ = test_server - session.add(file) - session.commit() + _, get_session, _ = test_server + + with get_session() as session: + session.add(file) + session.commit() request = CloneInitiationRequest( destination_location=garbage_filename, @@ -425,5 +448,8 @@ def test_clone_file_exists(test_client, test_server, test_orm, garbage_filename) decoded_response = CloneFailedResponse.model_validate_json(response.content) # Clean up that garbage - session.delete(file) - session.commit() + with get_session() as session: + file = session.get(test_orm.File, str(garbage_filename)) + + session.delete(file) + session.commit() diff --git a/tests/server_unit_test/test_upload.py b/tests/server_unit_test/test_upload.py index 4b1a7bf..7e24c52 100644 --- a/tests/server_unit_test/test_upload.py +++ b/tests/server_unit_test/test_upload.py @@ -45,7 +45,7 @@ def test_negative_upload_size(test_client: TestClient): def test_extreme_upload_size( - test_client: TestClient, test_server: tuple[FastAPI, Session, Server], test_orm: Any + test_client: TestClient, test_server: tuple[FastAPI, callable, Server], test_orm: Any ): """ Tests that an upload size that is too large results in an error. @@ -69,16 +69,17 @@ def test_extreme_upload_size( assert response.status_code == 413 # Check we put the stuff in the database! - _, session, _ = test_server + _, get_session, _ = test_server - assert ( - session.query(test_orm.IncomingTransfer).first().status - == test_orm.TransferStatus.FAILED - ) + with get_session() as session: + assert ( + session.query(test_orm.IncomingTransfer).first().status + == test_orm.TransferStatus.FAILED + ) def test_valid_stage( - test_client: TestClient, test_server: tuple[FastAPI, Session, Server], test_orm: Any + test_client: TestClient, test_server: tuple[FastAPI, callable, Server], test_orm: Any ): """ Tests that a valid stage works. @@ -104,15 +105,16 @@ def test_valid_stage( # Check we got this thing in the database. - _, session, _ = test_server + _, get_session, _ = test_server - assert ( - session.query(test_orm.IncomingTransfer) - .filter_by(id=decoded_response.transfer_id) - .first() - .status - == test_orm.TransferStatus.INITIATED - ) + with get_session() as session: + assert ( + session.query(test_orm.IncomingTransfer) + .filter_by(id=decoded_response.transfer_id) + .first() + .status + == test_orm.TransferStatus.INITIATED + ) # Now we can check what happens when we try to upload the same file. response = test_client.post( @@ -160,7 +162,7 @@ def helper_generate_transfer( def test_full_upload( test_client: TestClient, - test_server: tuple[FastAPI, Session, Server], + test_server: tuple[FastAPI, callable, Server], test_orm: Any, garbage_file: Path, garbage_filename: Path, @@ -197,24 +199,26 @@ def test_full_upload( assert response.status_code == 200 # Check we got this thing in the database. - _, session, _ = test_server - incoming_transfer = ( - session.query(test_orm.IncomingTransfer) - .filter_by(id=stage_response.transfer_id) - .first() - ) + _, get_session, _ = test_server - assert incoming_transfer.status == test_orm.TransferStatus.COMPLETED + with get_session() as session: + incoming_transfer = ( + session.query(test_orm.IncomingTransfer) + .filter_by(id=stage_response.transfer_id) + .first() + ) - # Find the file in the store. - instance = ( - session.query(test_orm.Instance) - .filter_by(file_name=str(garbage_filename)) - .first() - ) + assert incoming_transfer.status == test_orm.TransferStatus.COMPLETED - # Check the file is where it should be. - assert Path(instance.path).exists() + # Find the file in the store. + instance = ( + session.query(test_orm.Instance) + .filter_by(file_name=str(garbage_filename)) + .first() + ) + + # Check the file is where it should be. + assert Path(instance.path).exists() # Now highjack this test to see what happens if we try to upload again! @@ -273,14 +277,16 @@ def test_commit_no_file_uploaded( assert response.status_code == 404 # Check we got this thing in the database. - _, session, _ = test_server - incoming_transfer = ( - session.query(test_orm.IncomingTransfer) - .filter_by(id=stage_response.transfer_id) - .first() - ) + _, get_session, _ = test_server + + with get_session() as session: + incoming_transfer = ( + session.query(test_orm.IncomingTransfer) + .filter_by(id=stage_response.transfer_id) + .first() + ) - assert incoming_transfer.status == test_orm.TransferStatus.FAILED + assert incoming_transfer.status == test_orm.TransferStatus.FAILED def test_commit_wrong_file_uploaded( @@ -322,14 +328,16 @@ def test_commit_wrong_file_uploaded( assert response.status_code == 406 # Check we got this thing in the database. - _, session, _ = test_server - incoming_transfer = ( - session.query(test_orm.IncomingTransfer) - .filter_by(id=stage_response.transfer_id) - .first() - ) + _, get_session, _ = test_server + + with get_session() as session: + incoming_transfer = ( + session.query(test_orm.IncomingTransfer) + .filter_by(id=stage_response.transfer_id) + .first() + ) - assert incoming_transfer.status == test_orm.TransferStatus.FAILED + assert incoming_transfer.status == test_orm.TransferStatus.FAILED # Check we deleted the file assert not Path(stage_response.staging_location).exists() @@ -346,21 +354,22 @@ def test_commit_file_exists( test_client, test_server, test_orm, garbage_file, garbage_filename ) - _, session, _ = test_server + _, get_session, _ = test_server - store_metadata = ( - session.query(test_orm.StoreMetadata) - .filter_by(name=stage_response.store_name) - .first() - ) + with get_session() as session: + store_metadata = ( + session.query(test_orm.StoreMetadata) + .filter_by(name=stage_response.store_name) + .first() + ) - # Copy the file to the store area manually. - shutil.copy2( - garbage_file, - store_metadata.store_manager._resolved_path_store( - stage_response.destination_location - ), - ) + # Copy the file to the store area manually. + shutil.copy2( + garbage_file, + store_metadata.store_manager._resolved_path_store( + stage_response.destination_location + ), + ) # Now we can actually test the commit endpoint. @@ -387,13 +396,14 @@ def test_commit_file_exists( # Check we got this thing in the database. - incoming_transfer = ( - session.query(test_orm.IncomingTransfer) - .filter_by(id=stage_response.transfer_id) - .first() - ) + with get_session() as session: + incoming_transfer = ( + session.query(test_orm.IncomingTransfer) + .filter_by(id=stage_response.transfer_id) + .first() + ) - assert incoming_transfer.status == test_orm.TransferStatus.FAILED + assert incoming_transfer.status == test_orm.TransferStatus.FAILED assert not Path(stage_response.staging_location).exists() @@ -456,12 +466,13 @@ def test_directory_upload(test_client, test_server, test_orm, tmp_path): # Check we got this thing in the database. - _, session, _ = test_server + _, get_session, _ = test_server - incoming_transfer = ( - session.query(test_orm.IncomingTransfer) - .filter_by(id=decoded_response.transfer_id) - .first() - ) + with get_session() as session: + incoming_transfer = ( + session.query(test_orm.IncomingTransfer) + .filter_by(id=decoded_response.transfer_id) + .first() + ) - assert incoming_transfer.status == test_orm.TransferStatus.COMPLETED + assert incoming_transfer.status == test_orm.TransferStatus.COMPLETED