From 74bf39ff853fb3be3bbc913edb757109383ad3df Mon Sep 17 00:00:00 2001 From: Sudeep Pillai Date: Mon, 4 Sep 2023 11:06:43 -0700 Subject: [PATCH] Discord bot with NOS fine-tuning API (#314) ## Summary - adds discord bot with Dockerfile, docker-compose.yml and requirements for discord app - new discord training bot with new fine-tuning API for sdv2 lora - added tests for fine-tuning API ## Related issues ## Checks - [x] `make lint`: I've run `make lint` to lint the changes in this PR. - [x] `make test`: I've made sure the tests (`make test-cpu` or `make test`) are passing. - Additional tests: - [ ] Benchmark tests (when contributing new models) - [ ] GPU/HW tests --- docker/.dockerignore | 2 + examples/discord/.env.template | 1 + examples/discord/Dockerfile | 10 + examples/discord/Makefile | 7 + examples/discord/bot.py | 287 ++++++++++++++++++ examples/discord/docker-compose.yml | 36 +++ examples/discord/requirements.txt | 4 + nos/client/grpc.py | 85 +++++- nos/executors/ray.py | 15 + nos/experimental/discord/nos_bot.py | 52 ---- nos/experimental/train/__init__.py | 1 - nos/experimental/train/_train_service.py | 63 ---- nos/models/__init__.py | 1 + nos/models/dreambooth/dreambooth.py | 20 +- nos/proto/nos_service.proto | 18 +- nos/server/_service.py | 71 +++-- nos/server/train/__init__.py | 1 + nos/server/train/_service.py | 114 +++++++ nos/{experimental => server}/train/config.py | 0 .../train/dreambooth/config.py | 13 +- requirements/requirements.server.txt | 2 +- scripts/entrypoint.sh | 6 +- tests/client/grpc/test_grpc_client.py | 16 + tests/client/test_client_integration.py | 70 ++++- tests/server/test_inference_service.py | 1 + tests/server/test_training_service.py | 32 +- 26 files changed, 740 insertions(+), 188 deletions(-) create mode 100644 examples/discord/.env.template create mode 100644 examples/discord/Dockerfile create mode 100644 examples/discord/Makefile create mode 100644 examples/discord/bot.py create mode 100644 examples/discord/docker-compose.yml create mode 100644 examples/discord/requirements.txt delete mode 100755 nos/experimental/discord/nos_bot.py delete mode 100644 nos/experimental/train/__init__.py delete mode 100644 nos/experimental/train/_train_service.py create mode 100644 nos/server/train/__init__.py create mode 100644 nos/server/train/_service.py rename nos/{experimental => server}/train/config.py (100%) rename nos/{experimental => server}/train/dreambooth/config.py (89%) diff --git a/docker/.dockerignore b/docker/.dockerignore index 68228799..ef5efe55 100644 --- a/docker/.dockerignore +++ b/docker/.dockerignore @@ -24,3 +24,5 @@ dist bdist *.cache *.ts + +site/ diff --git a/examples/discord/.env.template b/examples/discord/.env.template new file mode 100644 index 00000000..1bcaa3cd --- /dev/null +++ b/examples/discord/.env.template @@ -0,0 +1 @@ +DISCORD_BOT_TOKEN= diff --git a/examples/discord/Dockerfile b/examples/discord/Dockerfile new file mode 100644 index 00000000..3a92306e --- /dev/null +++ b/examples/discord/Dockerfile @@ -0,0 +1,10 @@ +ARG BASE_IMAGE +FROM ${BASE_IMAGE:-autonomi/nos:latest-cpu} + +WORKDIR /tmp/$PROJECT +ADD requirements.txt . +RUN pip install -r requirements.txt + +WORKDIR /app +COPY bot.py . +CMD ["python", "/app/bot.py"] diff --git a/examples/discord/Makefile b/examples/discord/Makefile new file mode 100644 index 00000000..a4d0a8d1 --- /dev/null +++ b/examples/discord/Makefile @@ -0,0 +1,7 @@ +SHELL := /bin/bash + +docker-compose-upd-discord-bot: + sudo mkdir -p ~/.nosd/volumes + sudo chown -R $(USER):$(USER) ~/.nosd/volumes + pushd ../../ && make docker-build-cpu docker-build-gpu && popd; + docker compose -f docker-compose.yml up --build diff --git a/examples/discord/bot.py b/examples/discord/bot.py new file mode 100644 index 00000000..aed34cd4 --- /dev/null +++ b/examples/discord/bot.py @@ -0,0 +1,287 @@ +"""Example discord both with Stable Diffusion LoRA fine-tuning support.""" +import asyncio +import io +import os +import time +import uuid +from dataclasses import dataclass +from pathlib import Path + +import discord +from discord.ext import commands +from diskcache import Cache + +from nos.client import InferenceClient, TaskType +from nos.constants import NOS_TMP_DIR +from nos.logging import logger + + +@dataclass +class LoRAPromptModel: + + thread_id: str + """Discord thread ID""" + + thread_name: str + """Discord thread name""" + + model_id: str + """Training job ID / model ID""" + + prompt: str + """Prompt used to train the model""" + + @property + def job_id(self) -> str: + return self.model_id + + def __str__(self) -> str: + return f"LoRAPromptModel(thread_id={self.thread_id}, thread_name={self.thread_name}, model_id={self.model_id}, prompt={self.prompt})" + + +NOS_PLAYGROUND_CHANNEL = "nos-playground" +BASE_MODEL = "runwayml/stable-diffusion-v1-5" + +# Init NOS server, wait for it to spin up then confirm its healthy. +client = InferenceClient() + +logger.debug("Waiting for server to start...") +client.WaitForServer() + +logger.debug("Confirming server is healthy...") +if not client.IsHealthy(): + raise RuntimeError("NOS server is not healthy") + +logger.debug("Server is healthy!") +NOS_VOLUME_DIR = Path(client.Volume()) +NOS_TRAINING_VOLUME_DIR = Path(client.Volume("nos-playground")) +logger.debug(f"Creating training data volume [volume={NOS_TRAINING_VOLUME_DIR}]") + +# Set permissions for our bot to allow it to read messages +intents = discord.Intents.default() +intents.message_content = True + +# Create our bot, with the command prefix set to "/" +bot = commands.Bot(command_prefix="/", intents=intents) + +logger.debug("Starting bot, initializing existing threads ...") + +# Simple persistent dict/database for storing models +# This maps (channel-id -> LoRAPromptModel) +MODEL_DB = Cache(str(NOS_TMP_DIR / NOS_PLAYGROUND_CHANNEL)) + + +@bot.command() +async def generate(ctx, *, prompt): + """Create a callback to read messages and generate images from prompt + + Usage: + 1. In the main channel, you can run: + /generate a photo of a dog on the moon + to generate an image with the pre-trained SDv1.5 model. + + 2. In a thread that has been created by fine-tuning a new model, + you can run: + /generate a photo of a sks dog on the moon + to generate the specific instance of the dog using the fine-tuned model. + """ + logger.debug( + f"/generate [prompt={prompt}, channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]" + ) + + st = time.perf_counter() + if ctx.channel.name == NOS_PLAYGROUND_CHANNEL: + # Pull the channel id so we know which model to run: + + # Acknowledge the request, by reacting to the message with a checkmark + logger.debug(f"Request acknowledged [id={ctx.message.id}]") + await ctx.message.add_reaction("✅") + + logger.debug(f"/generate request submitted [id={ctx.message.id}]") + response = client.Run( + task=TaskType.IMAGE_GENERATION, + model_name=BASE_MODEL, + prompts=[prompt], + width=512, + height=512, + num_images=1, + ) + logger.debug(f"/generate request completed [id={ctx.message.id}, elapsed={time.perf_counter() - st:.2f}s]") + (image,) = response["images"] + thread = ctx + else: + # Pull the channel id so we know which model to run: + thread = ctx.channel + thread_id = ctx.channel.id + model = MODEL_DB.get(thread_id, default=None) + if model is None: + logger.debug(f"Failed to fetch model [thread_id={thread_id}]") + await thread.send("No model found") + return + + # Get model info + try: + models = client.ListModels() + info = {m.name: m for m in models}[model.model_id] + logger.debug(f"Got model info [model_id={model.model_id}, info={info}]") + except Exception as e: + logger.debug(f"Failed to fetch model info [model_id={model.model_id}, e={e}]") + await thread.send(f"Failed to fetch model [model_id={model.model_id}]") + return + + # Acknowledge the request, by reacting to the message with a checkmark + logger.debug(f"Request acknowledged [id={ctx.message.id}]") + await ctx.message.add_reaction("✅") + + # Run inference on the trained model + logger.debug(f"/generate request submitted [id={ctx.message.id}, model_id={model.model_id}, model={model}]") + response = client.Run( + task=TaskType.IMAGE_GENERATION, + model_name=model.model_id, + prompts=[prompt], + width=512, + height=512, + num_images=1, + ) + logger.debug( + f"/generate request completed [id={ctx.message.id}, model={model}, elapsed={time.perf_counter() - st:.2f}s]" + ) + (image,) = response["images"] + + # Save the image to a buffer and send it back to the user: + image_bytes = io.BytesIO() + image.save(image_bytes, format="PNG") + image_bytes.seek(0) + await thread.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png")) + + +@bot.command() +async def train(ctx, *, prompt): + """Fine-tune a new model with the provided prompt and images. + + Example: + Upload a few images of your favorite dog, and then run: + `/train sks a photo of a sks dog on the moon` + """ + + logger.debug(f"/train [channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]") + + if ctx.channel.name != NOS_PLAYGROUND_CHANNEL: + logger.debug("ignoring [channel={ctx.channel.name}]") + return + + if not ctx.message.attachments: + logger.debug("no attachments to train on, returning!") + return + + if "sks" not in prompt: + await ctx.send("Please include 'sks' in your training prompt!") + return + + # Create a thread for this training job + message_id = str(ctx.message.id) + thread = await ctx.message.create_thread(name=f"{prompt} ({message_id})") + logger.debug(f"Created thread [id={thread.id}, name={thread.name}]") + + # Save the thread id + thread_id = thread.id + + # Create the training directory for this thread + dirname = NOS_TRAINING_VOLUME_DIR / str(thread_id) + dirname.mkdir(parents=True, exist_ok=True) + logger.debug(f"Created training directory [dirname={dirname}]") + + # Save all attachments to the training directory + logger.debug(f"Saving attachments [dirname={dirname}, attachments={len(ctx.message.attachments)}]") + for attachment in ctx.message.attachments: + filename = str(dirname / f"{str(uuid.uuid4().hex[:8])}_{attachment.filename}") + await attachment.save(filename) + logger.debug(f"Saved attachment [filename={filename}]") + + # Acknowledge the request, by reacting to the message with a checkmark + logger.debug(f"Request acknowledged [id={ctx.message.id}]") + await ctx.message.add_reaction("✅") + + # Train a new LoRA model with the image of a bench + response = client.Train( + method="stable-diffusion-dreambooth-lora", + inputs={ + "model_name": BASE_MODEL, + "instance_directory": dirname.relative_to(NOS_VOLUME_DIR), + "instance_prompt": prompt, + "max_train_steps": 500, + }, + metadata={ + "name": "sdv15-dreambooth-lora", + }, + ) + job_id = response["job_id"] + logger.debug(f"Submitted training job [id={thread_id}, response={response}, dirname={dirname}]") + await thread.send(f"@here Submitted training job [id={thread.name}, model={job_id}]") + + # Save the model + MODEL_DB[thread_id] = LoRAPromptModel( + thread_id=thread_id, + thread_name=thread.name, + model_id=f"custom/{job_id}", + prompt=prompt, + ) + logger.debug(f"Saved model [id={thread_id}, model={MODEL_DB[thread_id]}]") + + if response is None: + logger.error(f"Failed to submit training job [id={thread_id}, response={response}, dirname={dirname}]") + await thread.send(f"Failed to train [prompt={prompt}, response={response}, dirname={dirname}]") + + # Create a new thread to watch the training job + async def post_on_training_complete_async(): + # Wait for the model to be ready + response = client.Wait(job_id=job_id, timeout=600, retry_interval=10) + logger.debug(f"Training completed [job_id={job_id}, response={response}].") + + # Get the thread + _thread = bot.get_channel(thread_id) + await _thread.send(f"@here Training complete [id={_thread.name}, model={job_id}]") + + # Wait for model to be registered after the job is complete + await asyncio.sleep(5) + + # Run inference on the trained model + st = time.perf_counter() + response = client.Run( + task=TaskType.IMAGE_GENERATION, + model_name=f"custom/{job_id}", + prompts=[prompt], + width=512, + height=512, + num_images=1, + ) + logger.debug(f"/generate request completed [model={job_id}, elapsed={time.perf_counter() - st:.2f}s]") + (image,) = response["images"] + + # Save the image to a buffer and send it back to the user: + image_bytes = io.BytesIO() + image.save(image_bytes, format="PNG") + image_bytes.seek(0) + await _thread.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png")) + + logger.debug(f"Starting thread to watch training job [id={thread_id}, job_id={job_id}]") + asyncio.run_coroutine_threadsafe(post_on_training_complete_async(), loop) + logger.debug(f"Started thread to watch training job [id={thread_id}, job_id={job_id}]") + + +async def run_bot(token: str): + """Start the bot with the user-provided token.""" + await bot.start(token) + + +if __name__ == "__main__": + # Get the bot token from the environment + token = os.environ.get("DISCORD_BOT_TOKEN") + if token is None: + raise Exception("DISCORD_BOT_TOKEN environment variable not set") + logger.debug(f"Starting bot with token [token={token[:5]}****]") + + # Start the asyncio event loop, and run the bot + loop = asyncio.get_event_loop() + loop.create_task(run_bot(token)) + loop.run_forever() diff --git a/examples/discord/docker-compose.yml b/examples/discord/docker-compose.yml new file mode 100644 index 00000000..336f093e --- /dev/null +++ b/examples/discord/docker-compose.yml @@ -0,0 +1,36 @@ +version: "3.8" + +services: + bot: + image: autonomi/nos:latest-discord-app + build: + context: . + dockerfile: Dockerfile + args: + - BASE_IMAGE=autonomi/nos:latest-cpu + env_file: + - .env + environment: + - NOS_HOME=/app/.nos + - NOS_LOGGING_LEVEL=DEBUG + volumes: + - ~/.nosd:/app/.nos + - /dev/shm:/dev/shm + network_mode: host + ipc: host + + server: + image: autonomi/nos:latest-gpu + environment: + - NOS_HOME=/app/.nos + - NOS_LOGGING_LEVEL=DEBUG + volumes: + - ~/.nosd:/app/.nos + - /dev/shm:/dev/shm + network_mode: host + ipc: host + deploy: + resources: + reservations: + devices: + - capabilities: [gpu] diff --git a/examples/discord/requirements.txt b/examples/discord/requirements.txt new file mode 100644 index 00000000..ffb8316a --- /dev/null +++ b/examples/discord/requirements.txt @@ -0,0 +1,4 @@ +discord==2.3.2 +discord.py==2.3.2 +diskcache +docker diff --git a/nos/client/grpc.py b/nos/client/grpc.py index 03b92ac1..68a04474 100644 --- a/nos/client/grpc.py +++ b/nos/client/grpc.py @@ -1,8 +1,10 @@ """gRPC client for NOS service.""" import secrets import time +import uuid from dataclasses import dataclass, field from functools import cached_property, lru_cache +from pathlib import Path from typing import Any, Callable, Dict, List import grpc @@ -18,7 +20,7 @@ NosServerReadyException, ) from nos.common.shm import NOS_SHM_ENABLED, SharedMemoryTransportManager -from nos.constants import DEFAULT_GRPC_PORT, NOS_PROFILING_ENABLED +from nos.constants import DEFAULT_GRPC_PORT, NOS_HOME, NOS_PROFILING_ENABLED from nos.logging import logger from nos.protoc import import_module from nos.version import __version__ @@ -165,6 +167,20 @@ def GetServiceVersion(self) -> str: except grpc.RpcError as e: raise NosServerReadyException(f"Failed to get service info (details={e.details()})", e) + def GetServiceRuntime(self) -> str: + """Get service runtime. + + Returns: + str: Service runtime (e.g. cpu, gpu, local). + Raises: + NosClientException: If the server fails to respond to the request. + """ + try: + response: nos_service_pb2.ServiceInfoResponse = self.stub.GetServiceInfo(empty_pb2.Empty()) + return response.runtime + except grpc.RpcError as e: + raise NosServerReadyException(f"Failed to get service info (details={e.details()})", e) + def CheckCompatibility(self) -> bool: """Check if the service version is compatible with the client. @@ -268,6 +284,71 @@ def Run( module: InferenceModule = self.Module(task, model_name) return module(**inputs) + def Train( + self, method: str, inputs: Dict[str, Any], metadata: Dict[str, Any] = None + ) -> nos_service_pb2.GenericResponse: + """Training module. + + Args: + method (str): Training method (e.g. `stable-diffusion-dreambooth-lora`). + inputs (Dict[str, Any]): Training inputs. + metadata (Dict[str, Any], optional): Metadata for the training job. Defaults to None. + Returns: + str: Job ID. + Raises: + NosClientException: If the server fails to respond to the request. + """ + try: + request = nos_service_pb2.GenericRequest( + request_bytes=dumps({"method": method, "inputs": inputs, "metadata": metadata}) + ) + response = self.stub.Train(request) + return loads(response.response_bytes) + except grpc.RpcError as e: + raise NosClientException(f"Failed to train model (details={(e.details())})", e) + + def Volume(self, name: str = None) -> str: + """Remote volume module for NOS. + + Note: This is meant for remote volume mounts especially useful for training. + """ + runtime = self.GetServiceRuntime() + root = NOS_HOME / "volumes" if runtime == "local" else Path.home() / ".nosd/volumes" + if name is None: + root.mkdir(parents=True, exist_ok=True) + return str(root) + path = root / f"{name}_{uuid.uuid4().hex[:8]}" + path.mkdir(parents=True, exist_ok=True) + return str(path) + + def Wait(self, job_id: str, timeout: int = 60, retry_interval: int = 5) -> None: + """Wait for job to finish. + + Args: + job_id (str): Job ID. + timeout (int, optional): Timeout in seconds. Defaults to 60. + retry_interval (int, optional): Retry interval in seconds. Defaults to 5. + """ + st = time.time() + response = None + while time.time() - st <= timeout: + try: + response: nos_service_pb2.GenericResponse = self.stub.GetJobStatus( + nos_service_pb2.GenericRequest(request_bytes=dumps({"job_id": job_id})) + ) + response = loads(response.response_bytes) + if str(response) != "PENDING" and str(response) != "RUNNING": + return response + else: + logger.debug( + f"Waiting for job to finish [job_id={job_id}, response={response}, elapsed={time.time() - st:.0f}s]" + ) + except Exception as e: + logger.warning(f"Failed to fetch job status ... [elapsed={time.time() - st:.0f}s, e={e}]") + time.sleep(retry_interval) + logger.warning(f"Job timed out [job_id={job_id}]") + return response + @dataclass class InferenceModule: @@ -513,7 +594,7 @@ def __call__(self, **inputs: Dict[str, Any]) -> Dict[str, Any]: ) # Execute the request st = time.perf_counter() - logger.debug(f"Executing request [model={self._spec.name}]]") + logger.debug(f"Executing request [model={self._spec.name}]") response = self.stub.Run(request) if NOS_PROFILING_ENABLED: logger.debug( diff --git a/nos/executors/ray.py b/nos/executors/ray.py index d3190b62..e0277ac6 100644 --- a/nos/executors/ray.py +++ b/nos/executors/ray.py @@ -212,6 +212,21 @@ def logs(self, job_id: str) -> str: """Get logs for a job.""" return self.client.get_job_logs(job_id) + def wait(self, job_id: str, timeout: int = 600, retry_interval: int = 5) -> str: + """Wait for a job to complete.""" + status = None + st = time.time() + while time.time() - st < timeout: + status = self.status(job_id) + if str(status) == "SUCCEEDED": + logger.debug(f"Training job completed [job_id={job_id}, status={status}]") + return status + else: + logger.debug(f"Training job not completed yet [job_id={job_id}, status={status}]") + time.sleep(retry_interval) + logger.warning(f"Training job timed out [job_id={job_id}, status={status}]") + return status + def init(*args, **kwargs) -> bool: """Initialize Ray executor.""" diff --git a/nos/experimental/discord/nos_bot.py b/nos/experimental/discord/nos_bot.py deleted file mode 100755 index a2b535b2..00000000 --- a/nos/experimental/discord/nos_bot.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python - -import io -import os - -import discord -from discord.ext import commands - -import nos -from nos.client import InferenceClient, TaskType - - -# Init nos server, wait for it to spin up then confirm its healthy: -nos.init(runtime="gpu") -nos_client = InferenceClient() -nos_client.WaitForServer() -if not nos_client.IsHealthy(): - raise RuntimeError("NOS server is not healthy") - -# Set permissions for our bot to allow it to read messages: -intents = discord.Intents.default() -intents.message_content = True - -# Create our bot: -bot = commands.Bot(command_prefix="$", intents=intents) - -# Create a callback to read messages and generate images from prompt: -@bot.command() -async def generate(ctx, *, prompt): - response = nos_client.Run( - TaskType.IMAGE_GENERATION, - "stabilityai/stable-diffusion-2", - prompts=[prompt], - width=512, - height=512, - num_images=1, - ) - image = response["images"][0] - - image_bytes = io.BytesIO() - image.save(image_bytes, format="PNG") - image_bytes.seek(0) - - await ctx.send(file=discord.File(image_bytes, filename="image.png")) - - -# Pull API token out of environment and run the bot: -bot_token = os.environ.get("BOT_TOKEN") -if bot_token is None: - raise Exception("BOT_TOKEN environment variable not set") - -bot.run(bot_token) diff --git a/nos/experimental/train/__init__.py b/nos/experimental/train/__init__.py deleted file mode 100644 index bf8dc8e4..00000000 --- a/nos/experimental/train/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ._train_service import TrainingService # noqa: F401 diff --git a/nos/experimental/train/_train_service.py b/nos/experimental/train/_train_service.py deleted file mode 100644 index 0a0fc1bf..00000000 --- a/nos/experimental/train/_train_service.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import Any, Dict - -from nos.exceptions import ModelNotFoundError -from nos.executors.ray import RayExecutor, RayJobExecutor -from nos.experimental.train.dreambooth.config import StableDiffusionTrainingJobConfig -from nos.logging import logger -from nos.protoc import import_module - - -nos_service_pb2 = import_module("nos_service_pb2") -nos_service_pb2_grpc = import_module("nos_service_pb2_grpc") - - -class TrainingService: - """Ray-executor based training service.""" - - config_cls = { - "stable-diffusion-dreambooth-lora": StableDiffusionTrainingJobConfig, - } - - def __init__(self): - self.executor = RayExecutor.get() - try: - self.executor.init() - except Exception as e: - err_msg = f"Failed to initialize executor [e={e}]" - logger.info(err_msg) - raise RuntimeError(err_msg) - - def train(self, method: str, training_inputs: Dict[str, Any], metadata: Dict[str, Any] = None) -> str: - """Train / Fine-tune a model by submitting a job to the RayJobExecutor. - - Args: - method (str): Training method (e.g. `stable-diffusion-dreambooth-lora`). - training_inputs (Dict[str, Any]): Training inputs. - Returns: - str: Job ID. - """ - try: - config_cls = self.config_cls[method] - except KeyError: - raise ModelNotFoundError(f"Training not supported for method [method={method}]") - - # Check if the training inputs are correctly specified - config = config_cls(method=method, **training_inputs) - try: - pass - except Exception as e: - raise ValueError(f"Invalid training inputs [training_inputs={training_inputs}, e={e}]") - - # Submit the training job as a Ray job - configd = config.job_dict() - if metadata is not None: - configd["metadata"] = metadata - logger.debug("Submitting training job") - logger.debug(f"config\n{configd}]") - job_id = self.executor.jobs.submit(**configd) - logger.debug(f"Submitted training job [job_id={job_id}, config={configd}]") - return job_id - - @property - def jobs(self) -> RayJobExecutor: - return self.executor.jobs diff --git a/nos/models/__init__.py b/nos/models/__init__.py index eab06523..c562e63c 100644 --- a/nos/models/__init__.py +++ b/nos/models/__init__.py @@ -10,6 +10,7 @@ from ._noop import NoOp # noqa: F401 from .clip import CLIP # noqa: F401 +from .dreambooth.dreambooth import StableDiffusionLoRA # noqa: F401 from .faster_rcnn import FasterRCNN # noqa: F401 from .monodepth import MonoDepth # noqa: F401 from .openmmlab.mmdetection import MMDetection # noqa: F401 diff --git a/nos/models/dreambooth/dreambooth.py b/nos/models/dreambooth/dreambooth.py index c05fd2f6..46b4e1b9 100644 --- a/nos/models/dreambooth/dreambooth.py +++ b/nos/models/dreambooth/dreambooth.py @@ -6,6 +6,8 @@ import torch from PIL import Image +from nos import hub +from nos.common import Batch, ImageSpec, ImageT, TaskType from nos.hub.config import NOS_MODELS_DIR from nos.logging import logger @@ -185,8 +187,7 @@ def __call__( self, prompts: Union[str, List[str]], num_images: int = 1, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, + num_inference_steps: int = 30, height: int = None, width: int = None, ) -> List[Image.Image]: @@ -196,7 +197,20 @@ def __call__( return self.pipe( prompts * num_images, num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, height=height if height is not None else self.cfg.resolution, width=width if width is not None else self.cfg.resolution, ).images + + +for model_name in StableDiffusionLoRA.configs.keys(): + logger.debug(f"Registering model: {model_name}") + hub.register( + model_name, + TaskType.IMAGE_GENERATION, + StableDiffusionLoRA, + init_args=(model_name,), + init_kwargs={"dtype": torch.float16}, + method_name="__call__", + inputs={"prompts": Batch[str, 1], "num_images": int, "height": int, "width": int}, + outputs={"images": Batch[ImageT[Image.Image, ImageSpec(shape=(None, None, 3), dtype="uint8")]]}, + ) diff --git a/nos/proto/nos_service.proto b/nos/proto/nos_service.proto index d1a9b007..c59307be 100644 --- a/nos/proto/nos_service.proto +++ b/nos/proto/nos_service.proto @@ -65,6 +65,7 @@ message PingResponse { // Service information repsonse message ServiceInfoResponse { string version = 1; // (e.g. "0.1.0") + string runtime = 2; // (e.g. "cpu", "gpu", "local" etc) } // Register system shared memory request @@ -94,6 +95,12 @@ service InferenceService { // Run the inference request rpc Run(InferenceRequest) returns (InferenceResponse) {} + // Dispatch a training request + rpc Train(GenericRequest) returns (GenericResponse) {} + + // Job status + rpc GetJobStatus(GenericRequest) returns (GenericResponse) {} + // Register shared memory rpc RegisterSystemSharedMemory(GenericRequest) returns (GenericResponse) {} @@ -108,14 +115,3 @@ service InferenceService { // TODO (spillai): To be implemented later (for power-users) // rpc DeleteModel(DeleteModelRequest) returns (DeleteModelResponse) {} } - - -message TrainingRequest { - ModelInfo model = 1; - map inputs = 2; - map outputs = 3; -} - -message TrainingResponse { - bytes response_bytes = 1; -} diff --git a/nos/server/_service.py b/nos/server/_service.py index cfeddb9e..b6a20761 100644 --- a/nos/server/_service.py +++ b/nos/server/_service.py @@ -20,6 +20,7 @@ from nos.logging import logger from nos.managers import ModelHandle, ModelManager from nos.protoc import import_module +from nos.server.train._service import TrainingService from nos.version import __version__ @@ -49,14 +50,10 @@ class InferenceService: """ def __init__(self): - self.model_manager = ModelManager() self.executor = RayExecutor.get() - try: - self.executor.init() - except Exception as e: - err_msg = f"Failed to initialize executor [e={e}]" - logger.info(err_msg) - raise RuntimeError(err_msg) + if not self.executor.is_initialized(): + raise RuntimeError("Ray executor is not initialized") + self.model_manager = ModelManager() if NOS_SHM_ENABLED: self.shm_manager = SharedMemoryTransportManager() else: @@ -115,7 +112,7 @@ def execute(self, model_name: str, task: TaskType = None, inputs: Dict[str, Any] return response -class InferenceServiceImpl(nos_service_pb2_grpc.InferenceServiceServicer, InferenceService): +class InferenceServiceImpl(nos_service_pb2_grpc.InferenceServiceServicer, InferenceService, TrainingService): """ Experimental gRPC-based inference service. @@ -126,6 +123,13 @@ class InferenceServiceImpl(nos_service_pb2_grpc.InferenceServiceServicer, Infere """ def __init__(self, *args, **kwargs): + self.executor = RayExecutor.get() + try: + self.executor.init() + except Exception as e: + err_msg = f"Failed to initialize executor [e={e}]" + logger.info(err_msg) + raise RuntimeError(err_msg) super().__init__(*args, **kwargs) def Ping(self, request: empty_pb2.Empty, context: grpc.ServicerContext) -> nos_service_pb2.PingResponse: @@ -136,7 +140,13 @@ def GetServiceInfo( self, request: empty_pb2.Empty, context: grpc.ServicerContext ) -> nos_service_pb2.ServiceInfoResponse: """Get information on the service.""" - return nos_service_pb2.ServiceInfoResponse(version=__version__) + from nos.common.system import has_gpu, is_inside_docker + + if is_inside_docker(): + runtime = "gpu" if has_gpu() else "cpu" + else: + runtime = "local" + return nos_service_pb2.ServiceInfoResponse(version=__version__, runtime=runtime) def ListModels(self, request: empty_pb2.Empty, context: grpc.ServicerContext) -> nos_service_pb2.ModelListResponse: """List all models.""" @@ -238,23 +248,44 @@ def Run( context.abort(grpc.StatusCode.INTERNAL, "Internal Server Error") def Train( - self, request: nos_service_pb2.TrainingRequest, context: grpc.ServicerContext - ) -> nos_service_pb2.TrainingResponse: - model_request = request.model - logger.debug(f"=> Received training request [task={model_request.task}, model={model_request.name}]") - if model_request.task not in (TaskType.IMAGE_GENERATION.value,): - context.abort(grpc.StatusCode.NOT_FOUND, f"Invalid training task [task={model_request.task}]") + self, request: nos_service_pb2.GenericRequest, context: grpc.ServicerContext + ) -> nos_service_pb2.GenericResponse: + request = loads(request.request_bytes) + logger.debug(f"=> Received training request [method={request['method']}]") + if request["method"] not in TrainingService.config_cls: + context.abort(grpc.StatusCode.NOT_FOUND, f"Invalid training task [method={request['method']}]") + + try: + st = time.perf_counter() + logger.info(f"Training request [method={request['method']}]") + job_id = self.train(request["method"], inputs=request["inputs"], metadata=request["metadata"]) + response = {"job_id": job_id} + logger.info( + f"Trained request dispatched [id={id}, method={request['method']}, elapsed={(time.perf_counter() - st) * 1e3:.1f}ms]" + ) + return nos_service_pb2.GenericResponse(response_bytes=dumps(response)) + except (grpc.RpcError, Exception) as e: + msg = f"Failed to train request [method={request['method']}]" + msg += f"{traceback.format_exc()}" + logger.error(f"{msg}, e={e}") + context.abort(grpc.StatusCode.INTERNAL, "Internal Server Error") + + def GetJobStatus( + self, request: nos_service_pb2.GenericRequest, context: grpc.ServicerContext + ) -> nos_service_pb2.GenericResponse: + request = loads(request.request_bytes) + logger.debug(f"=> Received job status request [job_id={request['job_id']}]") try: st = time.perf_counter() - logger.info(f"Training request [task={model_request.task}, model={model_request.name}]") - response = self.train(model_request.name, task=TaskType(model_request.task), inputs=request.inputs) + logger.info(f"Job status request [job_id={request['job_id']}]") + response = str(self.jobs.status(request["job_id"])) logger.info( - f"Trained request dispatched [id={id}, task={model_request.task}, model={model_request.name}, elapsed={(time.perf_counter() - st) * 1e3:.1f}ms]" + f"Job status request [job_id={request['job_id']}, response={response}, elapsed={(time.perf_counter() - st) * 1e3:.1f}ms]" ) - return nos_service_pb2.TrainingResponse(response_bytes=dumps(response)) + return nos_service_pb2.GenericResponse(response_bytes=dumps(response)) except (grpc.RpcError, Exception) as e: - msg = f"Failed to train request [task={model_request.task}, model={model_request.name}]" + msg = f"Failed to get job status [job_id={request['job_id']}]" msg += f"{traceback.format_exc()}" logger.error(f"{msg}, e={e}") context.abort(grpc.StatusCode.INTERNAL, "Internal Server Error") diff --git a/nos/server/train/__init__.py b/nos/server/train/__init__.py new file mode 100644 index 00000000..ab18c774 --- /dev/null +++ b/nos/server/train/__init__.py @@ -0,0 +1 @@ +from ._service import TrainingService # noqa: F401 diff --git a/nos/server/train/_service.py b/nos/server/train/_service.py new file mode 100644 index 00000000..97d4de98 --- /dev/null +++ b/nos/server/train/_service.py @@ -0,0 +1,114 @@ +import threading +import time +from typing import Any, Dict + +from nos.exceptions import ModelNotFoundError +from nos.executors.ray import RayExecutor, RayJobExecutor +from nos.logging import logger +from nos.protoc import import_module + +from .dreambooth.config import StableDiffusionTrainingJobConfig + + +nos_service_pb2 = import_module("nos_service_pb2") +nos_service_pb2_grpc = import_module("nos_service_pb2_grpc") + + +def register_model(model_name: str, *args, **kwargs): + import torch + from PIL import Image + + from nos import hub + from nos.common import Batch, ImageSpec, ImageT, TaskType + from nos.models.dreambooth.dreambooth import StableDiffusionDreamboothHub, StableDiffusionLoRA + + # Update the registry with newer configs + sd_hub = StableDiffusionDreamboothHub(namespace="custom") + psize = len(sd_hub) + sd_hub.update() + logger.debug(f"Updated registry with newer configs [namespace=custom, size={len(sd_hub)}, prev_size={psize}]") + + # Register the model + model_id = f"custom/{model_name}" + logger.debug(f"Registering new model [model={model_id}]") + hub.register( + model_id, + TaskType.IMAGE_GENERATION, + StableDiffusionLoRA, + init_args=(model_id,), + init_kwargs={"dtype": torch.float16}, + method_name="__call__", + inputs={"prompts": Batch[str, 1], "num_images": int, "height": int, "width": int}, + outputs={"images": Batch[ImageT[Image.Image, ImageSpec(shape=(None, None, 3), dtype="uint8")]]}, + ) + logger.debug(f"Registering new model [model={model_id}]") + + +class TrainingService: + """Ray-executor based training service.""" + + config_cls = { + "stable-diffusion-dreambooth-lora": StableDiffusionTrainingJobConfig, + } + + def __init__(self): + """Initialize the training service.""" + self.executor = RayExecutor.get() + if not self.executor.is_initialized(): + raise RuntimeError("Ray executor is not initialized") + + def train(self, method: str, inputs: Dict[str, Any], metadata: Dict[str, Any] = None) -> str: + """Train / Fine-tune a model by submitting a job to the RayJobExecutor. + + Args: + method (str): Training method (e.g. `stable-diffusion-dreambooth-lora`). + inputs (Dict[str, Any]): Training inputs. + Returns: + str: Job ID. + """ + try: + config_cls = self.config_cls[method] + except KeyError: + raise ModelNotFoundError(f"Training not supported for method [method={method}]") + + # Check if the training inputs are correctly specified + config = config_cls(method=method, **inputs) + try: + pass + except Exception as e: + raise ValueError(f"Invalid training inputs [inputs={inputs}, e={e}]") + + # Submit the training job as a Ray job + configd = config.job_dict() + if metadata is not None: + configd["metadata"] = metadata + logger.debug("Submitting training job") + logger.debug(f"config\n{configd}]") + job_id = self.executor.jobs.submit(**configd) + logger.debug(f"Submitted training job [job_id={job_id}, config={configd}]") + + hooks = {"on_completed": (register_model, (job_id,), {})} + + # Spawn a thread to monitor the job + def monitor_job_hook(job_id: str, timeout: int = 600, retry_interval: int = 5): + """Hook for monitoring the job status and running callbacks on completion.""" + st = time.time() + while time.time() - st < timeout: + status = self.executor.jobs.status(job_id) + if str(status) == "SUCCEEDED": + logger.debug(f"Training job completed [job_id={job_id}, status={status}]") + cb, args, kwargs = hooks["on_completed"] + logger.debug(f"Running callback [cb={cb}, args={args}, kwargs={kwargs}]") + cb(*args, **kwargs) + logger.debug(f"Callback completed [cb={cb}, args={args}, kwargs={kwargs}]") + break + else: + logger.debug(f"Training job not completed yet [job_id={job_id}, status={status}]") + time.sleep(retry_interval) + + threading.Thread(target=monitor_job_hook, args=(job_id,), daemon=True).start() + return job_id + + @property + def jobs(self) -> RayJobExecutor: + return self.executor.jobs diff --git a/nos/experimental/train/config.py b/nos/server/train/config.py similarity index 100% rename from nos/experimental/train/config.py rename to nos/server/train/config.py diff --git a/nos/experimental/train/dreambooth/config.py b/nos/server/train/dreambooth/config.py similarity index 89% rename from nos/experimental/train/dreambooth/config.py rename to nos/server/train/dreambooth/config.py index dae13757..c298e25f 100644 --- a/nos/experimental/train/dreambooth/config.py +++ b/nos/server/train/dreambooth/config.py @@ -7,13 +7,15 @@ from typing import Any, Dict from nos.common.git import cached_repo -from nos.experimental.train.config import TrainingJobConfig +from nos.constants import NOS_HOME from nos.logging import logger from nos.models.dreambooth.dreambooth import StableDiffusionDreamboothConfigs +from nos.server.train.config import TrainingJobConfig GIT_TAG = "v0.20.1" +NOS_VOLUME_DIR = NOS_HOME / "volumes" RUNTIME_ENVS = { "diffusers-latest": { "working_dir": "./nos/experimental/", @@ -71,11 +73,14 @@ def __post_init__(self): runtime_env["working_dir"] = self.repo_directory # Create a new short unique name using method and uuid (with 8 characters) - job_id = f"{self.method}_{uuid.uuid4().hex[:8]}" - self.job_config = TrainingJobConfig(uuid=job_id, runtime_env=runtime_env) + model_id = f"{self.method}_{uuid.uuid4().hex[:8]}" + self.job_config = TrainingJobConfig(uuid=model_id, runtime_env=runtime_env) + job_id = self.job_config.uuid working_directory = Path(self.job_config.working_directory) # Copy the instance directory to the working directory + self.instance_directory = NOS_VOLUME_DIR / self.instance_directory + logger.debug(f"Instance directory [dir={self.instance_directory}]") if not Path(self.instance_directory).exists(): raise IOError(f"Failed to load instance_directory={self.instance_directory}.") instance_directory = working_directory / "instances" @@ -104,7 +109,7 @@ def entrypoint(self): f""" --resolution={self.resolution}""" f""" --train_batch_size=1""" f""" --gradient_accumulation_steps=1""" - f""" --checkpointing_steps=100""" + f""" --checkpointing_steps={self.max_train_steps // 5}""" f""" --learning_rate=1e-4""" f''' --lr_scheduler="constant"''' f""" --lr_warmup_steps=0""" diff --git a/requirements/requirements.server.txt b/requirements/requirements.server.txt index 7a222391..4be6cc4d 100644 --- a/requirements/requirements.server.txt +++ b/requirements/requirements.server.txt @@ -3,7 +3,7 @@ diffusers>=0.17.1 huggingface_hub memray pyarrow>=12.0.0 -ray>=2.6.1 +ray[default]>=2.6.1 safetensors>=0.3.0 tabulate timm>=0.9.2 diff --git a/scripts/entrypoint.sh b/scripts/entrypoint.sh index e70c162a..c7967271 100755 --- a/scripts/entrypoint.sh +++ b/scripts/entrypoint.sh @@ -2,9 +2,11 @@ set -e set -x -echo "Starting Ray server with OMP_NUM_THREADS=${OMP_NUM_THREADS}..." +# Get number of cores +NCORES=$(nproc --all) +echo "Starting Ray server with OMP_NUM_THREADS=${OMP_NUM_THREADS:-${NCORES}}..." # Get OMP_NUM_THREADS from environment variable, if set otherwise use 1 -OMP_NUM_THREADS=${OMP_NUM_THREADS} ray start --head +OMP_NUM_THREADS=${OMP_NUM_THREADS:-${NCORES}} ray start --head echo "Starting NOS server..." nos-grpc-server diff --git a/tests/client/grpc/test_grpc_client.py b/tests/client/grpc/test_grpc_client.py index a3e892b7..b301b1d7 100644 --- a/tests/client/grpc/test_grpc_client.py +++ b/tests/client/grpc/test_grpc_client.py @@ -27,3 +27,19 @@ def predict_module_wrap(): predict_fn = dumps(predict_module_wrap) assert isinstance(predict_fn, bytes) + + def train_wrap(): + return grpc_client.Train( + method="stable-diffusion-dreambooth-lora", + inputs={ + "model_name": "stabilityai/stable-diffusion-2-1", + "instance_directory": "/tmp", + "instance_prompt": "A photo of a bench on the moon", + }, + metadata={ + "name": "sdv21-dreambooth-lora-test-bench", + }, + ) + + train_fn = dumps(train_wrap) + assert isinstance(train_fn, bytes) diff --git a/tests/client/test_client_integration.py b/tests/client/test_client_integration.py index a158afff..27cd595a 100644 --- a/tests/client/test_client_integration.py +++ b/tests/client/test_client_integration.py @@ -13,7 +13,7 @@ @pytest.mark.client @pytest.mark.parametrize("runtime", ["cpu", "gpu", "auto"]) -def test_nos_init(runtime): # noqa: F811 +def test_client_init(runtime): # noqa: F811 """Test the NOS server daemon initialization.""" # Initialize the server @@ -44,8 +44,8 @@ def test_nos_init(runtime): # noqa: F811 @pytest.mark.client @pytest.mark.benchmark(group=PyTestGroup.INTEGRATION) @pytest.mark.parametrize("runtime", ["gpu"]) -def test_nos_object_detection_benchmark(runtime): # noqa: F811 - """Test the NOS server daemon initialization.""" +def test_client_inference_benchmark(runtime): # noqa: F811 + """Test and benchmark end-to-end client inference interface.""" from itertools import islice import cv2 @@ -55,7 +55,7 @@ def test_nos_object_detection_benchmark(runtime): # noqa: F811 from nos.logging import logger # Initialize the server - container = nos.init(runtime=runtime, port=GRPC_PORT, utilization=0.8) + container = nos.init(runtime=runtime, port=GRPC_PORT, utilization=1) assert container is not None assert container.id is not None containers = InferenceServiceRuntime.list() @@ -103,3 +103,65 @@ def test_nos_object_detection_benchmark(runtime): # noqa: F811 nos.shutdown() containers = InferenceServiceRuntime.list() assert len(containers) == 0 + + +@pytest.mark.client +@pytest.mark.benchmark(group=PyTestGroup.INTEGRATION) +@pytest.mark.parametrize( + "client_with_server", + ("local_grpc_client_with_server",), +) +def test_client_training(client_with_server, request): # noqa: F811 + """Test end-to-end client training interface.""" + import shutil + from pathlib import Path + + from nos.common import TaskType + from nos.logging import logger + from nos.test.utils import NOS_TEST_IMAGE + + # Test waiting for server to start + # This call should be instantaneous as the server is already ready for the test + client = request.getfixturevalue(client_with_server) + assert client.IsHealthy() + + # Create a temporary volume for training images + volume_dir = client.Volume("dreambooth_training") + + logger.debug("Testing training service...") + + # Copy test image to volume and test training service + tmp_image = Path(volume_dir) / "test_image.jpg" + shutil.copy(NOS_TEST_IMAGE, tmp_image) + + # Train a new LoRA model with the image of a bench + response = client.Train( + method="stable-diffusion-dreambooth-lora", + inputs={ + "model_name": "stabilityai/stable-diffusion-2-1", + "instance_directory": volume_dir, + "instance_prompt": "A photo of a bench on the moon", + "max_train_steps": 10, + }, + ) + assert response is not None + model_id = response["job_id"] + logger.debug(f"Training service test completed [model_id={model_id}].") + + # Wait for the model to be ready + # For e.g. model_id = "stable-diffusion-dreambooth-lora_16cd4490" + # model_id = "stable-diffusion-dreambooth-lora_ef939db5" + response = client.Wait(job_id=model_id, timeout=600, retry_interval=5) + logger.debug(f"Training service test completed [model_id={model_id}, response={response}].") + time.sleep(10) + + # Test inference with the trained model + logger.debug("Testing inference service...") + response = client.Run( + task=TaskType.IMAGE_GENERATION, + model_name=f"custom/{model_id}", + prompts=["a photo of a bench on the moon"], + width=512, + height=512, + num_images=1, + ) diff --git a/tests/server/test_inference_service.py b/tests/server/test_inference_service.py index 575738be..0059c55b 100644 --- a/tests/server/test_inference_service.py +++ b/tests/server/test_inference_service.py @@ -127,6 +127,7 @@ def test_shm_registry(client_with_server, request): # noqa: F811 # assert len(shm_files) == 0, "Expected no shared memory regions, but found some." +@pytest.mark.benchmark @pytest.mark.parametrize( "client_with_server", ("local_grpc_client_with_server", "grpc_client_with_cpu_backend", "grpc_client_with_gpu_backend"), diff --git a/tests/server/test_training_service.py b/tests/server/test_training_service.py index 4d9f804f..967600e7 100644 --- a/tests/server/test_training_service.py +++ b/tests/server/test_training_service.py @@ -15,7 +15,7 @@ def test_training_service(ray_executor: RayExecutor): # noqa: F811 """Test training service.""" - from nos.experimental.train import TrainingService + from nos.server.train import TrainingService # Test training service svc = TrainingService() @@ -27,7 +27,7 @@ def test_training_service(ray_executor: RayExecutor): # noqa: F811 job_id = svc.train( method="stable-diffusion-dreambooth-lora", - training_inputs={ + inputs={ "model_name": "stabilityai/stable-diffusion-2-1", "instance_directory": tmp_dir, "instance_prompt": "A photo of a bench on the moon", @@ -57,26 +57,8 @@ def test_training_service(ray_executor: RayExecutor): # noqa: F811 assert status is not None logger.debug(f"Status for job {job_id}: {status}") - -def test_inference_service_with_trained_model(ray_executor: RayExecutor): # noqa: F811 - """Test inference service.""" - from nos.server._service import InferenceService - - # Test training service - InferenceService() - # job_id = svc.execute( - # method="stable-diffusion-dreambooth-lora", - # inference_inputs={ - # "model_name": "stabilityai/stable-diffusion-2-1", - # "instance_directory": tmp_dir, - # "instance_prompt": "A photo of a bench on the moon", - # "resolution": 512, - # "max_train_steps": 100, - # "seed": 0, - # }, - # metadata={ - # "name": "sdv21-dreambooth-lora-test-bench", - # }, - # ) - # assert job_id is not None - # logger.debug(f"Submitted job with id: {job_id}") + # Wait for the job to complete + status = svc.jobs.wait(job_id, timeout=600, retry_interval=5) + assert status is not None + logger.debug(f"Status for job {job_id}: {status}") + assert status == "SUCCEEDED"