-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Discord bot with NOS fine-tuning API #314
Changes from all commits
d8321ef
5a94567
6e1f87e
6ee8363
8609402
2ff9cd9
4150c12
d1da82e
d14f302
ed4c513
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,3 +24,5 @@ dist | |
bdist | ||
*.cache | ||
*.ts | ||
|
||
site/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
DISCORD_BOT_TOKEN= | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
ARG BASE_IMAGE | ||
FROM ${BASE_IMAGE:-autonomi/nos:latest-cpu} | ||
|
||
WORKDIR /tmp/$PROJECT | ||
ADD requirements.txt . | ||
RUN pip install -r requirements.txt | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how happy are we with keeping this in examples? This means that it will be included with the wheel file I believe? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Top-level |
||
|
||
WORKDIR /app | ||
COPY bot.py . | ||
CMD ["python", "/app/bot.py"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
SHELL := /bin/bash | ||
|
||
docker-compose-upd-discord-bot: | ||
sudo mkdir -p ~/.nosd/volumes | ||
sudo chown -R $(USER):$(USER) ~/.nosd/volumes | ||
pushd ../../ && make docker-build-cpu docker-build-gpu && popd; | ||
docker compose -f docker-compose.yml up --build |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,287 @@ | ||
"""Example discord both with Stable Diffusion LoRA fine-tuning support.""" | ||
import asyncio | ||
import io | ||
import os | ||
import time | ||
import uuid | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
|
||
import discord | ||
from discord.ext import commands | ||
from diskcache import Cache | ||
|
||
from nos.client import InferenceClient, TaskType | ||
from nos.constants import NOS_TMP_DIR | ||
from nos.logging import logger | ||
|
||
|
||
@dataclass | ||
class LoRAPromptModel: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this not live with the training service? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, possibly. Just kept it here for simplicity. Also this is specific to the discord bot with |
||
|
||
thread_id: str | ||
"""Discord thread ID""" | ||
|
||
thread_name: str | ||
"""Discord thread name""" | ||
|
||
model_id: str | ||
"""Training job ID / model ID""" | ||
|
||
prompt: str | ||
"""Prompt used to train the model""" | ||
|
||
@property | ||
def job_id(self) -> str: | ||
return self.model_id | ||
|
||
def __str__(self) -> str: | ||
return f"LoRAPromptModel(thread_id={self.thread_id}, thread_name={self.thread_name}, model_id={self.model_id}, prompt={self.prompt})" | ||
|
||
|
||
NOS_PLAYGROUND_CHANNEL = "nos-playground" | ||
BASE_MODEL = "runwayml/stable-diffusion-v1-5" | ||
|
||
# Init NOS server, wait for it to spin up then confirm its healthy. | ||
client = InferenceClient() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nit] do we want to roll these into an init function? |
||
|
||
logger.debug("Waiting for server to start...") | ||
client.WaitForServer() | ||
|
||
logger.debug("Confirming server is healthy...") | ||
if not client.IsHealthy(): | ||
raise RuntimeError("NOS server is not healthy") | ||
|
||
logger.debug("Server is healthy!") | ||
NOS_VOLUME_DIR = Path(client.Volume()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the training volume get blown out when we restart the server? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, it's persistent as long as the permissions are set correctly. We could use docker volumes instead of volume mounts. |
||
NOS_TRAINING_VOLUME_DIR = Path(client.Volume("nos-playground")) | ||
logger.debug(f"Creating training data volume [volume={NOS_TRAINING_VOLUME_DIR}]") | ||
|
||
# Set permissions for our bot to allow it to read messages | ||
intents = discord.Intents.default() | ||
intents.message_content = True | ||
|
||
# Create our bot, with the command prefix set to "/" | ||
bot = commands.Bot(command_prefix="/", intents=intents) | ||
|
||
logger.debug("Starting bot, initializing existing threads ...") | ||
|
||
# Simple persistent dict/database for storing models | ||
# This maps (channel-id -> LoRAPromptModel) | ||
MODEL_DB = Cache(str(NOS_TMP_DIR / NOS_PLAYGROUND_CHANNEL)) | ||
|
||
|
||
@bot.command() | ||
async def generate(ctx, *, prompt): | ||
"""Create a callback to read messages and generate images from prompt | ||
|
||
Usage: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we add this to docs? don't think its necessary since its not part of the main service/API, might be better suited to a readme if we release this by itself. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Somewhat evolving set of docstrings for now, so better to keep them here until we have fully fleshed out demos. |
||
1. In the main channel, you can run: | ||
/generate a photo of a dog on the moon | ||
to generate an image with the pre-trained SDv1.5 model. | ||
|
||
2. In a thread that has been created by fine-tuning a new model, | ||
you can run: | ||
/generate a photo of a sks dog on the moon | ||
to generate the specific instance of the dog using the fine-tuned model. | ||
""" | ||
logger.debug( | ||
f"/generate [prompt={prompt}, channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]" | ||
) | ||
|
||
st = time.perf_counter() | ||
if ctx.channel.name == NOS_PLAYGROUND_CHANNEL: | ||
# Pull the channel id so we know which model to run: | ||
|
||
# Acknowledge the request, by reacting to the message with a checkmark | ||
spillai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.debug(f"Request acknowledged [id={ctx.message.id}]") | ||
await ctx.message.add_reaction("✅") | ||
|
||
logger.debug(f"/generate request submitted [id={ctx.message.id}]") | ||
response = client.Run( | ||
task=TaskType.IMAGE_GENERATION, | ||
model_name=BASE_MODEL, | ||
prompts=[prompt], | ||
width=512, | ||
height=512, | ||
num_images=1, | ||
) | ||
logger.debug(f"/generate request completed [id={ctx.message.id}, elapsed={time.perf_counter() - st:.2f}s]") | ||
(image,) = response["images"] | ||
thread = ctx | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nit] a bit long for a single function |
||
# Pull the channel id so we know which model to run: | ||
thread = ctx.channel | ||
thread_id = ctx.channel.id | ||
model = MODEL_DB.get(thread_id, default=None) | ||
if model is None: | ||
logger.debug(f"Failed to fetch model [thread_id={thread_id}]") | ||
await thread.send("No model found") | ||
return | ||
|
||
# Get model info | ||
try: | ||
models = client.ListModels() | ||
info = {m.name: m for m in models}[model.model_id] | ||
logger.debug(f"Got model info [model_id={model.model_id}, info={info}]") | ||
except Exception as e: | ||
logger.debug(f"Failed to fetch model info [model_id={model.model_id}, e={e}]") | ||
await thread.send(f"Failed to fetch model [model_id={model.model_id}]") | ||
return | ||
|
||
# Acknowledge the request, by reacting to the message with a checkmark | ||
logger.debug(f"Request acknowledged [id={ctx.message.id}]") | ||
await ctx.message.add_reaction("✅") | ||
|
||
# Run inference on the trained model | ||
logger.debug(f"/generate request submitted [id={ctx.message.id}, model_id={model.model_id}, model={model}]") | ||
response = client.Run( | ||
task=TaskType.IMAGE_GENERATION, | ||
model_name=model.model_id, | ||
prompts=[prompt], | ||
width=512, | ||
height=512, | ||
num_images=1, | ||
) | ||
logger.debug( | ||
f"/generate request completed [id={ctx.message.id}, model={model}, elapsed={time.perf_counter() - st:.2f}s]" | ||
) | ||
(image,) = response["images"] | ||
|
||
# Save the image to a buffer and send it back to the user: | ||
image_bytes = io.BytesIO() | ||
image.save(image_bytes, format="PNG") | ||
image_bytes.seek(0) | ||
await thread.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png")) | ||
|
||
|
||
@bot.command() | ||
async def train(ctx, *, prompt): | ||
"""Fine-tune a new model with the provided prompt and images. | ||
|
||
Example: | ||
Upload a few images of your favorite dog, and then run: | ||
`/train sks a photo of a sks dog on the moon` | ||
""" | ||
|
||
logger.debug(f"/train [channel={ctx.channel.name}, channel_id={ctx.channel.id}, user={ctx.author.name}]") | ||
|
||
if ctx.channel.name != NOS_PLAYGROUND_CHANNEL: | ||
logger.debug("ignoring [channel={ctx.channel.name}]") | ||
return | ||
|
||
if not ctx.message.attachments: | ||
logger.debug("no attachments to train on, returning!") | ||
return | ||
|
||
if "sks" not in prompt: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nit] maybe a different tag? "sks" seems arbitrary, would prefer "INSTANCE" or "OBJECT" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both "INSTANCE" and "OBJECT" is a common word in tokenizers. We need to pick something that's unique, where the model hasn't seen this token before. |
||
await ctx.send("Please include 'sks' in your training prompt!") | ||
return | ||
|
||
# Create a thread for this training job | ||
message_id = str(ctx.message.id) | ||
thread = await ctx.message.create_thread(name=f"{prompt} ({message_id})") | ||
logger.debug(f"Created thread [id={thread.id}, name={thread.name}]") | ||
|
||
# Save the thread id | ||
thread_id = thread.id | ||
|
||
# Create the training directory for this thread | ||
dirname = NOS_TRAINING_VOLUME_DIR / str(thread_id) | ||
dirname.mkdir(parents=True, exist_ok=True) | ||
logger.debug(f"Created training directory [dirname={dirname}]") | ||
|
||
# Save all attachments to the training directory | ||
logger.debug(f"Saving attachments [dirname={dirname}, attachments={len(ctx.message.attachments)}]") | ||
for attachment in ctx.message.attachments: | ||
filename = str(dirname / f"{str(uuid.uuid4().hex[:8])}_{attachment.filename}") | ||
await attachment.save(filename) | ||
logger.debug(f"Saved attachment [filename={filename}]") | ||
|
||
# Acknowledge the request, by reacting to the message with a checkmark | ||
logger.debug(f"Request acknowledged [id={ctx.message.id}]") | ||
await ctx.message.add_reaction("✅") | ||
|
||
# Train a new LoRA model with the image of a bench | ||
response = client.Train( | ||
method="stable-diffusion-dreambooth-lora", | ||
inputs={ | ||
"model_name": BASE_MODEL, | ||
"instance_directory": dirname.relative_to(NOS_VOLUME_DIR), | ||
"instance_prompt": prompt, | ||
"max_train_steps": 500, | ||
}, | ||
metadata={ | ||
"name": "sdv15-dreambooth-lora", | ||
}, | ||
) | ||
job_id = response["job_id"] | ||
logger.debug(f"Submitted training job [id={thread_id}, response={response}, dirname={dirname}]") | ||
await thread.send(f"@here Submitted training job [id={thread.name}, model={job_id}]") | ||
|
||
# Save the model | ||
MODEL_DB[thread_id] = LoRAPromptModel( | ||
thread_id=thread_id, | ||
thread_name=thread.name, | ||
model_id=f"custom/{job_id}", | ||
prompt=prompt, | ||
) | ||
logger.debug(f"Saved model [id={thread_id}, model={MODEL_DB[thread_id]}]") | ||
|
||
if response is None: | ||
logger.error(f"Failed to submit training job [id={thread_id}, response={response}, dirname={dirname}]") | ||
await thread.send(f"Failed to train [prompt={prompt}, response={response}, dirname={dirname}]") | ||
|
||
# Create a new thread to watch the training job | ||
async def post_on_training_complete_async(): | ||
# Wait for the model to be ready | ||
response = client.Wait(job_id=job_id, timeout=600, retry_interval=10) | ||
logger.debug(f"Training completed [job_id={job_id}, response={response}].") | ||
|
||
# Get the thread | ||
_thread = bot.get_channel(thread_id) | ||
await _thread.send(f"@here Training complete [id={_thread.name}, model={job_id}]") | ||
|
||
# Wait for model to be registered after the job is complete | ||
await asyncio.sleep(5) | ||
|
||
# Run inference on the trained model | ||
st = time.perf_counter() | ||
response = client.Run( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so this always does a single generation run at the end of training on the new thread? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it's a validation prompt to make sure it generated something reasonable. |
||
task=TaskType.IMAGE_GENERATION, | ||
model_name=f"custom/{job_id}", | ||
prompts=[prompt], | ||
width=512, | ||
height=512, | ||
num_images=1, | ||
) | ||
logger.debug(f"/generate request completed [model={job_id}, elapsed={time.perf_counter() - st:.2f}s]") | ||
(image,) = response["images"] | ||
|
||
# Save the image to a buffer and send it back to the user: | ||
image_bytes = io.BytesIO() | ||
image.save(image_bytes, format="PNG") | ||
image_bytes.seek(0) | ||
await _thread.send(f"{prompt}", file=discord.File(image_bytes, filename=f"{ctx.message.id}.png")) | ||
|
||
logger.debug(f"Starting thread to watch training job [id={thread_id}, job_id={job_id}]") | ||
asyncio.run_coroutine_threadsafe(post_on_training_complete_async(), loop) | ||
logger.debug(f"Started thread to watch training job [id={thread_id}, job_id={job_id}]") | ||
|
||
|
||
async def run_bot(token: str): | ||
"""Start the bot with the user-provided token.""" | ||
await bot.start(token) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Get the bot token from the environment | ||
token = os.environ.get("DISCORD_BOT_TOKEN") | ||
if token is None: | ||
raise Exception("DISCORD_BOT_TOKEN environment variable not set") | ||
logger.debug(f"Starting bot with token [token={token[:5]}****]") | ||
|
||
# Start the asyncio event loop, and run the bot | ||
loop = asyncio.get_event_loop() | ||
loop.create_task(run_bot(token)) | ||
loop.run_forever() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
version: "3.8" | ||
|
||
services: | ||
bot: | ||
image: autonomi/nos:latest-discord-app | ||
build: | ||
context: . | ||
dockerfile: Dockerfile | ||
args: | ||
- BASE_IMAGE=autonomi/nos:latest-cpu | ||
env_file: | ||
- .env | ||
environment: | ||
- NOS_HOME=/app/.nos | ||
- NOS_LOGGING_LEVEL=DEBUG | ||
volumes: | ||
- ~/.nosd:/app/.nos | ||
- /dev/shm:/dev/shm | ||
network_mode: host | ||
ipc: host | ||
|
||
server: | ||
image: autonomi/nos:latest-gpu | ||
environment: | ||
- NOS_HOME=/app/.nos | ||
- NOS_LOGGING_LEVEL=DEBUG | ||
volumes: | ||
- ~/.nosd:/app/.nos | ||
- /dev/shm:/dev/shm | ||
network_mode: host | ||
ipc: host | ||
deploy: | ||
resources: | ||
reservations: | ||
devices: | ||
- capabilities: [gpu] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
discord==2.3.2 | ||
discord.py==2.3.2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. still unclear to me why/if we need both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure, I pulled this from your example, so I figured you would know. |
||
diskcache | ||
docker |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably need a more sophisticated way to manage this going forward. Doesn't look like we are persisting this secret anywhere though which is good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's wrong with this? This is a pretty standard way of showing that a
.env
needs to be created with theDISCORD_BOT_TOKEN=
specified.