Skip to content

Commit

Permalink
Discord bot with fine-tuning API
Browse files Browse the repository at this point in the history
 - added tests for fine-tuning API
  • Loading branch information
spillai committed Sep 1, 2023
1 parent d14f302 commit ed4c513
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 119 deletions.
58 changes: 36 additions & 22 deletions examples/discord/bot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Example discord both with Stable Diffusion LoRA fine-tuning support."""
import asyncio
import io
import os
Expand Down Expand Up @@ -39,11 +40,9 @@ def __str__(self) -> str:


NOS_PLAYGROUND_CHANNEL = "nos-playground"

BASE_MODEL = "runwayml/stable-diffusion-v1-5"
# BASE_MODEL = "stabilityai/stable-diffusion-2-1"

# Init nos server, wait for it to spin up then confirm its healthy:
# Init NOS server, wait for it to spin up then confirm its healthy.
client = InferenceClient()

logger.debug("Waiting for server to start...")
Expand All @@ -58,22 +57,34 @@ def __str__(self) -> str:
NOS_TRAINING_VOLUME_DIR = Path(client.Volume("nos-playground"))
logger.debug(f"Creating training data volume [volume={NOS_TRAINING_VOLUME_DIR}]")

# Set permissions for our bot to allow it to read messages:
# Set permissions for our bot to allow it to read messages
intents = discord.Intents.default()
intents.message_content = True

# Create our bot, with the command prefix set to "/":
# Create our bot, with the command prefix set to "/"
bot = commands.Bot(command_prefix="/", intents=intents)

logger.debug("Starting bot, initializing existing threads ...")

# Maps channel_id -> LoRAPromptModel
# Simple persistent dict/database for storing models
# This maps (channel-id -> LoRAPromptModel)
MODEL_DB = Cache(str(NOS_TMP_DIR / NOS_PLAYGROUND_CHANNEL))


@bot.command()
async def generate(ctx, *, prompt):
"""Create a callback to read messages and generate images from prompt"""
"""Create a callback to read messages and generate images from prompt
Usage:
1. In the main channel, you can run:
/generate a photo of a dog on the moon
to generate an image with the pre-trained SDv1.5 model.
2. In a thread that has been created by fine-tuning a new model,
you can run:
/generate a photo of a sks dog on the moon
to generate the specific instance of the dog using the fine-tuned model.
"""
logger.debug(
f"/generate [prompt={prompt}, channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]"
)
Expand Down Expand Up @@ -146,6 +157,13 @@ async def generate(ctx, *, prompt):

@bot.command()
async def train(ctx, *, prompt):
"""Fine-tune a new model with the provided prompt and images.
Example:
Upload a few images of your favorite dog, and then run:
`/train sks a photo of a sks dog on the moon`
"""

logger.debug(f"/train [channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]")

if ctx.channel.name != NOS_PLAYGROUND_CHANNEL:
Expand Down Expand Up @@ -246,28 +264,24 @@ async def post_on_training_complete_async():
image_bytes.seek(0)
await _thread.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png"))

# def post_on_training_complete():
# asyncio.run(post_on_training_complete_async())

logger.debug(f"Starting thread to watch training job [id={thread_id}, job_id={job_id}]")
# threading.Thread(target=post_on_training_complete, daemon=True).start()
asyncio.run_coroutine_threadsafe(post_on_training_complete_async(), loop)
logger.debug(f"Started thread to watch training job [id={thread_id}, job_id={job_id}]")


# Pull API token out of environment and run the bot:
bot_token = os.environ.get("DISCORD_BOT_TOKEN")
if bot_token is None:
raise Exception("DISCORD_BOT_TOKEN environment variable not set")
logger.debug(f"Starting bot with token [token={bot_token[:5]}****]")
# bot.loop.run_until_complete(setup())


async def run_bot():
await bot.start(bot_token)
async def run_bot(token: str):
"""Start the bot with the user-provided token."""
await bot.start(token)


if __name__ == "__main__":
# Get the bot token from the environment
token = os.environ.get("DISCORD_BOT_TOKEN")
if token is None:
raise Exception("DISCORD_BOT_TOKEN environment variable not set")
logger.debug(f"Starting bot with token [token={token[:5]}****]")

# Start the asyncio event loop, and run the bot
loop = asyncio.get_event_loop()
loop.create_task(run_bot())
loop.create_task(run_bot(token))
loop.run_forever()
34 changes: 17 additions & 17 deletions examples/discord/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
version: "3.8"

services:
# bot:
# image: autonomi/nos:latest-discord-app
# build:
# context: .
# dockerfile: Dockerfile
# args:
# - BASE_IMAGE=autonomi/nos:latest-cpu
# env_file:
# - .env
# environment:
# - NOS_HOME=/app/.nos
# - NOS_LOGGING_LEVEL=DEBUG
# volumes:
# - ~/.nosd:/app/.nos
# - /dev/shm:/dev/shm
# network_mode: host
# ipc: host
bot:
image: autonomi/nos:latest-discord-app
build:
context: .
dockerfile: Dockerfile
args:
- BASE_IMAGE=autonomi/nos:latest-cpu
env_file:
- .env
environment:
- NOS_HOME=/app/.nos
- NOS_LOGGING_LEVEL=DEBUG
volumes:
- ~/.nosd:/app/.nos
- /dev/shm:/dev/shm
network_mode: host
ipc: host

server:
image: autonomi/nos:latest-gpu
Expand Down
17 changes: 16 additions & 1 deletion nos/executors/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,22 @@ def status(self, job_id: str) -> str:
def logs(self, job_id: str) -> str:
"""Get logs for a job."""
return self.client.get_job_logs(job_id)


def wait(self, job_id: str, timeout: int = 600, retry_interval: int = 5) -> str:
"""Wait for a job to complete."""
status = None
st = time.time()
while time.time() - st < timeout:
status = self.status(job_id)
if str(status) == "SUCCEEDED":
logger.debug(f"Training job completed [job_id={job_id}, status={status}]")
return status
else:
logger.debug(f"Training job not completed yet [job_id={job_id}, status={status}]")
time.sleep(retry_interval)
logger.warning(f"Training job timed out [job_id={job_id}, status={status}]")
return status


def init(*args, **kwargs) -> bool:
"""Initialize Ray executor."""
Expand Down
52 changes: 0 additions & 52 deletions nos/experimental/discord/nos_bot.py

This file was deleted.

2 changes: 1 addition & 1 deletion nos/server/_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from nos.logging import logger
from nos.managers import ModelHandle, ModelManager
from nos.protoc import import_module
from nos.server.train._train_service import TrainingService
from nos.server.train._service import TrainingService
from nos.version import __version__


Expand Down
2 changes: 1 addition & 1 deletion nos/server/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._train_service import TrainingService # noqa: F401
from ._service import TrainingService # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from nos.executors.ray import RayExecutor, RayJobExecutor
from nos.logging import logger
from nos.protoc import import_module
from nos.server.train.dreambooth.config import StableDiffusionTrainingJobConfig

from .dreambooth.config import StableDiffusionTrainingJobConfig


nos_service_pb2 = import_module("nos_service_pb2")
Expand Down
30 changes: 6 additions & 24 deletions tests/server/test_training_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_training_service(ray_executor: RayExecutor): # noqa: F811

job_id = svc.train(
method="stable-diffusion-dreambooth-lora",
training_inputs={
inputs={
"model_name": "stabilityai/stable-diffusion-2-1",
"instance_directory": tmp_dir,
"instance_prompt": "A photo of a bench on the moon",
Expand Down Expand Up @@ -57,26 +57,8 @@ def test_training_service(ray_executor: RayExecutor): # noqa: F811
assert status is not None
logger.debug(f"Status for job {job_id}: {status}")


def test_inference_service_with_trained_model(ray_executor: RayExecutor): # noqa: F811
"""Test inference service."""
from nos.server._service import InferenceService

# Test training service
InferenceService()
# job_id = svc.execute(
# method="stable-diffusion-dreambooth-lora",
# inference_inputs={
# "model_name": "stabilityai/stable-diffusion-2-1",
# "instance_directory": tmp_dir,
# "instance_prompt": "A photo of a bench on the moon",
# "resolution": 512,
# "max_train_steps": 100,
# "seed": 0,
# },
# metadata={
# "name": "sdv21-dreambooth-lora-test-bench",
# },
# )
# assert job_id is not None
# logger.debug(f"Submitted job with id: {job_id}")
# Wait for the job to complete
status = svc.jobs.wait(job_id, timeout=600, retry_interval=5)
assert status is not None
logger.debug(f"Status for job {job_id}: {status}")
assert status == "SUCCEEDED"

0 comments on commit ed4c513

Please sign in to comment.