Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discord bot with NOS fine-tuning API #314

Merged
merged 10 commits into from
Sep 4, 2023
Merged
2 changes: 2 additions & 0 deletions docker/.dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ dist
bdist
*.cache
*.ts

site/
1 change: 1 addition & 0 deletions examples/discord/.env.template
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DISCORD_BOT_TOKEN=
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably need a more sophisticated way to manage this going forward. Doesn't look like we are persisting this secret anywhere though which is good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's wrong with this? This is a pretty standard way of showing that a .env needs to be created with the DISCORD_BOT_TOKEN= specified.

10 changes: 10 additions & 0 deletions examples/discord/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
ARG BASE_IMAGE
FROM ${BASE_IMAGE:-autonomi/nos:latest-cpu}

WORKDIR /tmp/$PROJECT
ADD requirements.txt .
RUN pip install -r requirements.txt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how happy are we with keeping this in examples? This means that it will be included with the wheel file I believe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Top-level examples are not included in the wheel file. Only subdirectories under nos are.


WORKDIR /app
COPY bot.py .
CMD ["python", "/app/bot.py"]
7 changes: 7 additions & 0 deletions examples/discord/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SHELL := /bin/bash

docker-compose-upd-discord-bot:
sudo mkdir -p ~/.nosd/volumes
sudo chown -R $(USER):$(USER) ~/.nosd/volumes
pushd ../../ && make docker-build-cpu docker-build-gpu && popd;
docker compose -f docker-compose.yml up --build
287 changes: 287 additions & 0 deletions examples/discord/bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
"""Example discord both with Stable Diffusion LoRA fine-tuning support."""
import asyncio
import io
import os
import time
import uuid
from dataclasses import dataclass
from pathlib import Path

import discord
from discord.ext import commands
from diskcache import Cache

from nos.client import InferenceClient, TaskType
from nos.constants import NOS_TMP_DIR
from nos.logging import logger


@dataclass
class LoRAPromptModel:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this not live with the training service?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, possibly. Just kept it here for simplicity. Also this is specific to the discord bot with thread_id etc which has nothing to do with training parameters.


thread_id: str
"""Discord thread ID"""

thread_name: str
"""Discord thread name"""

model_id: str
"""Training job ID / model ID"""

prompt: str
"""Prompt used to train the model"""

@property
def job_id(self) -> str:
return self.model_id

def __str__(self) -> str:
return f"LoRAPromptModel(thread_id={self.thread_id}, thread_name={self.thread_name}, model_id={self.model_id}, prompt={self.prompt})"


NOS_PLAYGROUND_CHANNEL = "nos-playground"
BASE_MODEL = "runwayml/stable-diffusion-v1-5"

# Init NOS server, wait for it to spin up then confirm its healthy.
client = InferenceClient()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] do we want to roll these into an init function?


logger.debug("Waiting for server to start...")
client.WaitForServer()

logger.debug("Confirming server is healthy...")
if not client.IsHealthy():
raise RuntimeError("NOS server is not healthy")

logger.debug("Server is healthy!")
NOS_VOLUME_DIR = Path(client.Volume())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the training volume get blown out when we restart the server?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's persistent as long as the permissions are set correctly. We could use docker volumes instead of volume mounts.

NOS_TRAINING_VOLUME_DIR = Path(client.Volume("nos-playground"))
logger.debug(f"Creating training data volume [volume={NOS_TRAINING_VOLUME_DIR}]")

# Set permissions for our bot to allow it to read messages
intents = discord.Intents.default()
intents.message_content = True

# Create our bot, with the command prefix set to "/"
bot = commands.Bot(command_prefix="/", intents=intents)

logger.debug("Starting bot, initializing existing threads ...")

# Simple persistent dict/database for storing models
# This maps (channel-id -> LoRAPromptModel)
MODEL_DB = Cache(str(NOS_TMP_DIR / NOS_PLAYGROUND_CHANNEL))


@bot.command()
async def generate(ctx, *, prompt):
"""Create a callback to read messages and generate images from prompt

Usage:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add this to docs? don't think its necessary since its not part of the main service/API, might be better suited to a readme if we release this by itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat evolving set of docstrings for now, so better to keep them here until we have fully fleshed out demos.

1. In the main channel, you can run:
/generate a photo of a dog on the moon
to generate an image with the pre-trained SDv1.5 model.

2. In a thread that has been created by fine-tuning a new model,
you can run:
/generate a photo of a sks dog on the moon
to generate the specific instance of the dog using the fine-tuned model.
"""
logger.debug(
f"/generate [prompt={prompt}, channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]"
)

st = time.perf_counter()
if ctx.channel.name == NOS_PLAYGROUND_CHANNEL:
# Pull the channel id so we know which model to run:

# Acknowledge the request, by reacting to the message with a checkmark
spillai marked this conversation as resolved.
Show resolved Hide resolved
logger.debug(f"Request acknowledged [id={ctx.message.id}]")
await ctx.message.add_reaction("✅")

logger.debug(f"/generate request submitted [id={ctx.message.id}]")
response = client.Run(
task=TaskType.IMAGE_GENERATION,
model_name=BASE_MODEL,
prompts=[prompt],
width=512,
height=512,
num_images=1,
)
logger.debug(f"/generate request completed [id={ctx.message.id}, elapsed={time.perf_counter() - st:.2f}s]")
(image,) = response["images"]
thread = ctx
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] a bit long for a single function

# Pull the channel id so we know which model to run:
thread = ctx.channel
thread_id = ctx.channel.id
model = MODEL_DB.get(thread_id, default=None)
if model is None:
logger.debug(f"Failed to fetch model [thread_id={thread_id}]")
await thread.send("No model found")
return

# Get model info
try:
models = client.ListModels()
info = {m.name: m for m in models}[model.model_id]
logger.debug(f"Got model info [model_id={model.model_id}, info={info}]")
except Exception as e:
logger.debug(f"Failed to fetch model info [model_id={model.model_id}, e={e}]")
await thread.send(f"Failed to fetch model [model_id={model.model_id}]")
return

# Acknowledge the request, by reacting to the message with a checkmark
logger.debug(f"Request acknowledged [id={ctx.message.id}]")
await ctx.message.add_reaction("✅")

# Run inference on the trained model
logger.debug(f"/generate request submitted [id={ctx.message.id}, model_id={model.model_id}, model={model}]")
response = client.Run(
task=TaskType.IMAGE_GENERATION,
model_name=model.model_id,
prompts=[prompt],
width=512,
height=512,
num_images=1,
)
logger.debug(
f"/generate request completed [id={ctx.message.id}, model={model}, elapsed={time.perf_counter() - st:.2f}s]"
)
(image,) = response["images"]

# Save the image to a buffer and send it back to the user:
image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG")
image_bytes.seek(0)
await thread.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png"))


@bot.command()
async def train(ctx, *, prompt):
"""Fine-tune a new model with the provided prompt and images.

Example:
Upload a few images of your favorite dog, and then run:
`/train sks a photo of a sks dog on the moon`
"""

logger.debug(f"/train [channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]")

if ctx.channel.name != NOS_PLAYGROUND_CHANNEL:
logger.debug("ignoring [channel={ctx.channel.name}]")
return

if not ctx.message.attachments:
logger.debug("no attachments to train on, returning!")
return

if "sks" not in prompt:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] maybe a different tag? "sks" seems arbitrary, would prefer "INSTANCE" or "OBJECT"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both "INSTANCE" and "OBJECT" is a common word in tokenizers. We need to pick something that's unique, where the model hasn't seen this token before.

await ctx.send("Please include 'sks' in your training prompt!")
return

# Create a thread for this training job
message_id = str(ctx.message.id)
thread = await ctx.message.create_thread(name=f"{prompt} ({message_id})")
logger.debug(f"Created thread [id={thread.id}, name={thread.name}]")

# Save the thread id
thread_id = thread.id

# Create the training directory for this thread
dirname = NOS_TRAINING_VOLUME_DIR / str(thread_id)
dirname.mkdir(parents=True, exist_ok=True)
logger.debug(f"Created training directory [dirname={dirname}]")

# Save all attachments to the training directory
logger.debug(f"Saving attachments [dirname={dirname}, attachments={len(ctx.message.attachments)}]")
for attachment in ctx.message.attachments:
filename = str(dirname / f"{str(uuid.uuid4().hex[:8])}_{attachment.filename}")
await attachment.save(filename)
logger.debug(f"Saved attachment [filename={filename}]")

# Acknowledge the request, by reacting to the message with a checkmark
logger.debug(f"Request acknowledged [id={ctx.message.id}]")
await ctx.message.add_reaction("✅")

# Train a new LoRA model with the image of a bench
response = client.Train(
method="stable-diffusion-dreambooth-lora",
inputs={
"model_name": BASE_MODEL,
"instance_directory": dirname.relative_to(NOS_VOLUME_DIR),
"instance_prompt": prompt,
"max_train_steps": 500,
},
metadata={
"name": "sdv15-dreambooth-lora",
},
)
job_id = response["job_id"]
logger.debug(f"Submitted training job [id={thread_id}, response={response}, dirname={dirname}]")
await thread.send(f"@here Submitted training job [id={thread.name}, model={job_id}]")

# Save the model
MODEL_DB[thread_id] = LoRAPromptModel(
thread_id=thread_id,
thread_name=thread.name,
model_id=f"custom/{job_id}",
prompt=prompt,
)
logger.debug(f"Saved model [id={thread_id}, model={MODEL_DB[thread_id]}]")

if response is None:
logger.error(f"Failed to submit training job [id={thread_id}, response={response}, dirname={dirname}]")
await thread.send(f"Failed to train [prompt={prompt}, response={response}, dirname={dirname}]")

# Create a new thread to watch the training job
async def post_on_training_complete_async():
# Wait for the model to be ready
response = client.Wait(job_id=job_id, timeout=600, retry_interval=10)
logger.debug(f"Training completed [job_id={job_id}, response={response}].")

# Get the thread
_thread = bot.get_channel(thread_id)
await _thread.send(f"@here Training complete [id={_thread.name}, model={job_id}]")

# Wait for model to be registered after the job is complete
await asyncio.sleep(5)

# Run inference on the trained model
st = time.perf_counter()
response = client.Run(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this always does a single generation run at the end of training on the new thread?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a validation prompt to make sure it generated something reasonable.

task=TaskType.IMAGE_GENERATION,
model_name=f"custom/{job_id}",
prompts=[prompt],
width=512,
height=512,
num_images=1,
)
logger.debug(f"/generate request completed [model={job_id}, elapsed={time.perf_counter() - st:.2f}s]")
(image,) = response["images"]

# Save the image to a buffer and send it back to the user:
image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG")
image_bytes.seek(0)
await _thread.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png"))

logger.debug(f"Starting thread to watch training job [id={thread_id}, job_id={job_id}]")
asyncio.run_coroutine_threadsafe(post_on_training_complete_async(), loop)
logger.debug(f"Started thread to watch training job [id={thread_id}, job_id={job_id}]")


async def run_bot(token: str):
"""Start the bot with the user-provided token."""
await bot.start(token)


if __name__ == "__main__":
# Get the bot token from the environment
token = os.environ.get("DISCORD_BOT_TOKEN")
if token is None:
raise Exception("DISCORD_BOT_TOKEN environment variable not set")
logger.debug(f"Starting bot with token [token={token[:5]}****]")

# Start the asyncio event loop, and run the bot
loop = asyncio.get_event_loop()
loop.create_task(run_bot(token))
loop.run_forever()
36 changes: 36 additions & 0 deletions examples/discord/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
version: "3.8"

services:
bot:
image: autonomi/nos:latest-discord-app
build:
context: .
dockerfile: Dockerfile
args:
- BASE_IMAGE=autonomi/nos:latest-cpu
env_file:
- .env
environment:
- NOS_HOME=/app/.nos
- NOS_LOGGING_LEVEL=DEBUG
volumes:
- ~/.nosd:/app/.nos
- /dev/shm:/dev/shm
network_mode: host
ipc: host

server:
image: autonomi/nos:latest-gpu
environment:
- NOS_HOME=/app/.nos
- NOS_LOGGING_LEVEL=DEBUG
volumes:
- ~/.nosd:/app/.nos
- /dev/shm:/dev/shm
network_mode: host
ipc: host
deploy:
resources:
reservations:
devices:
- capabilities: [gpu]
4 changes: 4 additions & 0 deletions examples/discord/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
discord==2.3.2
discord.py==2.3.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still unclear to me why/if we need both discord and discord.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, I pulled this from your example, so I figured you would know.

diskcache
docker
Loading
Loading