Skip to content

Commit

Permalink
Add machine-to-machine API authentication example
Browse files Browse the repository at this point in the history
This PR allows developers to authenticate their requests with a server with
machine-to-machine authentication. This is useful when the server is
accessible to the public but requires a secure API-key authentication to
access the HTTP services.
  • Loading branch information
spillai committed May 3, 2024
1 parent b0de4df commit 3acbe8e
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 43 deletions.
23 changes: 12 additions & 11 deletions docker/Dockerfile.cpu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# >>>>>>>>>>>>>>>>>>>>>>>>>>>
# Auto-generated by agi-pack (version=0.2.0).
# Auto-generated by agi-pack (version=0.3.0).
FROM debian:buster-slim AS base-cpu

# Setup environment variables
Expand Down Expand Up @@ -39,15 +39,16 @@ RUN --mount=type=cache,target=/var/cache/apt \
git \
&& echo "system install complete"

# Install mambaforge, with cache mounting ${CONDA_PKGS_DIRS} for faster builds
# Install miniconda, with cache mounting ${CONDA_PKGS_DIRS} for faster builds
RUN --mount=type=cache,target=${CONDA_PKGS_DIRS} \
curl -sLo ~/mambaforge.sh "https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-$(uname)-$(uname -m).sh" \
&& chmod +x ~/mambaforge.sh \
&& ~/mambaforge.sh -b -p ${AGIPACK_PATH}/conda \
&& ${AGIPACK_PATH}/conda/bin/mamba init bash \
&& ${AGIPACK_PATH}/conda/bin/mamba config --add channels conda-forge \
&& ${AGIPACK_PATH}/conda/bin/mamba create -n ${AGIPACK_PYENV} python=${PYTHON_VERSION} -y \
&& rm ~/mambaforge.sh
curl -sLo ~/miniconda.sh "https://repo.anaconda.com/miniconda/Miniconda3-latest-$(uname)-$(uname -m).sh" \
&& chmod +x ~/miniconda.sh \
&& ~/miniconda.sh -b -p ${AGIPACK_PATH}/conda \
&& ${AGIPACK_PATH}/conda/bin/conda init bash \
&& ${AGIPACK_PATH}/conda/bin/conda config --add channels conda-forge \
&& ${AGIPACK_PATH}/conda/bin/conda create -n ${AGIPACK_PYENV} python=${PYTHON_VERSION} -y \
&& ${AGIPACK_PATH}/conda/bin/conda install mamba -y \
&& rm ~/miniconda.sh

# Upgrade pip
RUN pip install --upgrade pip
Expand Down Expand Up @@ -110,7 +111,7 @@ ENV RAY_CONDA_HOME=/opt/conda
ENV RAY_ENABLE_MAC_LARGE_OBJECT_STORE=1

# >>>>>>>>>>>>>>>>>>>>>>>>>>>
# Auto-generated by agi-pack (version=0.2.0).
# Auto-generated by agi-pack (version=0.3.0).
FROM base-cpu AS cpu

# Setup working directory
Expand All @@ -128,7 +129,7 @@ RUN --mount=type=cache,target=${CONDA_PKGS_DIRS} \
RUN echo "run commands complete"
CMD ["bash", "-c", "/app/entrypoint.sh"]
# >>>>>>>>>>>>>>>>>>>>>>>>>>>
# Auto-generated by agi-pack (version=0.2.0).
# Auto-generated by agi-pack (version=0.3.0).
FROM cpu AS test-cpu

# Install additional system packages
Expand Down
23 changes: 12 additions & 11 deletions docker/Dockerfile.gpu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# >>>>>>>>>>>>>>>>>>>>>>>>>>>
# Auto-generated by agi-pack (version=0.2.0).
# Auto-generated by agi-pack (version=0.3.0).
FROM nvidia/cuda:11.8.0-base-ubuntu22.04 AS base-gpu

# Setup environment variables
Expand Down Expand Up @@ -39,15 +39,16 @@ RUN --mount=type=cache,target=/var/cache/apt \
git \
&& echo "system install complete"

# Install mambaforge, with cache mounting ${CONDA_PKGS_DIRS} for faster builds
# Install miniconda, with cache mounting ${CONDA_PKGS_DIRS} for faster builds
RUN --mount=type=cache,target=${CONDA_PKGS_DIRS} \
curl -sLo ~/mambaforge.sh "https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-$(uname)-$(uname -m).sh" \
&& chmod +x ~/mambaforge.sh \
&& ~/mambaforge.sh -b -p ${AGIPACK_PATH}/conda \
&& ${AGIPACK_PATH}/conda/bin/mamba init bash \
&& ${AGIPACK_PATH}/conda/bin/mamba config --add channels conda-forge \
&& ${AGIPACK_PATH}/conda/bin/mamba create -n ${AGIPACK_PYENV} python=${PYTHON_VERSION} -y \
&& rm ~/mambaforge.sh
curl -sLo ~/miniconda.sh "https://repo.anaconda.com/miniconda/Miniconda3-latest-$(uname)-$(uname -m).sh" \
&& chmod +x ~/miniconda.sh \
&& ~/miniconda.sh -b -p ${AGIPACK_PATH}/conda \
&& ${AGIPACK_PATH}/conda/bin/conda init bash \
&& ${AGIPACK_PATH}/conda/bin/conda config --add channels conda-forge \
&& ${AGIPACK_PATH}/conda/bin/conda create -n ${AGIPACK_PYENV} python=${PYTHON_VERSION} -y \
&& ${AGIPACK_PATH}/conda/bin/conda install mamba -y \
&& rm ~/miniconda.sh

# Upgrade pip
RUN pip install --upgrade pip
Expand Down Expand Up @@ -116,7 +117,7 @@ ENV RAY_CONDA_HOME=/opt/conda
ENV RAY_ENABLE_MAC_LARGE_OBJECT_STORE=1

# >>>>>>>>>>>>>>>>>>>>>>>>>>>
# Auto-generated by agi-pack (version=0.2.0).
# Auto-generated by agi-pack (version=0.3.0).
FROM base-gpu AS gpu

# Setup working directory
Expand All @@ -134,7 +135,7 @@ RUN --mount=type=cache,target=${CONDA_PKGS_DIRS} \
RUN echo "run commands complete"
CMD ["bash", "-c", "/app/entrypoint.sh"]
# >>>>>>>>>>>>>>>>>>>>>>>>>>>
# Auto-generated by agi-pack (version=0.2.0).
# Auto-generated by agi-pack (version=0.3.0).
FROM gpu AS test-gpu

# Install additional system packages
Expand Down
3 changes: 2 additions & 1 deletion examples/tutorials/05-serving-with-docker/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ services:
- NOS_HOME=/app/.nos
- NOS_LOGGING_LEVEL=INFO
- NOS_GRPC_HOST=nos-grpc-server
- NOS_HTTP_ENV=dev
- NOS_HTTP_ENV=prod
volumes:
- ~/.nosd:/app/.nos
- /dev/shm:/dev/shm
Expand All @@ -22,6 +22,7 @@ services:
image: autonomi/nos:latest-gpu
environment:
- NOS_HOME=/app/.nos
- NOS_GRPC_HOST=[::]
- NOS_LOGGING_LEVEL=INFO
volumes:
- ~/.nosd:/app/.nos
Expand Down
34 changes: 34 additions & 0 deletions examples/tutorials/06-serving-with-m2m-auth/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
version: "3.8"

services:
nos-http-gateway:
image: autonomi/nos:latest-cpu
command: nos-http-server --host 0.0.0.0 --port 8000 --workers 1
environment:
- NOS_HOME=/app/.nos
- NOS_LOGGING_LEVEL=INFO
- NOS_GRPC_HOST=nos-grpc-server
- NOS_HTTP_ENV=prod
- NOS_M2M_API_KEYS=sk-test-key-1,sk-test-key-2
volumes:
- ~/.nosd:/app/.nos
- /dev/shm:/dev/shm
ports:
- 8000:8000
ipc: host
depends_on:
- nos-grpc-server

nos-grpc-server:
image: autonomi/nos:latest-cpu
environment:
- NOS_HOME=/app/.nos
- NOS_GRPC_HOST=[::]
- NOS_LOGGING_LEVEL=INFO
volumes:
- ~/.nosd:/app/.nos
- /dev/shm:/dev/shm
ports:
- 50051:50051
ipc: host

27 changes: 27 additions & 0 deletions examples/tutorials/06-serving-with-m2m-auth/tests/test_m2m_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import requests


if __name__ == "__main__":
BASE_URL = "http://localhost:8000"

# Test health
response = requests.get(f"{BASE_URL}/v1/health")
response.raise_for_status()

# Test model info without authentication
response = requests.get(f"{BASE_URL}/v1/models")
assert response.status_code == 401, "Expected 401 Unauthorized"

# Test model info with invalid authentication
response = requests.get(f"{BASE_URL}/v1/models", headers={"X-Api-Key": "invalid-api-key"})
assert response.status_code == 403, "Expected 403 Forbidden"

# Test model info with valid authentication
response = requests.get(f"{BASE_URL}/v1/models", headers={"X-Api-Key": "sk-test-key-1"})
response.raise_for_status()
assert response.status_code == 200, "Expected 200 OK"

# Test model inference without authentication
response = requests.get(f"{BASE_URL}/v1/models", headers={"X-Api-Key": "sk-test-key-2"})
response.raise_for_status()
assert response.status_code == 200, "Expected 200 OK"
3 changes: 2 additions & 1 deletion nos/client/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ def WaitForServer(self, timeout: int = 60, retry_interval: int = 5) -> None:
if int(elapsed) > 10:
logger.warning("Waiting for server to start... (elapsed={:.0f}s)".format(time.time() - st))
time.sleep(retry_interval)
raise ServerReadyException("Failed to ping server.")
default_msg = """\n If you are running the server in docker, make sure the server sets `NOS_GRPC_HOST=[::]` and the client sets `NOS_GRPC_HOST=<server-container-name>` in their environment variables."""
raise ServerReadyException(f"Failed to ping server. {default_msg}")

def GetServiceVersion(self) -> str:
"""Get service version.
Expand Down
26 changes: 26 additions & 0 deletions nos/server/http/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from fastapi import status
from fastapi.responses import JSONResponse
from loguru import logger
from starlette.requests import Request


async def default_exception_middleware(request: Request, call_next):
try:
return await call_next(request)
except Exception as exc:
base_error_message = f"Internal server error: [method={request.method}], url={request.url}]"
logger.error(f"Internal server error: [method={request.method}, url={request.url}, exc={exc}]")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"message": base_error_message},
)


async def default_exception_handler(request, error):
"""Default exception handler for all routes."""
base_error_message = f"Internal server error: [method={request.method}], url={request.url}]"
logger.error(f"Internal server error: [method={request.method}, url={request.url}, error={error}]")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"message": f"{base_error_message}."},
)
38 changes: 38 additions & 0 deletions nos/server/http/_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os

from fastapi import Depends, HTTPException, Request, status
from fastapi.security import APIKeyHeader
from loguru import logger


valid_m2m_keys = {}
for key in os.getenv("NOS_M2M_API_KEYS", "").split(","):
if len(key) > 0:
logger.debug(f"Adding valid_m2m_keys [key={key}]")
valid_m2m_keys[key] = key
api_key_header = APIKeyHeader(name="X-Api-Key", auto_error=False)


async def validate_m2m_key(request: Request, api_key: str = Depends(api_key_header)) -> bool:
logger.debug(f"validate_m2m_key [api_key={api_key}]")

if not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing X-Api-Key Key header",
)

if api_key not in valid_m2m_keys:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid Machine-to-Machine Key",
)

assert isinstance(api_key, str)
return True


if valid_m2m_keys:
ValidMachineToMachine = Depends(validate_m2m_key)
else:
ValidMachineToMachine = None
53 changes: 35 additions & 18 deletions nos/server/http/_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from nos.protoc import import_module
from nos.version import __version__

from ._exceptions import default_exception_handler, default_exception_middleware
from ._security import ValidMachineToMachine as ValidM2M
from ._utils import decode_item, encode_item
from .integrations.openai.models import (
ChatCompletionsRequest,
Expand Down Expand Up @@ -60,28 +62,14 @@ class InferenceRequest:
class InferenceService:
"""HTTP server application for NOS API."""

version: str = field(default="v1")
"""NOS version."""

address: str = field(default=DEFAULT_GRPC_ADDRESS)
"""gRPC address."""

env: str = field(default=HTTP_ENV)
"""Environment (dev/prod/test)."""

model_config = ConfigDict(arbitrary_types_allowed=True)
"""Model configuration."""

def __post_init__(self):
"""Post initialization."""
self.app = FastAPI(
title="NOS REST API",
description=f"NOS REST API (version={__version__}, api_version={self.version})",
version=self.version,
debug=self.env != "prod",
)
logger.debug(f"Starting NOS REST API (version={__version__}, env={self.env})")

self.client = Client(self.address)
logger.debug(f"Connecting to gRPC server (address={self.client.address})")

Expand Down Expand Up @@ -144,12 +132,33 @@ def app_factory(version: str = HTTP_API_VERSION, address: str = DEFAULT_GRPC_ADD
Returns:
(FastAPI) FastAPI application.
"""
nos_app = InferenceService(version=version, address=address, env=env)
app = nos_app.app
from fastapi.middleware.cors import CORSMiddleware

svc = InferenceService(address=address)
logger.info(f"app_factory [env={env}]: Adding CORS middleware ...")
app = FastAPI(
title="NOS REST API",
description=f"NOS REST API (version={__version__}, api_version={version})",
version=version,
debug=env != "prod",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
logger.debug(f"Starting NOS REST API (version={__version__}, env={env})")

# Add default exception handler
logger.info(f"app_factory [env={env}]: Adding default exception handlers ...")
app.middleware("http")(default_exception_middleware)
app.add_exception_handler(Exception, default_exception_handler)

def get_client() -> Client:
"""Get the inference client."""
return nos_app.client
return svc.client

def unnormalize_id(model_id: str) -> str:
"""Un-normalize the model identifier."""
Expand Down Expand Up @@ -192,14 +201,19 @@ def health(client: Client = Depends(get_client)) -> JSONResponse:
@app.get(f"/{version}/models", status_code=status.HTTP_200_OK, response_model=Model)
def models(
client: Client = Depends(get_client),
user: Depends = ValidM2M,
) -> Model:
"""List all available models."""
_model_table = build_model_table(client)
logger.debug(f"Listing models [models={_model_table.values()}]")
return Model(data=list(_model_table.values()))

@app.get(f"/{version}/models/" + "{model:path}", response_model=ChatModel)
def model_info(model: str, client: Client = Depends(get_client)) -> ChatModel:
def model_info(
model: str,
client: Client = Depends(get_client),
user: Depends = ValidM2M,
) -> ChatModel:
"""Get model information."""
_model_table = build_model_table(client)
try:
Expand All @@ -211,6 +225,7 @@ def model_info(model: str, client: Client = Depends(get_client)) -> ChatModel:
def chat(
request: ChatCompletionsRequest,
client: Client = Depends(get_client),
user: Depends = ValidM2M,
) -> StreamingResponse:
"""Perform chat completion on the given input data."""
logger.debug(f"Received chat request [model={request.model}, messages={request.messages}]")
Expand Down Expand Up @@ -270,6 +285,7 @@ def openai_streaming_generator():
def infer(
request: InferenceRequest,
client: Client = Depends(get_client),
user: Depends = ValidM2M,
) -> JSONResponse:
"""Perform inference on the given input data.
Expand Down Expand Up @@ -319,6 +335,7 @@ def infer_file(
file: Optional[UploadFile] = File(None),
url: Optional[str] = Form(None),
client: Client = Depends(get_client),
user: Depends = ValidM2M,
) -> JSONResponse:
"""Perform inference on the given input data using multipart/form-data.
Expand Down
2 changes: 1 addition & 1 deletion nos/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.1"
__version__ = "0.3.0"

0 comments on commit 3acbe8e

Please sign in to comment.