Skip to content

Commit

Permalink
Update annotator.py
Browse files Browse the repository at this point in the history
Add optional candidates
  • Loading branch information
LittlePea13 authored Aug 2, 2024
1 parent 4a09e2e commit b82a698
Showing 1 changed file with 70 additions and 46 deletions.
116 changes: 70 additions & 46 deletions relik/inference/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def __init__(
self._retriever = {
task_type: r for task_type, r in self._retriever.items() if r is not None
}

self._index: Dict[TaskType, GoldenRetriever] = {
TaskType.SPAN: None,
TaskType.TRIPLET: None,
Expand Down Expand Up @@ -390,61 +389,86 @@ def __call__(
)

windows_candidates = {TaskType.SPAN: None, TaskType.TRIPLET: None}
# candidates are provided, use them and skip retrieval
# candidates are provided, use them and skip retrieval if they exist for the specific task
if candidates is not None:
# again, check if candidates is a dict
if isinstance(candidates, Dict):
if self.task not in candidates:
raise ValueError(
f"Task `{self.task}` not found in `candidates`."
f"Please choose one of {list(TaskType)}."
)
for task_type in windows_candidates.keys():
if task_type in candidates:
_candidates = candidates[task_type]
if isinstance(_candidates, list):
_candidates = [
[
c if isinstance(c, Document) else Document(c)
for c in _candidates[w.doc_id]
]
for w in windows
]
windows_candidates[task_type] = _candidates
else:
if self._retriever is None:
raise ValueError(
"No retriever was provided, please provide a retriever or candidates."
)
# retrieve candidates for this task
retriever = self._retriever.get(task_type)
if retriever is not None:
retriever_out = retriever.retrieve(
[w.text for w in windows],
text_pair=[
(
w.doc_topic
if (w.doc_topic is not None and use_doc_topic)
else None
)
for w in windows
],
k=top_k,
batch_size=retriever_batch_size,
progress_bar=progress_bar,
**kwargs,
)
windows_candidates[task_type] = [
[p.document for p in predictions] for predictions in retriever_out
]
else:
# check if the candidates are a list of lists
if not isinstance(candidates[0], list):
candidates = [candidates]
candidates = {self.task: candidates}

for task_type, _candidates in candidates.items():
if isinstance(_candidates, list):
_candidates = [
[
c if isinstance(c, Document) else Document(c)
for c in _candidates[w.doc_id]
for task_type, _candidates in candidates.items():
if isinstance(_candidates, list):
_candidates = [
[
c if isinstance(c, Document) else Document(c)
for c in _candidates[w.doc_id]
]
for w in windows
]
for w in windows
windows_candidates[task_type] = _candidates

# retrieve candidates for tasks without provided candidates
for task_type, retriever in self._retriever.items():
if windows_candidates[task_type] is None:
if retriever is not None:
retriever_out = retriever.retrieve(
[w.text for w in windows],
text_pair=[
(
w.doc_topic
if (w.doc_topic is not None and use_doc_topic)
else None
)
for w in windows
],
k=top_k,
batch_size=retriever_batch_size,
progress_bar=progress_bar,
**kwargs,
)
windows_candidates[task_type] = [
[p.document for p in predictions] for predictions in retriever_out
]
windows_candidates[task_type] = _candidates

else:
# retrieve candidates first
if self._retriever is None:
raise ValueError(
"No retriever was provided, please provide a retriever or candidates."
)
# start_retr = time.time()
# retrieve for each task type
for task_type, retriever in self._retriever.items():
retriever_out = retriever.retrieve(
[w.text for w in windows],
text_pair=[
(
w.doc_topic
if (w.doc_topic is not None and use_doc_topic)
else None
)
for w in windows
],
k=top_k,
batch_size=retriever_batch_size,
progress_bar=progress_bar,
**kwargs,
)
windows_candidates[task_type] = [
[p.document for p in predictions] for predictions in retriever_out
]
# end_retr = time.time()
# logger.debug(f"Retrieval took {end_retr - start_retr} seconds.")

# clean up None's
windows_candidates = {
Expand Down

0 comments on commit b82a698

Please sign in to comment.