diff --git a/lve-tools/lve_tools/lve/checkers/__init__.py b/lve-tools/lve_tools/lve/checkers/__init__.py index be30317..81c1612 100644 --- a/lve-tools/lve_tools/lve/checkers/__init__.py +++ b/lve-tools/lve_tools/lve/checkers/__init__.py @@ -13,6 +13,8 @@ def get_checker(checker_name: str, custom_checker_path: str = None) -> BaseCheck if custom_checker_path is not None: module_path = os.path.join(custom_checker_path) sys.path.append(module_path) - filename = os.listdir(module_path)[0] - import_module(filename[:filename.find(".py")]) + files = os.listdir(module_path) + for filename in files: + if filename.startswith("checker") and filename.endswith(".py"): + import_module(filename[:filename.find(".py")]) return CheckerRegistryHolder.get_checker_registry().get(checker_name) diff --git a/lve-tools/lve_tools/lve/checkers/base.py b/lve-tools/lve_tools/lve/checkers/base.py index 1fd034f..3e7d12a 100644 --- a/lve-tools/lve_tools/lve/checkers/base.py +++ b/lve-tools/lve_tools/lve/checkers/base.py @@ -67,12 +67,11 @@ def extract_response_from_prompt(self, prompt, full=False): response = self._extract(response) return response - def invoke_check(self, prompt_in, prompt_out, param_values=None, score_callback=None): + def invoke_check(self, prompt_out, param_values=None, score_callback=None): """ Function called by LVE to invoke the checker. Main function exposed to the lve internals, but not the user. Args: - prompt_in: The chat corresponding to the input. Currently unused. prompt_out: The chat including model responses. param_values: The parameter values plugged into the prompt. @@ -187,12 +186,11 @@ def __str__(self): class MultiRunBaseChecker(BaseChecker): - def invoke_check(self, prompts_in, prompts_out, param_values=None): + def invoke_check(self, prompts_out, param_values=None): """ Function called by LVE to invoke the checker. Main function exposed to the lve internals, but not the user. Args: - prompts_in: List of the chats corresponding to the inputs. prompts_out: List of the chats including the model responses. Order should match prompts_in. param_values: The parameter values plugged into the prompt. @@ -200,7 +198,6 @@ def invoke_check(self, prompts_in, prompts_out, param_values=None): is_safe: Whether the response is safe as determined by is_safe. response_or_variables: If there is only one response, returns the response as string, else returns all variables as a dictionary. """ - assert len(prompts_in) == len(prompts_out) cnt_variables = sum(p.role == Role.assistant and p.variable is not None for p in prompts_out[0]) is_safe = self.is_safe(prompts_out, param_values) diff --git a/lve-tools/lve_tools/lve/inference.py b/lve-tools/lve_tools/lve/inference.py index a86eaed..66fab41 100644 --- a/lve-tools/lve_tools/lve/inference.py +++ b/lve-tools/lve_tools/lve/inference.py @@ -290,12 +290,16 @@ async def execute_dummy(model, prompt_in, verbose=False, chunk_callback=None, ** """ Dummy model which fills all assistant messages with "Hello world!" """ + import random prompt, model = preprocess_prompt_model(model, prompt_in, verbose, **model_args) # go through all messages and fill in assistant messages, sending everything before as context for i in range(len(prompt)): if prompt[i].role == Role.assistant and prompt[i].content == None: - prompt[i].content = model_args.get("response", "Hello world") + if "random_responses" in model_args: + prompt[i].content = random.choice(model_args["random_responses"]) + else: + prompt[i].content = model_args.get("response", "Hello world") if chunk_callback is not None: chunk_callback(prompt[i].content) chunk_callback(None) diff --git a/lve-tools/lve_tools/lve/lve.py b/lve-tools/lve_tools/lve/lve.py index 5c46d83..5b2ad53 100644 --- a/lve-tools/lve_tools/lve/lve.py +++ b/lve-tools/lve_tools/lve/lve.py @@ -66,6 +66,7 @@ class TestInstance(BaseModel): passed: bool = True author: Optional[str] = None run_info: dict + prompt_out: Optional[List[Message]] = None TPrompt = Union[str, list[Message]] class MultiPrompt(BaseModel): @@ -256,7 +257,7 @@ async def execute(self, prompt_in, verbose=False, **model_args): model_args_upd.update(model_args) return await execute_llm(self.model, prompt_in, verbose, **model_args_upd) - async def run(self, author=None, verbose=False, engine='openai', score_callback=None, chunk_callback=None, **kwargs): + async def run(self, store_prompt_out=False, author=None, verbose=False, engine='openai', score_callback=None, chunk_callback=None, **kwargs): if engine == 'lmql': return await self.run_with_lmql(author=author, verbose=verbose, **kwargs) else: @@ -280,8 +281,8 @@ async def run(self, author=None, verbose=False, engine='openai', score_callback= prompt_out.append(po) checker = self.get_checker(**kwargs) - is_safe, response = checker.invoke_check(prompt, prompt_out, param_values, score_callback=score_callback) - hook("lve.check", prompt=prompt, prompt_out=response, param_values=param_values, checker_name=self.checker_args.get("checker_name", "unknown")) + is_safe, response = checker.invoke_check(prompt_out, param_values, score_callback=score_callback) + hook("lve.check", prompt_out=response, param_values=param_values, checker_name=self.checker_args.get("checker_name", "unknown")) response = checker.postprocess_response(response) @@ -291,6 +292,7 @@ async def run(self, author=None, verbose=False, engine='openai', score_callback= response=response, run_info=run_info, passed=is_safe, + prompt_out=prompt_out if store_prompt_out else None, ) async def run_with_lmql(self, author=None, verbose=False, **kwargs): @@ -314,7 +316,7 @@ async def run_with_lmql(self, author=None, verbose=False, **kwargs): checker = self.get_checker() prompt_out = copy.deepcopy(prompt) + [Message(content=response, role=Role.assistant, variable='response')] - is_safe, response = checker.invoke_check(prompt, prompt_out, param_values) + is_safe, response = checker.invoke_check(prompt_out, param_values) return TestInstance( author=author,