Skip to content

Commit

Permalink
feat: Cria o objeto GPTModel a partir do modelo do fine tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
RWallan committed May 20, 2024
1 parent 35ae917 commit 7f42801
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
21 changes: 20 additions & 1 deletion openiziai/fine_tuning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from pathlib import Path
from typing import Any
from typing import Any, Optional

from openai import OpenAI
from pydantic import (
Expand All @@ -11,6 +11,7 @@
field_validator,
)

from openiziai.schemas import GPTModel
from openiziai.task import Task


Expand All @@ -31,6 +32,7 @@ class FineTuning(BaseModel):
_file_id: str = PrivateAttr(default=None)
_job_id: str = PrivateAttr(default=None)
_job_status: JobStatus = PrivateAttr(default=None)
_model: GPTModel = PrivateAttr(default=None)

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down Expand Up @@ -94,3 +96,20 @@ def status(self) -> str:
self._job_status = JobStatus(_job_status)

return self._job_status.name

# TODO: Consertar mocks para testar a propriedade
@property
def model(self) -> Optional[GPTModel]:
if self._model:
return self._model

model_name = self.client.fine_tuning.jobs.retrieve(
self.job_id
).fine_tuned_model
if not model_name:
print(f'Modelo não disponível. Status: {self.status}')
return None

self._model = GPTModel(
name=model_name, base_model=self.base_model, task=self.task
)
13 changes: 13 additions & 0 deletions openiziai/schemas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import sys
from datetime import datetime
from typing import Any, Protocol

from pydantic import Field
from pydantic.dataclasses import dataclass

from openiziai.task import Task

if sys.version_info < (3, 12):
from typing_extensions import TypedDict
else:
Expand All @@ -14,3 +20,10 @@ class DataDict(TypedDict):
class Pipeline(Protocol):
def run(self): ...


@dataclass
class GPTModel:
name: str
task: Task
base_model: str
created_at: datetime = Field(init=False, default_factory=datetime.now)

0 comments on commit 7f42801

Please sign in to comment.