Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
nauyisu022 committed Apr 20, 2024
2 parents 1107ab2 + cdd3ca7 commit 2df5423
Showing 1 changed file with 37 additions and 96 deletions.
133 changes: 37 additions & 96 deletions trustllm_pkg/trustllm/utils/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,118 +11,60 @@
import trustllm.config
import replicate

# online_model = trustllm.config.model_info['online_model']
# deepinfra_model = trustllm.config.model_info['deepinfra_model']

# Load model information from configuration
model_info = trustllm.config.model_info
online_model_list = model_info['online_model']
model_mapping = model_info['model_mapping']
rev_model_mapping = {value: key for key, value in model_mapping.items()}

# Define safety settings to allow harmful content generation
safety_setting = [
{
"category": safety_types.HarmCategory.HARM_CATEGORY_DEROGATORY,
"threshold": safety_types.HarmBlockThreshold.BLOCK_NONE,
},

{
"category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE,
"threshold": safety_types.HarmBlockThreshold.BLOCK_NONE,
},

{
"category": safety_types.HarmCategory.HARM_CATEGORY_SEXUAL,
"threshold": safety_types.HarmBlockThreshold.BLOCK_NONE,
},

{
"category": safety_types.HarmCategory.HARM_CATEGORY_TOXICITY,
"threshold": safety_types.HarmBlockThreshold.BLOCK_NONE,
},
{
"category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL,
"threshold": safety_types.HarmBlockThreshold.BLOCK_NONE,
},
{
"category": safety_types.HarmCategory.HARM_CATEGORY_DANGEROUS,
"threshold": safety_types.HarmBlockThreshold.BLOCK_NONE,
},
]


{"category": safety_types.HarmCategory.HARM_CATEGORY_DEROGATORY, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
{"category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
{"category": safety_types.HarmCategory.HARM_CATEGORY_SEXUAL, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
{"category": safety_types.HarmCategory.HARM_CATEGORY_TOXICITY, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
{"category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
{"category": safety_types.HarmCategory.HARM_CATEGORY_DANGEROUS, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
]

# Retrieve model information
def get_models():
return model_mapping, online_model_list

# Function to obtain access token for APIs
def get_access_token():
url = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={}&client_secret=".format(
trustllm.config.client_id,
trustllm.config.client_secret
)

payload = json.dumps("")
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={trustllm.config.client_id}&client_secret={trustllm.config.client_secret}"
headers = {'Content-Type': 'application/json', 'Accept': 'application/json'}
response = requests.post(url, headers=headers, data=json.dumps(""))
return response.json().get("access_token")


# Function to get responses from the ERNIE API
def get_ernie_res(string, temperature):
if (temperature == 0.0):
if temperature == 0.0:
temperature = 0.00000001
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=" + get_access_token()
payload = json.dumps({
"messages": [
{
"role": "user",
"content": string,
}
],
'temperature': temperature
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
print(response.text)
url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token={get_access_token()}"
payload = json.dumps({"messages": [{"role": "user", "content": string}], 'temperature': temperature})
headers = {'Content-Type': 'application/json'}
response = requests.post(url, headers=headers, data=payload)
res_data = json.loads(response.text)
result_text = res_data.get('result', '')
return result_text


return res_data.get('result', '')

# Function to generate responses using OpenAI's API
def get_res_openai(string, model, temperature):
gpt_model_mapping={ "chatgpt":"gpt-3.5-turbo",
"gpt-4":"gpt-4-1106-preview"}

gpt_model=gpt_model_mapping[model]
gpt_model_mapping = {"chatgpt": "gpt-3.5-turbo", "gpt-4": "gpt-4-1106-preview"}
gpt_model = gpt_model_mapping[model]
api_key = trustllm.config.openai_key
client = OpenAI(api_key=api_key,)
stream = client.chat.completions.create(model=gpt_model,
messages=[{"role": "user", "content": string}],
temperature=temperature,)
if not stream.choices[0].message.content:
raise ValueError("The response from the API is NULL or an empty string!")
response = stream.choices[0].message.content
return response


client = OpenAI(api_key=api_key)
response = client.chat.completions.create(model=gpt_model, messages=[{"role": "user", "content": string}], temperature=temperature)
return response.choices[0].message.content if response.choices[0].message.content else ValueError("Empty response from API")

# Function to generate responses using Deepinfra's API
def deepinfra_api(string, model, temperature):
api_token = trustllm.config.deepinfra_api
top_p = 1 if temperature <= 1e-5 else 0.9
OpenAI(api_key=api_token,api_base="https://api.deepinfra.com/v1/openai")
stream = client.chat.completions.create(
model=rev_model_mapping[model],
messages=[{"role": "user", "content": string}],
max_tokens=5192,
temperature=temperature,
top_p=top_p
)
response = stream.choices[0].message.content
return response

top_p = 0.9 if temperature > 1e-5 else 1
client = OpenAI(api_key=api_token, api_base="https://api.deepinfra.com/v1/openai")
stream = client.chat.completions.create(model=rev_model_mapping[model], messages=[{"role": "user", "content": string}], max_tokens=5192, temperature=temperature, top_p=top_p)
return stream.choices[0].message.content


def replicate_api(string, model, temperature):
Expand Down Expand Up @@ -228,11 +170,10 @@ def gen_online(model_name, prompt, temperature, replicate=False, deepinfra=False
return res


def prompt2conversation(model_path,prompt):
msg = prompt
# Convert prompt to conversation format for specific models
def prompt2conversation(model_path, prompt):
conv = get_conversation_template(model_path)
conv.set_system_message('')
conv.append_message(conv.roles[0], msg)
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
conversation = conv.get_prompt()
return conversation
return conv.get_prompt()

0 comments on commit 2df5423

Please sign in to comment.