diff --git a/nos/server/http/_service.py b/nos/server/http/_service.py index a13df828..32394cbc 100644 --- a/nos/server/http/_service.py +++ b/nos/server/http/_service.py @@ -2,6 +2,7 @@ import dataclasses import os import time +import uuid from dataclasses import field from functools import lru_cache from pathlib import Path @@ -14,10 +15,11 @@ from PIL import Image from pydantic import ConfigDict from pydantic.dataclasses import dataclass +from tqdm import tqdm from nos.client import Client from nos.common.tasks import TaskType -from nos.constants import DEFAULT_GRPC_ADDRESS +from nos.constants import DEFAULT_GRPC_ADDRESS, NOS_TMP_DIR from nos.logging import logger from nos.protoc import import_module from nos.version import __version__ @@ -136,6 +138,7 @@ def app_factory(version: str = HTTP_API_VERSION, address: str = DEFAULT_GRPC_ADD svc = InferenceService(address=address) logger.info(f"app_factory [env={env}]: Adding CORS middleware ...") + app = FastAPI( title="NOS REST API", description=f"NOS REST API (version={__version__}, api_version={version})", @@ -156,6 +159,9 @@ def app_factory(version: str = HTTP_API_VERSION, address: str = DEFAULT_GRPC_ADD app.middleware("http")(default_exception_middleware) app.add_exception_handler(Exception, default_exception_handler) + NOS_TMP_FILES_DIR = Path(NOS_TMP_DIR) / "uploaded_files" + NOS_TMP_FILES_DIR.mkdir(parents=True, exist_ok=True) + def get_client() -> Client: """Get the inference client.""" return svc.client @@ -221,6 +227,33 @@ def model_info( except KeyError: raise HTTPException(status_code=400, detail=f"Invalid model {model}") + # TODO (delete file after processing) + @app.post(f"/{version}/file/upload", status_code=201) + def upload_file(file: UploadFile = File(...), client: Client = Depends(get_client)) -> JSONResponse: + try: + uid = uuid.uuid4() + basename = f"{uid}-{Path(file.filename).name}" + path = NOS_TMP_FILES_DIR / basename + logger.debug(f"Uploading file: [local={file.filename}, path={path}]") + file.file.seek(0) + with path.open("wb") as f: + for chunk in tqdm( + iter(lambda: file.file.read(1024), b""), + desc="Uploading file", + unit="KB", + unit_scale=True, + unit_divisor=1024, + ): + f.write(chunk) + logger.info(f"Successfully uploaded file [path={path}]") + except Exception as exc: + logger.error(f"""Failed to upload file [file={file.filename}, exc={exc}]""") + raise HTTPException(status_code=500, detail="Failed to upload file.") + return { + "file_id": str(uid), + "filename": basename, + } + @app.post(f"/{version}/chat/completions", status_code=status.HTTP_201_CREATED) def chat( request: ChatCompletionsRequest, diff --git a/nos/version.py b/nos/version.py index 6a9beea8..3d26edf7 100644 --- a/nos/version.py +++ b/nos/version.py @@ -1 +1 @@ -__version__ = "0.4.0" +__version__ = "0.4.1" diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 7036cf30..9d2c7005 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -15,6 +15,7 @@ pydantic>=2.5.0 python-multipart pyyaml rich>=12.5.1 +setuptools>=70.0.0 sentry-sdk[loguru] tqdm typer>=0.7.0