From ed4c5136c54977112397898a44b2c7e792cfe22d Mon Sep 17 00:00:00 2001 From: Sudeep Pillai Date: Thu, 31 Aug 2023 18:28:34 -0700 Subject: [PATCH] Discord bot with fine-tuning API - added tests for fine-tuning API --- examples/discord/bot.py | 58 ++++++++++++------- examples/discord/docker-compose.yml | 34 +++++------ nos/executors/ray.py | 17 +++++- nos/experimental/discord/nos_bot.py | 52 ----------------- nos/server/_service.py | 2 +- nos/server/train/__init__.py | 2 +- .../train/{_train_service.py => _service.py} | 3 +- tests/server/test_training_service.py | 30 ++-------- 8 files changed, 79 insertions(+), 119 deletions(-) delete mode 100755 nos/experimental/discord/nos_bot.py rename nos/server/train/{_train_service.py => _service.py} (98%) diff --git a/examples/discord/bot.py b/examples/discord/bot.py index 1e064b4b..aed34cd4 100644 --- a/examples/discord/bot.py +++ b/examples/discord/bot.py @@ -1,3 +1,4 @@ +"""Example discord both with Stable Diffusion LoRA fine-tuning support.""" import asyncio import io import os @@ -39,11 +40,9 @@ def __str__(self) -> str: NOS_PLAYGROUND_CHANNEL = "nos-playground" - BASE_MODEL = "runwayml/stable-diffusion-v1-5" -# BASE_MODEL = "stabilityai/stable-diffusion-2-1" -# Init nos server, wait for it to spin up then confirm its healthy: +# Init NOS server, wait for it to spin up then confirm its healthy. client = InferenceClient() logger.debug("Waiting for server to start...") @@ -58,22 +57,34 @@ def __str__(self) -> str: NOS_TRAINING_VOLUME_DIR = Path(client.Volume("nos-playground")) logger.debug(f"Creating training data volume [volume={NOS_TRAINING_VOLUME_DIR}]") -# Set permissions for our bot to allow it to read messages: +# Set permissions for our bot to allow it to read messages intents = discord.Intents.default() intents.message_content = True -# Create our bot, with the command prefix set to "/": +# Create our bot, with the command prefix set to "/" bot = commands.Bot(command_prefix="/", intents=intents) logger.debug("Starting bot, initializing existing threads ...") -# Maps channel_id -> LoRAPromptModel +# Simple persistent dict/database for storing models +# This maps (channel-id -> LoRAPromptModel) MODEL_DB = Cache(str(NOS_TMP_DIR / NOS_PLAYGROUND_CHANNEL)) @bot.command() async def generate(ctx, *, prompt): - """Create a callback to read messages and generate images from prompt""" + """Create a callback to read messages and generate images from prompt + + Usage: + 1. In the main channel, you can run: + /generate a photo of a dog on the moon + to generate an image with the pre-trained SDv1.5 model. + + 2. In a thread that has been created by fine-tuning a new model, + you can run: + /generate a photo of a sks dog on the moon + to generate the specific instance of the dog using the fine-tuned model. + """ logger.debug( f"/generate [prompt={prompt}, channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]" ) @@ -146,6 +157,13 @@ async def generate(ctx, *, prompt): @bot.command() async def train(ctx, *, prompt): + """Fine-tune a new model with the provided prompt and images. + + Example: + Upload a few images of your favorite dog, and then run: + `/train sks a photo of a sks dog on the moon` + """ + logger.debug(f"/train [channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]") if ctx.channel.name != NOS_PLAYGROUND_CHANNEL: @@ -246,28 +264,24 @@ async def post_on_training_complete_async(): image_bytes.seek(0) await _thread.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png")) - # def post_on_training_complete(): - # asyncio.run(post_on_training_complete_async()) - logger.debug(f"Starting thread to watch training job [id={thread_id}, job_id={job_id}]") - # threading.Thread(target=post_on_training_complete, daemon=True).start() asyncio.run_coroutine_threadsafe(post_on_training_complete_async(), loop) logger.debug(f"Started thread to watch training job [id={thread_id}, job_id={job_id}]") -# Pull API token out of environment and run the bot: -bot_token = os.environ.get("DISCORD_BOT_TOKEN") -if bot_token is None: - raise Exception("DISCORD_BOT_TOKEN environment variable not set") -logger.debug(f"Starting bot with token [token={bot_token[:5]}****]") -# bot.loop.run_until_complete(setup()) - - -async def run_bot(): - await bot.start(bot_token) +async def run_bot(token: str): + """Start the bot with the user-provided token.""" + await bot.start(token) if __name__ == "__main__": + # Get the bot token from the environment + token = os.environ.get("DISCORD_BOT_TOKEN") + if token is None: + raise Exception("DISCORD_BOT_TOKEN environment variable not set") + logger.debug(f"Starting bot with token [token={token[:5]}****]") + + # Start the asyncio event loop, and run the bot loop = asyncio.get_event_loop() - loop.create_task(run_bot()) + loop.create_task(run_bot(token)) loop.run_forever() diff --git a/examples/discord/docker-compose.yml b/examples/discord/docker-compose.yml index 09a7ea1e..336f093e 100644 --- a/examples/discord/docker-compose.yml +++ b/examples/discord/docker-compose.yml @@ -1,23 +1,23 @@ version: "3.8" services: - # bot: - # image: autonomi/nos:latest-discord-app - # build: - # context: . - # dockerfile: Dockerfile - # args: - # - BASE_IMAGE=autonomi/nos:latest-cpu - # env_file: - # - .env - # environment: - # - NOS_HOME=/app/.nos - # - NOS_LOGGING_LEVEL=DEBUG - # volumes: - # - ~/.nosd:/app/.nos - # - /dev/shm:/dev/shm - # network_mode: host - # ipc: host + bot: + image: autonomi/nos:latest-discord-app + build: + context: . + dockerfile: Dockerfile + args: + - BASE_IMAGE=autonomi/nos:latest-cpu + env_file: + - .env + environment: + - NOS_HOME=/app/.nos + - NOS_LOGGING_LEVEL=DEBUG + volumes: + - ~/.nosd:/app/.nos + - /dev/shm:/dev/shm + network_mode: host + ipc: host server: image: autonomi/nos:latest-gpu diff --git a/nos/executors/ray.py b/nos/executors/ray.py index 72c2e868..e0277ac6 100644 --- a/nos/executors/ray.py +++ b/nos/executors/ray.py @@ -211,7 +211,22 @@ def status(self, job_id: str) -> str: def logs(self, job_id: str) -> str: """Get logs for a job.""" return self.client.get_job_logs(job_id) - + + def wait(self, job_id: str, timeout: int = 600, retry_interval: int = 5) -> str: + """Wait for a job to complete.""" + status = None + st = time.time() + while time.time() - st < timeout: + status = self.status(job_id) + if str(status) == "SUCCEEDED": + logger.debug(f"Training job completed [job_id={job_id}, status={status}]") + return status + else: + logger.debug(f"Training job not completed yet [job_id={job_id}, status={status}]") + time.sleep(retry_interval) + logger.warning(f"Training job timed out [job_id={job_id}, status={status}]") + return status + def init(*args, **kwargs) -> bool: """Initialize Ray executor.""" diff --git a/nos/experimental/discord/nos_bot.py b/nos/experimental/discord/nos_bot.py deleted file mode 100755 index a2b535b2..00000000 --- a/nos/experimental/discord/nos_bot.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python - -import io -import os - -import discord -from discord.ext import commands - -import nos -from nos.client import InferenceClient, TaskType - - -# Init nos server, wait for it to spin up then confirm its healthy: -nos.init(runtime="gpu") -nos_client = InferenceClient() -nos_client.WaitForServer() -if not nos_client.IsHealthy(): - raise RuntimeError("NOS server is not healthy") - -# Set permissions for our bot to allow it to read messages: -intents = discord.Intents.default() -intents.message_content = True - -# Create our bot: -bot = commands.Bot(command_prefix="$", intents=intents) - -# Create a callback to read messages and generate images from prompt: -@bot.command() -async def generate(ctx, *, prompt): - response = nos_client.Run( - TaskType.IMAGE_GENERATION, - "stabilityai/stable-diffusion-2", - prompts=[prompt], - width=512, - height=512, - num_images=1, - ) - image = response["images"][0] - - image_bytes = io.BytesIO() - image.save(image_bytes, format="PNG") - image_bytes.seek(0) - - await ctx.send(file=discord.File(image_bytes, filename="image.png")) - - -# Pull API token out of environment and run the bot: -bot_token = os.environ.get("BOT_TOKEN") -if bot_token is None: - raise Exception("BOT_TOKEN environment variable not set") - -bot.run(bot_token) diff --git a/nos/server/_service.py b/nos/server/_service.py index 7479395f..b6a20761 100644 --- a/nos/server/_service.py +++ b/nos/server/_service.py @@ -20,7 +20,7 @@ from nos.logging import logger from nos.managers import ModelHandle, ModelManager from nos.protoc import import_module -from nos.server.train._train_service import TrainingService +from nos.server.train._service import TrainingService from nos.version import __version__ diff --git a/nos/server/train/__init__.py b/nos/server/train/__init__.py index bf8dc8e4..ab18c774 100644 --- a/nos/server/train/__init__.py +++ b/nos/server/train/__init__.py @@ -1 +1 @@ -from ._train_service import TrainingService # noqa: F401 +from ._service import TrainingService # noqa: F401 diff --git a/nos/server/train/_train_service.py b/nos/server/train/_service.py similarity index 98% rename from nos/server/train/_train_service.py rename to nos/server/train/_service.py index ce72ca37..97d4de98 100644 --- a/nos/server/train/_train_service.py +++ b/nos/server/train/_service.py @@ -6,7 +6,8 @@ from nos.executors.ray import RayExecutor, RayJobExecutor from nos.logging import logger from nos.protoc import import_module -from nos.server.train.dreambooth.config import StableDiffusionTrainingJobConfig + +from .dreambooth.config import StableDiffusionTrainingJobConfig nos_service_pb2 = import_module("nos_service_pb2") diff --git a/tests/server/test_training_service.py b/tests/server/test_training_service.py index 777f201f..967600e7 100644 --- a/tests/server/test_training_service.py +++ b/tests/server/test_training_service.py @@ -27,7 +27,7 @@ def test_training_service(ray_executor: RayExecutor): # noqa: F811 job_id = svc.train( method="stable-diffusion-dreambooth-lora", - training_inputs={ + inputs={ "model_name": "stabilityai/stable-diffusion-2-1", "instance_directory": tmp_dir, "instance_prompt": "A photo of a bench on the moon", @@ -57,26 +57,8 @@ def test_training_service(ray_executor: RayExecutor): # noqa: F811 assert status is not None logger.debug(f"Status for job {job_id}: {status}") - -def test_inference_service_with_trained_model(ray_executor: RayExecutor): # noqa: F811 - """Test inference service.""" - from nos.server._service import InferenceService - - # Test training service - InferenceService() - # job_id = svc.execute( - # method="stable-diffusion-dreambooth-lora", - # inference_inputs={ - # "model_name": "stabilityai/stable-diffusion-2-1", - # "instance_directory": tmp_dir, - # "instance_prompt": "A photo of a bench on the moon", - # "resolution": 512, - # "max_train_steps": 100, - # "seed": 0, - # }, - # metadata={ - # "name": "sdv21-dreambooth-lora-test-bench", - # }, - # ) - # assert job_id is not None - # logger.debug(f"Submitted job with id: {job_id}") + # Wait for the job to complete + status = svc.jobs.wait(job_id, timeout=600, retry_interval=5) + assert status is not None + logger.debug(f"Status for job {job_id}: {status}") + assert status == "SUCCEEDED"