-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Discord bot with NOS fine-tuning API #314
Conversation
- added tests for fine-tuning API
1c9ffac
to
ed4c513
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think this can be merged as is. Training interface will need cleanup over time.
|
||
|
||
@dataclass | ||
class LoRAPromptModel: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this not live with the training service?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, possibly. Just kept it here for simplicity. Also this is specific to the discord bot with thread_id
etc which has nothing to do with training parameters.
@@ -0,0 +1,4 @@ | |||
discord==2.3.2 | |||
discord.py==2.3.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still unclear to me why/if we need both discord
and discord.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure, I pulled this from your example, so I figured you would know.
@@ -0,0 +1 @@ | |||
DISCORD_BOT_TOKEN= |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably need a more sophisticated way to manage this going forward. Doesn't look like we are persisting this secret anywhere though which is good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's wrong with this? This is a pretty standard way of showing that a .env
needs to be created with the DISCORD_BOT_TOKEN=
specified.
|
||
WORKDIR /tmp/$PROJECT | ||
ADD requirements.txt . | ||
RUN pip install -r requirements.txt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how happy are we with keeping this in examples? This means that it will be included with the wheel file I believe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Top-level examples
are not included in the wheel file. Only subdirectories under nos
are.
raise RuntimeError("NOS server is not healthy") | ||
|
||
logger.debug("Server is healthy!") | ||
NOS_VOLUME_DIR = Path(client.Volume()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the training volume get blown out when we restart the server?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's persistent as long as the permissions are set correctly. We could use docker volumes instead of volume mounts.
logger.debug("no attachments to train on, returning!") | ||
return | ||
|
||
if "sks" not in prompt: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] maybe a different tag? "sks" seems arbitrary, would prefer "INSTANCE" or "OBJECT"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both "INSTANCE" and "OBJECT" is a common word in tokenizers. We need to pick something that's unique, where the model hasn't seen this token before.
|
||
# Run inference on the trained model | ||
st = time.perf_counter() | ||
response = client.Run( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this always does a single generation run at the end of training on the new thread?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's a validation prompt to make sure it generated something reasonable.
except grpc.RpcError as e: | ||
raise NosClientException(f"Failed to train model (details={(e.details())})", e) | ||
|
||
def Volume(self, name: str = None) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this live in grpc? maybe utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a client utility, so we have everything under grpc
for now.
@@ -3,7 +3,7 @@ diffusers>=0.17.1 | |||
huggingface_hub | |||
memray | |||
pyarrow>=12.0.0 | |||
ray>=2.6.1 | |||
ray[default]>=2.6.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need this now for the ray jobs submission client, will add some docs to explain this.
@@ -2,9 +2,11 @@ | |||
set -e | |||
set -x | |||
|
|||
echo "Starting Ray server with OMP_NUM_THREADS=${OMP_NUM_THREADS}..." | |||
# Get number of cores | |||
NCORES=$(nproc --all) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we sure we want to be maxing this out to all cores?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. The nos.init()
has a utilization kwarg that optionally allows users to specify <100% cpu core utilization
BASE_MODEL = "runwayml/stable-diffusion-v1-5" | ||
|
||
# Init NOS server, wait for it to spin up then confirm its healthy. | ||
client = InferenceClient() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] do we want to roll these into an init function?
async def generate(ctx, *, prompt): | ||
"""Create a callback to read messages and generate images from prompt | ||
|
||
Usage: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add this to docs? don't think its necessary since its not part of the main service/API, might be better suited to a readme if we release this by itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somewhat evolving set of docstrings for now, so better to keep them here until we have fully fleshed out demos.
@@ -65,6 +65,7 @@ message PingResponse { | |||
// Service information repsonse | |||
message ServiceInfoResponse { | |||
string version = 1; // (e.g. "0.1.0") | |||
string runtime = 2; // (e.g. "cpu", "gpu", "local" etc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did we version the API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. We also check if the server/client versions are consistent with this ServiceInfo
routine.
hooks = {"on_completed": (register_model, (job_id,), {})} | ||
|
||
# Spawn a thread to monitor the job | ||
def monitor_job_hook(job_id: str, timeout: int = 600, retry_interval: int = 5): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a way to do this without explicitly monitoring the training run?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need a bunch of post-training hooks to do a lot of book-keeping (model registry, upload to hub etc). For now this is a placeholder hook for registering custom models.
Summary
Related issues
Checks
make lint
: I've runmake lint
to lint the changes in this PR.make test
: I've made sure the tests (make test-cpu
ormake test
) are passing.