Skip to content

Commit

Permalink
Add retry to get OpenAI embeddings (#378)
Browse files Browse the repository at this point in the history
Allow multiple times retry to get OpenAI embeddings if some requests
failed
  • Loading branch information
zechengz authored Mar 15, 2024
1 parent c1f0cb8 commit 0cb4d7f
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions benchmark/data_frame_text_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from peft import LoraConfig
from peft import TaskType as peftTaskType
from peft import get_peft_model
from tenacity import retry, stop_after_attempt, wait_random_exponential
from torch import Tensor
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss
from torch.optim.lr_scheduler import ExponentialLR
Expand Down Expand Up @@ -202,15 +203,21 @@ def __init__(self, model: str, api_key: str):
def __call__(self, sentences: list[str]) -> Tensor:
from openai import Embedding

items: list[Embedding] = self.client.embeddings.create(
input=sentences, model=self.model).data
items: list[Embedding] = embeddings_with_backoff(
self.client, self.model, sentences)
assert len(items) == len(sentences)
embeddings = [
torch.FloatTensor(item.embedding).view(1, -1) for item in items
]
return torch.cat(embeddings, dim=0)


@retry(wait=wait_random_exponential(min=1, max=30), stop=stop_after_attempt(6))
def embeddings_with_backoff(client: Any, model: str,
sentences: list[str]) -> list[Any]:
return client.embeddings.create(input=sentences, model=model).data


def mean_pooling(last_hidden_state: Tensor, attention_mask: Tensor) -> Tensor:
input_mask_expanded = (attention_mask.unsqueeze(-1).expand(
last_hidden_state.size()).float())
Expand Down

0 comments on commit 0cb4d7f

Please sign in to comment.