Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed save and load NARRE model #517

Merged
merged 2 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 92 additions & 84 deletions cornac/models/narre/narre.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, initializers, Input
from tensorflow.python.keras.preprocessing.sequence import pad_sequences

from ...utils import get_rng
from ...utils.init_utils import uniform


class TextProcessor(keras.Model):
def __init__(self, max_text_length, filters=64, kernel_sizes=[3], dropout_rate=0.5, name=''):
super(TextProcessor, self).__init__(name=name)
def __init__(self, max_text_length, filters=64, kernel_sizes=[3], dropout_rate=0.5, name='', **kwargs):
super(TextProcessor, self).__init__(name=name, **kwargs)
self.max_text_length = max_text_length
self.filters = filters
self.kernel_sizes = kernel_sizes
Expand All @@ -51,7 +52,6 @@ def call(self, inputs, training=False):


def get_data(batch_ids, train_set, max_text_length, by='user', max_num_review=None):
from tensorflow.python.keras.preprocessing.sequence import pad_sequences
batch_reviews, batch_id_reviews, batch_num_reviews = [], [], []
review_group = train_set.review_text.user_review if by == 'user' else train_set.review_text.item_review
for idx in batch_ids:
Expand All @@ -65,8 +65,8 @@ def get_data(batch_ids, train_set, max_text_length, by='user', max_num_review=No
reviews = train_set.review_text.batch_seq(review_ids, max_length=max_text_length)
batch_reviews.append(reviews)
batch_num_reviews.append(len(reviews))
batch_reviews = pad_sequences(batch_reviews, padding="post")
batch_id_reviews = pad_sequences(batch_id_reviews, padding="post")
batch_reviews = pad_sequences(batch_reviews, maxlen=max_num_review, padding="post")
batch_id_reviews = pad_sequences(batch_id_reviews, maxlen=max_num_review, padding="post")
batch_num_reviews = np.array(batch_num_reviews)
return batch_reviews, batch_id_reviews, batch_num_reviews

Expand All @@ -80,13 +80,69 @@ def __init__(self, init_value=0.0, name="global_bias"):
def build(self, input_shape):
self.global_bias = self.add_weight(shape=1,
initializer=tf.keras.initializers.Constant(self.init_value),
trainable=True)
trainable=True, name="add_weight")

def call(self, inputs):
return inputs + self.global_bias

class Model:
def __init__(self, n_users, n_items, vocab, global_mean, n_factors=32, embedding_size=100, id_embedding_size=32, attention_size=16, kernel_sizes=[3], n_filters=64, dropout_rate=0.5, max_text_length=50, pretrained_word_embeddings=None, verbose=False, seed=None):
class Model(keras.Model):
def __init__(self, n_users, n_items, n_vocab, embedding_matrix, global_mean, n_factors=32, embedding_size=100, id_embedding_size=32, attention_size=16, kernel_sizes=[3], n_filters=64, dropout_rate=0.5, max_text_length=50):
super().__init__()
self.l_user_review_embedding = layers.Embedding(n_vocab, embedding_size, embeddings_initializer=embedding_matrix, mask_zero=True, name="layer_user_review_embedding")
self.l_item_review_embedding = layers.Embedding(n_vocab, embedding_size, embeddings_initializer=embedding_matrix, mask_zero=True, name="layer_item_review_embedding")
self.l_user_iid_embedding = layers.Embedding(n_items, id_embedding_size, embeddings_initializer="uniform", name="user_iid_embedding")
self.l_item_uid_embedding = layers.Embedding(n_users, id_embedding_size, embeddings_initializer="uniform", name="item_uid_embedding")
self.l_user_embedding = layers.Embedding(n_users, id_embedding_size, embeddings_initializer="uniform", name="user_embedding")
self.l_item_embedding = layers.Embedding(n_items, id_embedding_size, embeddings_initializer="uniform", name="item_embedding")
self.user_bias = layers.Embedding(n_users, 1, embeddings_initializer=tf.initializers.Constant(0.1), name="user_bias")
self.item_bias = layers.Embedding(n_items, 1, embeddings_initializer=tf.initializers.Constant(0.1), name="item_bias")
self.user_text_processor = TextProcessor(max_text_length, filters=n_filters, kernel_sizes=kernel_sizes, dropout_rate=dropout_rate, name='user_text_processor')
self.item_text_processor = TextProcessor(max_text_length, filters=n_filters, kernel_sizes=kernel_sizes, dropout_rate=dropout_rate, name='item_text_processor')
self.a_user = keras.models.Sequential([
layers.Dense(attention_size, activation="relu", use_bias=True),
layers.Dense(1, activation=None, use_bias=True)
])
self.user_attention = layers.Softmax(axis=1, name="user_attention")
self.a_item = keras.models.Sequential([
layers.Dense(attention_size, activation="relu", use_bias=True),
layers.Dense(1, activation=None, use_bias=True)
])
self.item_attention = layers.Softmax(axis=1, name="item_attention")
self.user_Oi_dropout = layers.Dropout(rate=dropout_rate, name="user_Oi")
self.Xu = layers.Dense(n_factors, use_bias=True, name="Xu")
self.item_Oi_dropout = layers.Dropout(rate=dropout_rate, name="item_Oi")
self.Yi = layers.Dense(n_factors, use_bias=True, name="Yi")

self.W1 = layers.Dense(1, activation=None, use_bias=False, name="W1")
self.add_global_bias = AddGlobalBias(init_value=global_mean, name="global_bias")

def call(self, inputs, training=None):
i_user_id, i_item_id, i_user_review, i_user_iid_review, i_user_num_reviews, i_item_review, i_item_uid_review, i_item_num_reviews = inputs
user_review_h = self.user_text_processor(self.l_user_review_embedding(i_user_review), training=training)
a_user = self.a_user(tf.concat([user_review_h, self.l_user_iid_embedding(i_user_iid_review)], axis=-1))
a_user_masking = tf.expand_dims(tf.sequence_mask(tf.reshape(i_user_num_reviews, [-1]), maxlen=i_user_review.shape[1]), -1)
user_attention = self.user_attention(a_user, a_user_masking)
user_Oi = self.user_Oi_dropout(tf.reduce_sum(tf.multiply(user_attention, user_review_h), 1), training=training)
Xu = self.Xu(user_Oi)
item_review_h = self.item_text_processor(self.l_item_review_embedding(i_item_review), training=training)
a_item = self.a_item(tf.concat([item_review_h, self.l_item_uid_embedding(i_item_uid_review)], axis=-1))
a_item_masking = tf.expand_dims(tf.sequence_mask(tf.reshape(i_item_num_reviews, [-1]), maxlen=i_item_review.shape[1]), -1)
item_attention = self.item_attention(a_item, a_item_masking)
item_Oi = self.item_Oi_dropout(tf.reduce_sum(tf.multiply(item_attention, item_review_h), 1), training=training)
Yi = self.Yi(item_Oi)
h0 = tf.multiply(tf.add(self.l_user_embedding(i_user_id), Xu), tf.add(self.l_item_embedding(i_item_id), Yi))
r = self.add_global_bias(
tf.add_n([
self.W1(h0),
self.user_bias(i_user_id),
self.item_bias(i_item_id)
])
)
# import pdb; pdb.set_trace()
return r

class NARREModel:
def __init__(self, n_users, n_items, vocab, global_mean, n_factors=32, embedding_size=100, id_embedding_size=32, attention_size=16, kernel_sizes=[3], n_filters=64, dropout_rate=0.5, max_text_length=50, max_num_review=32, pretrained_word_embeddings=None, verbose=False, seed=None):
self.n_users = n_users
self.n_items = n_items
self.n_vocab = vocab.size
Expand All @@ -99,6 +155,7 @@ def __init__(self, n_users, n_items, vocab, global_mean, n_factors=32, embedding
self.n_filters = n_filters
self.dropout_rate = dropout_rate
self.max_text_length = max_text_length
self.max_num_review = max_num_review
self.verbose = verbose
if seed is not None:
self.rng = get_rng(seed)
Expand All @@ -118,88 +175,39 @@ def __init__(self, n_users, n_items, vocab, global_mean, n_factors=32, embedding
print("Number of OOV words: %d" % oov_count)

embedding_matrix = initializers.Constant(embedding_matrix)
i_user_id = Input(shape=(1,), dtype="int32", name="input_user_id")
i_item_id = Input(shape=(1,), dtype="int32", name="input_item_id")
i_user_review = Input(shape=(None, self.max_text_length), dtype="int32", name="input_user_review")
i_item_review = Input(shape=(None, self.max_text_length), dtype="int32", name="input_item_review")
i_user_iid_review = Input(shape=(None,), dtype="int32", name="input_user_iid_review")
i_item_uid_review = Input(shape=(None,), dtype="int32", name="input_item_uid_review")
i_user_num_reviews = Input(shape=(1,), dtype="int32", name="input_user_number_of_review")
i_item_num_reviews = Input(shape=(1,), dtype="int32", name="input_item_number_of_review")

l_user_review_embedding = layers.Embedding(self.n_vocab, self.embedding_size, embeddings_initializer=embedding_matrix, mask_zero=True, name="layer_user_review_embedding")
l_item_review_embedding = layers.Embedding(self.n_vocab, self.embedding_size, embeddings_initializer=embedding_matrix, mask_zero=True, name="layer_item_review_embedding")
l_user_iid_embedding = layers.Embedding(self.n_items, self.id_embedding_size, embeddings_initializer="uniform", name="user_iid_embedding")
l_item_uid_embedding = layers.Embedding(self.n_users, self.id_embedding_size, embeddings_initializer="uniform", name="item_uid_embedding")
l_user_embedding = layers.Embedding(self.n_users, self.id_embedding_size, embeddings_initializer="uniform", name="user_embedding")
l_item_embedding = layers.Embedding(self.n_items, self.id_embedding_size, embeddings_initializer="uniform", name="item_embedding")
user_bias = layers.Embedding(self.n_users, 1, embeddings_initializer=tf.initializers.Constant(0.1), name="user_bias")
item_bias = layers.Embedding(self.n_items, 1, embeddings_initializer=tf.initializers.Constant(0.1), name="item_bias")

user_text_processor = TextProcessor(self.max_text_length, filters=self.n_filters, kernel_sizes=self.kernel_sizes, dropout_rate=self.dropout_rate, name='user_text_processor')
item_text_processor = TextProcessor(self.max_text_length, filters=self.n_filters, kernel_sizes=self.kernel_sizes, dropout_rate=self.dropout_rate, name='item_text_processor')

user_review_h = user_text_processor(l_user_review_embedding(i_user_review), training=True)
item_review_h = item_text_processor(l_item_review_embedding(i_item_review), training=True)
a_user = layers.Dense(1, activation=None, use_bias=True)(
layers.Dense(self.attention_size, activation="relu", use_bias=True)(
tf.concat([user_review_h, l_user_iid_embedding(i_user_iid_review)], axis=-1)
)
)
a_user_masking = tf.expand_dims(tf.sequence_mask(tf.reshape(i_user_num_reviews, [-1]), maxlen=i_user_review.shape[1]), -1)
user_attention = layers.Softmax(axis=1, name="user_attention")(a_user, a_user_masking)
a_item = layers.Dense(1, activation=None, use_bias=True)(
layers.Dense(self.attention_size, activation="relu", use_bias=True)(
tf.concat([item_review_h, l_item_uid_embedding(i_item_uid_review)], axis=-1)
)
)
a_item_masking = tf.expand_dims(tf.sequence_mask(tf.reshape(i_item_num_reviews, [-1]), maxlen=i_item_review.shape[1]), -1)
item_attention = layers.Softmax(axis=1, name="item_attention")(a_item, a_item_masking)

Xu = layers.Dense(self.n_factors, use_bias=True, name="Xu")(
layers.Dropout(rate=self.dropout_rate, name="user_Oi")(
tf.reduce_sum(layers.Multiply()([user_attention, user_review_h]), 1)
)
)
Yi = layers.Dense(self.n_factors, use_bias=True, name="Yi")(
layers.Dropout(rate=self.dropout_rate, name="item_Oi")(
tf.reduce_sum(layers.Multiply()([item_attention, item_review_h]), 1)
)
self.graph = Model(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as for the HDRModel

self.n_users, self.n_items, self.n_vocab, embedding_matrix, self.global_mean,
self.n_factors, self.embedding_size, self.id_embedding_size, self.attention_size,
self.kernel_sizes, self.n_filters, self.dropout_rate, self.max_text_length
)

h0 = layers.Multiply(name="h0")([
layers.Add()([l_user_embedding(i_user_id), Xu]), layers.Add()([l_item_embedding(i_item_id), Yi])
])

W1 = layers.Dense(1, activation=None, use_bias=False, name="W1")
add_global_bias = AddGlobalBias(init_value=self.global_mean, name="global_bias")
r = layers.Add(name="prediction")([
W1(h0),
user_bias(i_user_id),
item_bias(i_item_id)
])
r = add_global_bias(r)
self.graph = keras.Model(inputs=[i_user_id, i_item_id, i_user_review, i_user_iid_review, i_user_num_reviews, i_item_review, i_item_uid_review, i_item_num_reviews], outputs=r)
if self.verbose:
self.graph.summary()

def get_weights(self, train_set, batch_size=64, max_num_review=None):
user_attention_review_pooling = keras.Model(inputs=[self.graph.get_layer('input_user_review').input, self.graph.get_layer('input_user_iid_review').input, self.graph.get_layer('input_user_number_of_review').input], outputs=self.graph.get_layer('Xu').output)
item_attention_review_pooling = keras.Model(inputs=[self.graph.get_layer('input_item_review').input, self.graph.get_layer('input_item_uid_review').input, self.graph.get_layer('input_item_number_of_review').input], outputs=self.graph.get_layer('Yi').output)
def get_weights(self, train_set, batch_size=64):
X = np.zeros((self.n_users, self.n_factors))
Y = np.zeros((self.n_items, self.n_factors))
for batch_users in train_set.user_iter(batch_size):
user_reviews, user_iid_reviews, user_num_reviews = get_data(batch_users, train_set, self.max_text_length, by='user', max_num_review=max_num_review)
Xu = user_attention_review_pooling([user_reviews, user_iid_reviews, user_num_reviews], training=False)
i_user_review, i_user_iid_review, i_user_num_reviews = get_data(batch_users, train_set, self.max_text_length, by='user', max_num_review=self.max_num_review)
user_review_embedding = self.graph.l_user_review_embedding(i_user_review)
user_review_h = self.graph.user_text_processor(user_review_embedding, training=False)
a_user = self.graph.a_user(tf.concat([user_review_h, self.graph.l_user_iid_embedding(i_user_iid_review)], axis=-1))
a_user_masking = tf.expand_dims(tf.sequence_mask(tf.reshape(i_user_num_reviews, [-1]), maxlen=i_user_review.shape[1]), -1)
user_attention = self.graph.user_attention(a_user, a_user_masking)
user_Oi = tf.reduce_sum(tf.multiply(user_attention, user_review_h), 1)
Xu = self.graph.Xu(user_Oi)
X[batch_users] = Xu.numpy()
for batch_items in train_set.item_iter(batch_size):
item_reviews, item_uid_reviews, item_num_reviews = get_data(batch_items, train_set, self.max_text_length, by='item', max_num_review=max_num_review)
Yi = item_attention_review_pooling([item_reviews, item_uid_reviews, item_num_reviews], training=False)
i_item_review, i_item_uid_review, i_item_num_reviews = get_data(batch_items, train_set, self.max_text_length, by='item', max_num_review=self.max_num_review)
item_review_embedding = self.graph.l_item_review_embedding(i_item_review)
item_review_h = self.graph.item_text_processor(item_review_embedding, training=False)
a_item = self.graph.a_item(tf.concat([item_review_h, self.graph.l_item_uid_embedding(i_item_uid_review)], axis=-1))
a_item_masking = tf.expand_dims(tf.sequence_mask(tf.reshape(i_item_num_reviews, [-1]), maxlen=i_item_review.shape[1]), -1)
item_attention = self.graph.item_attention(a_item, a_item_masking)
item_Oi = tf.reduce_sum(tf.multiply(item_attention, item_review_h), 1)
Yi = self.graph.Yi(item_Oi)
Y[batch_items] = Yi.numpy()
W1 = self.graph.get_layer('W1').get_weights()[0]
user_embedding = self.graph.get_layer('user_embedding').get_weights()[0]
item_embedding = self.graph.get_layer('item_embedding').get_weights()[0]
bu = self.graph.get_layer('user_bias').get_weights()[0]
bi = self.graph.get_layer('item_bias').get_weights()[0]
mu = self.graph.get_layer('global_bias').get_weights()[0][0]
W1 = self.graph.W1.get_weights()[0]
user_embedding = self.graph.l_user_embedding.get_weights()[0]
item_embedding = self.graph.l_item_embedding.get_weights()[0]
bu = self.graph.user_bias.get_weights()[0]
bi = self.graph.item_bias.get_weights()[0]
mu = self.graph.add_global_bias.get_weights()[0][0]
return X, Y, W1, user_embedding, item_embedding, bu, bi, mu
Loading
Loading