Skip to content

Commit

Permalink
Fix pre chunking logic for negative confidences (#657)
Browse files Browse the repository at this point in the history
If `return_negative_confidence` is set `True` for prediction of a sequence 
labeling model, the return format is a `List[Dict]` instead of `List[List[Dict]]`. 
As a result, the code added for pre chunking and then remerging chunks 
into a single document needs modifications to work error free.
  • Loading branch information
rdedhia authored Oct 1, 2021
1 parent df9d8b6 commit cf485f1
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 35 deletions.
132 changes: 97 additions & 35 deletions finetune/target_models/sequence_labeling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import itertools
import copy
from collections import Counter
from collections import Counter, defaultdict
import math
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union

import tensorflow as tf
import numpy as np
Expand Down Expand Up @@ -110,8 +110,9 @@ def feed_shape_type_def(self):
def _target_encoder(self):
if self.multi_label:
return SequenceMultiLabelingEncoder(pad_token=self.config.pad_token)
return SequenceLabelingEncoder(pad_token=self.config.pad_token,
bio_tagging=self.config.bio_tagging)
return SequenceLabelingEncoder(
pad_token=self.config.pad_token, bio_tagging=self.config.bio_tagging
)


def _combine_and_format(subtokens, start, end, raw_text):
Expand Down Expand Up @@ -163,6 +164,7 @@ def _spacy_token_predictions(raw_text, tokens, probas, positions):

return spacy_results


def negative_samples(preds, labels, pad="<PAD>"):
modified_labels = []
for p, l in zip(preds, labels):
Expand Down Expand Up @@ -255,9 +257,13 @@ def finetune(
Xs and Y as fully annotated data. We only perform auto negative sampling on
fully annotated data.
"""
if self.config.chunk_long_sequences and self.config.auto_negative_sampling and Y is not None:
if (
self.config.chunk_long_sequences
and self.config.auto_negative_sampling
and Y is not None
):
# clear the saver to save memory.
self.saver.fallback # retrieve the fallback future.
self.saver.fallback # retrieve the fallback future.
self.saver = None
model_copy = copy.deepcopy(self)
model_copy._initialize()
Expand All @@ -274,12 +280,19 @@ def finetune(
# Heuristic to select batch size for prediction to limit memory consumption.
# Aim is to give us the smallest batch size that will give us full batches.
approx_max_tokens_per_doc = max(len(x) for x in Xs) / 5
approx_chunks_per_doc = approx_max_tokens_per_doc / (self.config.max_length - self.input_pipeline.chunker.total_context_width)
outer_batch_size = min(max(int(self.config.predict_batch_size / approx_chunks_per_doc), 1), self.config.predict_batch_size)
approx_chunks_per_doc = approx_max_tokens_per_doc / (
self.config.max_length - self.input_pipeline.chunker.total_context_width
)
outer_batch_size = min(
max(int(self.config.predict_batch_size / approx_chunks_per_doc), 1),
self.config.predict_batch_size,
)

with self.cached_predict():
for b_start in range(0, len(Xs), outer_batch_size):
initial_run_preds += model_copy.predict(Xs[b_start: b_start + outer_batch_size])
initial_run_preds += model_copy.predict(
Xs[b_start : b_start + outer_batch_size]
)
del model_copy

# Tag negative predictions with <PAD> label and add to label set
Expand Down Expand Up @@ -307,9 +320,13 @@ def finetune(
# TODO Determine if we need something more sophisticated for chunking
self.config.max_empty_chunk_ratio = 0.0

return super().finetune(Xs, Y=Y, context=context, update_hook=update_hook, log_hooks=log_hooks)
return super().finetune(
Xs, Y=Y, context=context, update_hook=update_hook, log_hooks=log_hooks
)

def _pre_chunk_document(self, texts: List[str]) -> Tuple[List[str], List[List[int]]]:
def _pre_chunk_document(
self, texts: List[str]
) -> Tuple[List[str], List[List[int]]]:
"""
If self.config.max_document_chars is set, "pre-chunk" any documents that
are longer than that into multiple "sub documents" to more easily process
Expand All @@ -331,36 +348,70 @@ def _pre_chunk_document(self, texts: List[str]) -> Tuple[List[str], List[List[in
if len(doc) > max_doc_len:
num_splits = math.ceil(len(doc) / max_doc_len)
for split_idx in range(num_splits):
new_texts.append(doc[split_idx * max_doc_len: (split_idx + 1) * max_doc_len])
split_indices.append(list(range(doc_idx + offset, doc_idx + offset + num_splits)))
new_texts.append(
doc[split_idx * max_doc_len : (split_idx + 1) * max_doc_len]
)
split_indices.append(
list(range(doc_idx + offset, doc_idx + offset + num_splits))
)
offset += num_splits - 1
else:
new_texts.append(doc)
split_indices.append([doc_idx + offset])

return new_texts, split_indices

def _merge_chunked_preds(self, preds: List[Dict], split_indices: List[List[int]]) -> List[Dict]:
def _merge_chunked_preds(
self,
preds: Union[List[List[Dict]], List[Dict]],
split_indices: List[List[int]],
return_negative_confidence: bool = False,
) -> List[List[Dict]]:
"""
If self.config.max_document_chars is set, text for long documents is split
into multiple "sub documents". Given model predictions, and the indices
specifying which documents have been split, join the labels for previously
split documents together.
preds is a list, where each element of the list corresponds to a document.
In most cases, this will be a List[List[Dict]], where we have a List[Dict]
for each document, which is a list of predictions.
However, if return_negative_confidence is True, we instead have a Dict for
each document, which contains the keys "negative_confidence" and "prediction"
Args:
preds: Model predictions
split_indices: Indices specifying how documents were split
return_negative_confidence: If True, expect preds to be List[Dict]
instead List[List[Dict]]
Returns:
merged_preds: Model predictions after merging documents together
"""
merged_preds = []
for pred_idxs in split_indices:
if len(pred_idxs) == 1:
merged_preds.append(preds[pred_idxs[0]])
# len(pred_idxs) > 1 indicates that a document was split into multiple
# "sub documents", for which the labels need to be merged
elif return_negative_confidence:
doc_preds = {"prediction": []}
all_doc_neg_confs = defaultdict(list)
for chunk_idx, pred_idx in enumerate(pred_idxs):
offset = chunk_idx * self.config.max_document_chars
chunk_preds = preds[pred_idx]["prediction"]
# Add offset to label start/ends so that they index correctly
# into the text after joining across pre chunks
for i in range(len(chunk_preds)):
chunk_preds[i]["start"] += offset
chunk_preds[i]["end"] += offset
doc_preds["prediction"].extend(chunk_preds)
# Create list of negative confidences for each key, to later take the max of
for key, val in preds[pred_idx]["negative_confidence"].items():
all_doc_neg_confs[key].append(val)
# Get max of negative conf values for each key
doc_preds["negative_confidence"] = {
key: np.max(vals) for key, vals in all_doc_neg_confs.items()
}
merged_preds.append(doc_preds)
else:
# len(pred_idxs) > 1 indicates that a document was split into multiple
# "sub documents", for which the labels need to be merged
doc_preds = []
for chunk_idx, pred_idx in enumerate(pred_idxs):
offset = chunk_idx * self.config.max_document_chars
Expand All @@ -371,7 +422,6 @@ def _merge_chunked_preds(self, preds: List[Dict], split_indices: List[List[int]]
chunk_preds[i]["start"] += offset
chunk_preds[i]["end"] += offset
doc_preds.extend(chunk_preds)

merged_preds.append(doc_preds)

return merged_preds
Expand Down Expand Up @@ -406,21 +456,31 @@ def predict(
)

if self.config.max_document_chars:
preds = self._merge_chunked_preds(preds, split_indices)
preds = self._merge_chunked_preds(
preds, split_indices, return_negative_confidence
)

return preds

def _predict(
self, zipped_data, per_token=False, return_negative_confidence=False, **kwargs
):
predictions = self.process_long_sequence(zipped_data, **kwargs)
return self._predict_decode(zipped_data, predictions,
per_token=per_token,
return_negative_confidence=return_negative_confidence,
**kwargs)
return self._predict_decode(
zipped_data,
predictions,
per_token=per_token,
return_negative_confidence=return_negative_confidence,
**kwargs
)

def _predict_decode(
self, zipped_data, predictions, per_token=False, return_negative_confidence=False, **kwargs
self,
zipped_data,
predictions,
per_token=False,
return_negative_confidence=False,
**kwargs
):
"""
Produces a list of most likely class labels as determined by the fine-tuned model.
Expand Down Expand Up @@ -502,21 +562,22 @@ def _get_label(label):
# or the current subsequence has the wrong label
# or bio tagging is on and we have a B- tag
if (
not doc_subseqs or per_token or
(self.config.bio_tagging and bio_prefix == "B-") or
(self.config.group_bio_tagging and group_prefix == "BG-") or
(
label != doc_labels[-1] and
(
not doc_subseqs
or per_token
or (self.config.bio_tagging and bio_prefix == "B-")
or (self.config.group_bio_tagging and group_prefix == "BG-")
or (
label != doc_labels[-1]
and (
# Merge spans if the labels are the same,
# disregarding group BIO tags
# This is safe as we already hard break on BG-, so
# we will only be merging IG- tags
not self.config.group_bio_tagging or
_get_label(label) != _get_label(doc_labels[-1])
not self.config.group_bio_tagging
or _get_label(label) != _get_label(doc_labels[-1])
)
)
):
):
assert start_idx <= end_idx, "Start: {}, End: {}".format(
start_idx, end_idx
)
Expand Down Expand Up @@ -556,7 +617,8 @@ def _get_label(label):
probs=[prob_dicts],
none_value=self.config.pad_token,
subtoken_predictions=self.config.subtoken_predictions,
bio_tagging=self.config.bio_tagging or self.config.group_bio_tagging,
bio_tagging=self.config.bio_tagging
or self.config.group_bio_tagging,
)
if per_token:
doc_annotations.append(
Expand Down
34 changes: 34 additions & 0 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,40 @@ def test_pre_chunking(self):
assert split_pred["start"] + max_doc_len * 2 == pred["start"]
assert split_pred["end"] + max_doc_len * 2 == pred["end"]

def test_pre_chunking_neg_confidences(self):
"""
Similar to test_pre_chunking(), except with running predict w/ the
return_negative_confidence flag set to True. This changes the return format
of prediction from List[List[Dict]] to List[Dict], so hits a different
code path in merging of chunks that needs to be tested.
"""
max_doc_len = 250

# Create a mix of short and long sequences for prediction
test_sequence = (
"I am a dog. A dog that's incredibly bright. I can talk, read, and write! "
)
test_sequences = [test_sequence * 10]

# Use animal test data to train model
path = os.path.join(os.path.dirname(__file__), "data", "testdata.json")
with open(path, "rt") as fp:
text, labels = json.load(fp)
self.model.finetune(text * 10, labels * 10)

# Predict w/ max_document_chars set
self.model.config.max_document_chars = max_doc_len
preds_mdc = self.model.predict(test_sequences, return_negative_confidence=True)

# Predict w/o max_document_chars set
self.model.config.max_document_chars = None
preds = self.model.predict(test_sequences, return_negative_confidence=True)

# Verify that the indices line up
for pred_mdc, pred in zip(preds_mdc[0]["prediction"], preds[0]["prediction"]):
assert pred_mdc["start"] == pred["start"]
assert pred_mdc["end"] == pred["end"]

def test_max_document_chars(self):
"""
If documents are "pre chunked" due to config.max_document_chars being set,
Expand Down

0 comments on commit cf485f1

Please sign in to comment.