Skip to content

Commit

Permalink
Discord bot with NOS fine-tuning API (#314)
Browse files Browse the repository at this point in the history
## Summary
- adds discord bot with Dockerfile, docker-compose.yml and requirements
for discord app
- new discord training bot with new fine-tuning API for sdv2 lora
- added tests for fine-tuning API

<!-- Thank you for your contribution! Please review
https://github.com/autonomi-ai/nos/blob/main/docs/CONTRIBUTING.md before
opening a pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issues

<!-- For example: "Closes #1234" -->

## Checks

- [x] `make lint`: I've run `make lint` to lint the changes in this PR.
- [x] `make test`: I've made sure the tests (`make test-cpu` or `make
test`) are passing.
- Additional tests:
   - [ ] Benchmark tests (when contributing new models)
   - [ ] GPU/HW tests
  • Loading branch information
spillai authored Sep 4, 2023
1 parent 5007e83 commit 74bf39f
Show file tree
Hide file tree
Showing 26 changed files with 740 additions and 188 deletions.
2 changes: 2 additions & 0 deletions docker/.dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ dist
bdist
*.cache
*.ts

site/
1 change: 1 addition & 0 deletions examples/discord/.env.template
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DISCORD_BOT_TOKEN=
10 changes: 10 additions & 0 deletions examples/discord/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
ARG BASE_IMAGE
FROM ${BASE_IMAGE:-autonomi/nos:latest-cpu}

WORKDIR /tmp/$PROJECT
ADD requirements.txt .
RUN pip install -r requirements.txt

WORKDIR /app
COPY bot.py .
CMD ["python", "/app/bot.py"]
7 changes: 7 additions & 0 deletions examples/discord/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SHELL := /bin/bash

docker-compose-upd-discord-bot:
sudo mkdir -p ~/.nosd/volumes
sudo chown -R $(USER):$(USER) ~/.nosd/volumes
pushd ../../ && make docker-build-cpu docker-build-gpu && popd;
docker compose -f docker-compose.yml up --build
287 changes: 287 additions & 0 deletions examples/discord/bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
"""Example discord both with Stable Diffusion LoRA fine-tuning support."""
import asyncio
import io
import os
import time
import uuid
from dataclasses import dataclass
from pathlib import Path

import discord
from discord.ext import commands
from diskcache import Cache

from nos.client import InferenceClient, TaskType
from nos.constants import NOS_TMP_DIR
from nos.logging import logger


@dataclass
class LoRAPromptModel:

thread_id: str
"""Discord thread ID"""

thread_name: str
"""Discord thread name"""

model_id: str
"""Training job ID / model ID"""

prompt: str
"""Prompt used to train the model"""

@property
def job_id(self) -> str:
return self.model_id

def __str__(self) -> str:
return f"LoRAPromptModel(thread_id={self.thread_id}, thread_name={self.thread_name}, model_id={self.model_id}, prompt={self.prompt})"


NOS_PLAYGROUND_CHANNEL = "nos-playground"
BASE_MODEL = "runwayml/stable-diffusion-v1-5"

# Init NOS server, wait for it to spin up then confirm its healthy.
client = InferenceClient()

logger.debug("Waiting for server to start...")
client.WaitForServer()

logger.debug("Confirming server is healthy...")
if not client.IsHealthy():
raise RuntimeError("NOS server is not healthy")

logger.debug("Server is healthy!")
NOS_VOLUME_DIR = Path(client.Volume())
NOS_TRAINING_VOLUME_DIR = Path(client.Volume("nos-playground"))
logger.debug(f"Creating training data volume [volume={NOS_TRAINING_VOLUME_DIR}]")

# Set permissions for our bot to allow it to read messages
intents = discord.Intents.default()
intents.message_content = True

# Create our bot, with the command prefix set to "/"
bot = commands.Bot(command_prefix="/", intents=intents)

logger.debug("Starting bot, initializing existing threads ...")

# Simple persistent dict/database for storing models
# This maps (channel-id -> LoRAPromptModel)
MODEL_DB = Cache(str(NOS_TMP_DIR / NOS_PLAYGROUND_CHANNEL))


@bot.command()
async def generate(ctx, *, prompt):
"""Create a callback to read messages and generate images from prompt
Usage:
1. In the main channel, you can run:
/generate a photo of a dog on the moon
to generate an image with the pre-trained SDv1.5 model.
2. In a thread that has been created by fine-tuning a new model,
you can run:
/generate a photo of a sks dog on the moon
to generate the specific instance of the dog using the fine-tuned model.
"""
logger.debug(
f"/generate [prompt={prompt}, channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]"
)

st = time.perf_counter()
if ctx.channel.name == NOS_PLAYGROUND_CHANNEL:
# Pull the channel id so we know which model to run:

# Acknowledge the request, by reacting to the message with a checkmark
logger.debug(f"Request acknowledged [id={ctx.message.id}]")
await ctx.message.add_reaction("✅")

logger.debug(f"/generate request submitted [id={ctx.message.id}]")
response = client.Run(
task=TaskType.IMAGE_GENERATION,
model_name=BASE_MODEL,
prompts=[prompt],
width=512,
height=512,
num_images=1,
)
logger.debug(f"/generate request completed [id={ctx.message.id}, elapsed={time.perf_counter() - st:.2f}s]")
(image,) = response["images"]
thread = ctx
else:
# Pull the channel id so we know which model to run:
thread = ctx.channel
thread_id = ctx.channel.id
model = MODEL_DB.get(thread_id, default=None)
if model is None:
logger.debug(f"Failed to fetch model [thread_id={thread_id}]")
await thread.send("No model found")
return

# Get model info
try:
models = client.ListModels()
info = {m.name: m for m in models}[model.model_id]
logger.debug(f"Got model info [model_id={model.model_id}, info={info}]")
except Exception as e:
logger.debug(f"Failed to fetch model info [model_id={model.model_id}, e={e}]")
await thread.send(f"Failed to fetch model [model_id={model.model_id}]")
return

# Acknowledge the request, by reacting to the message with a checkmark
logger.debug(f"Request acknowledged [id={ctx.message.id}]")
await ctx.message.add_reaction("✅")

# Run inference on the trained model
logger.debug(f"/generate request submitted [id={ctx.message.id}, model_id={model.model_id}, model={model}]")
response = client.Run(
task=TaskType.IMAGE_GENERATION,
model_name=model.model_id,
prompts=[prompt],
width=512,
height=512,
num_images=1,
)
logger.debug(
f"/generate request completed [id={ctx.message.id}, model={model}, elapsed={time.perf_counter() - st:.2f}s]"
)
(image,) = response["images"]

# Save the image to a buffer and send it back to the user:
image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG")
image_bytes.seek(0)
await thread.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png"))


@bot.command()
async def train(ctx, *, prompt):
"""Fine-tune a new model with the provided prompt and images.
Example:
Upload a few images of your favorite dog, and then run:
`/train sks a photo of a sks dog on the moon`
"""

logger.debug(f"/train [channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]")

if ctx.channel.name != NOS_PLAYGROUND_CHANNEL:
logger.debug("ignoring [channel={ctx.channel.name}]")
return

if not ctx.message.attachments:
logger.debug("no attachments to train on, returning!")
return

if "sks" not in prompt:
await ctx.send("Please include 'sks' in your training prompt!")
return

# Create a thread for this training job
message_id = str(ctx.message.id)
thread = await ctx.message.create_thread(name=f"{prompt} ({message_id})")
logger.debug(f"Created thread [id={thread.id}, name={thread.name}]")

# Save the thread id
thread_id = thread.id

# Create the training directory for this thread
dirname = NOS_TRAINING_VOLUME_DIR / str(thread_id)
dirname.mkdir(parents=True, exist_ok=True)
logger.debug(f"Created training directory [dirname={dirname}]")

# Save all attachments to the training directory
logger.debug(f"Saving attachments [dirname={dirname}, attachments={len(ctx.message.attachments)}]")
for attachment in ctx.message.attachments:
filename = str(dirname / f"{str(uuid.uuid4().hex[:8])}_{attachment.filename}")
await attachment.save(filename)
logger.debug(f"Saved attachment [filename={filename}]")

# Acknowledge the request, by reacting to the message with a checkmark
logger.debug(f"Request acknowledged [id={ctx.message.id}]")
await ctx.message.add_reaction("✅")

# Train a new LoRA model with the image of a bench
response = client.Train(
method="stable-diffusion-dreambooth-lora",
inputs={
"model_name": BASE_MODEL,
"instance_directory": dirname.relative_to(NOS_VOLUME_DIR),
"instance_prompt": prompt,
"max_train_steps": 500,
},
metadata={
"name": "sdv15-dreambooth-lora",
},
)
job_id = response["job_id"]
logger.debug(f"Submitted training job [id={thread_id}, response={response}, dirname={dirname}]")
await thread.send(f"@here Submitted training job [id={thread.name}, model={job_id}]")

# Save the model
MODEL_DB[thread_id] = LoRAPromptModel(
thread_id=thread_id,
thread_name=thread.name,
model_id=f"custom/{job_id}",
prompt=prompt,
)
logger.debug(f"Saved model [id={thread_id}, model={MODEL_DB[thread_id]}]")

if response is None:
logger.error(f"Failed to submit training job [id={thread_id}, response={response}, dirname={dirname}]")
await thread.send(f"Failed to train [prompt={prompt}, response={response}, dirname={dirname}]")

# Create a new thread to watch the training job
async def post_on_training_complete_async():
# Wait for the model to be ready
response = client.Wait(job_id=job_id, timeout=600, retry_interval=10)
logger.debug(f"Training completed [job_id={job_id}, response={response}].")

# Get the thread
_thread = bot.get_channel(thread_id)
await _thread.send(f"@here Training complete [id={_thread.name}, model={job_id}]")

# Wait for model to be registered after the job is complete
await asyncio.sleep(5)

# Run inference on the trained model
st = time.perf_counter()
response = client.Run(
task=TaskType.IMAGE_GENERATION,
model_name=f"custom/{job_id}",
prompts=[prompt],
width=512,
height=512,
num_images=1,
)
logger.debug(f"/generate request completed [model={job_id}, elapsed={time.perf_counter() - st:.2f}s]")
(image,) = response["images"]

# Save the image to a buffer and send it back to the user:
image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG")
image_bytes.seek(0)
await _thread.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png"))

logger.debug(f"Starting thread to watch training job [id={thread_id}, job_id={job_id}]")
asyncio.run_coroutine_threadsafe(post_on_training_complete_async(), loop)
logger.debug(f"Started thread to watch training job [id={thread_id}, job_id={job_id}]")


async def run_bot(token: str):
"""Start the bot with the user-provided token."""
await bot.start(token)


if __name__ == "__main__":
# Get the bot token from the environment
token = os.environ.get("DISCORD_BOT_TOKEN")
if token is None:
raise Exception("DISCORD_BOT_TOKEN environment variable not set")
logger.debug(f"Starting bot with token [token={token[:5]}****]")

# Start the asyncio event loop, and run the bot
loop = asyncio.get_event_loop()
loop.create_task(run_bot(token))
loop.run_forever()
36 changes: 36 additions & 0 deletions examples/discord/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
version: "3.8"

services:
bot:
image: autonomi/nos:latest-discord-app
build:
context: .
dockerfile: Dockerfile
args:
- BASE_IMAGE=autonomi/nos:latest-cpu
env_file:
- .env
environment:
- NOS_HOME=/app/.nos
- NOS_LOGGING_LEVEL=DEBUG
volumes:
- ~/.nosd:/app/.nos
- /dev/shm:/dev/shm
network_mode: host
ipc: host

server:
image: autonomi/nos:latest-gpu
environment:
- NOS_HOME=/app/.nos
- NOS_LOGGING_LEVEL=DEBUG
volumes:
- ~/.nosd:/app/.nos
- /dev/shm:/dev/shm
network_mode: host
ipc: host
deploy:
resources:
reservations:
devices:
- capabilities: [gpu]
4 changes: 4 additions & 0 deletions examples/discord/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
discord==2.3.2
discord.py==2.3.2
diskcache
docker
Loading

0 comments on commit 74bf39f

Please sign in to comment.