Skip to content

Commit

Permalink
feat: move fine tuning in worker, add fine tuning endpoints for worker
Browse files Browse the repository at this point in the history
  • Loading branch information
okradze committed Nov 2, 2023
1 parent b288afd commit 3f7add7
Show file tree
Hide file tree
Showing 11 changed files with 295 additions and 81 deletions.
2 changes: 1 addition & 1 deletion apps/server/controllers/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def index_documents(value: str, datasource_id: UUID, account: AccountOutput):

session = create_session()

settings = ConfigModel.get_account_settings(session, account)
settings = ConfigModel.get_account_settings(session, account.id)
datasource = DatasourceModel.get_datasource_by_id(session, datasource_id, account)

try:
Expand Down
26 changes: 24 additions & 2 deletions apps/server/controllers/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from exceptions import FineTuningNotFoundException
from models.config import ConfigModel
from models.fine_tuning import FineTuningModel
from services.fine_tuning import fine_tune_openai_model
from services.fine_tuning import check_fine_tuning, fine_tune_openai_model
from typings.auth import UserAccount
from typings.fine_tuning import FineTuningInput, FineTuningOutput
from utils.auth import authenticate
Expand All @@ -17,6 +17,14 @@
router = APIRouter()


@router.post(
"/{fine_tuning_id}/check", status_code=200, response_model=FineTuningOutput
)
def check_fine_tuning_status(fine_tuning_id: UUID):
check_fine_tuning(db.session, fine_tuning_id)
return {"message": "Fine-tuning status checked"}


@router.post("", status_code=201, response_model=FineTuningOutput)
def create_fine_tuning(
fine_tuning: FineTuningInput,
Expand All @@ -30,7 +38,7 @@ def create_fine_tuning(
auth.account.id,
)

settings = ConfigModel.get_account_settings(db.session, auth.account)
settings = ConfigModel.get_account_settings(db.session, auth.account.id)

if not settings.openai_api_key:
raise HTTPException(
Expand Down Expand Up @@ -85,6 +93,20 @@ def update_fine_tuning(
raise HTTPException(status_code=404, detail="Fine-tuning not found")


@router.get("/pending", response_model=List[FineTuningOutput])
def get_pending_fine_tunings(
auth: UserAccount = Depends(authenticate),
) -> List[FineTuningOutput]:
"""
Get all pending fine-tunings for worker.
Returns:
List[FineTuningOutput]: List of pending fine-tunings associated.
"""
fine_tuning_models = FineTuningModel.get_pending_fine_tunings(db.session)
return convert_fine_tunings_to_fine_tuning_list(fine_tuning_models)


@router.get("/{id}", response_model=FineTuningOutput)
def get_fine_tuning_by_id(
id: UUID, auth: UserAccount = Depends(authenticate)
Expand Down
6 changes: 4 additions & 2 deletions apps/server/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ def get_config_by_session_id(
return config

@classmethod
def get_account_settings(cls, session: Session, account) -> AccountSettings:
def get_account_settings(
cls, session: Session, account_id: UUID
) -> AccountSettings:
keys = [
"open_api_key",
"hugging_face_access_token",
Expand All @@ -262,7 +264,7 @@ def get_account_settings(cls, session: Session, account) -> AccountSettings:
session.query(ConfigModel)
.filter(
ConfigModel.key.in_(keys),
ConfigModel.account_id == account.id,
ConfigModel.account_id == account_id,
or_(
or_(
ConfigModel.is_deleted.is_(False),
Expand Down
45 changes: 33 additions & 12 deletions apps/server/models/fine_tuning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import uuid
from typing import Optional

from sqlalchemy import UUID, Boolean, Column, ForeignKey, String
from sqlalchemy.orm import Session, relationship
Expand Down Expand Up @@ -160,26 +161,46 @@ def get_fine_tunings(cls, session: Session, account_id: UUID):
)

@classmethod
def get_fine_tuning_by_id(cls, session: Session, id: UUID, account_id: UUID):
def get_pending_fine_tunings(cls, session: Session):
return (
session.query(FineTuningModel)
.filter(
FineTuningModel.status.in_(
[
FineTuningStatus.VALIDATING.value,
FineTuningStatus.QUEUED.value,
FineTuningStatus.RUNNING.value,
]
),
FineTuningModel.is_deleted.is_(False),
)
.all()
)

@classmethod
def get_fine_tuning_by_id(
cls, session: Session, id: UUID, account_id: Optional[UUID] = None
):
"""
Get Datasource from datasource_id
Get FineTuningModel from id
Args:
session: The database session.
datasource_id(int) : Unique identifier of an Datasource.
id(UUID) : Unique identifier of a FineTuningModel.
account_id(UUID, optional) : Unique identifier of an account. Defaults to None.
Returns:
Datasource: Datasource object is returned.
FineTuningModel: FineTuningModel object is returned.
"""
fine_tuning_model = (
session.query(FineTuningModel)
.filter(
FineTuningModel.id == id,
FineTuningModel.account_id == account_id,
FineTuningModel.is_deleted.is_(False),
)
.first()
query = session.query(FineTuningModel).filter(
FineTuningModel.id == id,
FineTuningModel.is_deleted.is_(False),
)

if account_id is not None:
query = query.filter(FineTuningModel.account_id == account_id)

fine_tuning_model = query.first()
return fine_tuning_model

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion apps/server/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def process_chat_message(
run_id=run.id,
)

settings = ConfigModel.get_account_settings(db.session, provider_account)
settings = ConfigModel.get_account_settings(db.session, provider_account.id)

if len(agents) > 0:
for agent_with_configs in agents:
Expand Down
51 changes: 23 additions & 28 deletions apps/server/services/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import openai
from openai.error import AuthenticationError
from sqlalchemy.orm import Session

from models.config import ConfigModel
from models.fine_tuning import FineTuningModel
from services.aws_s3 import AWSS3Service
from typings.config import AccountSettings
Expand Down Expand Up @@ -66,43 +68,36 @@ def fine_tune_openai_model(
)

fine_tuning_model.openai_fine_tuning_id = fine_tuning_job.id
except AuthenticationError:
fine_tuning_model.error = "Invalid OpenAI API Key"
except Exception as err:
fine_tuning_model.error = str(err)

def retrieve_job():
job = openai.FineTuningJob.retrieve(
api_key=settings.openai_api_key, id=fine_tuning_job.id
)

session.refresh(fine_tuning_model)

fine_tuning_model.status = OPENAI_TO_FINE_TUNING_STATUS[job.status].value
session.commit()

is_finished = False

if fine_tuning_model.status == FineTuningStatus.COMPLETED.value:
fine_tuning_model.model_identifier = job.fine_tuned_model
is_finished = True
def check_fine_tuning(session: Session, id: UUID):
fine_tuning_model = FineTuningModel.get_fine_tuning_by_id(session, id)
settings = ConfigModel.get_account_settings(session, fine_tuning_model.account_id)

if job.error:
fine_tuning_model.error = job.error
is_finished = True
fine_tuning_job = openai.FineTuningJob.retrieve(
id=fine_tuning_model.openai_fine_tuning_id,
api_key=settings.openai_api_key,
)

session.commit()
job = openai.FineTuningJob.retrieve(
api_key=settings.openai_api_key, id=fine_tuning_job.id
)

return is_finished
fine_tuning_model.status = OPENAI_TO_FINE_TUNING_STATUS[job.status].value

while True:
is_finished = retrieve_job()
if fine_tuning_model.status == FineTuningStatus.COMPLETED.value:
fine_tuning_model.model_identifier = job.fine_tuned_model

if is_finished:
break
if job.error:
fine_tuning_model.error = job.error

time.sleep(60)
except AuthenticationError:
fine_tuning_model.error = "Invalid OpenAI API Key"
session.commit()
except Exception as err:
fine_tuning_model.error = str(err)
session.commit()
session.commit()


def convert_message_to_openai_conversation_format(message: Dict):
Expand Down
18 changes: 18 additions & 0 deletions apps/worker/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
version: "3.8"

services:
redis:
image: redis:latest
ports:
- "6379:6379"

worker:
build:
context: .
dockerfile: docker/Dockerfile
volumes:
- .:/app
ports:
- "3001:80"
depends_on:
- redis
28 changes: 0 additions & 28 deletions apps/worker/helpers.py

This file was deleted.

58 changes: 52 additions & 6 deletions apps/worker/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
from datetime import timedelta

import httpx
import requests
from celery import Celery

from config import Config
from helpers import execute_scheduled_runs

app = Celery("l3agi", include=["main"], imports=["main"])

Expand All @@ -14,26 +15,71 @@
app.conf.accept_content = ["application/x-python-serialize", "application/json"]

CELERY_BEAT_SCHEDULE = {
"execute-scheduled-runs": {
"task": "execute-scheduled-runs",
"register-scheduled-run-tasks": {
"task": "register-scheduled-run-tasks",
"schedule": timedelta(minutes=2),
},
"register-fine-tuning-tasks": {
"task": "register-fine-tuning-tasks",
"schedule": timedelta(minutes=5),
},
}


app.conf.beat_schedule = CELERY_BEAT_SCHEDULE


@app.task(
name="execute-scheduled-runs",
name="register-scheduled-run-tasks",
autoretry_for=(Exception,),
retry_backoff=2,
max_retries=5,
)
def execute_scheduled_runs_task():
print("Running scheduled agents")
res = requests.get(f"{Config.SERVER_URL}/schedule/due")
schedules_with_configs = res.json()

for schedule in schedules_with_configs:
execute_single_schedule_task.apply_async(args=[schedule["schedule"]["id"]])

schedule_ids = [schedule["schedule"]["id"] for schedule in schedules_with_configs]
return schedule_ids


@app.task(
name="execute-single-schedule",
autoretry_for=(Exception,),
retry_backoff=2,
max_retries=5,
)
def execute_single_schedule_task(schedule_id: str):
res = requests.post(f"{Config.SERVER_URL}/schedule/{schedule_id}/run")
return res.json()


@app.task(
name="register-fine-tuning-tasks",
autoretry_for=(Exception,),
retry_backoff=2,
max_retries=5,
)
def register_fine_tunings_task():
res = requests.get(f"{Config.SERVER_URL}/fine-tuning/pending")
fine_tunings = res.json()

for fine_tuning in fine_tunings:
check_single_fine_tuning_task.apply_async(args=[fine_tuning["id"]])

asyncio.run(execute_scheduled_runs())

@app.task(
name="check-single-fine-tuning",
autoretry_for=(Exception,),
retry_backoff=2,
max_retries=5,
)
def check_single_fine_tuning_task(fine_tuning_id: str):
res = requests.post(f"{Config.SERVER_URL}/fine-tuning/{fine_tuning_id}/check")
return res


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 3f7add7

Please sign in to comment.