diff --git a/src/scrapeghost/apicall.py b/src/scrapeghost/apicall.py index 85d86e0..484f42e 100644 --- a/src/scrapeghost/apicall.py +++ b/src/scrapeghost/apicall.py @@ -9,6 +9,7 @@ from typing import Callable from .errors import ( + ScrapeghostError, TooManyTokens, MaxCostExceeded, BadStop, @@ -95,17 +96,19 @@ def _raw_api_request( raise MaxCostExceeded( f"Total cost {self.total_cost:.2f} exceeds max cost {self.max_cost:.2f}" ) - # json_mode = ( - # {"response_format": "json_object"} if _model_dict[model].json_mode else {} - # ) - json_mode = {} + json_mode = ( + {"response_format": "json_object"} if _model_dict[model].json_mode else {} + ) start_t = time.time() completion = client.chat.completions.create( - model=model, messages=messages, **self.model_params, **json_mode, + model=model, messages=messages, **self.model_params, **json_mode, # type: ignore ) elapsed = time.time() - start_t - p_tokens = completion.usage.prompt_tokens - c_tokens = completion.usage.completion_tokens + if completion.usage: + p_tokens = completion.usage.prompt_tokens + c_tokens = completion.usage.completion_tokens + else: + raise ScrapeghostError("no usage data returned") cost = _model_dict[model].cost(c_tokens, p_tokens) logger.info( "API response", @@ -133,7 +136,7 @@ def _raw_api_request( f"(prompt_tokens={p_tokens}, " f"completion_tokens={c_tokens})" ) - response.data = choice.message.content + response.data = choice.message.content # type: ignore return response def _api_request(self, html: str) -> Response: diff --git a/src/scrapeghost/models.py b/src/scrapeghost/models.py index 9870076..564e130 100644 --- a/src/scrapeghost/models.py +++ b/src/scrapeghost/models.py @@ -20,7 +20,7 @@ def cost(self, prompt_tokens: int, completion_tokens: int) -> float: Model("gpt-4", 0.03, 0.06, 8192, False), Model("gpt-4-32k", 0.06, 0.12, 32768, False), Model("gpt-4-1106-preview", 0.01, 0.03, 128000, True), - Model("gpt-3.5-turbo", 0.001, 0.002, 16384, True), + Model("gpt-3.5-turbo", 0.001, 0.002, 16384, False), Model("gpt-3.5-turbo-1106", 0.001, 0.002, 16384, True), ] _model_dict = {model.name: model for model in models}