Skip to content

Commit

Permalink
adding token classification (#376)
Browse files Browse the repository at this point in the history
* adding ner
  • Loading branch information
liususan091219 authored Jan 3, 2022
1 parent 8602def commit 207b693
Show file tree
Hide file tree
Showing 10 changed files with 1,118 additions and 159 deletions.
24 changes: 18 additions & 6 deletions flaml/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .data import (
concat,
CLASSIFICATION,
TOKENCLASSIFICATION,
TS_FORECAST,
FORECAST,
REGRESSION,
Expand Down Expand Up @@ -866,6 +867,8 @@ def _validate_data(

# check the validity of input dimensions under the nlp mode
if _is_nlp_task(self._state.task):
from .nlp.utils import is_a_list_of_str

is_all_str = True
is_all_list = True
for column in X.columns:
Expand All @@ -874,17 +877,25 @@ def _validate_data(
"string",
), "If the task is an NLP task, X can only contain text columns"
for each_cell in X[column]:
if each_cell:
if each_cell is not None:
is_str = isinstance(each_cell, str)
is_list_of_int = isinstance(each_cell, list) and all(
isinstance(x, int) for x in each_cell
)
assert is_str or is_list_of_int, (
"Each column of the input must either be str (untokenized) "
"or a list of integers (tokenized)"
)
is_list_of_str = is_a_list_of_str(each_cell)
if self._state.task == TOKENCLASSIFICATION:
assert is_list_of_str, (
"For the token-classification task, the input column needs to be a list of string,"
"instead of string, e.g., ['EU', 'rejects','German', 'call','to','boycott','British','lamb','.',].",
"For more examples, please refer to test/nlp/test_autohf_tokenclassification.py",
)
else:
assert is_str or is_list_of_int, (
"Each column of the input must either be str (untokenized) "
"or a list of integers (tokenized)"
)
is_all_str &= is_str
is_all_list &= is_list_of_int
is_all_list &= is_list_of_int or is_list_of_str
assert is_all_str or is_all_list, (
"Currently FLAML only supports two modes for NLP: either all columns of X are string (non-tokenized), "
"or all columns of X are integer ids (tokenized)"
Expand Down Expand Up @@ -963,6 +974,7 @@ def _prepare_data(self, eval_method, split_ratio, n_splits):
and self._auto_augment
and self._state.fit_kwargs.get("sample_weight") is None
and self._split_type in ["stratified", "uniform"]
and self._state.task != TOKENCLASSIFICATION
):
# logger.info(f"label {pd.unique(y_train_all)}")
label_set, counts = np.unique(y_train_all, return_counts=True)
Expand Down
8 changes: 5 additions & 3 deletions flaml/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
# TODO: if your task is not specified in here, define your task as an all-capitalized word
SEQCLASSIFICATION = "seq-classification"
MULTICHOICECLASSIFICATION = "multichoice-classification"
TOKENCLASSIFICATION = "token-classification"
CLASSIFICATION = (
"binary",
"multi",
"classification",
SEQCLASSIFICATION,
MULTICHOICECLASSIFICATION,
TOKENCLASSIFICATION,
)
SEQREGRESSION = "seq-regression"
REGRESSION = ("regression", SEQREGRESSION)
Expand All @@ -34,6 +36,7 @@
SEQREGRESSION,
SEQCLASSIFICATION,
MULTICHOICECLASSIFICATION,
TOKENCLASSIFICATION,
)


Expand Down Expand Up @@ -354,11 +357,10 @@ def fit_transform(self, X: Union[DataFrame, np.array], y, task):
datetime_columns,
)
self._drop = drop

if (
task in CLASSIFICATION
or not pd.api.types.is_numeric_dtype(y)
(task in CLASSIFICATION or not pd.api.types.is_numeric_dtype(y))
and task not in NLG_TASKS
and task != TOKENCLASSIFICATION
):
from sklearn.preprocessing import LabelEncoder

Expand Down
13 changes: 12 additions & 1 deletion flaml/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,21 @@ def metric_loss_score(
score = metric.compute(predictions=y_predict, references=y_true)[
metric_name
].mid.fmeasure
elif metric_name == "seqeval":
y_true = [
[x for x in each_y_true if x != -100] for each_y_true in y_true
]
y_pred = [
y_predict[each_idx][: len(y_true[each_idx])]
for each_idx in range(len(y_predict))
]
score = metric.compute(predictions=y_pred, references=y_true)[
"overall_accuracy"
]
else:
score = metric.compute(predictions=y_predict, references=y_true)[
metric_name
]

except ImportError:
raise Exception(
metric_name
Expand Down Expand Up @@ -226,6 +236,7 @@ def sklearn_metric_loss_score(
Returns:
score: A float number of the loss, the lower the better.
"""

metric_name = metric_name.lower()

if "r2" == metric_name:
Expand Down
28 changes: 15 additions & 13 deletions flaml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TS_VALUE_COL,
SEQCLASSIFICATION,
SEQREGRESSION,
TOKENCLASSIFICATION,
SUMMARIZATION,
NLG_TASKS,
MULTICHOICECLASSIFICATION,
Expand Down Expand Up @@ -310,7 +311,8 @@ def __init__(self, task="seq-classification", **config):

@staticmethod
def _join(X_train, y_train):
y_train = DataFrame(y_train, columns=["label"], index=X_train.index)
y_train = DataFrame(y_train, index=X_train.index)
y_train.columns = ["label"]
train_df = X_train.join(y_train)
return train_df

Expand Down Expand Up @@ -370,17 +372,12 @@ def _init_hpo_args(self, automl_fit_kwargs: dict = None):
self.custom_hpo_args = custom_hpo_args

def _preprocess(self, X, y=None, **kwargs):
from .nlp.utils import tokenize_text

# is_str = False
# for each_type in ["string", "str"]:
# try:
# is_str = is_str or (X.dtypes[0] == each_type)
# except TypeError:
# pass
from .nlp.utils import tokenize_text, is_a_list_of_str

is_str = str(X.dtypes[0]) in ("string", "str")
is_list_of_str = is_a_list_of_str(X[list(X.keys())[0]].to_list()[0])

if is_str:
if is_str or is_list_of_str:
return tokenize_text(
X=X, Y=y, task=self._task, custom_hpo_args=self.custom_hpo_args
)
Expand All @@ -391,6 +388,7 @@ def fit(self, X_train: DataFrame, y_train: Series, budget=None, **kwargs):
from transformers import EarlyStoppingCallback
from transformers.trainer_utils import set_seed
from transformers import AutoTokenizer
from transformers.data import DataCollatorWithPadding

import transformers
from datasets import Dataset
Expand Down Expand Up @@ -455,7 +453,7 @@ def on_epoch_end(self, args, state, control, **callback_kwargs):
X_val = kwargs.get("X_val")
y_val = kwargs.get("y_val")

if self._task not in NLG_TASKS:
if (self._task not in NLG_TASKS) and (self._task != TOKENCLASSIFICATION):
self._X_train, _ = self._preprocess(X=X_train, **kwargs)
self._y_train = y_train
else:
Expand All @@ -474,7 +472,7 @@ def on_epoch_end(self, args, state, control, **callback_kwargs):
# make sure they are the same

if X_val is not None:
if self._task not in NLG_TASKS:
if (self._task not in NLG_TASKS) and (self._task != TOKENCLASSIFICATION):
self._X_val, _ = self._preprocess(X=X_val, **kwargs)
self._y_val = y_val
else:
Expand Down Expand Up @@ -648,6 +646,8 @@ def _compute_metrics_by_dataset_name(self, eval_pred):
predictions = (
np.squeeze(predictions)
if self._task == SEQREGRESSION
else np.argmax(predictions, axis=2)
if self._task == TOKENCLASSIFICATION
else np.argmax(predictions, axis=1)
)
return {
Expand Down Expand Up @@ -724,7 +724,9 @@ def predict(self, X_test):
if self._task == SEQCLASSIFICATION:
return np.argmax(predictions.predictions, axis=1)
elif self._task == SEQREGRESSION:
return predictions.predictions
return predictions.predictions.reshape((len(predictions.predictions),))
elif self._task == TOKENCLASSIFICATION:
return np.argmax(predictions.predictions, axis=2)
# TODO: elif self._task == your task, return the corresponding prediction
# e.g., if your task == QUESTIONANSWERING, you need to return the answer instead
# of the index
Expand Down
5 changes: 5 additions & 0 deletions flaml/nlp/huggingface/switch_head_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
if transformers.__version__.startswith("3"):
from transformers.modeling_electra import ElectraClassificationHead
from transformers.modeling_roberta import RobertaClassificationHead
from transformers.models.electra.modeling_electra import ElectraForTokenClassification
from transformers.models.roberta.modeling_roberta import RobertaForTokenClassification

else:
from transformers.models.electra.modeling_electra import ElectraClassificationHead
from transformers.models.roberta.modeling_roberta import RobertaClassificationHead
from transformers.models.electra.modeling_electra import ElectraForTokenClassification
from transformers.models.roberta.modeling_roberta import RobertaForTokenClassification

MODEL_CLASSIFICATION_HEAD_MAPPING = OrderedDict(
[
Expand Down
Loading

0 comments on commit 207b693

Please sign in to comment.