Skip to content

Commit

Permalink
Support file-uploads in the HTTP service
Browse files Browse the repository at this point in the history
  • Loading branch information
spillai committed Jun 1, 2024
1 parent c15a185 commit 2cb13d9
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
35 changes: 34 additions & 1 deletion nos/server/http/_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dataclasses
import os
import time
import uuid
from dataclasses import field
from functools import lru_cache
from pathlib import Path
Expand All @@ -14,10 +15,11 @@
from PIL import Image
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass
from tqdm import tqdm

from nos.client import Client
from nos.common.tasks import TaskType
from nos.constants import DEFAULT_GRPC_ADDRESS
from nos.constants import DEFAULT_GRPC_ADDRESS, NOS_TMP_DIR
from nos.logging import logger
from nos.protoc import import_module
from nos.version import __version__
Expand Down Expand Up @@ -136,6 +138,7 @@ def app_factory(version: str = HTTP_API_VERSION, address: str = DEFAULT_GRPC_ADD

svc = InferenceService(address=address)
logger.info(f"app_factory [env={env}]: Adding CORS middleware ...")

app = FastAPI(
title="NOS REST API",
description=f"NOS REST API (version={__version__}, api_version={version})",
Expand All @@ -156,6 +159,9 @@ def app_factory(version: str = HTTP_API_VERSION, address: str = DEFAULT_GRPC_ADD
app.middleware("http")(default_exception_middleware)
app.add_exception_handler(Exception, default_exception_handler)

NOS_TMP_FILES_DIR = Path(NOS_TMP_DIR) / "uploaded_files"
NOS_TMP_FILES_DIR.mkdir(parents=True, exist_ok=True)

def get_client() -> Client:
"""Get the inference client."""
return svc.client
Expand Down Expand Up @@ -221,6 +227,33 @@ def model_info(
except KeyError:
raise HTTPException(status_code=400, detail=f"Invalid model {model}")

# TODO (delete file after processing)
@app.post(f"/{version}/file/upload", status_code=201)
def upload_file(file: UploadFile = File(...), client: Client = Depends(get_client)) -> JSONResponse:
try:
uid = uuid.uuid4()
basename = f"{uid}-{Path(file.filename).name}"
path = NOS_TMP_FILES_DIR / basename
logger.debug(f"Uploading file: [local={file.filename}, path={path}]")
file.file.seek(0)
with path.open("wb") as f:
for chunk in tqdm(
iter(lambda: file.file.read(1024), b""),
desc="Uploading file",
unit="KB",
unit_scale=True,
unit_divisor=1024,
):
f.write(chunk)
logger.info(f"Successfully uploaded file [path={path}]")
except Exception as exc:
logger.error(f"""Failed to upload file [file={file.filename}, exc={exc}]""")
raise HTTPException(status_code=500, detail="Failed to upload file.")
return {
"file_id": str(uid),
"filename": basename,
}

@app.post(f"/{version}/chat/completions", status_code=status.HTTP_201_CREATED)
def chat(
request: ChatCompletionsRequest,
Expand Down
2 changes: 1 addition & 1 deletion nos/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.0"
__version__ = "0.4.1"
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pydantic>=2.5.0
python-multipart
pyyaml
rich>=12.5.1
setuptools>=70.0.0
sentry-sdk[loguru]
tqdm
typer>=0.7.0
Expand Down

0 comments on commit 2cb13d9

Please sign in to comment.