Skip to content

Commit

Permalink
Merge pull request #413 from cisco/vidamoda/feature/pytorch_crf
Browse files Browse the repository at this point in the history
Integrated working PyTorch-CRF in MM
  • Loading branch information
vrdn-23 authored Jul 12, 2022
2 parents da4dc72 + e2d3a2e commit 238412f
Show file tree
Hide file tree
Showing 7 changed files with 844 additions and 30 deletions.
9 changes: 8 additions & 1 deletion mindmeld/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import re
from tempfile import mkstemp
import numpy as np

import nltk
from sklearn.metrics import make_scorer
Expand Down Expand Up @@ -534,6 +535,12 @@ def add_resource(func):
return add_resource


def np_encoder(val):
if isinstance(val, np.generic):
return val.item()
raise TypeError(f"{type(val)} cannot be serialized by JSON.")


class FileBackedList:
"""
FileBackedList implements an interface for simple list use cases
Expand All @@ -553,7 +560,7 @@ def __len__(self):
def append(self, line):
if self.file_handle is None:
self.file_handle = open(self.filename, "w")
self.file_handle.write(json.dumps(line))
self.file_handle.write(json.dumps(line, default=np_encoder))
self.file_handle.write("\n")
self.num_lines += 1

Expand Down
37 changes: 24 additions & 13 deletions mindmeld/models/tagger_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from .model import ModelConfig, Model, PytorchModel, AbstractModelFactory
from .nn_utils import get_token_classifier_cls, TokenClassificationType
from .taggers.crf import ConditionalRandomFields
from .taggers.crf import ConditionalRandomFields, TorchCrfTagger
from .taggers.memm import MemmModel
from ..exceptions import MindMeldError

Expand Down Expand Up @@ -73,12 +73,14 @@ class TaggerModel(Model):
CRF_TYPE = "crf"
MEMM_TYPE = "memm"
LSTM_TYPE = "lstm"
ALLOWED_CLASSIFIER_TYPES = [CRF_TYPE, MEMM_TYPE, LSTM_TYPE]
TORCH_CRF_TYPE = "torch-crf"
ALLOWED_CLASSIFIER_TYPES = [CRF_TYPE, MEMM_TYPE, LSTM_TYPE, TORCH_CRF_TYPE]

# for default model scoring types
ACCURACY_SCORING = "accuracy"
SEQ_ACCURACY_SCORING = "seq_accuracy"
SEQUENCE_MODELS = ["crf"]
# TODO: Rename torch-crf to crf implementation. Created https://github.com/cisco/mindmeld/issues/416 for this.
SEQUENCE_MODELS = ["crf", "torch-crf"]

DEFAULT_FEATURES = {
"bag-of-words-seq": {
Expand Down Expand Up @@ -131,6 +133,7 @@ def _get_model_constructor(self):
return {
TaggerModel.MEMM_TYPE: MemmModel,
TaggerModel.CRF_TYPE: ConditionalRandomFields,
TaggerModel.TORCH_CRF_TYPE: TorchCrfTagger,
TaggerModel.LSTM_TYPE: LstmModel,
}[classifier_type]
except KeyError as e:
Expand Down Expand Up @@ -231,7 +234,7 @@ def fit(self, examples, labels, params=None):
"There are no labels in this label set, so we don't fit the model."
)
return self
# Extract labels - label encoders are the same accross all entity recognition models
# Extract labels - label encoders are the same across all entity recognition models
self._label_encoder = get_label_encoder(self.config)
y = self._label_encoder.encode(labels, examples=examples)

Expand All @@ -245,9 +248,10 @@ def fit(self, examples, labels, params=None):
self._clf = self._fit(X, y, params)
self._current_params = params
else:
non_supported_classes = (TorchCrfTagger, LstmModel) if LstmModel is not None else TorchCrfTagger
# run cross validation to select params
if self._clf.__class__ == LstmModel:
raise MindMeldError("The LSTM model does not support cross-validation")
if isinstance(self._clf, non_supported_classes):
raise MindMeldError(f"The {type(self._clf).__name__} model does not support cross-validation")

_, best_params = self._fit_cv(X, y, groups)
self._clf = self._fit(X, y, best_params)
Expand Down Expand Up @@ -393,12 +397,19 @@ def _dump(self, path):
else:
# underneath tagger dump for LSTM model, returned `model_dir` is None for MEMM & CRF
self._clf.dump(path)
metadata.update({
"current_params": self._current_params,
"label_encoder": self._label_encoder,
"no_entities": self._no_entities,
"model_config": self.config
})
if isinstance(self._clf, TorchCrfTagger):
metadata.update({
"model": self,
"model_type": "torch-crf"
})
elif isinstance(self._clf, LstmModel):
metadata.update({
"current_params": self._current_params,
"label_encoder": self._label_encoder,
"no_entities": self._no_entities,
"model_config": self.config,
"model_type": "lstm"
})

# dump model metadata
os.makedirs(os.path.dirname(path), exist_ok=True)
Expand All @@ -421,7 +432,7 @@ def load(cls, path):

# If model is serializable, it can be loaded and used as-is. But if not serializable,
# it means we need to create an instance and load necessary details for it to be used.
if not is_serializable:
if not is_serializable and metadata.get('model_type') == 'lstm':
model = cls(metadata["model_config"])

# misc resources load
Expand Down
150 changes: 150 additions & 0 deletions mindmeld/models/taggers/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .taggers import Tagger, extract_sequence_features
from ..helpers import FileBackedList
from .pytorch_crf import TorchCrfModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -182,6 +183,155 @@ def setup_model(self, config):
self._feat_binner = FeatureBinner()


class TorchCrfTagger(Tagger):
"""A Conditional Random Fields model."""

def fit(self, X, y):
self._clf.fit(X, y)
return self

# TODO: Refactor to move initialization into init() or setup_model()
def set_params(self, **parameters):
self._clf = TorchCrfModel()
self._clf.set_params(**parameters)
return self

def get_params(self, deep=True):
return self._clf.get_params()

def predict(self, X, dynamic_resource=None):
return self._clf.predict(X)

def predict_proba(self, examples, config, resources):
"""
Args:
examples (list of mindmeld.core.Query): a list of queries to predict on
config (ModelConfig): The ModelConfig which may contain information used for feature
extraction
resources (dict): Resources which may be used for this model's feature extraction
Returns:
list of tuples of (mindmeld.core.QueryEntity): a list of predicted labels \
with confidence scores
"""
X, _, _ = self.extract_features(examples, config, resources, in_memory=True)
seq = self._clf.predict(X)
marginals_dict = self._clf.predict_marginals(X)
marginal_tuples = []
for query_index, query_seq in enumerate(seq):
query_marginal_tuples = []
for i, tag in enumerate(query_seq):
query_marginal_tuples.append([tag, marginals_dict[query_index][i][tag]])
marginal_tuples.append(query_marginal_tuples)
return marginal_tuples

def predict_proba_distribution(self, examples, config, resources):
"""
Args:
examples (list of mindmeld.core.Query): a list of queries to predict on
config (ModelConfig): The ModelConfig which may contain information used for feature
extraction
resources (dict): Resources which may be used for this model's feature extraction
Returns:
list of list of ((list of str) and (list of float)): a list of predicted labels \
with confidence scores
"""
X, _, _ = self.extract_features(examples, config, resources, in_memory=True)
seq = self._clf.predict(X)
marginals_dict = self._clf.predict_marginals(X)
predictions = []
tag_maps = []
for query_index, query_seq in enumerate(seq):
tags = []
preds = []
for i, _ in enumerate(query_seq):
tags.append(list(marginals_dict[query_index][i].keys()))
preds.append(list(marginals_dict[query_index][i].values()))
tag_maps.extend(tags)
predictions.extend(preds)
return [[tag_maps, predictions]]

def extract_features(self,
examples,
config,
resources,
y=None,
fit=False,
in_memory=STORE_CRF_FEATURES_IN_MEMORY):
"""Transforms a list of examples into a feature matrix.
Args:
examples (list of mindmeld.core.Query): a list of queries
config (ModelConfig): The ModelConfig which may contain information used for feature
extraction
resources (dict): Resources which may be used for this model's feature extraction
Returns:
(list of list of str): features in CRF suite format
"""
# Extract features and classes
feats = []
# The FileBackedList now has support for indexing but it still loads the list
# eventually into memory cause of the scikit-learn train_test_split function.
# Created https://github.com/cisco/mindmeld/issues/417 for this.
if not in_memory:
logger.warning("PyTorch CRF does not currently support STORE_CRF_FEATURES_IN_MEMORY. This may be fixed in "
"a future release.")
for _, example in enumerate(examples):
feats.append(self.extract_example_features(example, config, resources))
X = self._preprocess_data(feats, fit)
return X, y, None

@staticmethod
def extract_example_features(example, config, resources):
"""Extracts feature dicts for each token in an example.
Args:
example (mindmeld.core.Query): A query.
config (ModelConfig): The ModelConfig which may contain information used for feature \
extraction.
resources (dict): Resources which may be used for this model's feature extraction.
Returns:
list[dict]: Features.
"""
return extract_sequence_features(
example, config.example_type, config.features, resources
)

def _preprocess_data(self, X, fit=False):
"""Converts data into formats of CRF suite.
Args:
X (list of list of dict): features of an example
fit (bool, optional): True if processing data at fit time, false for predict time.
Returns:
(list of list of str): features in CRF suite format
"""
if fit:
self._feat_binner.fit(X)

# We want to use a list for in-memory and a LineGenerator for disk based
new_X = X.__class__()
# Maintain append code structure to make sure it supports in-memory and FileBackedList()
for feat_seq in self._feat_binner.transform(X):
new_X.append(feat_seq)
return new_X

def setup_model(self, config):
self._feat_binner = FeatureBinner()

@property
def is_serializable(self):
return False

def dump(self, path):
best_model_save_path = os.path.join(os.path.split(path)[0], "best_crf_wts.pt")
self._clf.save_best_weights_path(best_model_save_path)


# Feature extraction for CRF


Expand Down
Loading

0 comments on commit 238412f

Please sign in to comment.