Skip to content

Commit

Permalink
Merge pull request #8 from lukfre/decode-fix
Browse files Browse the repository at this point in the history
Decode fix
  • Loading branch information
Duguce authored Dec 25, 2024
2 parents 491160e + 6d78835 commit c6ab86f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/xfinder/modules/answer_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def compare(
if type(standard_answer_range) == str:
standard_answer_range_list = ast.literal_eval(
standard_answer_range)
else:
standard_answer_range_list = standard_answer_range
for option in standard_answer_range_list:
if option[0] == correct and \
extracted.strip().rstrip(".").lower() == option[1].strip().rstrip(".").lower():
Expand Down
4 changes: 2 additions & 2 deletions src/xfinder/modules/answer_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def _execute_local_inference(self, query: Dict[str, Any]) -> str:
output_ids = self.model.generate(
input_ids, max_new_tokens=self.max_tokens, temperature=self.temperature)
response = self.tokenizer.decode(
output_ids[0], skip_special_tokens=True)
return response.replace(prompt, '').strip()
output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
return response.strip()

def generate_output(self, question, llm_output, standard_answer_range) -> str:
formatted_query = f'Question: """{question}"""\n\nOutput sentences: """{llm_output}"""\n\nAnswer range: {standard_answer_range}\n\nKey extracted answer: '
Expand Down

0 comments on commit c6ab86f

Please sign in to comment.