Skip to content

Commit

Permalink
make context optional; improve error handling and doc (#997)
Browse files Browse the repository at this point in the history
* make context optional

* better error handling and doc

* skip instantiation if no context

* skip test
  • Loading branch information
sonichi authored Apr 16, 2023
1 parent b235fe0 commit d4070e2
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 41 deletions.
73 changes: 46 additions & 27 deletions flaml/autogen/oai/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import numpy as np
import time
from typing import List
from typing import List, Optional, Dict
import sys
from flaml import tune, BlendSearch
from flaml.automl.logger import logger_formatter
Expand Down Expand Up @@ -142,9 +142,10 @@ def _get_response(cls, config: dict, eval_only=False, use_cache=True):
return response
openai_completion = openai.ChatCompletion if config["model"] in cls.chat_models else openai.Completion
start_time = time.time()
request_timeout = cls.request_timeout
while True:
try:
response = openai_completion.create(request_timeout=cls.request_timeout, **config)
response = openai_completion.create(request_timeout=request_timeout, **config)
cls._cache.set(key, response)
return response
except (
Expand All @@ -155,14 +156,22 @@ def _get_response(cls, config: dict, eval_only=False, use_cache=True):
# transient error
logger.warning(f"retrying in {cls.retry_time} seconds...", exc_info=1)
sleep(cls.retry_time)
except (RateLimitError, Timeout):
# retry after retry_time seconds
if time.time() - start_time + cls.retry_time < cls.retry_timeout:
except (RateLimitError, Timeout) as e:
time_left = cls.retry_timeout - (time.time() - start_time + cls.retry_time)
if (
time_left > 0
and isinstance(e, RateLimitError)
or time_left > request_timeout
and isinstance(e, Timeout)
):
logger.info(f"retrying in {cls.retry_time} seconds...", exc_info=1)
elif eval_only:
raise
else:
break
if isinstance(e, Timeout):
request_timeout <<= 1
request_timeout = min(request_timeout, time_left)
sleep(cls.retry_time)
except InvalidRequestError:
if "azure" == openai.api_type and "model" in config:
Expand Down Expand Up @@ -472,14 +481,14 @@ def eval_func(responses, **data):
For prompt, please provide a string/Callable or a list of strings/Callables.
- If prompt is provided for chat models, it will be converted to messages under role "user".
- Do not provide both prompt and messages for chat models, but provide either of them.
- A string `prompt` template will be used to generate a prompt for each data instance
- A string template will be used to generate a prompt for each data instance
using `prompt.format(**data)`.
- A callable `prompt` template will be used to generate a prompt for each data instance
- A callable template will be used to generate a prompt for each data instance
using `prompt(data)`.
For stop, please provide a string, a list of strings, or a list of lists of strings.
For messages (chat models only), please provide a list of messages (for a single chat prefix)
or a list of lists of messages (for multiple choices of chat prefix to choose from).
Each message should be a dict with keys "role" and "content".
Each message should be a dict with keys "role" and "content". The value of "content" can be a string/Callable template.
Returns:
dict: The optimized hyperparameter setting.
Expand Down Expand Up @@ -610,17 +619,21 @@ def eval_func(responses, **data):
return params, analysis

@classmethod
def create(cls, context, use_cache=True, **config):
def create(cls, context: Optional[Dict] = None, use_cache: Optional[bool] = True, **config):
"""Make a completion for a given context.
Args:
context (dict): The context to instantiate the prompt.
context (dict, Optional): The context to instantiate the prompt.
It needs to contain keys that are used by the prompt template.
E.g., `prompt="Complete the following sentence: {prefix}"`.
`context={"prefix": "Today I feel"}`.
The actual prompt sent to OpenAI will be:
"Complete the following sentence: Today I feel".
use_cache (bool, Optional): Whether to use cached responses.
**config: Configuration for the completion.
Besides the parameters for the openai API call, it can also contain a seed (int) for the cache.
This is useful when implementing "controlled randomness" for the completion.
Also, the "prompt" or "messages" parameter can contain a template (str or Callable) which will be instantiated with the context.
Returns:
Responses from OpenAI API.
Expand All @@ -637,6 +650,14 @@ def create(cls, context, use_cache=True, **config):
cls.set_cache(seed)
return cls._get_response(params, eval_only=True)

@classmethod
def _instantiate(cls, template: str, context: Optional[Dict] = None):
if not context:
return template
if isinstance(template, str):
return template.format(**context)
return template(context)

@classmethod
def _construct_params(cls, data_instance, config, prompt=None, messages=None):
params = config.copy()
Expand All @@ -649,30 +670,28 @@ def _construct_params(cls, data_instance, config, prompt=None, messages=None):
if messages is None:
raise ValueError("Either prompt or messages should be in config for chat models.")
if prompt is None:
params["messages"] = [
{
"role": m["role"],
"content": m["content"].format(**data_instance)
if isinstance(m["content"], str)
else m["content"](data_instance),
}
for m in messages
]
params["messages"] = (
[
{
"role": m["role"],
"content": cls._instantiate(m["content"], data_instance),
}
for m in messages
]
if data_instance
else messages
)
elif model in cls.chat_models:
# convert prompt to messages
if isinstance(prompt, str):
prompt_msg = prompt.format(**data_instance)
else:
prompt_msg = prompt(data_instance)
params["messages"] = [
{
"role": "user",
"content": prompt_msg if isinstance(prompt, str) else prompt(data_instance),
"content": cls._instantiate(prompt, data_instance),
},
]
params.pop("prompt", None)
else:
params["prompt"] = prompt.format(**data_instance) if isinstance(prompt, str) else prompt(data_instance)
params["prompt"] = cls._instantiate(prompt, data_instance)
return params

@classmethod
Expand Down Expand Up @@ -811,7 +830,7 @@ def eval_func(responses, **data):

@classmethod
def cost(cls, model: str, response: dict):
"""Compute the cost of a completion.
"""Compute the cost of an API call.
Args:
model (str): The model name.
Expand All @@ -832,7 +851,7 @@ def cost(cls, model: str, response: dict):

@classmethod
def extract_text(cls, response: dict) -> List[str]:
"""Extract the text from a completion response.
"""Extract the text from a completion or chat response.
Args:
response (dict): The response from OpenAI API.
Expand Down
14 changes: 13 additions & 1 deletion test/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
from flaml.autogen.math_utils import eval_math_responses


def test_nocontext():
try:
import openai
import diskcache
except ImportError as exc:
print(exc)
return
response = oai.Completion.create(model="text-ada-001", prompt="1+1=", max_tokens=1)
print(response)


@pytest.mark.skipif(
sys.platform == "win32",
reason="do not run on windows",
Expand Down Expand Up @@ -223,5 +234,6 @@ def my_average(results):
import openai

openai.api_key_path = "test/openai/key.txt"
test_nocontext()
test_humaneval(1)
# test_math(1)
test_math(1)
4 changes: 2 additions & 2 deletions website/docs/Examples/AutoGen-OpenAI.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ print("best result on tuning data", analysis.best_result)
We can apply the tuned config to the request for an instance:

```python
responses = oai.Completion.create(context=tune_data[1], **config)
print(responses)
response = oai.Completion.create(context=tune_data[1], **config)
print(response)
print(eval_with_generated_assertions(oai.Completion.extract_text(response), **tune_data[1]))
```

Expand Down
42 changes: 31 additions & 11 deletions website/docs/Use-Cases/Auto-Generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ which can significantly affect both the utility and the cost of the generated te

The tunable hyperparameters include:
1. model - this is a required input, specifying the model ID to use.
1. prompt - the input prompt to the model, which provides the context for the text generation task.
1. prompt/messages - the input prompt/messages to the model, which provides the context for the text generation task.
1. max_tokens - the maximum number of tokens (words or word pieces) to generate in the output.
1. temperature - a value between 0 and 1 that controls the randomness of the generated text. A higher temperature will result in more random and diverse text, while a lower temperature will result in more predictable text.
1. top_p - a value between 0 and 1 that controls the sampling probability mass for each token generation. A lower top_p value will make it more likely to generate text based on the most likely tokens, while a higher value will allow the model to explore a wider range of possible tokens.
Expand Down Expand Up @@ -61,8 +61,8 @@ The metric to optimize is usually an aggregated metric over all the tuning data
Users can specify the (optional) search range for each hyperparameter.

1. model. Either a constant str, or multiple choices specified by `flaml.tune.choice`.
1. prompt. Either a str or a list of strs, of the prompt templates.
Each prompt template will be formatted with each data instance. For example, the prompt template can be:
1. prompt/messages. Prompt is either a str or a list of strs, of the prompt templates. messages is a list of dicts or a list of lists, of the message templates.
Each prompt/message template will be formatted with each data instance. For example, the prompt template can be:
"{problem} Solve the problem carefully. Simplify your answer as much as possible. Put the final answer in \\boxed{{}}."
And `{problem}` will be replaced by the "problem" field of each data instance.
1. max_tokens, n, best_of. They can be constants, or specified by `flaml.tune.randint`, `flaml.tune.qrandint`, `flaml.tune.lograndint` or `flaml.qlograndint`. By default, max_tokens is searched in [50, 1000); n is searched in [1, 100); and best_of is fixed to 1.
Expand Down Expand Up @@ -98,19 +98,39 @@ config, analysis = oai.Completion.tune(
`num_samples` is the number of configurations to sample. -1 means unlimited (until optimization budget is exhausted).
The returned `config` contains the optimized configuration and `analysis` contains an [ExperimentAnalysis](../reference/tune/analysis#experimentanalysis-objects) object for all the tried configurations and results.

### Perform inference with the tuned config
## Perform inference with the tuned config

One can use [`flaml.oai.Completion.create`](../reference/autogen/oai/completion#create) to performance inference. It materializes a prompt using a given context. For example,
One can use [`flaml.oai.Completion.create`](../reference/autogen/oai/completion#create) to performance inference.
There are a number of benefits of using `flaml.oai.Completion.create` to perform inference.

A template is either a format str, or a function which produces a str from several input fields.

### API unification

`flaml.oai.Completion.create` is compatible with both `openai.Completion.create` and `openai.ChatCompletion.create`, and both OpenAI API and Azure OpenAI API. So models such as "text-davinci-003", "gpt-3.5-turbo" and "gpt-4" can share a common API. When only tuning the chat-based models, `flaml.oai.ChatCompletion` can be used.

### Caching

API call results are cached locally and reused when the same request is issued. This is useful when repeating or continuing experiments for reproducibility and cost saving. It still allows controlled randomness by setting the "seed", using [`set_cache`](../reference/autogen/oai/completion#set_cache) or specifying in `create()`.

### Error handling

It is easy to hit error when calling OpenAI APIs, due to connection, rate limit, or timeout. Some of the errors are transient. `flaml.oai.Completion.create` deals with the transient errors and retries automatically. Initial request timeout, retry timeout and retry time interval can be configured via `flaml.oai.request_timeout`, `flaml.oai.retry_timeout` and `flaml.oai.retry_time`.

### Templating

If the provided prompt or message is a template, it will be automatically materialized with a given context. For example,

```python
response = oai.Completion.create(problme=problem, **config)
responses = oai.Completion.extract_text(response)
# Extract a list of str responses
response = oai.Completion.create(problme=problem, prompt="{problem} Solve the problem carefully.", **config)
```

`flaml.oai.Completion` is compatible with both `openai.Completion` and `openai.ChatCompletion`. So models such as "text-davinci-003", "gpt-3.5-turbo" and "gpt-4" can share a common API. When only tuning the chat-based models, `flaml.oai.ChatCompletion` can be used.

`flaml.oai.Completion` also offers some additional utilities including a `test` function to conveniently evaluate the configuration over test data, a `cost` function to calculate the cost of an API call, and caching and error handling. It also supports both OpenAI API and Azure OpenAI API.
## Other utilities
`flaml.oai.Completion` also offers some additional utilities, such as:
- a [`cost`](../reference/autogen/oai/completion#cost) function to calculate the cost of an API call.
- a [`test`](../reference/autogen/oai/completion#test) function to conveniently evaluate the configuration over test data.
- a [`extract_text`](../reference/autogen/oai/completion#extract_text) function to extract the text from a completion or chat response.
- a [`set_cache`](../reference/autogen/oai/completion#extract_text) function to set the seed and cache path. The caching is introduced in the section above, with the benefit of cost saving, reproducibility, and controlled randomness.

Interested in trying it yourself? Please check the following notebook examples:
* [Optimize for Code Gen](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_openai.ipynb)
Expand Down

0 comments on commit d4070e2

Please sign in to comment.