Skip to content

Commit

Permalink
make checker receive only prompt_out
Browse files Browse the repository at this point in the history
  • Loading branch information
mbalunovic committed Dec 13, 2023
1 parent 0efd63b commit b3643a2
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 12 deletions.
6 changes: 4 additions & 2 deletions lve-tools/lve_tools/lve/checkers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def get_checker(checker_name: str, custom_checker_path: str = None) -> BaseCheck
if custom_checker_path is not None:
module_path = os.path.join(custom_checker_path)
sys.path.append(module_path)
filename = os.listdir(module_path)[0]
import_module(filename[:filename.find(".py")])
files = os.listdir(module_path)
for filename in files:
if filename.startswith("checker") and filename.endswith(".py"):
import_module(filename[:filename.find(".py")])
return CheckerRegistryHolder.get_checker_registry().get(checker_name)
7 changes: 2 additions & 5 deletions lve-tools/lve_tools/lve/checkers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,11 @@ def extract_response_from_prompt(self, prompt, full=False):
response = self._extract(response)
return response

def invoke_check(self, prompt_in, prompt_out, param_values=None, score_callback=None):
def invoke_check(self, prompt_out, param_values=None, score_callback=None):
""" Function called by LVE to invoke the checker.
Main function exposed to the lve internals, but not the user.
Args:
prompt_in: The chat corresponding to the input. Currently unused.
prompt_out: The chat including model responses.
param_values: The parameter values plugged into the prompt.
Expand Down Expand Up @@ -187,20 +186,18 @@ def __str__(self):

class MultiRunBaseChecker(BaseChecker):

def invoke_check(self, prompts_in, prompts_out, param_values=None):
def invoke_check(self, prompts_out, param_values=None):
""" Function called by LVE to invoke the checker.
Main function exposed to the lve internals, but not the user.
Args:
prompts_in: List of the chats corresponding to the inputs.
prompts_out: List of the chats including the model responses. Order should match prompts_in.
param_values: The parameter values plugged into the prompt.
Returns:
is_safe: Whether the response is safe as determined by is_safe.
response_or_variables: If there is only one response, returns the response as string, else returns all variables as a dictionary.
"""
assert len(prompts_in) == len(prompts_out)
cnt_variables = sum(p.role == Role.assistant and p.variable is not None for p in prompts_out[0])
is_safe = self.is_safe(prompts_out, param_values)

Expand Down
6 changes: 5 additions & 1 deletion lve-tools/lve_tools/lve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,16 @@ async def execute_dummy(model, prompt_in, verbose=False, chunk_callback=None, **
"""
Dummy model which fills all assistant messages with "Hello world!"
"""
import random
prompt, model = preprocess_prompt_model(model, prompt_in, verbose, **model_args)

# go through all messages and fill in assistant messages, sending everything before as context
for i in range(len(prompt)):
if prompt[i].role == Role.assistant and prompt[i].content == None:
prompt[i].content = model_args.get("response", "Hello world")
if "random_responses" in model_args:
prompt[i].content = random.choice(model_args["random_responses"])
else:
prompt[i].content = model_args.get("response", "Hello world")
if chunk_callback is not None:
chunk_callback(prompt[i].content)
chunk_callback(None)
Expand Down
10 changes: 6 additions & 4 deletions lve-tools/lve_tools/lve/lve.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class TestInstance(BaseModel):
passed: bool = True
author: Optional[str] = None
run_info: dict
prompt_out: Optional[List[Message]] = None

TPrompt = Union[str, list[Message]]
class MultiPrompt(BaseModel):
Expand Down Expand Up @@ -256,7 +257,7 @@ async def execute(self, prompt_in, verbose=False, **model_args):
model_args_upd.update(model_args)
return await execute_llm(self.model, prompt_in, verbose, **model_args_upd)

async def run(self, author=None, verbose=False, engine='openai', score_callback=None, chunk_callback=None, **kwargs):
async def run(self, store_prompt_out=False, author=None, verbose=False, engine='openai', score_callback=None, chunk_callback=None, **kwargs):
if engine == 'lmql':
return await self.run_with_lmql(author=author, verbose=verbose, **kwargs)
else:
Expand All @@ -280,8 +281,8 @@ async def run(self, author=None, verbose=False, engine='openai', score_callback=
prompt_out.append(po)

checker = self.get_checker(**kwargs)
is_safe, response = checker.invoke_check(prompt, prompt_out, param_values, score_callback=score_callback)
hook("lve.check", prompt=prompt, prompt_out=response, param_values=param_values, checker_name=self.checker_args.get("checker_name", "unknown"))
is_safe, response = checker.invoke_check(prompt_out, param_values, score_callback=score_callback)
hook("lve.check", prompt_out=response, param_values=param_values, checker_name=self.checker_args.get("checker_name", "unknown"))

response = checker.postprocess_response(response)

Expand All @@ -291,6 +292,7 @@ async def run(self, author=None, verbose=False, engine='openai', score_callback=
response=response,
run_info=run_info,
passed=is_safe,
prompt_out=prompt_out if store_prompt_out else None,
)

async def run_with_lmql(self, author=None, verbose=False, **kwargs):
Expand All @@ -314,7 +316,7 @@ async def run_with_lmql(self, author=None, verbose=False, **kwargs):

checker = self.get_checker()
prompt_out = copy.deepcopy(prompt) + [Message(content=response, role=Role.assistant, variable='response')]
is_safe, response = checker.invoke_check(prompt, prompt_out, param_values)
is_safe, response = checker.invoke_check(prompt_out, param_values)

return TestInstance(
author=author,
Expand Down

0 comments on commit b3643a2

Please sign in to comment.