Skip to content

Commit

Permalink
Add features:
Browse files Browse the repository at this point in the history
1. Add PretrainedEmbedding to support fusion with pretrained embeddings
2. Update preprocess and features to support oov_idx based masking for PretrainedEmbedding
  • Loading branch information
xpai committed Oct 13, 2023
1 parent 93bae1d commit 893c59b
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 148 deletions.
6 changes: 3 additions & 3 deletions benchmark/run_expid.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@
gc.collect()

logging.info('******** Test evaluation ********')
test_gen = H5DataLoader(feature_map, stage='test', **params).make_iterator()
test_result = {}
if test_gen:
test_result = model.evaluate(test_gen)
if params["test_data"]:
test_gen = H5DataLoader(feature_map, stage='test', **params).make_iterator()
test_result = model.evaluate(test_gen)

result_filename = Path(args['config']).name.replace(".yaml", "") + '.csv'
with open(result_filename, 'a+') as fw:
Expand Down
53 changes: 27 additions & 26 deletions fuxictr/preprocess/feature_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import logging
import json
import re
from functools import partial
import sklearn.preprocessing as sklearn_preprocess
from fuxictr.features import FeatureMap
from .utils import Tokenizer, Normalizer
from .tokenizer import Tokenizer
from .normalizer import Normalizer


class FeatureProcessor(object):
Expand Down Expand Up @@ -119,6 +119,23 @@ def fit(self, train_ddf, min_categr_count=1, num_buckets=10, **kwargs):
min_categr_count=min_categr_count)
else:
raise NotImplementedError("feature_col={}".format(feature_col))

# Expand vocab from pretrained_emb
for name, spec in self.feature_map.features.items():
if "pretrained_emb" in col:
logging.info("Loading pretrained embedding: " + name)
self.feature_map.features[name]["pretrained_emb"] = "pretrained_emb.h5"
self.feature_map.features[name]["freeze_emb"] = col.get("freeze_emb", True)
self.feature_map.features[name]["pretrain_usage"] = col.get("pretrain_usage", "init")
tokenizer = self.processor_dict[name + "::tokenizer"]
tokenizer.load_pretrained_embedding(name,
self.dtype_dict[name],
col["pretrained_emb"],
os.path.join(self.data_dir, "pretrained_emb.h5"),
freeze_emb=col.get("freeze_emb", True))
self.processor_dict[name + "::tokenizer"] = tokenizer
self.feature_map.features[name]["vocab_size"] = tokenizer.vocab_size()

# Handle share_embedding vocab re-assign
for name, spec in self.feature_map.features.items():
if spec["type"] == "numeric":
Expand Down Expand Up @@ -176,6 +193,8 @@ def fit_categorical_col(self, col, col_values, min_categr_count=1, num_buckets=1
self.feature_map.features[name]["embedding_dim"] = col["embedding_dim"]
if "emb_output_dim" in col:
self.feature_map.features[name]["emb_output_dim"] = col["emb_output_dim"]
if "pretrain_dim" in col:
self.feature_map.features[name]["pretrain_dim"] = col["pretrain_dim"]
if "category_processor" not in col:
tokenizer = Tokenizer(min_freq=min_categr_count,
na_value=col.get("fill_na", ""),
Expand All @@ -189,20 +208,10 @@ def fit_categorical_col(self, col, col_values, min_categr_count=1, num_buckets=1
self.feature_map.features[col["share_embedding"]] \
.update({"oov_idx": self.processor_dict[tknzr_name].vocab["__OOV__"],
"vocab_size": self.processor_dict[tknzr_name].vocab_size()})
else:
if "pretrained_emb" in col:
logging.info("Loading pretrained embedding: " + name)
self.feature_map.features[name]["pretrained_emb"] = "pretrained_emb.h5"
self.feature_map.features[name]["freeze_emb"] = col.get("freeze_emb", True)
tokenizer.load_pretrained_embedding(name,
self.dtype_dict[name],
col["pretrained_emb"],
os.path.join(self.data_dir, "pretrained_emb.h5"),
freeze_emb=col.get("freeze_emb", True))
self.processor_dict[name + "::tokenizer"] = tokenizer
self.feature_map.features[name].update({"padding_idx": 0,
"oov_idx": tokenizer.vocab["__OOV__"],
"vocab_size": tokenizer.vocab_size()})
self.processor_dict[name + "::tokenizer"] = tokenizer
else:
category_processor = col["category_processor"]
self.feature_map.features[name]["category_processor"] = category_processor
Expand Down Expand Up @@ -236,6 +245,8 @@ def fit_sequence_col(self, col, col_values, min_categr_count=1):
self.feature_map.features[name]["embedding_dim"] = col["embedding_dim"]
if "emb_output_dim" in col:
self.feature_map.features[name]["emb_output_dim"] = col["emb_output_dim"]
if "pretrain_dim" in col:
self.feature_map.features[name]["pretrain_dim"] = col["pretrain_dim"]
splitter = col.get("splitter")
na_value = col.get("fill_na", "")
max_len = col.get("max_len", 0)
Expand All @@ -252,21 +263,11 @@ def fit_sequence_col(self, col, col_values, min_categr_count=1):
self.feature_map.features[col["share_embedding"]] \
.update({"oov_idx": self.processor_dict[tknzr_name].vocab["__OOV__"],
"vocab_size": self.processor_dict[tknzr_name].vocab_size()})
else:
if "pretrained_emb" in col:
logging.info("Loading pretrained embedding: " + name)
self.feature_map.features[name]["pretrained_emb"] = "pretrained_emb.h5"
self.feature_map.features[name]["freeze_emb"] = col.get("freeze_emb", True)
tokenizer.load_pretrained_embedding(name,
self.dtype_dict[name],
col["pretrained_emb"],
os.path.join(self.data_dir, "pretrained_emb.h5"),
freeze_emb=col.get("freeze_emb", True))
self.processor_dict[name + "::tokenizer"] = tokenizer
self.feature_map.features[name].update({"padding_idx": 0,
"oov_idx": tokenizer.vocab["__OOV__"],
"vocab_size": tokenizer.vocab_size(),
"max_len": tokenizer.max_len})
self.processor_dict[name + "::tokenizer"] = tokenizer
"max_len": tokenizer.max_len,
"vocab_size": tokenizer.vocab_size()})

def transform(self, ddf):
logging.info("Transform feature columns...")
Expand Down
43 changes: 43 additions & 0 deletions fuxictr/preprocess/normalizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# =========================================================================
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================

import numpy as np
import sklearn.preprocessing as sklearn_preprocess


class Normalizer(object):
def __init__(self, normalizer):
if not callable(normalizer):
self.callable = False
if normalizer in ['StandardScaler', 'MinMaxScaler']:
self.normalizer = getattr(sklearn_preprocess, normalizer)()
else:
raise NotImplementedError('normalizer={}'.format(normalizer))
else:
# normalizer is a method
self.normalizer = normalizer
self.callable = True

def fit(self, X):
if not self.callable:
null_index = np.isnan(X)
self.normalizer.fit(X[~null_index].reshape(-1, 1))

def normalize(self, X):
if self.callable:
return self.normalizer(X)
else:
return self.normalizer.transform(X.reshape(-1, 1)).flatten()
88 changes: 23 additions & 65 deletions fuxictr/preprocess/utils.py → fuxictr/preprocess/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
from tqdm import tqdm
import logging
import sklearn.preprocessing as sklearn_preprocess
from keras_preprocessing.sequence import pad_sequences
from concurrent.futures import ProcessPoolExecutor, as_completed


class Tokenizer(object):
def __init__(self, num_IDs=None, na_value="", min_freq=1, splitter=None, remap=True,
def __init__(self, max_features=None, na_value="", min_freq=1, splitter=None, remap=True,
lower=False, max_len=0, padding="pre", num_workers=8):
self._num_IDs = num_IDs
self._max_features = max_features
self._na_value = na_value
self._min_freq = min_freq
self._lower = lower
Expand Down Expand Up @@ -60,16 +61,18 @@ def fit_on_texts(self, texts):

def build_vocab(self, word_counts):
word_counts = word_counts.items()
if self._num_IDs:
word_counts = sorted(word_counts, key=lambda x: (-x[1], x[0]))
word_counts = word_counts[0:self._num_IDs]
# sort to guarantee the determinism of index order
word_counts = sorted(word_counts, key=lambda x: (-x[1], x[0]))
if self._max_features: # keep the most frequent features
word_counts = word_counts[0:self._max_features]
words = []
for token, count in word_counts:
if count >= self._min_freq:
if token != self._na_value:
words.append(token.lower() if self._lower else token)
else:
break # already sorted in decending order
if self.remap:
words.sort() # sort to guarantee the determinism of index order
self.vocab = dict((token, idx) for idx, token in enumerate(words, 1))
else:
self.vocab = dict((token, int(token)) for token in words)
Expand All @@ -94,7 +97,7 @@ def merge_vocab(self, shared_tokenizer):
def vocab_size(self):
return max(self.vocab.values()) + 1

def add_vocab(self, word_list):
def update_vocab(self, word_list):
new_words = 0
for word in word_list:
if word not in self.vocab:
Expand All @@ -103,12 +106,18 @@ def add_vocab(self, word_list):
if new_words > 0:
self.vocab["__OOV__"] = self.vocab_size()

def expand_pretrain_vocab(self, word_list):
# Do not update OOV index here
for word in word_list:
if word not in self.vocab:
self.vocab[word] = self.vocab_size()

def encode_meta(self, values):
word_counts = Counter(list(values))
if len(self.vocab) == 0:
self.build_vocab(word_counts)
else:
self.add_vocab(word_counts.keys())
else: # for considering meta data in test data
self.update_vocab(word_counts.keys())
meta_values = [self.vocab.get(x, self.vocab["__OOV__"]) for x in values]
return np.array(meta_values)

Expand All @@ -129,14 +138,15 @@ def encode_sequence(self, texts):
return np.array(sequence_list)

def load_pretrained_embedding(self, feature_name, feature_dtype, pretrain_path,
output_path, freeze_emb=True):
output_path, freeze_emb=True, expand_pretrain_vocab=True):
with h5py.File(pretrain_path, 'r') as hf:
keys = hf["key"][:]
keys = keys.astype(feature_dtype) # in case mismatch of dtype between int and str
pretrained_vocab = dict(zip(keys, range(len(keys))))
pretrained_emb = hf["value"][:]
# update vocab with pretrained keys, in case new token ids appear in validation or test set
self.add_vocab(pretrained_vocab.keys())
if expand_pretrain_vocab:
self.expand_pretrain_vocab(pretrained_vocab.keys())

logging.info("{}\'s pretrained_emb shape: {}".format(feature_name, pretrained_emb.shape))
embedding_dim = pretrained_emb.shape[1]
Expand All @@ -146,7 +156,8 @@ def load_pretrained_embedding(self, feature_name, feature_dtype, pretrain_path,
embedding_matrix = np.random.normal(loc=0, scale=1.e-4, size=(self.vocab_size(), embedding_dim))
embedding_matrix[self.vocab["__PAD__"], :] = 0. # set as zero embedding for PAD
for word in pretrained_vocab.keys():
embedding_matrix[self.vocab[word]] = pretrained_emb[pretrained_vocab[word]]
if word in self.vocab:
embedding_matrix[self.vocab[word]] = pretrained_emb[pretrained_vocab[word]]
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with h5py.File(output_path, 'a') as hf: # Add different embeddings to a single h5 file
hf.create_dataset(feature_name, data=embedding_matrix)
Expand All @@ -166,56 +177,3 @@ def count_tokens(texts, splitter):
for token in text_split:
word_counts[token] += 1
return word_counts, max_len


class Normalizer(object):
def __init__(self, normalizer):
if not callable(normalizer):
self.callable = False
if normalizer in ['StandardScaler', 'MinMaxScaler']:
self.normalizer = getattr(sklearn_preprocess, normalizer)()
else:
raise NotImplementedError('normalizer={}'.format(normalizer))
else:
# normalizer is a method
self.normalizer = normalizer
self.callable = True

def fit(self, X):
if not self.callable:
null_index = np.isnan(X)
self.normalizer.fit(X[~null_index].reshape(-1, 1))

def normalize(self, X):
if self.callable:
return self.normalizer(X)
else:
return self.normalizer.transform(X.reshape(-1, 1)).flatten()


def pad_sequences(sequences, maxlen=None, dtype='int32',
padding='pre', truncating='pre', value=0.):
""" Pads sequences (list of list) to the ndarray of same length
This is an equivalent implementation of tf.keras.preprocessing.sequence.pad_sequences
"""

assert padding in ["pre", "post"], "Invalid padding={}.".format(padding)
assert truncating in ["pre", "post"], "Invalid truncating={}.".format(truncating)

if maxlen is None:
maxlen = max(len(x) for x in sequences)
arr = np.full((len(sequences), maxlen), value, dtype=dtype)
for idx, x in enumerate(sequences):
if len(x) == 0:
continue # empty list
if truncating == 'pre':
trunc = x[-maxlen:]
else:
trunc = x[:maxlen]
trunc = np.asarray(trunc, dtype=dtype)

if padding == 'pre':
arr[idx, -len(trunc):] = trunc
else:
arr[idx, :len(trunc)] = trunc
return arr
2 changes: 1 addition & 1 deletion fuxictr/pytorch/layers/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .feature_embedding import *
from .pretrained_embedding import *
Loading

0 comments on commit 893c59b

Please sign in to comment.