Skip to content

Commit

Permalink
Some fixes on model server (#362)
Browse files Browse the repository at this point in the history
* Some fixes on model server

* Remove prompt_prefilling message

* Fix logging

* Fix poetry issues

* Improve logging and update the support for text truncation

* Fix tests

* Fix tests

* Fix tests

* Fix modelserver tests

* Update modelserver tests
  • Loading branch information
nehcgs authored Jan 11, 2025
1 parent ebda682 commit 88a02dc
Show file tree
Hide file tree
Showing 25 changed files with 1,090 additions and 1,666 deletions.
566 changes: 108 additions & 458 deletions arch/tools/poetry.lock

Large diffs are not rendered by default.

9 changes: 3 additions & 6 deletions arch/tools/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "archgw"
version = "0.1.7"
version = "0.1.8"
description = "Python-based CLI tool to manage Arch Gateway."
authors = ["Katanemo Labs, Inc."]
packages = [
Expand All @@ -9,15 +9,12 @@ packages = [
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.12"
archgw_modelserver = "0.1.7"
pyyaml = "^6.0.2"
pydantic = "^2.10.1"
python = "^3.10"
archgw_modelserver = "^0.1.8"
click = "^8.1.7"
jinja2 = "^3.1.4"
jsonschema = "^4.23.0"
setuptools = "75.5.0"
huggingface_hub = "^0.26.0"
docker = "^7.1.0"
python-dotenv = "^1.0.1"

Expand Down
6 changes: 3 additions & 3 deletions archgw.code-workspace
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
{
"folders": [
{
{
"name": "root",
"path": "."
},
"path": "."
},
{
"name": "crates",
"path": "crates"
Expand Down
844 changes: 359 additions & 485 deletions model_server/poetry.lock

Large diffs are not rendered by default.

16 changes: 7 additions & 9 deletions model_server/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "archgw_modelserver"
version = "0.1.7"
version = "0.1.8"
description = "A model server for serving models"
authors = ["Katanemo Labs, Inc <info@katanemo.com>"]
license = "Apache 2.0"
Expand All @@ -10,17 +10,15 @@ packages = [
]

[tool.poetry.dependencies]
python = "^3.12"
python = "^3.10"
fastapi = "0.115.0"
torch = "2.4.1"
uvicorn = "0.31.0"
transformers = "*"
pyyaml = "6.0.2"
accelerate = "*"
psutil = "6.0.0"
pandas = "*"
transformers = "^4.37.0"
accelerate = "^1.0.0"
pydantic = "^2.10.1"
dateparser = "*"
openai = "1.50.2"
openai = "^1.50.2"
httpx = "0.27.2" # https://community.openai.com/t/typeerror-asyncclient-init-got-an-unexpected-keyword-argument-proxies/1040287
pytest-asyncio = "*"
pytest = "*"
Expand All @@ -33,7 +31,7 @@ pytest-retry = "^1.6.3"
pytest-httpserver = "^1.1.0"

[tool.poetry.scripts]
archgw_modelserver = "src.cli:run_server"
archgw_modelserver = "src.cli:main"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
140 changes: 61 additions & 79 deletions model_server/src/cli.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
import importlib
import logging
from os import path
import os
from signal import SIGKILL
import sys
import subprocess
import argparse
import signal
import tempfile
import time

import requests

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
import src.commons.utils as utils


logger = logging.getLogger(__name__)
logger = utils.get_model_server_logger()


def get_version():
Expand All @@ -42,76 +37,9 @@ def wait_for_health_check(url, timeout=300):
return False


def parse_args():
parser = argparse.ArgumentParser(description="Manage the Uvicorn server.")
parser.add_argument(
"action",
choices=["start", "stop", "restart"],
default="start",
nargs="?",
help="Action to perform on the server (default: start).",
)
parser.add_argument(
"--port",
type=int,
default=51000,
help="Port number for the server (default: 51000).",
)

parser.add_argument(
"--foreground",
default=False,
action="store_true",
help="Run the server in the foreground (default: False).",
)

return parser.parse_args()


def get_pid_file():
temp_dir = tempfile.gettempdir()
return path.join(temp_dir, "model_server.pid")


def stop_server():
"""Stop the Uvicorn server."""
pid_file = get_pid_file()
if os.path.exists(pid_file):
logger.info(f"PID file found, shutting down the server.")
# read pid from file
with open(pid_file, "r") as f:
pid = int(f.read())
logger.info(f"Killing model server {pid}")
try:
os.kill(pid, SIGKILL)
except ProcessLookupError:
logger.info(f"Process {pid} not found")
os.remove(pid_file)
else:
logger.info("No PID file found, server is not running.")


def restart_server(port=51000, foreground=False):
"""Restart the Uvicorn server."""
stop_server()
start_server(port, foreground)


def run_server():
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""

args = parse_args()
action = args.action

if action == "start":
start_server(args.port, args.foreground)
elif action == "stop":
stop_server()
elif action == "restart":
restart_server(args.port, args.foreground)
else:
logger.info(f"Unknown action: {action}")
sys.exit(1)
return os.path.join(temp_dir, "model_server.pid")


def ensure_killed(process):
Expand All @@ -131,7 +59,7 @@ def ensure_killed(process):
def start_server(port=51000, foreground=False):
"""Start the Uvicorn server."""

logging.info("model server version: %s", get_version())
logger.info("model server version: %s", get_version())

stop_server()

Expand Down Expand Up @@ -196,6 +124,57 @@ def start_server(port=51000, foreground=False):
ensure_killed(process)


def stop_server():
"""Stop the Uvicorn server."""

pid_file = get_pid_file()
if os.path.exists(pid_file):
logger.info("PID file found, shutting down the server.")
# read pid from file
with open(pid_file, "r") as f:
pid = int(f.read())
logger.info(f"Killing model server {pid}")
try:
os.kill(pid, signal.SIGKILL)
except ProcessLookupError:
logger.info(f"Process {pid} not found")
os.remove(pid_file)
else:
logger.info("No PID file found, server is not running.")


def restart_server(port=51000, foreground=False):
"""Restart the Uvicorn server."""
stop_server()
start_server(port, foreground)


def parse_args():
parser = argparse.ArgumentParser(description="Manage the Uvicorn server.")
parser.add_argument(
"action",
choices=["start", "stop", "restart"],
default="start",
nargs="?",
help="Action to perform on the server (default: start).",
)
parser.add_argument(
"--port",
type=int,
default=51000,
help="Port number for the server (default: 51000).",
)

parser.add_argument(
"--foreground",
default=False,
action="store_true",
help="Run the server in the foreground (default: False).",
)

return parser.parse_args()


def main():
"""
Start, stop, or restart the Uvicorn server based on command-line arguments.
Expand All @@ -204,11 +183,14 @@ def main():
args = parse_args()

if args.action == "start":
logger.info("[CLI] - Starting server")
start_server(args.port, args.foreground)
elif args.action == "stop":
logger.info("[CLI] - Stopping server")
stop_server()
elif args.action == "restart":
logger.info("[CLI] - Restarting server")
restart_server(args.port)
else:
logger.error(f"Unknown action: {args.action}")
logger.error(f"[CLI] - Unknown action: {args.action}")
sys.exit(1)
6 changes: 2 additions & 4 deletions model_server/src/commons/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
# Define model names
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"

logger.info("loading prompt guard model ...")
arch_guard_model = get_guardrail_handler()
ARCH_GUARD_MODEL_ALIAS = "katanemo/Arch-Guard"

# Define model handlers
handler_map = {
Expand All @@ -34,5 +32,5 @@
"Arch-Function": ArchFunctionHandler(
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
),
"Arch-Guard": arch_guard_model,
"Arch-Guard": get_guardrail_handler(ARCH_GUARD_MODEL_ALIAS),
}
79 changes: 21 additions & 58 deletions model_server/src/commons/utils.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,50 @@
import os
import sys
import time
import torch
import logging
import requests
import subprocess
import importlib


PROJ_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from datetime import datetime

# Default log directory and file
DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, ".logs")
DEFAULT_LOG_FILE = "modelserver.log"


def get_model_server_logger(log_dir=None, log_file=None):
def get_model_server_logger():
"""
Get or initialize the logger instance for the model server.
Parameters:
- log_dir (str): Custom directory to store the log file. Defaults to `./.logs`.
- log_file (str): Custom log file name. Defaults to `modelserver.log`.
Returns:
- logging.Logger: Configured logger instance.
"""
log_dir = log_dir or DEFAULT_LOG_DIR
log_file = log_file or DEFAULT_LOG_FILE
log_file_path = os.path.join(log_dir, log_file)

# Check if the logger is already configured
logger = logging.getLogger("model_server_logger")
logger = logging.getLogger("model_server")

# Return existing logger instance if already configured
if logger.hasHandlers():
# Return existing logger instance if already configured
return logger

# Ensure the log directory exists, create it if necessary
try:
# Create directory if it doesn't exist
os.makedirs(log_dir, exist_ok=True)

# Check for write permissions
if not os.access(log_dir, os.W_OK):
raise PermissionError(f"No write permission for the directory: {log_dir}")
except (PermissionError, OSError) as e:
raise RuntimeError(f"Failed to initialize logger: {e}")

# Configure logging to file
# Configure logging to only log to console
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
# logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in the file
logging.StreamHandler(), # Also log to console
],
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
)

return logger


logger = get_model_server_logger()

logging.info("initializing torch device ...")
import torch


def get_device():
available_device = {
"cpu": True,
"cuda": torch.cuda.is_available(),
"mps": (
torch.backends.mps.is_available()
if hasattr(torch.backends, "mps")
else False
),
}

if available_device["cuda"]:
if torch.cuda.is_available():
device = "cuda"
elif available_device["mps"]:
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"

return device


def get_today_date():
# Get today's date
today = datetime.now()

# Get full date with day of week
full_date = today.strftime("%Y-%m-%d")

return full_date
Loading

0 comments on commit 88a02dc

Please sign in to comment.