Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
nauyisu022 committed Jul 30, 2024
1 parent ee49766 commit 2b2e36e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion trustllm_pkg/trustllm/generation/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _generation_hf(self, prompt, tokenizer, model, temperature):
:return: The generated text as a string.
"""

prompt = prompt2conversation(prompt, self.model_path)
prompt = prompt2conversation(model_path=self.model_path,prompt=prompt,)
inputs = tokenizer([prompt])
inputs = {k: torch.tensor(v).to(self.device) for k, v in inputs.items()}
output_ids = model.generate(
Expand Down
2 changes: 1 addition & 1 deletion trustllm_pkg/trustllm/utils/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def replicate_api(string, model, temperature):
if model in ["llama3-70b","llama3-8b"]:
input["prompt_template"] = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
else:
input["prompt"]=prompt2conversation(rev_model_mapping[model],string)
input["prompt"]=prompt2conversation(model_path=rev_model_mapping[model],prompt=string)
os.environ["REPLICATE_API_TOKEN"] = trustllm.config.replicate_api
res = replicate.run(rev_model_mapping[model],
input=input
Expand Down

0 comments on commit 2b2e36e

Please sign in to comment.