From 2b2e36ed23fe5b9179e816f2a2fff587152d2c50 Mon Sep 17 00:00:00 2001 From: nauyisu022 <59754221+nauyisu022@users.noreply.github.com> Date: Tue, 30 Jul 2024 21:33:11 +0200 Subject: [PATCH] update --- trustllm_pkg/trustllm/generation/generation.py | 2 +- trustllm_pkg/trustllm/utils/generation_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/trustllm_pkg/trustllm/generation/generation.py b/trustllm_pkg/trustllm/generation/generation.py index fecc623..c8ed353 100644 --- a/trustllm_pkg/trustllm/generation/generation.py +++ b/trustllm_pkg/trustllm/generation/generation.py @@ -55,7 +55,7 @@ def _generation_hf(self, prompt, tokenizer, model, temperature): :return: The generated text as a string. """ - prompt = prompt2conversation(prompt, self.model_path) + prompt = prompt2conversation(model_path=self.model_path,prompt=prompt,) inputs = tokenizer([prompt]) inputs = {k: torch.tensor(v).to(self.device) for k, v in inputs.items()} output_ids = model.generate( diff --git a/trustllm_pkg/trustllm/utils/generation_utils.py b/trustllm_pkg/trustllm/utils/generation_utils.py index 9230070..7697b43 100644 --- a/trustllm_pkg/trustllm/utils/generation_utils.py +++ b/trustllm_pkg/trustllm/utils/generation_utils.py @@ -72,7 +72,7 @@ def replicate_api(string, model, temperature): if model in ["llama3-70b","llama3-8b"]: input["prompt_template"] = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" else: - input["prompt"]=prompt2conversation(rev_model_mapping[model],string) + input["prompt"]=prompt2conversation(model_path=rev_model_mapping[model],prompt=string) os.environ["REPLICATE_API_TOKEN"] = trustllm.config.replicate_api res = replicate.run(rev_model_mapping[model], input=input