Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
nauyisu022 committed Sep 6, 2024
1 parent 745a0db commit 0dae8ed
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 49 deletions.
2 changes: 0 additions & 2 deletions docs/guides/generation_details.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ config.claude_api = "claude api"

config.openai_key = "openai api"

config.palm_api = "palm api"

config.ernie_client_id = "ernie client id"

config.ernie_client_secret = "ernie client secret"
Expand Down
2 changes: 0 additions & 2 deletions trustllm_pkg/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
'python-dotenv',
'urllib3',
'anthropic',
'google.generativeai',
'google-api-python-client',
'google.ai.generativelanguage',
'replicate',
'zhipuai>=2.0.1'
],
Expand Down
5 changes: 1 addition & 4 deletions trustllm_pkg/trustllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
deepinfra_api = None
ernie_api = None
claude_api = None
palm_api = None
replicate_api = None
zhipu_api = None

Expand Down Expand Up @@ -38,19 +37,17 @@
zhipu_model = ["glm-4", "glm-3-turbo"]
claude_model = ["claude-2", "claude-instant-1"]
openai_model = ["chatgpt", "gpt-4"]
google_model = ["bison-001", "gemini"]
wenxin_model = ["ernie"]
replicate_model=["vicuna-7b","vicuna-13b","vicuna-33b","chatglm3-6b","llama3-70b","llama3-8b"]

online_model = deepinfra_model + zhipu_model + claude_model + openai_model + google_model + wenxin_model+replicate_model
online_model = deepinfra_model + zhipu_model + claude_model + openai_model + wenxin_model+replicate_model

model_info = {
"online_model": online_model,
"zhipu_model": zhipu_model,
"deepinfra_model": deepinfra_model,
'claude_model': claude_model,
'openai_model': openai_model,
'google_model': google_model,
'wenxin_model': wenxin_model,
'replicate_model':replicate_model,
"model_mapping": {
Expand Down
42 changes: 1 addition & 41 deletions trustllm_pkg/trustllm/utils/generation_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os, json
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
import google.generativeai as genai
from google.generativeai.types import safety_types

from fastchat.model import load_model, get_conversation_template
from openai import OpenAI,AzureOpenAI
from tenacity import retry, wait_random_exponential, stop_after_attempt
Expand All @@ -17,16 +16,6 @@
model_mapping = model_info['model_mapping']
rev_model_mapping = {value: key for key, value in model_mapping.items()}

# Define safety settings to allow harmful content generation
safety_setting = [
{"category": safety_types.HarmCategory.HARM_CATEGORY_DEROGATORY, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
{"category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
{"category": safety_types.HarmCategory.HARM_CATEGORY_SEXUAL, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
{"category": safety_types.HarmCategory.HARM_CATEGORY_TOXICITY, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
{"category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
{"category": safety_types.HarmCategory.HARM_CATEGORY_DANGEROUS, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE},
]

# Retrieve model information
def get_models():
return model_mapping, online_model_list
Expand Down Expand Up @@ -98,31 +87,7 @@ def claude_api(string, model, temperature):
return completion.completion


@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6))
def gemini_api(string, temperature):
genai.configure(api_key=trustllm.config.gemini_api)
model = genai.GenerativeModel('gemini-pro')
response = model.generate_content(string, temperature=temperature, safety_settings=safety_setting)
return response



@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6))
def palm_api(string, model, temperature):
genai.configure(api_key=trustllm.config.palm_api)

model_mapping = {
'bison-001': 'models/text-bison-001',
}
completion = genai.generate_text(
model=model_mapping[model], # models/text-bison-001
prompt=string,
temperature=temperature,
# The maximum length of the response
max_output_tokens=4000,
safety_settings=safety_setting
)
return completion.result


@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6))
Expand All @@ -148,11 +113,6 @@ def zhipu_api(string, model, temperature):
def gen_online(model_name, prompt, temperature, replicate=False, deepinfra=False):
if model_name in model_info['wenxin_model']:
res = get_ernie_res(prompt, temperature=temperature)
elif model_name in model_info['google_model']:
if model_name == 'bison-001':
res = palm_api(prompt, model=model_name, temperature=temperature)
elif model_name == 'gemini-pro':
res = gemini_api(prompt, temperature=temperature)
elif model_name in model_info['openai_model']:
res = get_res_openai(prompt, model=model_name, temperature=temperature)
elif model_name in model_info['deepinfra_model']:
Expand Down

0 comments on commit 0dae8ed

Please sign in to comment.