Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrated working PyTorch-CRF in MM #413

Merged
merged 22 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions mindmeld/models/tagger_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from .model import ModelConfig, Model, PytorchModel, AbstractModelFactory
from .nn_utils import get_token_classifier_cls, TokenClassificationType
from .taggers.crf import ConditionalRandomFields
from .taggers.crf import ConditionalRandomFields, PyTorchCRF
vrdn-23 marked this conversation as resolved.
Show resolved Hide resolved
from .taggers.memm import MemmModel
from ..exceptions import MindMeldError

Expand Down Expand Up @@ -73,12 +73,13 @@ class TaggerModel(Model):
CRF_TYPE = "crf"
MEMM_TYPE = "memm"
LSTM_TYPE = "lstm"
ALLOWED_CLASSIFIER_TYPES = [CRF_TYPE, MEMM_TYPE, LSTM_TYPE]
TORCH_CRF_TYPE = "torch-crf"
ALLOWED_CLASSIFIER_TYPES = [CRF_TYPE, MEMM_TYPE, LSTM_TYPE, TORCH_CRF_TYPE]

# for default model scoring types
ACCURACY_SCORING = "accuracy"
SEQ_ACCURACY_SCORING = "seq_accuracy"
SEQUENCE_MODELS = ["crf"]
SEQUENCE_MODELS = ["crf", "torch-crf"]

vrdn-23 marked this conversation as resolved.
Show resolved Hide resolved
DEFAULT_FEATURES = {
"bag-of-words-seq": {
Expand Down Expand Up @@ -131,6 +132,7 @@ def _get_model_constructor(self):
return {
TaggerModel.MEMM_TYPE: MemmModel,
TaggerModel.CRF_TYPE: ConditionalRandomFields,
TaggerModel.TORCH_CRF_TYPE: PyTorchCRF,
TaggerModel.LSTM_TYPE: LstmModel,
}[classifier_type]
except KeyError as e:
Expand Down Expand Up @@ -231,7 +233,7 @@ def fit(self, examples, labels, params=None):
"There are no labels in this label set, so we don't fit the model."
)
return self
# Extract labels - label encoders are the same accross all entity recognition models
# Extract labels - label encoders are the same across all entity recognition models
self._label_encoder = get_label_encoder(self.config)
y = self._label_encoder.encode(labels, examples=examples)

Expand All @@ -246,8 +248,8 @@ def fit(self, examples, labels, params=None):
self._current_params = params
else:
# run cross validation to select params
if self._clf.__class__ == LstmModel:
raise MindMeldError("The LSTM model does not support cross-validation")
if self._clf.__class__ in (LstmModel, PyTorchCRF):
raise MindMeldError(f"The {self._clf.__class__.__name__} model does not support cross-validation")
snow0x2d0 marked this conversation as resolved.
Show resolved Hide resolved

_, best_params = self._fit_cv(X, y, groups)
self._clf = self._fit(X, y, best_params)
Expand Down
145 changes: 145 additions & 0 deletions mindmeld/models/taggers/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .taggers import Tagger, extract_sequence_features
from ..helpers import FileBackedList
from .pytorch_crf import TorchCRF

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -182,6 +183,150 @@ def setup_model(self, config):
self._feat_binner = FeatureBinner()


class PyTorchCRF(Tagger):
"""A Conditional Random Fields model."""

@staticmethod
def _predict_proba(X):
vrdn-23 marked this conversation as resolved.
Show resolved Hide resolved
del X
pass
snow0x2d0 marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def load(model_path):
vrdn-23 marked this conversation as resolved.
Show resolved Hide resolved
del model_path
pass
snow0x2d0 marked this conversation as resolved.
Show resolved Hide resolved

def fit(self, X, y):
self._clf.fit(X, y)
return self

def set_params(self, **parameters):
self._clf = TorchCRF()
snow0x2d0 marked this conversation as resolved.
Show resolved Hide resolved
self._clf.set_params(**parameters)
return self

def get_params(self, deep=True):
return self._clf.get_params()

def predict(self, X, dynamic_resource=None):
return self._clf.predict(X)

def predict_proba(self, examples, config, resources):
"""
Args:
examples (list of mindmeld.core.Query): a list of queries to predict on
config (ModelConfig): The ModelConfig which may contain information used for feature
extraction
resources (dict): Resources which may be used for this model's feature extraction

Returns:
list of tuples of (mindmeld.core.QueryEntity): a list of predicted labels \
with confidence scores
"""
X, _, _ = self.extract_features(examples, config, resources, in_memory=True)
seq = self._clf.predict(X)
marginals_dict = self._clf.predict_marginals(X)
marginal_tuples = []
for query_index, query_seq in enumerate(seq):
query_marginal_tuples = []
for i, tag in enumerate(query_seq):
query_marginal_tuples.append([tag, marginals_dict[query_index][i][tag]])
marginal_tuples.append(query_marginal_tuples)
return marginal_tuples

def predict_proba_distribution(self, examples, config, resources):
vrdn-23 marked this conversation as resolved.
Show resolved Hide resolved
"""
Args:
examples (list of mindmeld.core.Query): a list of queries to predict on
config (ModelConfig): The ModelConfig which may contain information used for feature
extraction
resources (dict): Resources which may be used for this model's feature extraction

Returns:
list of tuples of (mindmeld.core.QueryEntity): a list of predicted labels \
with confidence scores
vrdn-23 marked this conversation as resolved.
Show resolved Hide resolved
"""
X, _, _ = self.extract_features(examples, config, resources, in_memory=True)
seq = self._clf.predict(X)
marginals_dict = self._clf.predict_marginals(X)
predictions = []
tag_maps = []
for query_index, query_seq in enumerate(seq):
tags = []
preds = []
for i in range(len(query_seq)):
vrdn-23 marked this conversation as resolved.
Show resolved Hide resolved
tags.append(list(marginals_dict[query_index][i].keys()))
preds.append(list(marginals_dict[query_index][i].values()))
tag_maps.extend(tags)
predictions.extend(preds)
return [[tag_maps, predictions]]

def extract_features(self,
examples,
config,
resources,
y=None,
fit=False,
in_memory=STORE_CRF_FEATURES_IN_MEMORY):
"""Transforms a list of examples into a feature matrix.

Args:
examples (list of mindmeld.core.Query): a list of queries
config (ModelConfig): The ModelConfig which may contain information used for feature
extraction
resources (dict): Resources which may be used for this model's feature extraction

Returns:
(list of list of str): features in CRF suite format
"""
# Extract features and classes
feats = [] if in_memory else FileBackedList()
for _, example in enumerate(examples):
feats.append(self.extract_example_features(example, config, resources))
X = self._preprocess_data(feats, fit)
return X, y, None

@staticmethod
def extract_example_features(example, config, resources):
"""Extracts feature dicts for each token in an example.

Args:
example (mindmeld.core.Query): A query.
config (ModelConfig): The ModelConfig which may contain information used for feature \
extraction.
resources (dict): Resources which may be used for this model's feature extraction.

Returns:
list[dict]: Features.
"""
return extract_sequence_features(
example, config.example_type, config.features, resources
)

def _preprocess_data(self, X, fit=False):
"""Converts data into formats of CRF suite.

Args:
X (list of dict): features of an example
fit (bool, optional): True if processing data at fit time, false for predict time.

Returns:
(list of list of str): features in CRF suite format
"""
if fit:
self._feat_binner.fit(X)

# We want to use a list for in-memory and a LineGenerator for disk based
new_X = X.__class__()
# Maintain append code structure to make sure it supports in-memory and FileBackedList()
for feat_seq in self._feat_binner.transform(X):
new_X.append(feat_seq)
return new_X

def setup_model(self, config):
self._feat_binner = FeatureBinner()


# Feature extraction for CRF


Expand Down
Loading