Skip to content

Commit

Permalink
Merge pull request #40 from mski-iksm/feature/add_export_item_embedding
Browse files Browse the repository at this point in the history
Added item embedding exporting function
  • Loading branch information
nishiba authored Aug 26, 2019
2 parents 928a5ce + 243178e commit 1e956fe
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 7 deletions.
24 changes: 17 additions & 7 deletions redshells/model/graph_convolutional_matrix_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,11 @@ def predict_with_new_items(self, user_ids: List, item_ids: List, additional_data

def get_user_feature(self, user_ids: List, item_ids: List, additional_dataset: GcmcDataset, with_user_embedding: bool = True) -> np.ndarray:
dataset = self.graph_dataset.add_dataset(additional_dataset, add_item=True)
return self._get_user_feature(
user_ids=user_ids, item_ids=item_ids, with_user_embedding=with_user_embedding, graph=self.graph, dataset=dataset, session=self.session)
return self._get_feature(user_ids=user_ids, item_ids=item_ids, with_user_embedding=with_user_embedding, graph=self.graph, dataset=dataset, session=self.session, feature='user')

def get_item_feature(self, user_ids: List, item_ids: List, additional_dataset: GcmcDataset, with_user_embedding: bool = True) -> np.ndarray:
dataset = self.graph_dataset.add_dataset(additional_dataset, add_item=True)
return self._get_feature(user_ids=user_ids, item_ids=item_ids, with_user_embedding=with_user_embedding, graph=self.graph, dataset=dataset, session=self.session, feature='item')

@classmethod
def _predict(cls, user_ids: List, item_ids: List, with_user_embedding, graph: GraphConvolutionalMatrixCompletionGraph, dataset: GcmcGraphDataset,
Expand All @@ -320,9 +323,9 @@ def _predict(cls, user_ids: List, item_ids: List, with_user_embedding, graph: Gr
return predictions

@classmethod
def _get_user_feature(cls, user_ids: List, item_ids: List, with_user_embedding,
graph: GraphConvolutionalMatrixCompletionGraph, dataset: GcmcGraphDataset,
session: tf.Session) -> np.ndarray:
def _get_feature(cls, user_ids: List, item_ids: List, with_user_embedding,
graph: GraphConvolutionalMatrixCompletionGraph, dataset: GcmcGraphDataset,
session: tf.Session, feature: str) -> np.ndarray:
if graph is None:
RuntimeError('Please call fit first.')

Expand All @@ -335,9 +338,10 @@ def _get_user_feature(cls, user_ids: List, item_ids: List, with_user_embedding,
input_data = dict(user=user_indices, item=item_indices, user_feature_indices=user_feature_indices,
item_feature_indices=item_feature_indices)
feed_dict = cls._feed_dict(input_data, graph, dataset, rating_adjacency_matrix, is_train=False)
encoder_map = dict(user=graph.user_encoder, item=graph.item_encoder)
with session.as_default():
user_feature = session.run(graph.user_encoder, feed_dict=feed_dict)
return user_feature
feature = session.run(encoder_map[feature], feed_dict=feed_dict)
return feature

@staticmethod
def _feed_dict(input_data, graph, graph_dataset, rating_adjacency_matrix, dropout_rate: float = 0.0, learning_rate: float = 0.0, is_train: bool = True):
Expand Down Expand Up @@ -382,6 +386,12 @@ def get_user_feature_with_new_items(self, item_ids: List, additional_dataset: Gc
indices = [i for i in range(len(users)) if i % len(item_ids) == 0]
return user_ids, user_feature[indices]

def get_item_feature_with_new_items(self, item_ids: List, additional_dataset: GcmcDataset, with_user_embedding: bool = True) -> pd.DataFrame:
user_id = self.graph_dataset.user_ids[0]
users, items = zip(*[(user_id, item_id) for item_id in item_ids])
item_feature = self.get_item_feature(user_ids=users, item_ids=items, additional_dataset=additional_dataset, with_user_embedding=with_user_embedding)
return items, item_feature

def _make_graph(self) -> GraphConvolutionalMatrixCompletionGraph:
return GraphConvolutionalMatrixCompletionGraph(
n_rating=self.graph_dataset.n_rating,
Expand Down
42 changes: 42 additions & 0 deletions test/model/test_graph_convolutional_matrix_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,48 @@ def test_get_user_feature_with_new_items(self, dummy_get_user_feature):
self.assertEqual(len(user_features[0]), n_users)
self.assertEqual(user_features[1].shape, (n_users, n_user_embed_dimension))

def test_get_item_feature_with_new_items(self):
n_users = 101
n_items = 233
n_data = 3007
am1 = _make_sparse_matrix(n_users, n_items, n_data)
am2 = 2 * _make_sparse_matrix(n_users, n_items, n_data)
adjacency_matrix = am1 + am2
user_ids = adjacency_matrix.tocoo().row
item_ids = adjacency_matrix.tocoo().col
ratings = adjacency_matrix.tocoo().data
item_features = [{i: np.array([i]) for i in range(n_items)}]
dataset = GcmcDataset(user_ids, item_ids, ratings, item_features=item_features)
graph_dataset = GcmcGraphDataset(dataset, test_size=0.1)
encoder_hidden_size = 100
encoder_size = 100
scope_name = 'GraphConvolutionalMatrixCompletionGraph'
model = GraphConvolutionalMatrixCompletion(
graph_dataset=graph_dataset,
encoder_hidden_size=encoder_hidden_size,
encoder_size=encoder_size,
scope_name=scope_name,
batch_size=1024,
epoch_size=10,
learning_rate=0.01,
dropout_rate=0.7,
normalization_type='symmetric')
model.fit()

user_ids = [90, 62, 3, 3]
item_ids = [11, 236, 240, 243]
additional_item_features = {item_id: np.array([999]) for item_id in item_ids}
additional_dataset = GcmcDataset(np.array(user_ids), np.array(item_ids), np.array([1, 2, 1, 1]), item_features=[additional_item_features])

target_item_ids = item_ids + [12, 13, 17, 55] # item_ids to get embeddings

item_feature = model.get_item_feature_with_new_items(item_ids=target_item_ids, additional_dataset=additional_dataset)
self.assertEqual(len(item_feature), 2)
self.assertEqual(list(item_feature[0]), target_item_ids)
self.assertEqual(item_feature[1].shape, (len(target_item_ids), encoder_size))
output_embedding = {k: v for k, v in zip(*item_feature)}
np.testing.assert_almost_equal(output_embedding[240], output_embedding[243])


if __name__ == '__main__':
unittest.main()

0 comments on commit 1e956fe

Please sign in to comment.