Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discord bot with NOS fine-tuning API #314

Merged
merged 10 commits into from
Sep 4, 2023

Conversation

spillai
Copy link
Contributor

@spillai spillai commented Sep 1, 2023

Summary

  • adds discord bot with Dockerfile, docker-compose.yml and requirements for discord app
  • new discord training bot with new fine-tuning API for sdv2 lora
  • added tests for fine-tuning API

Related issues

Checks

  • make lint: I've run make lint to lint the changes in this PR.
  • make test: I've made sure the tests (make test-cpu or make test) are passing.
  • Additional tests:
    • Benchmark tests (when contributing new models)
    • GPU/HW tests

@spillai spillai added the demo label Sep 1, 2023
@spillai spillai added this to the NOS v0.0.10 milestone Sep 1, 2023
@spillai spillai self-assigned this Sep 1, 2023
 - added tests for fine-tuning API
@spillai spillai force-pushed the spillai/sdv2-finetuning-with-discord branch from 1c9ffac to ed4c513 Compare September 1, 2023 01:45
Copy link
Contributor

@outtanames outtanames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this can be merged as is. Training interface will need cleanup over time.



@dataclass
class LoRAPromptModel:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this not live with the training service?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, possibly. Just kept it here for simplicity. Also this is specific to the discord bot with thread_id etc which has nothing to do with training parameters.

@@ -0,0 +1,4 @@
discord==2.3.2
discord.py==2.3.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still unclear to me why/if we need both discord and discord.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, I pulled this from your example, so I figured you would know.

@@ -0,0 +1 @@
DISCORD_BOT_TOKEN=
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably need a more sophisticated way to manage this going forward. Doesn't look like we are persisting this secret anywhere though which is good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's wrong with this? This is a pretty standard way of showing that a .env needs to be created with the DISCORD_BOT_TOKEN= specified.


WORKDIR /tmp/$PROJECT
ADD requirements.txt .
RUN pip install -r requirements.txt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how happy are we with keeping this in examples? This means that it will be included with the wheel file I believe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Top-level examples are not included in the wheel file. Only subdirectories under nos are.

raise RuntimeError("NOS server is not healthy")

logger.debug("Server is healthy!")
NOS_VOLUME_DIR = Path(client.Volume())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the training volume get blown out when we restart the server?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's persistent as long as the permissions are set correctly. We could use docker volumes instead of volume mounts.

logger.debug("no attachments to train on, returning!")
return

if "sks" not in prompt:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] maybe a different tag? "sks" seems arbitrary, would prefer "INSTANCE" or "OBJECT"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both "INSTANCE" and "OBJECT" is a common word in tokenizers. We need to pick something that's unique, where the model hasn't seen this token before.


# Run inference on the trained model
st = time.perf_counter()
response = client.Run(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this always does a single generation run at the end of training on the new thread?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a validation prompt to make sure it generated something reasonable.

except grpc.RpcError as e:
raise NosClientException(f"Failed to train model (details={(e.details())})", e)

def Volume(self, name: str = None) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this live in grpc? maybe utils?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a client utility, so we have everything under grpc for now.

@@ -3,7 +3,7 @@ diffusers>=0.17.1
huggingface_hub
memray
pyarrow>=12.0.0
ray>=2.6.1
ray[default]>=2.6.1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this now for the ray jobs submission client, will add some docs to explain this.

@@ -2,9 +2,11 @@
set -e
set -x

echo "Starting Ray server with OMP_NUM_THREADS=${OMP_NUM_THREADS}..."
# Get number of cores
NCORES=$(nproc --all)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure we want to be maxing this out to all cores?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. The nos.init() has a utilization kwarg that optionally allows users to specify <100% cpu core utilization

BASE_MODEL = "runwayml/stable-diffusion-v1-5"

# Init NOS server, wait for it to spin up then confirm its healthy.
client = InferenceClient()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] do we want to roll these into an init function?

async def generate(ctx, *, prompt):
"""Create a callback to read messages and generate images from prompt

Usage:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add this to docs? don't think its necessary since its not part of the main service/API, might be better suited to a readme if we release this by itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat evolving set of docstrings for now, so better to keep them here until we have fully fleshed out demos.

@@ -65,6 +65,7 @@ message PingResponse {
// Service information repsonse
message ServiceInfoResponse {
string version = 1; // (e.g. "0.1.0")
string runtime = 2; // (e.g. "cpu", "gpu", "local" etc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we version the API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We also check if the server/client versions are consistent with this ServiceInfo routine.

hooks = {"on_completed": (register_model, (job_id,), {})}

# Spawn a thread to monitor the job
def monitor_job_hook(job_id: str, timeout: int = 600, retry_interval: int = 5):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way to do this without explicitly monitoring the training run?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need a bunch of post-training hooks to do a lot of book-keeping (model registry, upload to hub etc). For now this is a placeholder hook for registering custom models.

@spillai spillai merged commit 74bf39f into main Sep 4, 2023
1 check passed
@spillai spillai deleted the spillai/sdv2-finetuning-with-discord branch September 5, 2023 04:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants