Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set up nos bot to save training images #313

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions makefiles/Makefile.base.mk
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,6 @@ docker-compose-upd-cpu: docker-build-cpu

docker-compose-upd-gpu: docker-build-gpu
docker compose -f docker-compose.gpu.yml up

docker-compose-upd-discord-bot: docker-build-gpu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to a standalone Makefile under examples/discord

docker compose -f docker-compose.discord.yml up
53 changes: 52 additions & 1 deletion nos/experimental/discord/nos_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

import nos
from nos.client import InferenceClient, TaskType
from nos.constants import NOS_TMP_DIR


# Init nos server, wait for it to spin up then confirm its healthy:
nos.init(runtime="gpu")
nos_client = InferenceClient()
nos_client.WaitForServer()
if not nos_client.IsHealthy():
Expand All @@ -24,9 +24,13 @@
# Create our bot:
bot = commands.Bot(command_prefix="$", intents=intents)

TRAINING_CHANNEL_NAME = "training"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nos-playground?

NOS_TRAINING_DIR = NOS_TMP_DIR / "train"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOS_TMP_DIR / "discord/{CHANNEL_NAME}/train"


# Create a callback to read messages and generate images from prompt:
@bot.command()
async def generate(ctx, *, prompt):
# pull the channel id so we know which model to run:
response = nos_client.Run(
TaskType.IMAGE_GENERATION,
"stabilityai/stable-diffusion-2",
Expand All @@ -44,6 +48,53 @@ async def generate(ctx, *, prompt):
await ctx.send(file=discord.File(image_bytes, filename="image.png"))


@bot.command()
async def train(ctx):
# check that its in the training channel
if ctx.channel.name != TRAINING_CHANNEL_NAME:
print("not in training channel, returning!")
return

if not ctx.message.attachments:
print("no attachments to train on, returning!")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit logger.debug

return

# create a thread for this training job:
thread_name = str(ctx.message.id)
thread = await ctx.channel.create_thread(name=thread_name, type=discord.ChannelType.public_thread)

await thread.send(f"Created a new thread: {thread.name}")

dirname = NOS_TRAINING_DIR / thread_name
dirname.mkdir(parents=True, exist_ok=True)

await thread.send("saving at dir: " + str(dirname))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit Use fstrings as much as possible `send(f"saving at dir: {dirname}")


# save the attachments
for attachment in ctx.message.attachments:
print(f"got attachement: {attachment.filename}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use logger.debug

await attachment.save(os.path.join(dirname, attachment.filename))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Pathlib operators instead dirname / attachment.filename

await thread.send(f"Image {attachment.filename} saved!")

# Kick off a nos training run
from nos.server._service import TrainingService
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to top of file


svc = TrainingService()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to top of file, no need to instantiate it every time we run training

job_id = svc.train(
method="stable-diffusion-dreambooth-lora",
training_inputs={
"model_name": "stabilityai/stable-diffusion-2-1",
"instance_directory": dirname,
},
metadata={
"name": "sdv21-dreambooth-lora-test-bench",
},
)
assert job_id is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise an error so that the discord bot user gets some internal server error


thread.send(f"Started training job: {job_id}")


# Pull API token out of environment and run the bot:
bot_token = os.environ.get("BOT_TOKEN")
if bot_token is None:
Expand Down
2 changes: 1 addition & 1 deletion scripts/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ echo "Starting Ray server with OMP_NUM_THREADS=${OMP_NUM_THREADS}..."
OMP_NUM_THREADS=${OMP_NUM_THREADS} ray start --head

echo "Starting NOS server..."
nos-grpc-server
nos-grpc-server && python ./nos_bot.py
Loading