From edc83aa5c0aa578fc0d0606e376dc0531155382d Mon Sep 17 00:00:00 2001 From: Quoc-Tuan Truong Date: Mon, 30 Oct 2023 15:23:06 -0700 Subject: [PATCH] Add recommend() function for the base class of Recommender (#538) --- .gitignore | 3 + cornac/data/dataset.py | 89 +--- cornac/eval_methods/base_method.py | 21 +- .../propensity_stratified_evaluation.py | 39 +- cornac/models/amr/recom_amr.py | 130 +++-- cornac/models/baseline_only/recom_bo.pyx | 18 +- cornac/models/bivaecf/recom_bivaecf.py | 18 +- cornac/models/bpr/recom_bpr.pyx | 12 +- cornac/models/bpr/recom_wbpr.pyx | 2 +- cornac/models/c2pf/recom_c2pf.py | 9 +- cornac/models/causalrec/recom_causalrec.py | 177 ++++--- cornac/models/cdl/recom_cdl.py | 35 +- cornac/models/cdr/recom_cdr.py | 37 +- cornac/models/coe/recom_coe.py | 10 +- cornac/models/comparer/recom_comparer_obj.pyx | 56 +-- cornac/models/comparer/recom_comparer_sub.pyx | 47 +- cornac/models/conv_mf/recom_convmf.py | 31 +- cornac/models/ctr/recom_ctr.py | 34 +- cornac/models/cvae/recom_cvae.py | 37 +- cornac/models/cvaecf/recom_cvaecf.py | 61 +-- cornac/models/ease/recom_ease.py | 41 +- cornac/models/efm/recom_efm.pyx | 49 +- cornac/models/fm/recom_fm.pyx | 20 +- cornac/models/gcmc/gcmc.py | 470 ++++++++---------- cornac/models/gcmc/nn_modules.py | 4 +- cornac/models/gcmc/recom_gcmc.py | 20 +- cornac/models/gcmc/utils.py | 2 +- cornac/models/global_avg/recom_global_avg.py | 4 +- cornac/models/hft/recom_hft.py | 48 +- cornac/models/hpf/recom_hpf.py | 25 +- cornac/models/hrdr/recom_hrdr.py | 150 ++++-- cornac/models/ibpr/recom_ibpr.py | 8 +- cornac/models/knn/recom_knn.py | 56 +-- cornac/models/lightgcn/lightgcn.py | 4 +- cornac/models/lightgcn/recom_lightgcn.py | 26 +- cornac/models/lrppm/recom_lrppm.pyx | 41 +- cornac/models/mcf/recom_mcf.py | 41 +- cornac/models/mf/recom_mf.pyx | 32 +- cornac/models/mmmf/recom_mmmf.pyx | 2 +- cornac/models/most_pop/recom_most_pop.py | 2 +- cornac/models/mter/recom_mter.pyx | 30 +- cornac/models/narre/recom_narre.py | 188 +++++-- cornac/models/ncf/recom_ncf_base.py | 44 +- cornac/models/ngcf/ngcf.py | 8 +- cornac/models/ngcf/recom_ngcf.py | 32 +- cornac/models/nmf/recom_nmf.pyx | 23 +- .../models/online_ibpr/recom_online_ibpr.py | 8 +- cornac/models/pcrl/pcrl.py | 37 +- cornac/models/pcrl/recom_pcrl.py | 7 +- cornac/models/pmf/recom_pmf.py | 36 +- cornac/models/recommender.py | 234 +++++++-- cornac/models/sbpr/recom_sbpr.pyx | 13 +- cornac/models/skm/recom_skmeans.py | 33 +- cornac/models/sorec/recom_sorec.py | 45 +- cornac/models/trirank/recom_trirank.py | 30 +- cornac/models/vaecf/recom_vaecf.py | 16 +- cornac/models/vbpr/recom_vbpr.py | 14 +- cornac/models/vmf/recom_vmf.py | 18 +- cornac/models/wmf/recom_wmf.py | 27 +- examples/amr_clothing.py | 1 - tests/cornac/data/test_dataset.py | 126 +++-- tests/cornac/models/test_recommender.py | 55 ++ 62 files changed, 1583 insertions(+), 1353 deletions(-) create mode 100644 tests/cornac/models/test_recommender.py diff --git a/.gitignore b/.gitignore index 5ff0a0ffa..270230400 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,9 @@ __pycache__/ # C extensions *.so +cornac/models/*/*.cpp +cornac/models/*/cython/*.cpp +cornac/utils/*.cpp # Distribution / packaging bin/ diff --git a/cornac/data/dataset.py b/cornac/data/dataset.py index 40480506c..5021b0335 100644 --- a/cornac/data/dataset.py +++ b/cornac/data/dataset.py @@ -64,14 +64,13 @@ class Dataset(object): global_mean: float Average value over the rating observations. - + uir_tuple: tuple Tuple three numpy arrays (user_indices, item_indices, rating_values). timestamps: numpy.array Numpy array of timestamps corresponding to feedback in `uir_tuple`. This is only available when input data is in `UIRT` format. - """ def __init__( @@ -99,12 +98,8 @@ def __init__( self.min_rating = np.min(r_values) self.global_mean = np.mean(r_values) - self.__total_users = None - self.__total_items = None self.__user_ids = None self.__item_ids = None - self.__user_indices = None - self.__item_indices = None self.__user_data = None self.__item_data = None @@ -114,47 +109,19 @@ def __init__( self.__csc_matrix = None self.__dok_matrix = None - @property - def total_users(self): - """Total number of users including test and validation users if exists""" - return self.__total_users if self.__total_users is not None else self.num_users - - @total_users.setter - def total_users(self, input_value): - """Set total number of users for the dataset""" - assert input_value >= self.num_users - self.__total_users = input_value - - @property - def total_items(self): - """Total number of items including test and validation items if exists""" - return self.__total_items if self.__total_items is not None else self.num_items - - @total_items.setter - def total_items(self, input_value): - """Set total number of items for the dataset""" - assert input_value >= self.num_items - self.__total_items = input_value - @property def user_ids(self): - """An iterator over the raw user ids""" - return self.uid_map.keys() + """Return the list of raw user ids""" + if self.__user_ids is None: + self.__user_ids = list(self.uid_map.keys()) + return self.__user_ids @property def item_ids(self): - """An iterator over the raw item ids""" - return self.iid_map.keys() - - @property - def user_indices(self): - """An iterator over the user indices""" - return self.uid_map.values() - - @property - def item_indices(self): - """An iterator over the item indices""" - return self.iid_map.values() + """Return the list of raw item ids""" + if self.__item_ids is None: + self.__item_ids = list(self.iid_map.keys()) + return self.__item_ids @property def user_data(self): @@ -185,7 +152,7 @@ def item_data(self): @property def chrono_user_data(self): """Data organized by user sorted chronologically (timestamps required). - A dictionary where keys are users, values are tuples of three chronologically + A dictionary where keys are users, values are tuples of three chronologically sorted lists (items, ratings, timestamps) interacted by the corresponding users. """ if self.timestamps is None: @@ -214,7 +181,7 @@ def chrono_user_data(self): @property def chrono_item_data(self): """Data organized by item sorted chronologically (timestamps required). - A dictionary where keys are items, values are tuples of three chronologically + A dictionary where keys are items, values are tuples of three chronologically sorted lists (users, ratings, timestamps) interacted with the corresponding items. """ if self.timestamps is None: @@ -272,7 +239,7 @@ def dok_matrix(self): """The user-item interaction matrix in DOK sparse format""" if self.__dok_matrix is None: self.__dok_matrix = dok_matrix( - (self.num_users, self.num_items), dtype='float' + (self.num_users, self.num_items), dtype="float" ) for u, i, r in zip(*self.uir_tuple): self.__dok_matrix[u, i] = r @@ -364,27 +331,29 @@ def build( raise ValueError("data is empty after being filtered!") uir_tuple = ( - np.asarray(u_indices, dtype='int'), - np.asarray(i_indices, dtype='int'), - np.asarray(r_values, dtype='float'), + np.asarray(u_indices, dtype="int"), + np.asarray(i_indices, dtype="int"), + np.asarray(r_values, dtype="float"), ) timestamps = ( - np.fromiter((int(data[i][3]) for i in valid_idx), dtype='int') + np.fromiter((int(data[i][3]) for i in valid_idx), dtype="int") if fmt == "UIRT" else None ) - return cls( + dataset = cls( num_users=len(global_uid_map), num_items=len(global_iid_map), - uid_map=uid_map, - iid_map=iid_map, + uid_map=global_uid_map, + iid_map=global_iid_map, uir_tuple=uir_tuple, timestamps=timestamps, seed=seed, ) + return dataset + @classmethod def from_uir(cls, data, seed=None): """Constructing Dataset from UIR (User, Item, Rating) triplet data. @@ -407,7 +376,7 @@ def from_uir(cls, data, seed=None): @classmethod def from_uirt(cls, data, seed=None): - """Constructing Dataset from UIRT (User, Item, Rating, Timestamp) + """Constructing Dataset from UIRT (User, Item, Rating, Timestamp) quadruplet data. Parameters @@ -528,7 +497,6 @@ def uij_iter(self, batch_size=1, shuffle=False, neg_sampling="uniform"): batch of negative items (array of 'int') """ - if neg_sampling.lower() == "uniform": neg_population = np.arange(self.num_items) elif neg_sampling.lower() == "popularity": @@ -564,7 +532,7 @@ def user_iter(self, batch_size=1, shuffle=False): ------- iterator : batch of user indices (array of 'int') """ - user_indices = np.fromiter(self.user_indices, dtype='int') + user_indices = np.fromiter(set(self.uir_tuple[0]), dtype="int") for batch_ids in self.idx_iter(len(user_indices), batch_size, shuffle): yield user_indices[batch_ids] @@ -582,18 +550,10 @@ def item_iter(self, batch_size=1, shuffle=False): ------- iterator : batch of item indices (array of 'int') """ - item_indices = np.fromiter(self.item_indices, 'int') + item_indices = np.fromiter(set(self.uir_tuple[1]), "int") for batch_ids in self.idx_iter(len(item_indices), batch_size, shuffle): yield item_indices[batch_ids] - def is_unk_user(self, user_idx): - """Return whether or not a user is unknown given the user index""" - return user_idx >= self.num_users - - def is_unk_item(self, item_idx): - """Return whether or not an item is unknown given the item index""" - return item_idx >= self.num_items - def add_modalities(self, **kwargs): self.user_feature = kwargs.get("user_feature", None) self.item_feature = kwargs.get("item_feature", None) @@ -605,4 +565,3 @@ def add_modalities(self, **kwargs): self.item_graph = kwargs.get("item_graph", None) self.sentiment = kwargs.get("sentiment", None) self.review_text = kwargs.get("review_text", None) - diff --git a/cornac/eval_methods/base_method.py b/cornac/eval_methods/base_method.py index 5a1bed5a2..e6cd47e47 100644 --- a/cornac/eval_methods/base_method.py +++ b/cornac/eval_methods/base_method.py @@ -85,6 +85,7 @@ def rating_eval(model, metrics, test_set, user_based=False, verbose=False): gt_mat = test_set.csr_matrix pd_mat = csr_matrix((r_preds, (u_indices, i_indices)), shape=gt_mat.shape) + test_user_indices = set(u_indices) for mt in metrics: if user_based: # averaging over users user_results.append( @@ -93,7 +94,7 @@ def rating_eval(model, metrics, test_set, user_based=False, verbose=False): gt_ratings=gt_mat.getrow(user_idx).data, pd_ratings=pd_mat.getrow(user_idx).data, ).item() - for user_idx in test_set.user_indices + for user_idx in test_user_indices } ) avg_results.append(sum(user_results[-1].values()) / len(user_results[-1])) @@ -159,7 +160,7 @@ def ranking_eval( avg_results = [] user_results = [{} for _ in enumerate(metrics)] - gt_mat = test_set.csr_matrix + test_mat = test_set.csr_matrix train_mat = train_set.csr_matrix val_mat = None if val_set is None else val_set.csr_matrix @@ -170,10 +171,11 @@ def pos_items(csr_row): if rating >= rating_threshold ] + test_user_indices = set(test_set.uir_tuple[0]) for user_idx in tqdm( - test_set.user_indices, desc="Ranking", disable=not verbose, miniters=100 + test_user_indices, desc="Ranking", disable=not verbose, miniters=100 ): - test_pos_items = pos_items(gt_mat.getrow(user_idx)) + test_pos_items = pos_items(test_mat.getrow(user_idx)) if len(test_pos_items) == 0: continue @@ -183,9 +185,9 @@ def pos_items(csr_row): val_pos_items = [] if val_mat is None else pos_items(val_mat.getrow(user_idx)) train_pos_items = ( - [] - if train_set.is_unk_user(user_idx) - else pos_items(train_mat.getrow(user_idx)) + pos_items(train_mat.getrow(user_idx)) + if user_idx < train_mat.shape[0] + else [] ) # binary mask for ground-truth negative items, removing all positive items @@ -196,7 +198,7 @@ def pos_items(csr_row): if exclude_unknowns: u_gt_pos_mask = u_gt_pos_mask[: train_set.num_items] u_gt_neg_mask = u_gt_neg_mask[: train_set.num_items] - + item_indices = np.nonzero(u_gt_pos_mask + u_gt_neg_mask)[0] u_gt_pos_items = np.nonzero(u_gt_pos_mask)[0] u_gt_neg_items = np.nonzero(u_gt_neg_mask)[0] @@ -538,9 +540,6 @@ def _build_datasets(self, train_data, test_data, val_data=None): print("Total users = {}".format(self.total_users)) print("Total items = {}".format(self.total_items)) - self.train_set.total_users = self.total_users - self.train_set.total_items = self.total_items - def _build_modalities(self): for user_modality in [ self.user_feature, diff --git a/cornac/eval_methods/propensity_stratified_evaluation.py b/cornac/eval_methods/propensity_stratified_evaluation.py index aa1751e06..08263f8ec 100644 --- a/cornac/eval_methods/propensity_stratified_evaluation.py +++ b/cornac/eval_methods/propensity_stratified_evaluation.py @@ -25,38 +25,38 @@ def ranking_eval( props=None, ): """Evaluate model on provided ranking metrics. - + Parameters ---------- model: :obj:`cornac.models.Recommender`, required Recommender model to be evaluated. - + metrics: :obj:`iterable`, required List of rating metrics :obj:`cornac.metrics.RankingMetric`. - + train_set: :obj:`cornac.data.Dataset`, required Dataset to be used for model training. This will be used to exclude observations already appeared during training. - + test_set: :obj:`cornac.data.Dataset`, required Dataset to be used for evaluation. - + val_set: :obj:`cornac.data.Dataset`, optional, default: None Dataset to be used for model selection. This will be used to exclude observations already appeared during validation. - + rating_threshold: float, optional, default: 1.0 The threshold to convert ratings into positive or negative feedback. - + exclude_unknowns: bool, optional, default: True Ignore unknown users and items during evaluation. - + verbose: bool, optional, default: False Output evaluation progress. - + props: dictionary, optional, default: None items propensity scores - + Returns ------- res: (List, List) @@ -82,12 +82,13 @@ def pos_items(csr_row): if rating >= rating_threshold ] - for user_idx in tqdm.tqdm(test_set.user_indices, disable=not verbose, miniters=100): + test_user_indices = set(test_set.uir_tuple[0]) + for user_idx in tqdm.tqdm(test_user_indices, disable=not verbose, miniters=100): test_pos_items = pos_items(gt_mat.getrow(user_idx)) if len(test_pos_items) == 0: continue - u_gt_pos = np.zeros(test_set.num_items, dtype='float') + u_gt_pos = np.zeros(test_set.num_items, dtype="float") u_gt_pos[test_pos_items] = 1 val_pos_items = [] if val_mat is None else pos_items(val_mat.getrow(user_idx)) @@ -97,7 +98,7 @@ def pos_items(csr_row): else pos_items(train_mat.getrow(user_idx)) ) - u_gt_neg = np.ones(test_set.num_items, dtype='int') + u_gt_neg = np.ones(test_set.num_items, dtype="int") u_gt_neg[test_pos_items + val_pos_items + train_pos_items] = 0 item_indices = None if exclude_unknowns else np.arange(test_set.num_items) @@ -256,7 +257,7 @@ def _estimate_propensities(self): item_freq[i] += 1 # fit the exponential param - data = np.array([e for e in item_freq.values()], dtype='float') + data = np.array([e for e in item_freq.values()], dtype="float") results = powerlaw.Fit(data, discrete=True, fit_method="Likelihood") alpha = results.power_law.alpha fmin = results.power_law.xmin @@ -276,9 +277,7 @@ def _build_stratified_dataset(self, test_data): self.stratified_sets = {} # match the corresponding propensity score for each feedback - test_props = np.array( - [self.props[i] for u, i, r in test_data], dtype='float' - ) + test_props = np.array([self.props[i] for u, i, r in test_data], dtype="float") # stratify minp = min(test_props) - 0.01 * min(test_props) @@ -338,11 +337,11 @@ def evaluate(self, model, metrics, user_based, show_validation=True): metrics: :obj:`iterable` List of metrics. - user_based: bool, required - Evaluation strategy for the rating metrics. Whether results + user_based: bool, required + Evaluation strategy for the rating metrics. Whether results are averaging based on number of users or number of ratings. - show_validation: bool, optional, default: True + show_validation: bool, optional, default: True Whether to show the results on validation set (if exists). Returns diff --git a/cornac/models/amr/recom_amr.py b/cornac/models/amr/recom_amr.py index 5173ce16b..b514337e5 100644 --- a/cornac/models/amr/recom_amr.py +++ b/cornac/models/amr/recom_amr.py @@ -18,11 +18,9 @@ from ..recommender import Recommender from ...exception import CornacException -from ...exception import ScoreException from ...utils import fast_dot -from ...utils.common import intersects from ...utils import get_rng -from ...utils.init_utils import zeros, xavier_uniform +from ...utils.init_utils import xavier_uniform class AMR(Recommender): @@ -78,24 +76,24 @@ class AMR(Recommender): ---------- * Tang, J., Du, X., He, X., Yuan, F., Tian, Q., and Chua, T. (2020). Adversarial Training Towards Robust Multimedia Recommender System. """ - + def __init__( - self, - name="AMR", - k=10, - k2=10, - n_epochs=50, - batch_size=100, - learning_rate=0.005, - lambda_w=0.01, - lambda_b=0.01, - lambda_e=0.0, - lambda_adv=1.0, - use_gpu=False, - trainable=True, - verbose=True, - init_params=None, - seed=None, + self, + name="AMR", + k=10, + k2=10, + n_epochs=50, + batch_size=100, + learning_rate=0.005, + lambda_w=0.01, + lambda_b=0.01, + lambda_e=0.0, + lambda_adv=1.0, + use_gpu=False, + trainable=True, + verbose=True, + init_params=None, + seed=None, ): super().__init__(name=name, trainable=trainable, verbose=verbose) self.k = k @@ -109,26 +107,26 @@ def __init__( self.lambda_adv = lambda_adv self.use_gpu = use_gpu self.seed = seed - + # Init params if provided self.init_params = {} if init_params is None else init_params self.gamma_user = self.init_params.get("Gu", None) self.gamma_item = self.init_params.get("Gi", None) self.emb_matrix = self.init_params.get("E", None) - + def _init(self, n_users, n_items, features): rng = get_rng(self.seed) - + if self.gamma_user is None: self.gamma_user = xavier_uniform((n_users, self.k), rng) if self.gamma_item is None: self.gamma_item = xavier_uniform((n_items, self.k), rng) if self.emb_matrix is None: self.emb_matrix = xavier_uniform((features.shape[1], self.k), rng) - + # pre-computed for faster evaluation self.theta_item = np.matmul(features, self.emb_matrix) - + def fit(self, train_set, val_set=None): """Fit the model to observations. @@ -145,49 +143,45 @@ def fit(self, train_set, val_set=None): self : object """ Recommender.fit(self, train_set, val_set) - + if train_set.item_image is None: raise CornacException("item_image modality is required but None.") - + # Item visual feature from CNN - train_features = train_set.item_image.features[: self.train_set.total_items] + train_features = train_set.item_image.features[: self.total_items] train_features = train_features.astype(np.float32) self._init( - n_users=train_set.total_users, - n_items=train_set.total_items, - features=train_features, + n_users=self.total_users, n_items=self.total_items, features=train_features ) - + if self.trainable: - self._fit_torch(train_features) - + self._fit_torch(train_set, train_features) + return self - - def _fit_torch(self, train_features): + + def _fit_torch(self, train_set, train_features): import torch - + def _l2_loss(*tensors): l2_loss = 0 for tensor in tensors: l2_loss += tensor.pow(2).sum() return l2_loss / 2 - + def _inner(a, b): return (a * b).sum(dim=1) - + dtype = torch.float device = ( torch.device("cuda:0") if (self.use_gpu and torch.cuda.is_available()) else torch.device("cpu") ) - + # set requireds_grad=True to get the adversarial gradient # if F is not put into the optimization list of parameters # it won't be updated - F = torch.tensor( - train_features, device=device, dtype=dtype, requires_grad=True - ) + F = torch.tensor(train_features, device=device, dtype=dtype, requires_grad=True) # Learned parameters Gu = torch.tensor( self.gamma_user, device=device, dtype=dtype, requires_grad=True @@ -198,77 +192,73 @@ def _inner(a, b): E = torch.tensor( self.emb_matrix, device=device, dtype=dtype, requires_grad=True ) - + optimizer = torch.optim.Adam([Gu, Gi, E], lr=self.learning_rate) - + for epoch in range(1, self.n_epochs + 1): sum_loss = 0.0 count = 0 progress_bar = tqdm( - total=self.train_set.num_batches(self.batch_size), + total=train_set.num_batches(self.batch_size), desc="Epoch {}/{}".format(epoch, self.n_epochs), disable=not self.verbose, ) - for batch_u, batch_i, batch_j in self.train_set.uij_iter( - self.batch_size, shuffle=True + for batch_u, batch_i, batch_j in train_set.uij_iter( + self.batch_size, shuffle=True ): gamma_u = Gu[batch_u] gamma_i = Gi[batch_i] gamma_j = Gi[batch_j] feat_i = F[batch_i] feat_j = F[batch_j] - + gamma_diff = gamma_i - gamma_j feat_diff = feat_i - feat_j - - Xuij = ( - _inner(gamma_u, gamma_diff) - + _inner(gamma_u, feat_diff.mm(E)) - ) - + + Xuij = _inner(gamma_u, gamma_diff) + _inner(gamma_u, feat_diff.mm(E)) + log_likelihood = torch.nn.functional.logsigmoid(Xuij).sum() - + # adversarial part feat_i.retain_grad() feat_j.retain_grad() log_likelihood.backward(retain_graph=True) feat_i_delta = feat_i.grad feat_j_delta = feat_j.grad - + adv_feat_diff = feat_diff + (feat_i_delta - feat_j_delta) - adv_Xuij = ( - _inner(gamma_u, gamma_diff) - + _inner(gamma_u, adv_feat_diff.mm(E)) + adv_Xuij = _inner(gamma_u, gamma_diff) + _inner( + gamma_u, adv_feat_diff.mm(E) ) - + adv_log_likelihood = torch.nn.functional.logsigmoid(adv_Xuij).sum() - + reg = ( - _l2_loss(gamma_u, gamma_i, gamma_j) * self.lambda_w - + _l2_loss(E) * self.lambda_e + _l2_loss(gamma_u, gamma_i, gamma_j) * self.lambda_w + + _l2_loss(E) * self.lambda_e ) - + loss = -log_likelihood - self.lambda_adv * adv_log_likelihood + reg - + optimizer.zero_grad() loss.backward() optimizer.step() - + sum_loss += loss.data.item() count += len(batch_u) if count % (self.batch_size * 10) == 0: progress_bar.set_postfix(loss=(sum_loss / count)) progress_bar.update(1) progress_bar.close() - + print("Optimization finished!") - + self.gamma_user = Gu.data.cpu().numpy() self.gamma_item = Gi.data.cpu().numpy() self.emb_matrix = E.data.cpu().numpy() # pre-computed for faster evaluation self.theta_item = F.mm(E).data.cpu().numpy() - + def score(self, user_idx, item_idx=None): """Predict the scores/ratings of a user for an item. diff --git a/cornac/models/baseline_only/recom_bo.pyx b/cornac/models/baseline_only/recom_bo.pyx index 6a6992c33..818b94f49 100644 --- a/cornac/models/baseline_only/recom_bo.pyx +++ b/cornac/models/baseline_only/recom_bo.pyx @@ -94,9 +94,9 @@ class BaselineOnly(Recommender): self.global_mean = 0.0 def _init(self): - n_users, n_items = self.train_set.num_users, self.train_set.num_items + n_users, n_items = self.num_users, self.num_items - self.global_mean = self.train_set.global_mean + self.global_mean = self.global_mean self.u_biases = zeros(n_users) if self.u_biases is None else self.u_biases self.i_biases = zeros(n_items) if self.i_biases is None else self.i_biases @@ -131,8 +131,8 @@ class BaselineOnly(Recommender): """Fit the model parameters (Bu, Bi) with SGD """ cdef: - long num_users = self.train_set.num_users - long num_items = self.train_set.num_items + long num_users = self.num_users + long num_items = self.num_items long num_ratings = val.shape[0] int max_iter = self.max_iter int num_threads = self.num_threads @@ -195,20 +195,16 @@ class BaselineOnly(Recommender): ------- res : A scalar or a Numpy array Relative scores that the user gives to the item or to all known items - """ - unk_user = self.train_set.is_unk_user(user_idx) - if item_idx is None: known_item_scores = np.add(self.i_biases, self.global_mean) - if not unk_user: + if self.knows_user(user_idx): known_item_scores = np.add(known_item_scores, self.u_biases[user_idx]) return known_item_scores else: - unk_item = self.train_set.is_unk_item(item_idx) item_score = self.global_mean - if not unk_user: + if self.knows_item(item_idx): item_score += self.u_biases[user_idx] - if not unk_item: + if self.knows_item(item_idx): item_score += self.i_biases[item_idx] return item_score diff --git a/cornac/models/bivaecf/recom_bivaecf.py b/cornac/models/bivaecf/recom_bivaecf.py index 71eaac7c7..9315c3ed3 100644 --- a/cornac/models/bivaecf/recom_bivaecf.py +++ b/cornac/models/bivaecf/recom_bivaecf.py @@ -176,7 +176,7 @@ def fit(self, train_set, val_set=None): learn( self.bivae, - self.train_set, + train_set, n_epochs=self.n_epochs, batch_size=self.batch_size, learn_rate=self.learning_rate, @@ -184,7 +184,6 @@ def fit(self, train_set, val_set=None): verbose=self.verbose, device=self.device, ) - elif self.verbose: print("%s is trained already (trainable = False)" % (self.name)) @@ -210,33 +209,24 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - theta_u = self.bivae.mu_theta[user_idx].view(1, -1) beta = self.bivae.mu_beta known_item_scores = ( self.bivae.decode_user(theta_u, beta).cpu().numpy().ravel() ) - return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) ) - theta_u = self.bivae.mu_theta[user_idx].view(1, -1) beta_i = self.bivae.mu_beta[item_idx].view(1, -1) pred = self.bivae.decode_user(theta_u, beta_i).cpu().numpy().ravel() - - pred = scale( - pred, self.train_set.min_rating, self.train_set.max_rating, 0.0, 1.0 - ) - + pred = scale(pred, self.min_rating, self.max_rating, 0.0, 1.0) return pred diff --git a/cornac/models/bpr/recom_bpr.pyx b/cornac/models/bpr/recom_bpr.pyx index 8ee74a7ea..8ff83e24b 100644 --- a/cornac/models/bpr/recom_bpr.pyx +++ b/cornac/models/bpr/recom_bpr.pyx @@ -140,7 +140,7 @@ class BPR(Recommender): self.i_biases = self.init_params.get('Bi', None) def _init(self): - n_users, n_items = self.train_set.total_users, self.train_set.total_items + n_users, n_items = self.total_users, self.total_items if self.u_factors is None: self.u_factors = (uniform((n_users, self.k), random_state=self.rng) - 0.5) / self.k @@ -148,12 +148,12 @@ class BPR(Recommender): self.i_factors = (uniform((n_items, self.k), random_state=self.rng) - 0.5) / self.k self.i_biases = zeros(n_items) if self.i_biases is None or self.use_bias is False else self.i_biases - def _prepare_data(self): - X = self.train_set.matrix # csr_matrix + def _prepare_data(self, train_set): + X = train_set.matrix # csr_matrix # this basically calculates the 'row' attribute of a COO matrix # without requiring us to get the whole COO matrix user_counts = np.ediff1d(X.indptr) - user_ids = np.repeat(np.arange(self.train_set.num_users), user_counts).astype(X.indices.dtype) + user_ids = np.repeat(np.arange(train_set.num_users), user_counts).astype(X.indices.dtype) return X, user_counts, user_ids @@ -179,7 +179,7 @@ class BPR(Recommender): if not self.trainable: return self - X, user_counts, user_ids = self._prepare_data() + X, user_counts, user_ids = self._prepare_data(train_set) neg_item_ids = np.arange(train_set.num_items, dtype=np.int32) cdef: @@ -213,7 +213,7 @@ class BPR(Recommender): """ cdef: long num_samples = len(user_ids), s, i_index, j_index, correct = 0, skipped = 0 - long num_items = self.train_set.num_items + long num_items = self.num_items integral f, i_id, j_id, thread_id floating z, score, temp bool use_bias = self.use_bias diff --git a/cornac/models/bpr/recom_wbpr.pyx b/cornac/models/bpr/recom_wbpr.pyx index 8b6fa12c3..e91a98c77 100644 --- a/cornac/models/bpr/recom_wbpr.pyx +++ b/cornac/models/bpr/recom_wbpr.pyx @@ -122,7 +122,7 @@ class WBPR(BPR): if not self.trainable: return self - X, user_counts, user_ids = self._prepare_data() + X, user_counts, user_ids = self._prepare_data(train_set) cdef: int num_threads = self.num_threads diff --git a/cornac/models/c2pf/recom_c2pf.py b/cornac/models/c2pf/recom_c2pf.py index 8c1e76eae..b65277c4a 100644 --- a/cornac/models/c2pf/recom_c2pf.py +++ b/cornac/models/c2pf/recom_c2pf.py @@ -145,7 +145,7 @@ def fit(self, train_set, val_set=None): """ Recommender.fit(self, train_set, val_set) - X = sp.csc_matrix(self.train_set.matrix) + X = train_set.csr_matrix # recover the striplet sparse format from csc sparse matrix X (needed to feed c++) (rid, cid, val) = sp.find(X) @@ -171,8 +171,10 @@ def fit(self, train_set, val_set=None): "L3_r": self.L3r, } - map_iid = train_set.item_indices - (rid, cid, val) = train_set.item_graph.get_train_triplet(map_iid, map_iid) + train_item_indices = set(train_set.uir_tuple[1]) + (rid, cid, val) = train_set.item_graph.get_train_triplet( + train_item_indices, train_item_indices + ) context_info = np.hstack( (rid.reshape(-1, 1), cid.reshape(-1, 1), val.reshape(-1, 1)) ) @@ -292,5 +294,4 @@ def score(self, user_idx, item_idx=None): ) # transform user_pred to a flatten array, user_pred = np.array(user_pred, dtype="float64").flatten() - return user_pred diff --git a/cornac/models/causalrec/recom_causalrec.py b/cornac/models/causalrec/recom_causalrec.py index 965c7b400..bc47ac425 100644 --- a/cornac/models/causalrec/recom_causalrec.py +++ b/cornac/models/causalrec/recom_causalrec.py @@ -85,26 +85,26 @@ class CausalRec(Recommender): ---------- * Qiu R., Wang S., Chen Z., Yin H., Huang Z. (2021). CausalRec: Causal Inference for Visual Debiasing in Visually-Aware Recommendation. """ - + def __init__( - self, - name="CausalRec", - k=10, - k2=10, - n_epochs=50, - batch_size=100, - learning_rate=0.005, - lambda_w=0.01, - lambda_b=0.01, - lambda_e=0.0, - mean_feat=None, - tanh=0, - lambda_2=0.8, - use_gpu=False, - trainable=True, - verbose=True, - init_params=None, - seed=None, + self, + name="CausalRec", + k=10, + k2=10, + n_epochs=50, + batch_size=100, + learning_rate=0.005, + lambda_w=0.01, + lambda_b=0.01, + lambda_e=0.0, + mean_feat=None, + tanh=0, + lambda_2=0.8, + use_gpu=False, + trainable=True, + verbose=True, + init_params=None, + seed=None, ): super().__init__(name=name, trainable=trainable, verbose=verbose) self.k = k @@ -120,7 +120,7 @@ def __init__( self.lambda_2 = lambda_2 self.use_gpu = use_gpu self.seed = seed - + # Init params if provided self.init_params = {} if init_params is None else init_params self.beta_item = self.init_params.get("Bi", None) @@ -133,10 +133,10 @@ def __init__( self.emb_matrix2 = self.init_params.get("E2", None) self.emb_ind_matrix2 = self.init_params.get("E_ind2", None) self.beta_prime = self.init_params.get("Bp", None) - + def _init(self, n_users, n_items, features): rng = get_rng(self.seed) - + self.beta_item = zeros(n_items) if self.beta_item is None else self.beta_item if self.gamma_user is None: self.gamma_user = xavier_uniform((n_users, self.k), rng) @@ -155,12 +155,12 @@ def _init(self, n_users, n_items, features): self.emb_ind_matrix2 = xavier_uniform((self.k, self.k), rng) if self.beta_prime is None: self.beta_prime = xavier_uniform((features.shape[1], 1), rng) - + # pre-computed for faster evaluation self.theta_item = np.matmul(features, self.emb_matrix) self.visual_bias = np.matmul(features, self.beta_prime).ravel() self.direct_theta_item = np.matmul(features, self.emb_ind_matrix) - + def fit(self, train_set, val_set=None): """Fit the model to observations. @@ -177,43 +177,43 @@ def fit(self, train_set, val_set=None): self : object """ Recommender.fit(self, train_set, val_set) - + if train_set.item_image is None: raise CornacException("item_image modality is required but None.") - + # Item visual feature from CNN - train_features = train_set.item_image.features[: self.train_set.total_items] + train_features = train_set.item_image.features[: self.total_items] train_features = train_features.astype(np.float32) self._init( - n_users=train_set.total_users, - n_items=train_set.total_items, + n_users=self.total_users, + n_items=self.total_items, features=train_features, ) - + if self.trainable: - self._fit_torch(train_features) - + self._fit_torch(train_set, train_features) + return self - - def _fit_torch(self, train_features): + + def _fit_torch(self, train_set, train_features): import torch - + def _l2_loss(*tensors): l2_loss = 0 for tensor in tensors: l2_loss += tensor.pow(2).sum() return l2_loss / 2 - + def _inner(a, b): return (a * b).sum(dim=1) - + dtype = torch.float device = ( torch.device("cuda:0") if (self.use_gpu and torch.cuda.is_available()) else torch.device("cpu") ) - + F = torch.tensor(train_features, device=device, dtype=dtype) # Learned parameters Bi = torch.tensor( @@ -241,7 +241,7 @@ def _inner(a, b): [self.mean_feat], device=device, dtype=dtype, requires_grad=False ) param = [Bi, Gu, Gi, Tu, E, Bp, E_ind] - + if self.tanh == 2: E2 = torch.tensor( self.emb_matrix2, device=device, dtype=dtype, requires_grad=True @@ -251,30 +251,30 @@ def _inner(a, b): ) param.append(E2) param.append(E_ind2) - + optimizer = torch.optim.Adam(param, lr=self.learning_rate) - + for epoch in range(1, self.n_epochs + 1): sum_loss = 0.0 count = 0 progress_bar = tqdm( - total=self.train_set.num_batches(self.batch_size), + total=train_set.num_batches(self.batch_size), desc="Epoch {}/{}".format(epoch, self.n_epochs), disable=not self.verbose, ) - for batch_u, batch_i, batch_j in self.train_set.uij_iter( - self.batch_size, shuffle=True + for batch_u, batch_i, batch_j in train_set.uij_iter( + self.batch_size, shuffle=True ): gamma_u = Gu[batch_u] theta_u = Tu[batch_u] - + beta_i = Bi[batch_i] beta_j = Bi[batch_j] gamma_i = Gi[batch_i] gamma_j = Gi[batch_j] feat_i = F[batch_i] feat_j = F[batch_j] - + if self.tanh == 0: direct_feat_i = feat_i.mm(E) ind_feat_i = feat_i.mm(E_ind) @@ -284,10 +284,14 @@ def _inner(a, b): elif self.tanh == 2: direct_feat_i = torch.tanh(torch.tanh(feat_i.mm(E)).mm(E2)) ind_feat_i = torch.tanh(torch.tanh(feat_i.mm(E_ind)).mm(E_ind2)) - - i_m = beta_i + _inner(gamma_u, gamma_i) + _inner(gamma_u, gamma_i * ind_feat_i) + + i_m = ( + beta_i + + _inner(gamma_u, gamma_i) + + _inner(gamma_u, gamma_i * ind_feat_i) + ) i_n = _inner(theta_u, direct_feat_i) + feat_i.mm(Bp) - + if self.tanh == 0: direct_feat_j = feat_j.mm(E) ind_feat_j = feat_j.mm(E_ind) @@ -297,44 +301,52 @@ def _inner(a, b): elif self.tanh == 2: direct_feat_j = torch.tanh(torch.tanh(feat_j.mm(E)).mm(E2)) ind_feat_j = torch.tanh(torch.tanh(feat_j.mm(E_ind)).mm(E_ind2)) - - j_m = beta_j + _inner(gamma_u, gamma_j) + _inner(gamma_u, gamma_j * ind_feat_j) + + j_m = ( + beta_j + + _inner(gamma_u, gamma_j) + + _inner(gamma_u, gamma_j * ind_feat_j) + ) j_n = _inner(theta_u, direct_feat_j) + feat_j.mm(Bp) - - i_score = torch.sigmoid(i_m + i_n) * torch.sigmoid(i_m) * torch.sigmoid(i_n) - j_score = torch.sigmoid(j_m + j_n) * torch.sigmoid(j_m) * torch.sigmoid(j_n) - + + i_score = ( + torch.sigmoid(i_m + i_n) * torch.sigmoid(i_m) * torch.sigmoid(i_n) + ) + j_score = ( + torch.sigmoid(j_m + j_n) * torch.sigmoid(j_m) * torch.sigmoid(j_n) + ) + log_likelihood = torch.nn.functional.logsigmoid(i_score - j_score).sum() log_likelihood_m = torch.nn.functional.logsigmoid(i_m - j_m).sum() log_likelihood_n = torch.nn.functional.logsigmoid(i_n - j_n).sum() - + if self.tanh < 2: l2_e = _l2_loss(E, Bp, E_ind) else: l2_e = _l2_loss(E, Bp, E_ind, E2, E_ind2) - + reg = ( - _l2_loss(gamma_u, gamma_i, gamma_j, theta_u) * self.lambda_w - + _l2_loss(beta_i) * self.lambda_b - + _l2_loss(beta_j) * self.lambda_b / 10 - + l2_e * self.lambda_e + _l2_loss(gamma_u, gamma_i, gamma_j, theta_u) * self.lambda_w + + _l2_loss(beta_i) * self.lambda_b + + _l2_loss(beta_j) * self.lambda_b / 10 + + l2_e * self.lambda_e ) - + loss = -log_likelihood + reg - log_likelihood_m - log_likelihood_n - + optimizer.zero_grad() loss.backward() optimizer.step() - + sum_loss += loss.data.item() count += len(batch_u) if count % (self.batch_size * 10) == 0: progress_bar.set_postfix(loss=(sum_loss / count)) progress_bar.update(1) progress_bar.close() - + print("Optimization finished!") - + self.beta_item = Bi.data.cpu().numpy() self.gamma_user = Gu.data.cpu().numpy() self.gamma_item = Gi.data.cpu().numpy() @@ -349,7 +361,9 @@ def _inner(a, b): elif self.tanh == 1: self.theta_item = torch.tanh(self.theta_item).data.cpu().numpy() elif self.tanh == 2: - self.theta_item = torch.tanh(torch.tanh(self.theta_item).mm(E2)).data.cpu().numpy() + self.theta_item = ( + torch.tanh(torch.tanh(self.theta_item).mm(E2)).data.cpu().numpy() + ) self.visual_bias = F.mm(Bp).squeeze().data.cpu().numpy() @@ -359,7 +373,11 @@ def _inner(a, b): elif self.tanh == 1: self.ind_theta_item = torch.tanh(self.ind_theta_item).data.cpu().numpy() elif self.tanh == 2: - self.ind_theta_item = torch.tanh(torch.tanh(self.ind_theta_item).mm(E_ind2)).data.cpu().numpy() + self.ind_theta_item = ( + torch.tanh(torch.tanh(self.ind_theta_item).mm(E_ind2)) + .data.cpu() + .numpy() + ) self.beta_item_mean = Bi.mean().unsqueeze(dim=0).data.cpu().numpy() self.gamma_item_mean = Gi.mean(dim=0).unsqueeze(dim=0).data.cpu().numpy() @@ -370,7 +388,9 @@ def _inner(a, b): elif self.tanh == 1: self.mean_feat = torch.tanh(self.mean_feat).data.cpu().numpy() elif self.tanh == 2: - self.mean_feat = torch.tanh(torch.tanh(self.mean_feat).mm(E_ind2)).data.cpu().numpy() + self.mean_feat = ( + torch.tanh(torch.tanh(self.mean_feat).mm(E_ind2)).data.cpu().numpy() + ) def score(self, user_idx, item_idx=None): """Predict the debiased scores/ratings of a user for an item. @@ -383,7 +403,7 @@ def score(self, user_idx, item_idx=None): item_idx: int, optional, default: None The index of the item for which to perform score prediction. If None, scores for all known items will be returned. - + Returns ------- res : A scalar or a Numpy array @@ -393,16 +413,23 @@ def score(self, user_idx, item_idx=None): if item_idx is None: m_score = self.beta_item fast_dot(self.gamma_user[user_idx], self.gamma_item, m_score) - fast_dot(self.gamma_user[user_idx], self.gamma_item * self.ind_theta_item, m_score) + fast_dot( + self.gamma_user[user_idx], + self.gamma_item * self.ind_theta_item, + m_score, + ) m_star = self.beta_item_mean fast_dot(self.gamma_user[user_idx], self.gamma_item_mean, m_star) - fast_dot(self.gamma_user[user_idx], self.gamma_item_mean * self.mean_feat, m_star) - + fast_dot( + self.gamma_user[user_idx], self.gamma_item_mean * self.mean_feat, m_star + ) + n_score = self.visual_bias fast_dot(self.theta_user[user_idx], self.theta_item, n_score) - return expit(m_score + n_score) * expit(m_score) * expit(n_score)\ - - self.lambda_2 * expit(m_star + n_score) * expit(m_star) * expit(n_score) + return expit(m_score + n_score) * expit(m_score) * expit( + n_score + ) - self.lambda_2 * expit(m_star + n_score) * expit(m_star) * expit(n_score) else: raise NotImplementedError("The sampled evaluation is not implemented!") diff --git a/cornac/models/cdl/recom_cdl.py b/cornac/models/cdl/recom_cdl.py index 349799bf7..214e05a58 100644 --- a/cornac/models/cdl/recom_cdl.py +++ b/cornac/models/cdl/recom_cdl.py @@ -78,7 +78,7 @@ class CDL(Recommender): The batch size for SGD. trainable: boolean, optional, default: True - When False, the model is not trained and Cornac assumes that the model already + When False, the model is not trained and Cornac assumes that the model already pre-trained (U and V are not None). init_params: dictionary, optional, default: None @@ -147,8 +147,7 @@ def __init__( self.V = self.init_params.get("V", None) def _init(self): - n_users, n_items = self.train_set.num_users, self.train_set.num_items - + n_users, n_items = self.num_users, self.num_items if self.U is None: self.U = xavier_uniform((n_users, self.k), self.rng) if self.V is None: @@ -174,22 +173,21 @@ def fit(self, train_set, val_set=None): self._init() if self.trainable: - self._fit_cdl() + self._fit_cdl(train_set) return self - def _fit_cdl(self): + def _fit_cdl(self, train_set): import tensorflow.compat.v1 as tf from .cdl import Model - + tf.disable_eager_execution() - R = self.train_set.csc_matrix # csc for efficient slicing over items - n_users, n_items = self.train_set.num_users, self.train_set.num_items + R = train_set.csc_matrix # csc for efficient slicing over items - text_feature = self.train_set.item_text.batch_bow( - np.arange(n_items) - ) # bag of word feature + text_feature = train_set.item_text.batch_bow( + np.arange(self.num_items) + ) # bag-of-words features text_feature = (text_feature - text_feature.min()) / ( text_feature.max() - text_feature.min() ) # normalization @@ -204,8 +202,8 @@ def _fit_cdl(self): ) tf.set_random_seed(self.seed) model = Model( - n_users=n_users, - n_items=n_items, + n_users=self.num_users, + n_items=self.num_items, n_vocab=self.vocab_size, k=self.k, layers=layer_sizes, @@ -230,12 +228,12 @@ def _fit_cdl(self): loop = trange(self.max_iter, disable=not self.verbose) for _ in loop: corruption_mask = self.rng.binomial( - 1, 1 - self.corruption_rate, size=(n_items, self.vocab_size) + 1, 1 - self.corruption_rate, size=(self.num_items, self.vocab_size) ) sum_loss = 0 count = 0 for i, batch_ids in enumerate( - self.train_set.item_iter(self.batch_size, shuffle=True) + train_set.item_iter(self.batch_size, shuffle=True) ): batch_R = R[:, batch_ids] batch_C = np.ones(batch_R.shape) * self.b @@ -283,17 +281,14 @@ def score(self, user_idx, item_idx=None): Relative scores that the user gives to the item or to all known items """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - known_item_scores = self.V.dot(self.U[user_idx, :]) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and not self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) diff --git a/cornac/models/cdr/recom_cdr.py b/cornac/models/cdr/recom_cdr.py index db22a63f1..111ec1290 100644 --- a/cornac/models/cdr/recom_cdr.py +++ b/cornac/models/cdr/recom_cdr.py @@ -69,7 +69,7 @@ class CDR(Recommender): The name of the recommender model. trainable: boolean, optional, default: True - When False, the model is not trained and Cornac assumes that the model already + When False, the model is not trained and Cornac assumes that the model already pre-trained (U and V are not None). init_params: dictionary, optional, default: None @@ -136,12 +136,10 @@ def __init__( self.V = self.init_params.get("V", None) def _init(self): - n_users, n_items = self.train_set.num_users, self.train_set.num_items - if self.U is None: - self.U = xavier_uniform((n_users, self.k), self.rng) + self.U = xavier_uniform((self.num_users, self.k), self.rng) if self.V is None: - self.V = xavier_uniform((n_items, self.k), self.rng) + self.V = xavier_uniform((self.num_items, self.k), self.rng) def fit(self, train_set, val_set=None): """Fit the model to observations. @@ -161,23 +159,20 @@ def fit(self, train_set, val_set=None): Recommender.fit(self, train_set, val_set) self._init() - + if self.trainable: - self._fit_cdr() + self._fit_cdr(train_set) return self - def _fit_cdr(self): + def _fit_cdr(self, train_set): import tensorflow.compat.v1 as tf from .cdr import Model tf.disable_eager_execution() - n_users = self.train_set.num_users - n_items = self.train_set.num_items - - text_feature = self.train_set.item_text.batch_bow( - np.arange(n_items) + text_feature = train_set.item_text.batch_bow( + np.arange(self.num_items) ) # bag of word feature text_feature = (text_feature - text_feature.min()) / ( text_feature.max() - text_feature.min() @@ -193,8 +188,8 @@ def _fit_cdr(self): ) tf.set_random_seed(self.seed) model = Model( - n_users=n_users, - n_items=n_items, + n_users=self.num_users, + n_items=self.num_items, n_vocab=self.vocab_size, k=self.k, layers=layer_sizes, @@ -219,12 +214,12 @@ def _fit_cdr(self): loop = trange(self.max_iter, disable=not self.verbose) for _ in loop: corruption_mask = self.rng.binomial( - 1, 1 - self.corruption_rate, (n_items, self.vocab_size) + 1, 1 - self.corruption_rate, (self.num_items, self.vocab_size) ) sum_loss = 0 count = 0 batch_count = 0 - for batch_u, batch_i, batch_j in self.train_set.uij_iter( + for batch_u, batch_i, batch_j in train_set.uij_iter( batch_size=self.batch_size, shuffle=True ): feed_dict = { @@ -273,21 +268,17 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - known_item_scores = self.V.dot(self.U[user_idx, :]) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) ) user_pred = self.V[item_idx, :].dot(self.U[user_idx, :]) - return user_pred diff --git a/cornac/models/coe/recom_coe.py b/cornac/models/coe/recom_coe.py index 4d472e813..28d39ec8c 100644 --- a/cornac/models/coe/recom_coe.py +++ b/cornac/models/coe/recom_coe.py @@ -113,7 +113,7 @@ def fit(self, train_set, val_set=None): print("Learning...") res = coe( - self.train_set.matrix, + train_set.matrix, k=self.k, n_epochs=self.max_iter, lamda=self.lamda, @@ -151,24 +151,20 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - known_item_scores = np.sum( np.abs(self.V - self.U[user_idx, :]) ** 2, axis=-1 ) ** (1.0 / 2) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) ) - user_pred = np.sum( np.abs(self.V[item_idx, :] - self.U[user_idx, :]) ** 2, axis=-1 ) ** (1.0 / 2) diff --git a/cornac/models/comparer/recom_comparer_obj.pyx b/cornac/models/comparer/recom_comparer_obj.pyx index c1584bbef..327be3a30 100644 --- a/cornac/models/comparer/recom_comparer_obj.pyx +++ b/cornac/models/comparer/recom_comparer_obj.pyx @@ -204,21 +204,21 @@ class ComparERObj(Recommender): """ Recommender.fit(self, train_set, val_set) - self._init_params() + self._init_params(train_set) if not self.trainable: return - (A, X, Y, earlier_indices, later_indices, aspect_indices, pair_counts) = self._build_matrices(self.train_set) + (A, X, Y, earlier_indices, later_indices, aspect_indices, pair_counts) = self._build_matrices(train_set) A_user_counts = np.ediff1d(A.indptr) A_item_counts = np.ediff1d(A.tocsc().indptr) - A_uids = np.repeat(np.arange(self.train_set.num_users), A_user_counts).astype(np.int32) + A_uids = np.repeat(np.arange(train_set.num_users), A_user_counts).astype(np.int32) X_user_counts = np.ediff1d(X.indptr) X_aspect_counts = np.ediff1d(X.tocsc().indptr) - X_uids = np.repeat(np.arange(self.train_set.num_users), X_user_counts).astype(np.int32) + X_uids = np.repeat(np.arange(train_set.num_users), X_user_counts).astype(np.int32) Y_item_counts = np.ediff1d(Y.indptr) Y_aspect_counts = np.ediff1d(Y.tocsc().indptr) - Y_iids = np.repeat(np.arange(self.train_set.num_items), Y_item_counts).astype(np.int32) + Y_iids = np.repeat(np.arange(train_set.num_items), Y_item_counts).astype(np.int32) self._fit_efm(self.num_threads, A.data, A_uids, A.indices, A_user_counts, A_item_counts, @@ -227,22 +227,23 @@ class ComparERObj(Recommender): earlier_indices, later_indices, aspect_indices, pair_counts, self.U1, self.U2, self.V, self.H1, self.H2) - def _init_params(self): + def _init_params(self, train_set): from ...utils import get_rng from ...utils.init_utils import uniform rng = get_rng(self.seed) num_factors = self.num_explicit_factors + self.num_latent_factors high = np.sqrt(self.rating_scale / num_factors) - self.U1 = self.init_params.get('U1', uniform((self.train_set.num_users, self.num_explicit_factors), + self.num_aspects = train_set.sentiment.num_aspects + self.U1 = self.init_params.get('U1', uniform((self.num_users, self.num_explicit_factors), high=high, random_state=rng)) - self.U2 = self.init_params.get('U2', uniform((self.train_set.num_items, self.num_explicit_factors), + self.U2 = self.init_params.get('U2', uniform((self.num_items, self.num_explicit_factors), high=high, random_state=rng)) - self.V = self.init_params.get('V', uniform((self.train_set.sentiment.num_aspects, self.num_explicit_factors), + self.V = self.init_params.get('V', uniform((self.num_aspects, self.num_explicit_factors), high=high, random_state=rng)) - self.H1 = self.init_params.get('H1', uniform((self.train_set.num_users, self.num_latent_factors), + self.H1 = self.init_params.get('H1', uniform((self.num_users, self.num_latent_factors), high=high, random_state=rng)) - self.H2 = self.init_params.get('H2', uniform((self.train_set.num_items, self.num_latent_factors), + self.H2 = self.init_params.get('H2', uniform((self.num_items, self.num_latent_factors), high=high, random_state=rng)) def get_params(self): @@ -272,9 +273,9 @@ class ComparERObj(Recommender): int DOM_MODEL = MODEL_TYPES["Dominant"] int AROUND_MODEL = MODEL_TYPES["Around"] - int num_users = self.train_set.num_users - int num_items = self.train_set.num_items - int num_aspects = self.train_set.sentiment.num_aspects + int num_users = self.num_users + int num_items = self.num_items + int num_aspects = self.num_aspects int num_explicit_factors = self.num_explicit_factors int num_latent_factors = self.num_latent_factors int min_pair_freq = self.min_pair_freq @@ -304,7 +305,6 @@ class ComparERObj(Recommender): np.ndarray[np.float32_t, ndim=2] H2_denominator = np.empty((num_items, num_latent_factors), dtype=np.float32) for t in range(1, self.max_iter + 1): - loss = 0. aspect_bpr_loss = 0. correct = 0 @@ -460,7 +460,7 @@ class ComparERObj(Recommender): if data_set is not None: for uid, iid, rating in data_set.uir_iter(): - if self.train_set.is_unk_user(uid) or self.train_set.is_unk_item(iid): + if not (self.knows_user(uid) and self.knows_item(iid)): continue ratings.append(rating) map_uid.append(uid) @@ -469,7 +469,7 @@ class ComparERObj(Recommender): map_uid = np.asarray(map_uid, dtype=np.int32).flatten() map_iid = np.asarray(map_iid, dtype=np.int32).flatten() rating_matrix = sp.csr_matrix((ratings, (map_uid, map_iid)), - shape=(self.train_set.num_users, self.train_set.num_items)) + shape=(self.num_users, self.num_items)) if self.verbose: print('Building rating matrix completed in %d s' % (time() - start_time)) @@ -485,7 +485,7 @@ class ComparERObj(Recommender): window = len(item_ids) if self.enum_window is None else min(self.enum_window, len(item_ids)) for sub_item_ids in [item_ids[i:i+window] for i in range(len(item_ids) - window + 1)]: for earlier_item_idx, later_item_idx in combinations(sub_item_ids, 2): - if self.train_set.is_unk_item(earlier_item_idx) or self.train_set.is_unk_item(later_item_idx): + if not (self.knows_item(earlier_item_idx) and self.knows_item(later_item_idx)): continue chrono_purchased_pairs[(earlier_item_idx, later_item_idx)] += 1 @@ -537,7 +537,7 @@ class ComparERObj(Recommender): map_uid = [] map_aspect_id = [] for uid, sentiment_tup_ids_by_item in sentiment.user_sentiment.items(): - if self.train_set.is_unk_user(uid): + if not self.knows_user(uid): continue user_aspects = [tup[0] for tup_id in sentiment_tup_ids_by_item.values() @@ -551,7 +551,7 @@ class ComparERObj(Recommender): map_uid = np.asarray(map_uid, dtype=np.int32).flatten() map_aspect_id = np.asarray(map_aspect_id, dtype=np.int32).flatten() X = sp.csr_matrix((attention_scores, (map_uid, map_aspect_id)), - shape=(self.train_set.num_users, sentiment.num_aspects)) + shape=(self.num_users, self.num_aspects)) if self.verbose: print('Building user aspect attention matrix completed in %d s' % (time() - start_time)) @@ -564,7 +564,7 @@ class ComparERObj(Recommender): map_iid = [] map_aspect_id = [] for iid, sentiment_tup_ids_by_user in sentiment.item_sentiment.items(): - if self.train_set.is_unk_item(iid): + if not self.knows_item(iid): continue item_aspects = [tup[0] for tup_id in sentiment_tup_ids_by_user.values() @@ -586,7 +586,7 @@ class ComparERObj(Recommender): map_iid = np.asarray(map_iid, dtype=np.int32).flatten() map_aspect_id = np.asarray(map_aspect_id, dtype=np.int32).flatten() Y = sp.csr_matrix((quality_scores, (map_iid, map_aspect_id)), - shape=(self.train_set.num_items, sentiment.num_aspects)) + shape=(self.num_items, self.num_aspects)) if self.verbose: print('Building item aspect quality matrix completed in %d s' % (time() - start_time)) @@ -594,7 +594,7 @@ class ComparERObj(Recommender): return Y def _build_matrices(self, data_set): - sentiment = self.train_set.sentiment + sentiment = data_set.sentiment A = self._build_rating_matrix(data_set) X = self._build_user_attention_matrix(data_set, sentiment) Y = self._build_item_quality_matrix(data_set, sentiment) @@ -626,7 +626,7 @@ class ComparERObj(Recommender): A = self._build_rating_matrix(self.val_set) A_user_counts = np.ediff1d(A.indptr) - A_uids = np.repeat(np.arange(self.train_set.num_users), A_user_counts).astype(A.indices.dtype) + A_uids = np.repeat(np.arange(self.num_users), A_user_counts).astype(A.indices.dtype) return -self._get_loss(self.num_threads, A.data.astype(np.float32), A_uids, A.indices, @@ -651,12 +651,12 @@ class ComparERObj(Recommender): """ if item_id is None: - if self.train_set.is_unk_user(user_id): + if not self.knows_user(user_id): raise ScoreException("Can't make score prediction for (user_id=%d" & user_id) item_scores = self.U2.dot(self.U1[user_id, :]) + self.H2.dot(self.H1[user_id, :]) return item_scores else: - if self.train_set.is_unk_user(user_id) or self.train_set.is_unk_item(item_id): + if not (self.knows_user(user_id) and self.knows_item(item_id)): raise ScoreException("Can't make score prediction for (user_id=%d, item_id=%d)" % (user_id, item_id)) item_score = self.U2[item_id, :].dot(self.U1[user_id, :]) + self.H2[item_id, :].dot(self.H1[user_id, :]) return item_score @@ -690,9 +690,9 @@ class ComparERObj(Recommender): item_scores = item_scores item_rank = item_scores.argsort()[::-1] else: - num_items = max(self.train_set.num_items, max(item_ids) + 1) + num_items = max(self.num_items, max(item_ids) + 1) item_scores = np.ones(num_items) * np.min(item_scores) - item_scores[:self.train_set.num_items] = item_scores + item_scores[:self.num_items] = item_scores item_rank = item_scores.argsort()[::-1] item_rank = intersects(item_rank, item_ids, assume_unique=True) item_scores = item_scores[item_ids] diff --git a/cornac/models/comparer/recom_comparer_sub.pyx b/cornac/models/comparer/recom_comparer_sub.pyx index 09de437ee..e1eec1c77 100644 --- a/cornac/models/comparer/recom_comparer_sub.pyx +++ b/cornac/models/comparer/recom_comparer_sub.pyx @@ -181,7 +181,7 @@ class ComparERSub(MTER): map_iid = [] map_aspect_id = [] for iid, sentiment_tup_ids_by_user in sentiment.item_sentiment.items(): - if self.train_set.is_unk_item(iid): + if not self.knows_item(iid): continue item_aspects = [tup[0] for tup_id in sentiment_tup_ids_by_user.values() @@ -203,7 +203,7 @@ class ComparERSub(MTER): map_iid = np.asarray(map_iid, dtype=np.int32).flatten() map_aspect_id = np.asarray(map_aspect_id, dtype=np.int32).flatten() Y = sp.csr_matrix((quality_scores, (map_iid, map_aspect_id)), - shape=(self.train_set.num_items, sentiment.num_aspects)) + shape=(self.num_items, sentiment.num_aspects)) if self.verbose: print('Building item aspect quality matrix completed in %d s' % (time() - start_time)) @@ -218,19 +218,22 @@ class ComparERSub(MTER): if self.verbose: print("Building data started!") - sentiment = self.train_set.sentiment + + sentiment = data_set.sentiment + self.num_aspects = sentiment.num_aspects + self.num_opinions = sentiment.num_opinions (u_indices, i_indices, r_values) = data_set.uir_tuple keys = np.array([get_key(u, i) for u, i in zip(u_indices, i_indices)], dtype=np.intp) cdef IntFloatDict rating_dict = IntFloatDict(keys, np.array(r_values, dtype=np.float64)) rating_matrix = sp.csr_matrix( (r_values, (u_indices, i_indices)), - shape=(self.train_set.num_users, self.train_set.num_items), + shape=(self.num_users, self.num_items), ) user_item_aspect = {} user_aspect_opinion = {} item_aspect_opinion = {} for u_idx, sentiment_tup_ids_by_item in tqdm(sentiment.user_sentiment.items(), disable=not self.verbose, desc='Count aspects'): - if self.train_set.is_unk_user(u_idx): + if not self.knows_user(u_idx): continue for i_idx, tup_idx in sentiment_tup_ids_by_item.items(): user_item_aspect[ @@ -296,7 +299,7 @@ class ComparERSub(MTER): window = len(item_ids) if self.enum_window is None else min(self.enum_window, len(item_ids)) for sub_item_ids in [item_ids[i:i+window] for i in range(len(item_ids) - window + 1)]: for earlier_item_idx, later_item_idx in combinations(sub_item_ids, 2): - if self.train_set.is_unk_item(earlier_item_idx) or self.train_set.is_unk_item(later_item_idx): + if not (self.knows_item(earlier_item_idx) and self.knows_item(later_item_idx)): continue chrono_purchased_pairs[(user_idx, earlier_item_idx, later_item_idx)] += 1 @@ -305,7 +308,7 @@ class ComparERSub(MTER): counted_pairs = set() not_dominated_pairs = set() for (user_idx, earlier_item_idx, later_item_idx), count in tqdm(chrono_purchased_pairs.most_common(), disable=not self.verbose, desc='Get skyline aspects'): - for k in range(self.train_set.sentiment.num_aspects - 1): # ignore rating at the last index + for k in range(self.num_aspects - 1): # ignore rating at the last index if user_item_aspect.get((user_idx, later_item_idx, k), 0) > user_item_aspect.get((user_idx, earlier_item_idx, k), 0): pair_counts[(user_idx, earlier_item_idx, later_item_idx, k)] += count not_dominated_pairs.add((user_idx, earlier_item_idx, later_item_idx)) @@ -348,7 +351,6 @@ class ComparERSub(MTER): return user_indices, earlier_indices, later_indices, aspect_indices, pair_freq - def fit(self, train_set, val_set=None): """Fit the model to observations. @@ -366,7 +368,7 @@ class ComparERSub(MTER): """ Recommender.fit(self, train_set, val_set) - self._init() + self._init(train_set) if not self.trainable: return self @@ -416,8 +418,8 @@ class ComparERSub(MTER): YI_oids = np.array(YI_oids, dtype=np.int32) user_counts = np.ediff1d(rating_matrix.indptr).astype(np.int32) - user_ids = np.repeat(np.arange(self.train_set.num_users), user_counts).astype(np.int32) - neg_item_ids = np.arange(train_set.num_items, dtype=np.int32) + user_ids = np.repeat(np.arange(self.num_users), user_counts).astype(np.int32) + neg_item_ids = np.arange(self.num_items, dtype=np.int32) cdef: int n_threads = self.n_threads @@ -529,10 +531,10 @@ class ComparERSub(MTER): """ cdef: long s, i_index, j_index, correct = 0, skipped = 0, aspect_correct = 0 - long n_users = self.train_set.num_users - long n_items = self.train_set.num_items - long n_aspects = self.train_set.sentiment.num_aspects - long n_opinions = self.train_set.sentiment.num_opinions + long n_users = self.num_users + long n_items = self.num_items + long n_aspects = self.num_aspects + long n_opinions = self.num_opinions long n_user_factors = self.n_user_factors long n_item_factors = self.n_item_factors long n_aspect_factors = self.n_aspect_factors @@ -759,14 +761,13 @@ class ComparERSub(MTER): def rank(self, user_idx, item_indices=None): if self.alpha > 0 and self.n_top_aspects > 0: - n_items = self.train_set.num_items - n_top_aspects = min(self.n_top_aspects, self.train_set.sentiment.num_aspects) + n_top_aspects = min(self.n_top_aspects, self.num_aspects) ts1 = np.einsum("abc,a->bc", self.G1, self.U[user_idx]) ts2 = np.einsum("bc,Mb->Mc", ts1, self.I) ts3 = np.einsum("Mc,Nc->MN", ts2, self.A) top_aspect_scores = ts3[ - np.repeat(range(n_items), n_top_aspects).reshape( - n_items, n_top_aspects + np.repeat(range(self.num_items), n_top_aspects).reshape( + self.num_items, n_top_aspects ), ts3[:, :-1].argsort(axis=1)[::-1][:, :n_top_aspects], ] @@ -776,17 +777,17 @@ class ComparERSub(MTER): # check if the returned scores also cover unknown items # if not, all unknown items will be given the MIN score - if len(known_item_scores) == self.train_set.total_items: + if len(known_item_scores) == self.total_items: all_item_scores = known_item_scores else: - all_item_scores = np.ones(self.train_set.total_items) * np.min( + all_item_scores = np.ones(self.total_items) * np.min( known_item_scores ) - all_item_scores[: self.train_set.num_items] = known_item_scores + all_item_scores[: self.num_items] = known_item_scores # rank items based on their scores if item_indices is None: - item_scores = all_item_scores[: self.train_set.num_items] + item_scores = all_item_scores[: self.num_items] item_rank = item_scores.argsort()[::-1] else: item_scores = all_item_scores[item_indices] diff --git a/cornac/models/conv_mf/recom_convmf.py b/cornac/models/conv_mf/recom_convmf.py index 7ff7e4169..95a254dc6 100644 --- a/cornac/models/conv_mf/recom_convmf.py +++ b/cornac/models/conv_mf/recom_convmf.py @@ -132,10 +132,10 @@ def __init__( self.V = self.init_params.get("V", None) self.W = self.init_params.get("W", None) - def _init(self): + def _init(self, train_set): rng = get_rng(self.seed) - n_users, n_items = self.train_set.num_users, self.train_set.num_items - vocab_size = self.train_set.item_text.vocab.size + n_users, n_items = train_set.num_users, train_set.num_items + vocab_size = train_set.item_text.vocab.size if self.U is None: self.U = xavier_uniform((n_users, self.k), rng) @@ -161,10 +161,10 @@ def fit(self, train_set, val_set=None): """ Recommender.fit(self, train_set, val_set) - self._init() + self._init(train_set) if self.trainable: - self._fit_convmf() + self._fit_convmf(train_set) return self @@ -181,9 +181,9 @@ def _build_data(csr_mat): data.append(rating_list) return data - def _fit_convmf(self): - user_data = self._build_data(self.train_set.matrix) - item_data = self._build_data(self.train_set.matrix.T.tocsr()) + def _fit_convmf(self, train_set): + user_data = self._build_data(train_set.matrix) + item_data = self._build_data(train_set.matrix.T.tocsr()) n_user = len(user_data[0]) n_item = len(item_data[0]) @@ -228,7 +228,7 @@ def _fit_convmf(self): sess.run(tf.global_variables_initializer()) # init variable - document = self.train_set.item_text.batch_seq( + document = train_set.item_text.batch_seq( np.arange(n_item), max_length=self.max_len ) @@ -276,10 +276,10 @@ def _fit_convmf(self): self.cnn_epochs, desc="Optimizing CNN", disable=not self.verbose ) for _ in loop: - for batch_ids in self.train_set.item_iter( + for batch_ids in train_set.item_iter( batch_size=self.cnn_bs, shuffle=True ): - batch_seq = self.train_set.item_text.batch_seq( + batch_seq = train_set.item_text.batch_seq( batch_ids, max_length=self.max_len ) feed_dict = { @@ -338,22 +338,17 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - known_item_scores = self.V.dot(self.U[user_idx, :]) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) ) - user_pred = self.V[item_idx, :].dot(self.U[user_idx, :]) - return user_pred diff --git a/cornac/models/ctr/recom_ctr.py b/cornac/models/ctr/recom_ctr.py index eaf72e04a..ecb1639c1 100644 --- a/cornac/models/ctr/recom_ctr.py +++ b/cornac/models/ctr/recom_ctr.py @@ -52,7 +52,7 @@ class CTR(Recommender): Added value for smoothing phi. trainable: boolean, optional, default: True - When False, the model is not trained and Cornac assumes that the model already + When False, the model is not trained and Cornac assumes that the model already pre-trained (U and V are not None). init_params: dictionary, optional, default: None @@ -107,13 +107,10 @@ def __init__( def _init(self): rng = get_rng(self.seed) - self.n_item = self.train_set.num_items - self.n_user = self.train_set.num_users - if self.U is None: - self.U = xavier_uniform((self.n_user, self.k), rng) + self.U = xavier_uniform((self.num_users, self.k), rng) if self.V is None: - self.V = xavier_uniform((self.n_item, self.k), rng) + self.V = xavier_uniform((self.num_items, self.k), rng) def fit(self, train_set, val_set=None): """Fit the model to observations. @@ -135,7 +132,7 @@ def fit(self, train_set, val_set=None): self._init() if self.trainable: - self._fit_ctr() + self._fit_ctr(train_set) return self @@ -149,24 +146,24 @@ def _build_data(csr_mat): rating_list.append(csr_mat.data[j:k]) return index_list, rating_list - def _fit_ctr(self,): + def _fit_ctr(self, train_set): from .ctr import Model - user_data = self._build_data(self.train_set.matrix) - item_data = self._build_data(self.train_set.matrix.T.tocsr()) + user_data = self._build_data(train_set.matrix) + item_data = self._build_data(train_set.matrix.T.tocsr()) - bow_mat = self.train_set.item_text.batch_bow( - np.arange(self.n_item), keep_sparse=True + bow_mat = train_set.item_text.batch_bow( + np.arange(self.num_items), keep_sparse=True ) doc_ids, doc_cnt = self._build_data(bow_mat) # bag of word feature self.model = Model( - n_user=self.n_user, - n_item=self.n_item, + n_user=self.num_users, + n_item=self.num_items, U=self.U, V=self.V, k=self.k, - n_vocab=self.train_set.item_text.vocab.size, + n_vocab=train_set.item_text.vocab.size, lambda_u=self.lambda_u, lambda_v=self.lambda_v, a=self.a, @@ -205,17 +202,14 @@ def score(self, user_idx, item_idx=None): Relative scores that the user gives to the item or to all known items """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - known_item_scores = self.V.dot(self.U[user_idx, :]) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) diff --git a/cornac/models/cvae/recom_cvae.py b/cornac/models/cvae/recom_cvae.py index 970f0fe39..589f71d28 100644 --- a/cornac/models/cvae/recom_cvae.py +++ b/cornac/models/cvae/recom_cvae.py @@ -135,12 +135,10 @@ def __init__( def _init(self): rng = get_rng(self.seed) - n_users, n_items = self.train_set.num_users, self.train_set.num_items - if self.U is None: - self.U = xavier_uniform((n_users, self.z_dim), rng) + self.U = xavier_uniform((self.num_users, self.z_dim), rng) if self.V is None: - self.V = xavier_uniform((n_items, self.z_dim), rng) + self.V = xavier_uniform((self.num_items, self.z_dim), rng) def fit(self, train_set, val_set=None): """Fit the model to observations. @@ -160,17 +158,17 @@ def fit(self, train_set, val_set=None): Recommender.fit(self, train_set, val_set) self._init() - + if self.trainable: - self._fit_cvae() + self._fit_cvae(train_set) return self - def _fit_cvae(self): - R = self.train_set.csc_matrix # csc for efficient slicing over items - document = self.train_set.item_text.batch_bow( - np.arange(self.train_set.num_items) - ) # bag of word feature + def _fit_cvae(self, train_set): + R = train_set.csc_matrix # csc for efficient slicing over items + document = train_set.item_text.batch_bow( + np.arange(train_set.num_items) + ) # bag-of-words features document = (document - document.min()) / ( document.max() - document.min() ) # normalization @@ -178,13 +176,13 @@ def _fit_cvae(self): # VAE initialization from .cvae import Model import tensorflow.compat.v1 as tf - + tf.disable_eager_execution() tf.set_random_seed(self.seed) model = Model( - n_users=self.train_set.num_users, - n_items=self.train_set.num_items, + n_users=train_set.num_users, + n_items=train_set.num_items, input_dim=self.input_dim, U=self.U, V=self.V, @@ -209,7 +207,7 @@ def _fit_cvae(self): for _ in loop: cf_loss, vae_loss, count = 0, 0, 0 for i, batch_ids in enumerate( - self.train_set.item_iter(self.batch_size, shuffle=True) + train_set.item_iter(self.batch_size, shuffle=True) ): batch_R = R[:, batch_ids] batch_C = np.ones(batch_R.shape) * self.b @@ -255,22 +253,17 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - known_item_scores = self.V.dot(self.U[user_idx, :]) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) ) - user_pred = self.V[item_idx, :].dot(self.U[user_idx, :]) - return user_pred diff --git a/cornac/models/cvaecf/recom_cvaecf.py b/cornac/models/cvaecf/recom_cvaecf.py index 1c0936919..4ab9149ad 100644 --- a/cornac/models/cvaecf/recom_cvaecf.py +++ b/cornac/models/cvaecf/recom_cvaecf.py @@ -93,23 +93,23 @@ class CVAECF(Recommender): """ def __init__( - self, - name="CVAECF", - z_dim=20, - h_dim=20, - autoencoder_structure=[20], - act_fn="tanh", - likelihood="mult", - n_epochs=100, - batch_size=128, - learning_rate=0.001, - beta=1.0, - alpha_1=1.0, - alpha_2=1.0, - trainable=True, - verbose=False, - seed=None, - use_gpu=False, + self, + name="CVAECF", + z_dim=20, + h_dim=20, + autoencoder_structure=[20], + act_fn="tanh", + likelihood="mult", + n_epochs=100, + batch_size=128, + learning_rate=0.001, + beta=1.0, + alpha_1=1.0, + alpha_2=1.0, + trainable=True, + verbose=False, + seed=None, + use_gpu=False, ): Recommender.__init__(self, name=name, trainable=trainable, verbose=verbose) self.z_dim = z_dim @@ -152,14 +152,17 @@ def fit(self, train_set, val_set=None): else torch.device("cpu") ) + self.r_mat = train_set.matrix + self.u_adj_mat = train_set.user_graph.matrix + if self.trainable: if self.seed is not None: torch.manual_seed(self.seed) torch.cuda.manual_seed(self.seed) if not hasattr(self, "cvae"): - n_items = train_set.matrix.shape[1] - n_users = train_set.matrix.shape[0] + n_items = self.r_mat.shape[1] + n_users = self.r_mat.shape[0] self.cvae = CVAE( self.z_dim, self.h_dim, @@ -171,7 +174,7 @@ def fit(self, train_set, val_set=None): learn( self.cvae, - self.train_set, + train_set, n_epochs=self.n_epochs, batch_size=self.batch_size, learn_rate=self.learning_rate, @@ -208,17 +211,17 @@ def score(self, user_idx, item_idx=None): import torch if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - y_u = self.train_set.matrix[user_idx].copy() + y_u = self.r_mat[user_idx].copy() y_u.data = np.ones(len(y_u.data)) y_u = torch.tensor(y_u.A, dtype=torch.float32, device=self.device) z_u, _ = self.cvae.encode_qz(y_u) - x_u = self.train_set.user_graph.matrix[user_idx].copy() + x_u = self.u_adj_mat[user_idx].copy() x_u.data = np.ones(len(x_u.data)) x_u = torch.tensor(x_u.A, dtype=torch.float32, device=self.device) h_u, _ = self.cvae.encode_qhx(x_u) @@ -227,24 +230,24 @@ def score(self, user_idx, item_idx=None): return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) ) - y_u = self.train_set.matrix[user_idx].copy() + y_u = self.r_mat[user_idx].copy() y_u.data = np.ones(len(y_u.data)) y_u = torch.tensor(y_u.A, dtype=torch.float32, device=self.device) z_u, _ = self.cvae.encode_qz(y_u) - x_u = self.train_set.user_graph.matrix[user_idx].copy() + x_u = self.u_adj_mat[user_idx].copy() x_u.data = np.ones(len(x_u.data)) x_u = torch.tensor(x_u.A, dtype=torch.float32, device=self.device) h_u, _ = self.cvae.encode_qhx(x_u) - user_pred = self.cvae.decode(z_u, h_u).data.cpu().numpy().flatten()[item_idx] + user_pred = ( + self.cvae.decode(z_u, h_u).data.cpu().numpy().flatten()[item_idx] + ) return user_pred diff --git a/cornac/models/ease/recom_ease.py b/cornac/models/ease/recom_ease.py index e154dd328..149da0077 100644 --- a/cornac/models/ease/recom_ease.py +++ b/cornac/models/ease/recom_ease.py @@ -3,6 +3,7 @@ from cornac.models.recommender import Recommender from cornac.exception import ScoreException + class EASE(Recommender): """Embarrassingly Shallow Autoencoders for Sparse Data. @@ -34,15 +35,15 @@ class EASE(Recommender): """ def __init__( - self, - name="EASEá´¿", - lamb=500, - posB=True, - trainable=True, - verbose=True, - seed=None, - B=None, - U=None, + self, + name="EASEá´¿", + lamb=500, + posB=True, + trainable=True, + verbose=True, + seed=None, + B=None, + U=None, ): Recommender.__init__(self, name=name, trainable=trainable, verbose=verbose) self.lamb = lamb @@ -70,7 +71,7 @@ def fit(self, train_set, val_set=None): Recommender.fit(self, train_set, val_set) # A rating matrix - self.U = self.train_set.matrix + self.U = train_set.matrix # Gram matrix is X^t X, compute dot product G = self.U.T.dot(self.U).toarray() @@ -82,18 +83,17 @@ def fit(self, train_set, val_set=None): P = np.linalg.inv(G) B = P / (-np.diag(P)) - + B[diag_indices] = 0.0 # if self.posB remove -ve values if self.posB: - B[B<0]=0 + B[B < 0] = 0 # save B for predictions - self.B=B - - return self + self.B = B + return self def score(self, user_idx, item_idx=None): """Predict the scores/ratings of a user for an item. @@ -114,22 +114,17 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - known_item_scores = self.U[user_idx, :].dot(self.B) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) ) - user_pred = self.B[item_idx, :].dot(self.U[user_idx, :]) - - return user_pred \ No newline at end of file + return user_pred diff --git a/cornac/models/efm/recom_efm.pyx b/cornac/models/efm/recom_efm.pyx index 4ec755e79..1b77d9da0 100644 --- a/cornac/models/efm/recom_efm.pyx +++ b/cornac/models/efm/recom_efm.pyx @@ -163,10 +163,11 @@ class EFM(Recommender): self.H1 = self.init_params.get('H1', None) self.H2 = self.init_params.get('H2', None) - def _init(self): + def _init(self, train_set): rng = get_rng(self.seed) - n_users, n_items = self.train_set.num_users, self.train_set.num_items - n_aspects = self.train_set.sentiment.num_aspects + self.num_aspects = train_set.sentiment.num_aspects + n_aspects = self.num_aspects + n_users, n_items = self.num_users, self.num_items n_efactors = self.num_explicit_factors n_lfactors = self.num_latent_factors n_factors = n_efactors + n_lfactors @@ -200,19 +201,19 @@ class EFM(Recommender): """ Recommender.fit(self, train_set, val_set) - self._init() + self._init(train_set) if self.trainable: - A, X, Y = self._build_matrices(self.train_set) + A, X, Y = self._build_matrices(train_set) A_user_counts = np.ediff1d(A.indptr) A_item_counts = np.ediff1d(A.tocsc().indptr) - A_uids = np.repeat(np.arange(self.train_set.num_users), A_user_counts).astype(A.indices.dtype) + A_uids = np.repeat(np.arange(train_set.num_users), A_user_counts).astype(A.indices.dtype) X_user_counts = np.ediff1d(X.indptr) X_aspect_counts = np.ediff1d(X.tocsc().indptr) - X_uids = np.repeat(np.arange(self.train_set.num_users), X_user_counts).astype(X.indices.dtype) + X_uids = np.repeat(np.arange(train_set.num_users), X_user_counts).astype(X.indices.dtype) Y_item_counts = np.ediff1d(Y.indptr) Y_aspect_counts = np.ediff1d(Y.tocsc().indptr) - Y_iids = np.repeat(np.arange(self.train_set.num_items), Y_item_counts).astype(Y.indices.dtype) + Y_iids = np.repeat(np.arange(train_set.num_items), Y_item_counts).astype(Y.indices.dtype) self._fit_efm( self.num_threads, @@ -235,9 +236,9 @@ class EFM(Recommender): """Fit the model parameters (U1, U2, V, H1, H2) """ cdef: - long num_users = self.train_set.num_users - long num_items = self.train_set.num_items - long num_aspects = self.train_set.sentiment.num_aspects + long num_users = self.num_users + long num_items = self.num_items + long num_aspects = self.num_aspects int num_explicit_factors = self.num_explicit_factors int num_latent_factors = self.num_latent_factors @@ -358,13 +359,13 @@ class EFM(Recommender): print('Optimization finished!') def _build_matrices(self, data_set): - sentiment = self.train_set.sentiment + sentiment = data_set.sentiment ratings = [] map_uid = [] map_iid = [] for uid, iid, rating in data_set.uir_iter(): - if self.train_set.is_unk_user(uid) or self.train_set.is_unk_item(iid): + if not (self.knows_user(uid) and self.knows_item(iid)): continue ratings.append(rating) map_uid.append(uid) @@ -374,13 +375,13 @@ class EFM(Recommender): map_uid = np.asarray(map_uid, dtype=np.int32).flatten() map_iid = np.asarray(map_iid, dtype=np.int32).flatten() A = sp.csr_matrix((ratings, (map_uid, map_iid)), - shape=(self.train_set.num_users, self.train_set.num_items)) + shape=(self.num_users, self.num_items)) attention_scores = [] map_uid = [] map_aspect_id = [] for uid, sentiment_tup_ids_by_item in sentiment.user_sentiment.items(): - if self.train_set.is_unk_user(uid): + if not self.knows_user(uid): continue user_aspects = [tup[0] for tup_id in sentiment_tup_ids_by_item.values() for tup in sentiment.sentiment[tup_id]] @@ -394,14 +395,14 @@ class EFM(Recommender): map_uid = np.asarray(map_uid, dtype=np.int32).flatten() map_aspect_id = np.asarray(map_aspect_id, dtype=np.int32).flatten() X = sp.csr_matrix((attention_scores, (map_uid, map_aspect_id)), - shape=(self.train_set.num_users, sentiment.num_aspects)) + shape=(self.num_users, self.num_aspects)) quality_scores = [] map_iid = [] map_aspect_id = [] for iid, sentiment_tup_ids_by_user in sentiment.item_sentiment.items(): - if self.train_set.is_unk_item(iid): + if not self.knows_item(iid): continue item_aspects = [tup[0] for tup_id in sentiment_tup_ids_by_user.values() for tup in sentiment.sentiment[tup_id]] @@ -423,7 +424,7 @@ class EFM(Recommender): map_iid = np.asarray(map_iid, dtype=np.int32).flatten() map_aspect_id = np.asarray(map_aspect_id, dtype=np.int32).flatten() Y = sp.csr_matrix((quality_scores, (map_iid, map_aspect_id)), - shape=(self.train_set.num_items, sentiment.num_aspects)) + shape=(self.num_items, self.num_aspects)) if self.verbose: print('Building matrices completed!') @@ -455,12 +456,12 @@ class EFM(Recommender): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException("Can't make score prediction for (user_id=%d" & user_idx) item_scores = self.U2.dot(self.U1[user_idx, :]) + self.H2.dot(self.H1[user_idx, :]) return item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item(item_idx): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException("Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx)) item_score = self.U2[item_idx, :].dot(self.U1[user_idx, :]) + self.H2[item_idx, :].dot(self.H1[user_idx, :]) return item_score @@ -492,17 +493,17 @@ class EFM(Recommender): # check if the returned scores also cover unknown items # if not, all unknown items will be given the MIN score - if len(known_item_scores) == self.train_set.total_items: + if len(known_item_scores) == self.total_items: all_item_scores = known_item_scores else: - all_item_scores = np.ones(self.train_set.total_items) * np.min( + all_item_scores = np.ones(self.total_items) * np.min( known_item_scores ) - all_item_scores[: self.train_set.num_items] = known_item_scores + all_item_scores[: self.num_items] = known_item_scores # rank items based on their scores if item_indices is None: - item_scores = all_item_scores[: self.train_set.num_items] + item_scores = all_item_scores[: self.num_items] item_rank = item_scores.argsort()[::-1] else: item_scores = all_item_scores[item_indices] diff --git a/cornac/models/fm/recom_fm.pyx b/cornac/models/fm/recom_fm.pyx index 318312acb..a6cf50683 100644 --- a/cornac/models/fm/recom_fm.pyx +++ b/cornac/models/fm/recom_fm.pyx @@ -215,7 +215,7 @@ class FM(Recommender): self.v = self.init_params.get('v', None) def _init(self): - num_features = self.train_set.total_users + self.train_set.total_items + num_features = self.total_users + self.total_items if self.w0 is None: self.w0 = 0.0 @@ -260,12 +260,12 @@ class FM(Recommender): @cython.boundscheck(False) @cython.wraparound(False) def _fit_libfm(self, train_set, val_set, double[:] w, double[:, :] v): - cdef unsigned int num_feature = self.train_set.total_users + self.train_set.total_items + cdef unsigned int num_feature = self.total_users + self.total_items - (uid, iid, val) = self.train_set.uir_tuple + (uid, iid, val) = train_set.uir_tuple cdef Data *train = _prepare_data( uid, - iid + self.train_set.total_users, + iid + self.total_users, val.astype(np.float32), num_feature, self.method in ["als", "mcmc"], @@ -280,7 +280,7 @@ class FM(Recommender): (uid, iid, val) = val_set.uir_tuple validation = _prepare_data( uid, - iid + self.train_set.total_users, + iid + self.total_users, val.astype(np.float32), num_feature, self.method in ["als", "mcmc"], @@ -345,8 +345,8 @@ class FM(Recommender): (fml).num_eval_cases = validation.num_cases fml.fm = &fm - fml.max_target = self.train_set.max_rating - fml.min_target = self.train_set.min_rating + fml.max_target = self.max_rating + fml.min_target = self.min_rating fml.meta = meta fml.task = 0 # regression @@ -385,7 +385,7 @@ class FM(Recommender): def _fm_predict(self, user_idx, item_idx): uid = user_idx - iid = item_idx + self.train_set.total_users + iid = item_idx + self.total_users score = 0.0 if self.k0: score += self.w0 @@ -418,9 +418,9 @@ class FM(Recommender): """ if item_idx is None: known_item_scores = np.fromiter( - (self._fm_predict(user_idx, i) for i in range(self.train_set.total_items)), + (self._fm_predict(user_idx, i) for i in range(self.total_items)), dtype=np.double, - count=self.train_set.total_items + count=self.total_items ) return known_item_scores else: diff --git a/cornac/models/gcmc/gcmc.py b/cornac/models/gcmc/gcmc.py index 120bcbf0b..21228d578 100644 --- a/cornac/models/gcmc/gcmc.py +++ b/cornac/models/gcmc/gcmc.py @@ -13,165 +13,14 @@ from .utils import get_optimizer, torch_net_info, torch_total_param_num -def _apply_support(graph, rating_values, data_set, symm=True): - """Adds graph support. Returns DGLGraph.""" - - def _calc_norm(val): - val = val.numpy().astype("float32") - val[val == 0.0] = np.inf - val = torch.FloatTensor(1.0 / np.sqrt(val)) - return val.unsqueeze(1) - - n_users, n_items = data_set.total_users, data_set.total_items - - user_ci = [] - user_cj = [] - item_ci = [] - item_cj = [] - - for rating in rating_values: - rating = str(rating).replace(".", "_") - user_ci.append(graph[f"rev-{rating}"].in_degrees()) - item_ci.append(graph[rating].in_degrees()) - - if symm: - user_cj.append(graph[rating].out_degrees()) - item_cj.append(graph[f"rev-{rating}"].out_degrees()) - else: - user_cj.append(torch.zeros((n_users,))) - item_cj.append(torch.zeros((n_items,))) - user_ci = _calc_norm(sum(user_ci)) - item_ci = _calc_norm(sum(item_ci)) - if symm: - user_cj = _calc_norm(sum(user_cj)) - item_cj = _calc_norm(sum(item_cj)) - else: - user_cj = torch.ones( - n_users, - ) - item_cj = torch.ones( - n_items, - ) - graph.nodes["user"].data.update({"ci": user_ci, "cj": user_cj}) - graph.nodes["item"].data.update({"ci": item_ci, "cj": item_cj}) - - return graph - - -def _generate_enc_graph(data_set, add_support=False): - """ - Generates encoding graph given a cornac data set - - Parameters - ---------- - data_set : cornac.data.dataset.Dataset - The data set as provided by cornac - add_support : bool, optional - """ - data_dict = {} - num_nodes_dict = {"user": data_set.total_users, "item": data_set.total_items} - rating_row, rating_col, rating_values = data_set.uir_tuple - for rating in set(rating_values): - ridx = np.where(rating_values == rating) - rrow = rating_row[ridx] - rcol = rating_col[ridx] - rating = str(rating).replace(".", "_") - data_dict.update( - { - ("user", str(rating), "item"): (rrow, rcol), - ("item", f"rev-{str(rating)}", "user"): (rcol, rrow), - } - ) - - graph = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict) - - # sanity check - assert ( - len(data_set.uir_tuple[2]) - == sum([graph.num_edges(et) for et in graph.etypes]) // 2 - ) - - if add_support: - graph = _apply_support( - graph=graph, - rating_values=np.unique(rating_values), - data_set=data_set, - ) - - return graph - - -def _generate_dec_graph(data_set): - """ - Generates decoding graph given a cornac data set - - Parameters - ---------- - data_set : cornac.data.dataset.Dataset - The data set as provided by cornac - - Returns - ------- - graph : dgl.heterograph - Heterograph containing user-item edges and nodes - """ - rating_pairs = data_set.uir_tuple[:2] - ones = np.ones_like(rating_pairs[0]) - user_item_ratings_coo = sp.coo_matrix( - (ones, rating_pairs), - shape=(data_set.total_users, data_set.total_items), - dtype=np.float32, - ) - - graph = dgl.bipartite_from_scipy( - user_item_ratings_coo, utype="_U", etype="_E", vtype="_V" - ) - - return dgl.heterograph( - {("user", "rate", "item"): graph.edges()}, - num_nodes_dict={"user": data_set.total_users, "item": data_set.total_items}, - ) - - -def _generate_test_user_graph(user_idx, total_users, total_items): - """ - Generates decoding graph given a cornac data set - - Parameters - ---------- - data_set : cornac.data.dataset.Dataset - The data set as provided by cornac - - Returns - ------- - graph : dgl.heterograph - Heterograph containing user-item edges and nodes - """ - u_list = np.array([user_idx for _ in range(total_items)]) - i_list = np.array([item_idx for item_idx in range(total_items)]) - - rating_pairs = (u_list, i_list) - ones = np.ones_like(rating_pairs[0]) - user_item_ratings_coo = sp.coo_matrix( - (ones, rating_pairs), - shape=(total_users, total_items), - dtype=np.float32, - ) - - graph = dgl.bipartite_from_scipy( - user_item_ratings_coo, utype="_U", etype="_E", vtype="_V" - ) - - return dgl.heterograph( - {("user", "rate", "item"): graph.edges()}, - num_nodes_dict={"user": total_users, "item": total_items}, - ) - class Model: def __init__( self, - activation_model, + rating_values, + total_users, + total_items, + activation_func, gcn_agg_units, gcn_out_units, gcn_dropout, @@ -181,14 +30,9 @@ def __init__( verbose, seed, ): - self.activation_model = activation_model - self.gcn_agg_units = gcn_agg_units - self.gcn_out_units = gcn_out_units - self.gcn_dropout = gcn_dropout - self.gcn_agg_accum = gcn_agg_accum - self.share_param = share_param - self.gen_r_num_basis_func = gen_r_num_basis_func - + self.rating_values = rating_values + self.total_users = total_users + self.total_items = total_items self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.verbose = verbose @@ -201,92 +45,180 @@ def __init__( if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - def train( - self, - train_set, - val_set, - max_iter, - learning_rate, - optimizer, - train_grad_clip, - train_valid_interval, - train_early_stopping_patience, - train_min_learning_rate, - train_decay_patience, - train_lr_decay_factor, - ): - # Prepare data for training - ( - rating_values, - nd_positive_rating_values, - train_dec_graph, - valid_enc_graph, - valid_dec_graph, - train_labels, - train_truths, - valid_truths, - ) = self._prepare_data(train_set, val_set) - # Build Net self.net = NeuralNetwork( - self.activation_model, + activation_func, rating_values, - train_set.total_users, - train_set.total_items, - self.gcn_agg_units, - self.gcn_out_units, - self.gcn_dropout, - self.gcn_agg_accum, - self.gen_r_num_basis_func, - self.share_param, + total_users, + total_items, + gcn_agg_units, + gcn_out_units, + gcn_dropout, + gcn_agg_accum, + gen_r_num_basis_func, + share_param, self.device, ).to(self.device) - optimizer = get_optimizer(optimizer)(self.net.parameters(), lr=learning_rate) - rating_loss_net = nn.CrossEntropyLoss() + def _apply_support(self, graph, symm=True): + """Adds graph support. Returns DGLGraph.""" + + def _calc_norm(val): + val = val.numpy().astype("float32") + val[val == 0.0] = np.inf + val = torch.FloatTensor(1.0 / np.sqrt(val)) + return val.unsqueeze(1) + + user_ci = [] + user_cj = [] + item_ci = [] + item_cj = [] + + for rating in self.rating_values: + rating = str(rating).replace(".", "_") + user_ci.append(graph[f"rev-{rating}"].in_degrees()) + item_ci.append(graph[rating].in_degrees()) + + if symm: + user_cj.append(graph[rating].out_degrees()) + item_cj.append(graph[f"rev-{rating}"].out_degrees()) + else: + user_cj.append(torch.zeros((self.total_users,))) + item_cj.append(torch.zeros((self.total_items,))) + user_ci = _calc_norm(sum(user_ci)) + item_ci = _calc_norm(sum(item_ci)) + if symm: + user_cj = _calc_norm(sum(user_cj)) + item_cj = _calc_norm(sum(item_cj)) + else: + user_cj = torch.ones(self.total_users,) + item_cj = torch.ones(self.total_items,) + graph.nodes["user"].data.update({"ci": user_ci, "cj": user_cj}) + graph.nodes["item"].data.update({"ci": item_ci, "cj": item_cj}) - self._train_model( - rating_values, - train_dec_graph, - valid_enc_graph, - valid_dec_graph, - train_labels, - train_truths, - valid_truths, - nd_positive_rating_values, - rating_loss_net, - max_iter, - optimizer, - learning_rate, - train_grad_clip, - train_valid_interval, - train_early_stopping_patience, - train_min_learning_rate, - train_decay_patience, - train_lr_decay_factor, + return graph + + + def _generate_enc_graph(self, data_set, add_support=False): + """ + Generates encoding graph given a cornac data set + + Parameters + ---------- + data_set : cornac.data.dataset.Dataset + The data set as provided by cornac + + add_support : bool, optional + """ + data_dict = {} + num_nodes_dict = {"user": self.total_users, "item": self.total_items} + rating_row, rating_col, rating_values = data_set.uir_tuple + for rating in set(rating_values): + ridx = np.where(rating_values == rating) + rrow = rating_row[ridx] + rcol = rating_col[ridx] + rating = str(rating).replace(".", "_") + data_dict.update( + { + ("user", str(rating), "item"): (rrow, rcol), + ("item", f"rev-{str(rating)}", "user"): (rcol, rrow), + } + ) + + graph = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict) + + # sanity check + assert ( + len(data_set.uir_tuple[2]) + == sum([graph.num_edges(et) for et in graph.etypes]) // 2 + ) + + graph = self._apply_support(graph) if add_support else graph + + return graph + + + def _generate_dec_graph(self, data_set): + """ + Generates decoding graph given a cornac data set + + Parameters + ---------- + data_set : cornac.data.dataset.Dataset + The data set as provided by cornac + + Returns + ------- + graph : dgl.heterograph + Heterograph containing user-item edges and nodes + """ + user_nodes, item_nodes, ratings = data_set.uir_tuple + user_item_ratings_coo = sp.coo_matrix( + (np.ones_like(ratings), (user_nodes, item_nodes)), + shape=(self.total_users, self.total_items), + dtype=np.float32, + ) + + graph = dgl.bipartite_from_scipy( + user_item_ratings_coo, utype="_U", etype="_E", vtype="_V" + ) + + return dgl.heterograph( + {("user", "rate", "item"): graph.edges()}, + num_nodes_dict={"user": self.total_users, "item": self.total_items}, + ) + + + def _generate_test_user_graph(self, user_idx): + """ + Generates decoding graph given a cornac data set + + Parameters + ---------- + data_set : cornac.data.dataset.Dataset + The data set as provided by cornac + + Returns + ------- + graph : dgl.heterograph + Heterograph containing user-item edges and nodes + """ + item_nodes = np.arange(self.total_items) + user_nodes = np.full_like(item_nodes, user_idx) + user_item_ratings_coo = sp.coo_matrix( + (np.ones_like(user_nodes), (user_nodes, item_nodes)), + shape=(self.total_users, self.total_items), + dtype=np.float32, + ) + + graph = dgl.bipartite_from_scipy( + user_item_ratings_coo, utype="_U", etype="_E", vtype="_V" + ) + + return dgl.heterograph( + {("user", "rate", "item"): graph.edges()}, + num_nodes_dict={"user": self.total_users, "item": self.total_items}, ) def _prepare_data(self, train_set, val_set): - rating_values = train_set.uir_tuple[2] # rating list - rating_values = np.unique(rating_values) - nd_positive_rating_values = torch.FloatTensor(rating_values).to(self.device) + nd_positive_rating_values = torch.FloatTensor(self.rating_values).to(self.device) # Prepare Data def _generate_labels(ratings): - labels = torch.LongTensor(np.searchsorted(rating_values, ratings)).to( + labels = torch.LongTensor(np.searchsorted(self.rating_values, ratings)).to( self.device ) return labels - self.train_enc_graph = _generate_enc_graph(train_set, add_support=True) - train_dec_graph = _generate_dec_graph(train_set) + self.train_enc_graph = self._generate_enc_graph(train_set, add_support=True) + train_dec_graph = self._generate_dec_graph(train_set) train_labels = _generate_labels(train_set.uir_tuple[2]) train_truths = torch.FloatTensor(train_set.uir_tuple[2]).to(self.device) def _count_pairs(graph): pair_count = 0 - for r_val in rating_values: + for r_val in self.rating_values: r_val = str(r_val).replace(".", "_") pair_count += graph.num_edges(str(r_val)) return pair_count @@ -307,7 +239,7 @@ def _count_pairs(graph): valid_enc_graph = self.train_enc_graph if val_set: - valid_dec_graph = _generate_dec_graph(val_set) + valid_dec_graph = self._generate_dec_graph(val_set) valid_truths = torch.FloatTensor(val_set.uir_tuple[2]).to(self.device) logging.info( "Valid enc graph: %s users, %s items, %s pairs", @@ -325,7 +257,6 @@ def _count_pairs(graph): valid_dec_graph = None return ( - rating_values, nd_positive_rating_values, train_dec_graph, valid_enc_graph, @@ -358,6 +289,7 @@ def _train_model( ): # initialize loss variables best_valid_rmse = np.inf + best_model_state_dict = None no_better_valid = 0 best_iter = -1 count_rmse = 0 @@ -468,13 +400,63 @@ def _train_model( logging.info(logging_str) - if valid_dec_graph: + if not (valid_dec_graph is None or best_model_state_dict is None): logging.info( "Best iter idx=%s, Best valid rmse=%.4f", best_iter, best_valid_rmse ) # load best model self.net.load_state_dict(best_model_state_dict) + + def train( + self, + train_set, + val_set, + max_iter, + learning_rate, + optimizer, + train_grad_clip, + train_valid_interval, + train_early_stopping_patience, + train_min_learning_rate, + train_decay_patience, + train_lr_decay_factor, + ): + # Prepare data for training + ( + nd_positive_rating_values, + train_dec_graph, + valid_enc_graph, + valid_dec_graph, + train_labels, + train_truths, + valid_truths, + ) = self._prepare_data(train_set, val_set) + + optimizer = get_optimizer(optimizer)(self.net.parameters(), lr=learning_rate) + rating_loss_net = nn.CrossEntropyLoss() + + self._train_model( + self.rating_values, + train_dec_graph, + valid_enc_graph, + valid_dec_graph, + train_labels, + train_truths, + valid_truths, + nd_positive_rating_values, + rating_loss_net, + max_iter, + optimizer, + learning_rate, + train_grad_clip, + train_valid_interval, + train_early_stopping_patience, + train_min_learning_rate, + train_decay_patience, + train_lr_decay_factor, + ) + def predict(self, test_set): """ @@ -493,7 +475,7 @@ def predict(self, test_set): Dictionary containing '{user_idx}-{item_idx}' as key and {score} as value. """ - test_dec_graph = _generate_dec_graph(test_set) + test_dec_graph = self._generate_dec_graph(test_set) test_dec_graph = test_dec_graph.int().to(self.device) self.net.eval() @@ -501,10 +483,7 @@ def predict(self, test_set): with torch.no_grad(): pred_ratings = self.net(self.train_enc_graph, test_dec_graph) - test_rating_values = test_set.uir_tuple[2] - test_rating_values = np.unique(test_rating_values) - - nd_positive_rating_values = torch.FloatTensor(test_rating_values).to( + nd_positive_rating_values = torch.FloatTensor(self.rating_values).to( self.device ) @@ -514,22 +493,14 @@ def predict(self, test_set): test_pred_ratings = test_pred_ratings.cpu().numpy() - uid_list = test_set.uir_tuple[0] - uid_list = np.unique(uid_list) - - u_list = np.array([user_idx for _ in range(test_set.total_items) for user_idx in uid_list]) - i_list = np.array([item_idx for item_idx in range(test_set.total_items) for _ in uid_list]) - - u_list = u_list.tolist() - i_list = i_list.tolist() - + user_nodes, item_nodes, _ = test_set.uir_tuple u_i_rating_dict = { - f"{u_list[idx]}-{i_list[idx]}": rating + f"{user_nodes[idx]}-{item_nodes[idx]}": rating for idx, rating in enumerate(test_pred_ratings) } return u_i_rating_dict - def predict_one(self, train_set, user_idx): + def predict_one(self, user_idx): """ Processes single user_idx from test set and returns numpy list of scores for all items. @@ -544,7 +515,7 @@ def predict_one(self, train_set, user_idx): test_pred_ratings : numpy.array Numpy array containing all ratings for the given user_idx. """ - test_dec_graph = _generate_test_user_graph(user_idx, train_set.total_users, train_set.total_items) + test_dec_graph = self._generate_test_user_graph(user_idx) test_dec_graph = test_dec_graph.int().to(self.device) self.net.eval() @@ -552,12 +523,7 @@ def predict_one(self, train_set, user_idx): with torch.no_grad(): pred_ratings = self.net(self.train_enc_graph, test_dec_graph) - test_rating_values = train_set.uir_tuple[2] - test_rating_values = np.unique(test_rating_values) - - nd_positive_rating_values = torch.FloatTensor(test_rating_values).to( - self.device - ) + nd_positive_rating_values = torch.FloatTensor(self.rating_values).to(self.device) test_pred_ratings = ( torch.softmax(pred_ratings, dim=1) * nd_positive_rating_values.view(1, -1) diff --git a/cornac/models/gcmc/nn_modules.py b/cornac/models/gcmc/nn_modules.py index 64faf9bf1..78a26fe77 100644 --- a/cornac/models/gcmc/nn_modules.py +++ b/cornac/models/gcmc/nn_modules.py @@ -16,7 +16,7 @@ class NeuralNetwork(nn.Module): def __init__( self, - activation_model, + activation_func, rating_values, n_users, n_items, @@ -29,7 +29,7 @@ def __init__( device, ): super(NeuralNetwork, self).__init__() - self._act = get_activation(activation_model) + self._act = get_activation(activation_func) self.encoder = GCMCLayer( rating_values, n_users, diff --git a/cornac/models/gcmc/recom_gcmc.py b/cornac/models/gcmc/recom_gcmc.py index 23cef620e..855d61877 100644 --- a/cornac/models/gcmc/recom_gcmc.py +++ b/cornac/models/gcmc/recom_gcmc.py @@ -14,6 +14,8 @@ # limitations under the License. # ============================================================================ +import numpy as np + from ..recommender import Recommender @@ -35,7 +37,7 @@ class GCMC(Recommender): optimizer: string, default: 'adam'. Supported values: 'adam','sgd'. The optimization method used for SGD - activation_model: string, default: 'leaky' + activation_func: string, default: 'leaky' The activation function used in the GCMC model. Supported values: ['leaky', 'linear','sigmoid','relu', 'tanh'] @@ -96,7 +98,7 @@ def __init__( max_iter=2000, learning_rate=0.01, optimizer="adam", - activation_model="leaky", + activation_func="leaky_relu", gcn_agg_units=500, gcn_out_units=75, gcn_dropout=0.7, @@ -116,7 +118,7 @@ def __init__( super().__init__(name=name, trainable=trainable, verbose=verbose) # architecture params - self.activation_model = activation_model + self.activation_func = activation_func self.gcn_agg_units = gcn_agg_units self.gcn_out_units = gcn_out_units self.gcn_dropout = gcn_dropout @@ -156,8 +158,13 @@ def fit(self, train_set, val_set=None): if self.trainable: from .gcmc import Model - self.model = Model( - activation_model=self.activation_model, + self.rating_values = np.unique(train_set.uir_tuple[2]) + + self.model = Model( + rating_values=self.rating_values, + total_users=self.total_users, + total_items=self.total_items, + activation_func=self.activation_func, gcn_agg_units=self.gcn_agg_units, gcn_out_units=self.gcn_out_units, gcn_dropout=self.gcn_dropout, @@ -167,6 +174,7 @@ def fit(self, train_set, val_set=None): verbose=self.verbose, seed=self.seed, ) + self.model.train( train_set, val_set, @@ -213,6 +221,6 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: # Return scores of all items for a given user - return self.model.predict_one(self.train_set, user_idx) + return self.model.predict_one(user_idx) # Return score of known user/item return self.u_i_rating_dict.get(f"{user_idx}-{item_idx}", self.default_score()) diff --git a/cornac/models/gcmc/utils.py b/cornac/models/gcmc/utils.py index f16e1496b..29e8ec53a 100644 --- a/cornac/models/gcmc/utils.py +++ b/cornac/models/gcmc/utils.py @@ -17,7 +17,7 @@ def get_activation(act): if act is None: return lambda x: x if isinstance(act, str): - if act == "leaky": + if act == "leaky_relu": return nn.LeakyReLU(0.1) elif act == "relu": return nn.ReLU() diff --git a/cornac/models/global_avg/recom_global_avg.py b/cornac/models/global_avg/recom_global_avg.py index 1c353d331..8603be323 100644 --- a/cornac/models/global_avg/recom_global_avg.py +++ b/cornac/models/global_avg/recom_global_avg.py @@ -51,6 +51,6 @@ def score(self, user_idx, item_idx=None): Relative scores that the user gives to the item or to all known items """ if item_idx is None: - return np.full(self.train_set.num_items, self.train_set.global_mean) + return np.full(self.num_items, self.global_mean) else: - return self.train_set.global_mean + return self.global_mean diff --git a/cornac/models/hft/recom_hft.py b/cornac/models/hft/recom_hft.py index 0a84059b3..a5eeecd41 100644 --- a/cornac/models/hft/recom_hft.py +++ b/cornac/models/hft/recom_hft.py @@ -72,7 +72,7 @@ class HFT(Recommender): verbose: boolean, optional, default: True When True, some running logs are displayed. - + seed: int, optional, default: None Random seed for weight initialization. @@ -118,19 +118,16 @@ def __init__( def _init(self): rng = get_rng(self.seed) - self.n_item = self.train_set.num_items - self.n_user = self.train_set.num_users - if self.alpha is None: - self.alpha = self.train_set.global_mean + self.alpha = self.global_mean if self.beta_u is None: - self.beta_u = normal(self.n_user, std=0.01, random_state=rng) + self.beta_u = normal(self.num_users, std=0.01, random_state=rng) if self.beta_i is None: - self.beta_i = normal(self.n_item, std=0.01, random_state=rng) + self.beta_i = normal(self.num_items, std=0.01, random_state=rng) if self.gamma_u is None: - self.gamma_u = normal((self.n_user, self.k), std=0.01, random_state=rng) + self.gamma_u = normal((self.num_users, self.k), std=0.01, random_state=rng) if self.gamma_i is None: - self.gamma_i = normal((self.n_item, self.k), std=0.01, random_state=rng) + self.gamma_i = normal((self.num_items, self.k), std=0.01, random_state=rng) def fit(self, train_set, val_set=None): """Fit the model to observations. @@ -152,7 +149,7 @@ def fit(self, train_set, val_set=None): self._init() if self.trainable: - self._fit_hft() + self._fit_hft(train_set) return self @@ -166,21 +163,21 @@ def _build_data(csr_mat): rating_list.append(csr_mat.data[j:k]) return index_list, rating_list - def _fit_hft(self): + def _fit_hft(self, train_set): from .hft import Model # document data - bow_mat = self.train_set.item_text.batch_bow( - np.arange(self.n_item), keep_sparse=True + bow_mat = train_set.item_text.batch_bow( + np.arange(self.num_items), keep_sparse=True ) documents, _ = self._build_data(bow_mat) # bag of word feature # Rating data - user_data = self._build_data(self.train_set.matrix) - item_data = self._build_data(self.train_set.matrix.T.tocsr()) + user_data = self._build_data(train_set.matrix) + item_data = self._build_data(train_set.matrix.T.tocsr()) model = Model( - n_user=self.n_user, - n_item=self.n_item, + n_user=self.num_users, + n_item=self.num_items, alpha=self.alpha, beta_u=self.beta_u, beta_i=self.beta_i, @@ -202,9 +199,13 @@ def _fit_hft(self): loss = model.update_params(rating_data=(user_data, item_data)) loop.set_postfix(loss=loss) - self.alpha, self.beta_u, self.beta_i, self.gamma_u, self.gamma_i = ( - model.get_parameter() - ) + ( + self.alpha, + self.beta_u, + self.beta_i, + self.gamma_u, + self.gamma_i, + ) = model.get_parameter() if self.verbose: print("Learning completed!") @@ -227,7 +228,7 @@ def score(self, user_idx, item_idx=None): Relative scores that the user gives to the item or to all known items """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) @@ -240,9 +241,7 @@ def score(self, user_idx, item_idx=None): ) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) @@ -254,5 +253,4 @@ def score(self, user_idx, item_idx=None): + self.beta_i[item_idx] + self.gamma_i[item_idx, :].dot(self.gamma_u[user_idx, :]) ) - return user_pred diff --git a/cornac/models/hpf/recom_hpf.py b/cornac/models/hpf/recom_hpf.py index cca10e756..f2c408680 100644 --- a/cornac/models/hpf/recom_hpf.py +++ b/cornac/models/hpf/recom_hpf.py @@ -132,7 +132,7 @@ def fit(self, train_set, val_set=None): "L_r": self.Lr, } - X = sp.csc_matrix(self.train_set.matrix) + X = train_set.csc_matrix # recover the striplet sparse format from csc sparse matrix X (needed to feed c++) (rid, cid, val) = sp.find(X) val = np.array(val, dtype="float32") @@ -146,11 +146,23 @@ def fit(self, train_set, val_set=None): if self.hierarchical: res = hpf.hpf( - tX, X.shape[0], X.shape[1], self.k, self.max_iter, self.seed, init_params + tX, + X.shape[0], + X.shape[1], + self.k, + self.max_iter, + self.seed, + init_params, ) else: res = hpf.pf( - tX, X.shape[0], X.shape[1], self.k, self.max_iter, self.seed, init_params + tX, + X.shape[0], + X.shape[1], + self.k, + self.max_iter, + self.seed, + init_params, ) self.Theta = np.asarray(res["Z"]) self.Beta = np.asarray(res["W"]) @@ -185,7 +197,7 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): u_representation = np.ones(self.k) else: u_representation = self.Theta[user_idx, :] @@ -194,9 +206,7 @@ def score(self, user_idx, item_idx=None): known_item_scores = np.array(known_item_scores, dtype="float64").flatten() return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) @@ -204,5 +214,4 @@ def score(self, user_idx, item_idx=None): user_pred = self.Beta[item_idx, :].dot(self.Theta[user_idx, :]) user_pred = np.array(user_pred, dtype="float64").flatten()[0] - return user_pred diff --git a/cornac/models/hrdr/recom_hrdr.py b/cornac/models/hrdr/recom_hrdr.py index 52369e850..678214d41 100644 --- a/cornac/models/hrdr/recom_hrdr.py +++ b/cornac/models/hrdr/recom_hrdr.py @@ -21,7 +21,8 @@ from ...exception import ScoreException -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + class HRDR(Recommender): """ @@ -107,9 +108,9 @@ def __init__( max_num_review=32, batch_size=64, max_iter=20, - optimizer='adam', + optimizer="adam", learning_rate=0.001, - model_selection='last', # last or best + model_selection="last", # last or best user_based=True, trainable=True, verbose=True, @@ -159,11 +160,12 @@ def fit(self, train_set, val_set=None): if self.trainable: if not hasattr(self, "model"): from .hrdr import HRDRModel + self.model = HRDRModel( - self.train_set.num_users, - self.train_set.num_items, - self.train_set.review_text.vocab, - self.train_set.global_mean, + self.num_users, + self.num_items, + train_set.review_text.vocab, + self.global_mean, n_factors=self.n_factors, embedding_size=self.embedding_size, id_embedding_size=self.id_embedding_size, @@ -175,69 +177,128 @@ def fit(self, train_set, val_set=None): dropout_rate=self.dropout_rate, max_text_length=self.max_text_length, max_num_review=self.max_num_review, - pretrained_word_embeddings=self.init_params.get('pretrained_word_embeddings'), + pretrained_word_embeddings=self.init_params.get( + "pretrained_word_embeddings" + ), verbose=self.verbose, seed=self.seed, ) - self._fit() + self._fit_tf(train_set) return self - def _fit(self): + def _fit_tf(self, train_set): import tensorflow as tf from tensorflow import keras from .hrdr import get_data from ...eval_methods.base_method import rating_eval from ...metrics import MSE - if not hasattr(self, '_optimizer'): + + if not hasattr(self, "_optimizer"): from tensorflow import keras - if self.optimizer == 'rmsprop': - self._optimizer = keras.optimizers.RMSprop(learning_rate=self.learning_rate) + + if self.optimizer == "rmsprop": + self._optimizer = keras.optimizers.RMSprop( + learning_rate=self.learning_rate + ) else: - self._optimizer = keras.optimizers.Adam(learning_rate=self.learning_rate) + self._optimizer = keras.optimizers.Adam( + learning_rate=self.learning_rate + ) loss = keras.losses.MeanSquaredError() train_loss = keras.metrics.Mean(name="loss") - val_loss = float('inf') - best_val_loss = float('inf') + val_loss = float("inf") + best_val_loss = float("inf") self.best_epoch = None - loop = trange(self.max_iter, disable=not self.verbose, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') + loop = trange( + self.max_iter, + disable=not self.verbose, + bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}", + ) for i_epoch, _ in enumerate(loop): train_loss.reset_states() - for i, (batch_users, batch_items, batch_ratings) in enumerate(self.train_set.uir_iter(self.batch_size, shuffle=True)): - user_reviews, user_num_reviews, user_ratings = get_data(batch_users, self.train_set, self.max_text_length, by='user', max_num_review=self.max_num_review) - item_reviews, item_num_reviews, item_ratings = get_data(batch_items, self.train_set, self.max_text_length, by='item', max_num_review=self.max_num_review) + for i, (batch_users, batch_items, batch_ratings) in enumerate( + train_set.uir_iter(self.batch_size, shuffle=True) + ): + user_reviews, user_num_reviews, user_ratings = get_data( + batch_users, + train_set, + self.max_text_length, + by="user", + max_num_review=self.max_num_review, + ) + item_reviews, item_num_reviews, item_ratings = get_data( + batch_items, + train_set, + self.max_text_length, + by="item", + max_num_review=self.max_num_review, + ) with tf.GradientTape() as tape: predictions = self.model.graph( - [batch_users, batch_items, user_ratings, user_reviews, user_num_reviews, item_ratings, item_reviews, item_num_reviews], + [ + batch_users, + batch_items, + user_ratings, + user_reviews, + user_num_reviews, + item_ratings, + item_reviews, + item_num_reviews, + ], training=True, ) _loss = loss(batch_ratings, predictions) gradients = tape.gradient(_loss, self.model.graph.trainable_variables) - self._optimizer.apply_gradients(zip(gradients, self.model.graph.trainable_variables)) + self._optimizer.apply_gradients( + zip(gradients, self.model.graph.trainable_variables) + ) train_loss(_loss) if i % 10 == 0: - loop.set_postfix(loss=train_loss.result().numpy(), val_loss=val_loss, best_val_loss=best_val_loss, best_epoch=self.best_epoch) - current_weights = self.model.get_weights(self.train_set, self.batch_size) + loop.set_postfix( + loss=train_loss.result().numpy(), + val_loss=val_loss, + best_val_loss=best_val_loss, + best_epoch=self.best_epoch, + ) + current_weights = self.model.get_weights(train_set, self.batch_size) if self.val_set is not None: - self.P, self.Q, self.W1, self.bu, self.bi, self.mu, self.A = current_weights + ( + self.P, + self.Q, + self.W1, + self.bu, + self.bi, + self.mu, + self.A, + ) = current_weights [current_val_mse], _ = rating_eval( model=self, metrics=[MSE()], test_set=self.val_set, - user_based=self.user_based + user_based=self.user_based, ) val_loss = current_val_mse if best_val_loss > val_loss: best_val_loss = val_loss self.best_epoch = i_epoch + 1 best_weights = current_weights - loop.set_postfix(loss=train_loss.result().numpy(), val_loss=val_loss, best_val_loss=best_val_loss, best_epoch=self.best_epoch) + loop.set_postfix( + loss=train_loss.result().numpy(), + val_loss=val_loss, + best_val_loss=best_val_loss, + best_epoch=self.best_epoch, + ) self.losses["train_losses"].append(train_loss.result().numpy()) self.losses["val_losses"].append(val_loss) loop.close() # save weights for predictions - self.P, self.Q, self.W1, self.bu, self.bi, self.mu, self.A = best_weights if self.val_set is not None and self.model_selection == 'best' else current_weights + self.P, self.Q, self.W1, self.bu, self.bi, self.mu, self.A = ( + best_weights + if self.val_set is not None and self.model_selection == "best" + else current_weights + ) if self.verbose: print("Learning completed!") @@ -261,7 +322,7 @@ def save(self, save_dir=None): self._optimizer = _optimizer self.model.graph = graph self.model.graph.save(model_file.replace(".pkl", ".cpt")) - with open(model_file.replace(".pkl", ".opt"), 'wb') as f: + with open(model_file.replace(".pkl", ".opt"), "wb") as f: pickle.dump(self._optimizer.get_weights(), f) return model_file @@ -276,9 +337,9 @@ def load(model_path, trainable=False): provided, the latest model will be loaded. trainable: boolean, optional, default: False - Set it to True if you would like to finetune the model. By default, + Set it to True if you would like to finetune the model. By default, the model parameters are assumed to be fixed after being loaded. - + Returns ------- self : object @@ -286,17 +347,24 @@ def load(model_path, trainable=False): import tensorflow as tf from tensorflow import keras import absl.logging + absl.logging.set_verbosity(absl.logging.ERROR) model = Recommender.load(model_path, trainable) - model.model.graph = keras.models.load_model(model.load_from.replace(".pkl", ".cpt"), compile=False) - if model.optimizer == 'rmsprop': - model._optimizer = keras.optimizers.RMSprop(learning_rate=model.learning_rate) + model.model.graph = keras.models.load_model( + model.load_from.replace(".pkl", ".cpt"), compile=False + ) + if model.optimizer == "rmsprop": + model._optimizer = keras.optimizers.RMSprop( + learning_rate=model.learning_rate + ) else: model._optimizer = keras.optimizers.Adam(learning_rate=model.learning_rate) zero_grads = [tf.zeros_like(w) for w in model.model.graph.trainable_variables] - model._optimizer.apply_gradients(zip(zero_grads, model.model.graph.trainable_variables)) - with open(model.load_from.replace(".pkl", ".opt"), 'rb') as f: + model._optimizer.apply_gradients( + zip(zero_grads, model.model.graph.trainable_variables) + ) + with open(model.load_from.replace(".pkl", ".opt"), "rb") as f: optimizer_weights = pickle.load(f) model._optimizer.set_weights(optimizer_weights) @@ -320,7 +388,7 @@ def score(self, user_idx, item_idx=None): Relative scores that the user gives to the item or to all known items """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) @@ -328,13 +396,13 @@ def score(self, user_idx, item_idx=None): known_item_scores = h0.dot(self.W1) + self.bu[user_idx] + self.bi + self.mu return known_item_scores.ravel() else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) or self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) ) h0 = self.P[user_idx] * self.Q[item_idx] - known_item_score = h0.dot(self.W1) + self.bu[user_idx] + self.bi[item_idx] + self.mu + known_item_score = ( + h0.dot(self.W1) + self.bu[user_idx] + self.bi[item_idx] + self.mu + ) return known_item_score diff --git a/cornac/models/ibpr/recom_ibpr.py b/cornac/models/ibpr/recom_ibpr.py index 886789a63..bf4d74c94 100644 --- a/cornac/models/ibpr/recom_ibpr.py +++ b/cornac/models/ibpr/recom_ibpr.py @@ -110,7 +110,7 @@ def fit(self, train_set, val_set=None): from .ibpr import ibpr res = ibpr( - self.train_set, + train_set, k=self.k, n_epochs=self.max_iter, lamda=self.lamda, @@ -143,7 +143,7 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) @@ -151,9 +151,7 @@ def score(self, user_idx, item_idx=None): known_item_scores = self.V.dot(self.U[user_idx, :]) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) diff --git a/cornac/models/knn/recom_knn.py b/cornac/models/knn/recom_knn.py index 372444a70..e15b90dc6 100644 --- a/cornac/models/knn/recom_knn.py +++ b/cornac/models/knn/recom_knn.py @@ -51,7 +51,7 @@ def _amplify(ui_mat, alpha=1.0): return ui_mat for i, w in enumerate(ui_mat.data): - ui_mat.data[i] = w ** alpha if w > 0 else -(-w) ** alpha + ui_mat.data[i] = w**alpha if w > 0 else -((-w) ** alpha) return ui_mat @@ -98,21 +98,21 @@ class UserKNN(Recommender): k: int, optional, default: 20 The number of nearest neighbors. - + similarity: str, optional, default: 'cosine' The similarity measurement. Supported types: ['cosine', 'pearson'] - + mean_centered: bool, optional, default: False Whether values of the user-item rating matrix will be centered by the mean - of their corresponding rows (mean rating of each user). - + of their corresponding rows (mean rating of each user). + weighting: str, optional, default: None The option for re-weighting the rating matrix. Supported types: ['idf', 'bm25']. If None, no weighting is applied. - + amplify: float, optional, default: 1.0 Amplifying the influence on similarity weights. - + num_threads: int, optional, default: 0 Number of parallel threads for training. If num_threads=0, all CPU cores will be utilized. If seed is not None, num_threads=1 to remove randomness from parallelization. @@ -182,21 +182,21 @@ def fit(self, train_set, val_set=None): """ Recommender.fit(self, train_set, val_set) - self.ui_mat = self.train_set.matrix.copy() + self.ui_mat = train_set.matrix.copy() self.mean_arr = np.zeros(self.ui_mat.shape[0]) - if self.train_set.min_rating != self.train_set.max_rating: # explicit feedback + if self.min_rating != self.max_rating: # explicit feedback self.ui_mat, self.mean_arr = _mean_centered(self.ui_mat) if self.mean_centered or self.similarity == "pearson": weight_mat = self.ui_mat.copy() else: - weight_mat = self.train_set.matrix.copy() + weight_mat = train_set.matrix.copy() # re-weighting if self.weighting == "idf": - weight_mat.data *= np.sqrt(_idf_weight(self.train_set.matrix)) + weight_mat.data *= np.sqrt(_idf_weight(train_set.matrix)) elif self.weighting == "bm25": - weight_mat.data *= np.sqrt(_bm25_weight(self.train_set.matrix)) + weight_mat.data *= np.sqrt(_bm25_weight(train_set.matrix)) # only need item-user matrix for prediction self.iu_mat = self.ui_mat.T.tocsr() @@ -226,12 +226,12 @@ def score(self, user_idx, item_idx=None): res : A scalar or a Numpy array Relative scores that the user gives to the item or to all known items """ - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - if item_idx is not None and self.train_set.is_unk_item(item_idx): + if item_idx is not None and not self.knows_item(item_idx): raise ScoreException( "Can't make score prediction for (item_id=%d)" % item_idx ) @@ -248,7 +248,7 @@ def score(self, user_idx, item_idx=None): ) return self.mean_arr[user_idx] + weighted_avg - weighted_avg = np.zeros(self.train_set.num_items) + weighted_avg = np.zeros(self.num_items) compute_score( True, self.sim_mat[user_idx].A.ravel(), @@ -277,18 +277,18 @@ class ItemKNN(Recommender): similarity: str, optional, default: 'cosine' The similarity measurement. Supported types: ['cosine', 'pearson'] - + mean_centered: bool, optional, default: False Whether values of the user-item rating matrix will be centered by the mean - of their corresponding rows (mean rating of each user). - + of their corresponding rows (mean rating of each user). + weighting: str, optional, default: None The option for re-weighting the rating matrix. Supported types: ['idf', 'bm25']. If None, no weighting is applied. - + amplify: float, optional, default: 1.0 Amplifying the influence on similarity weights. - + num_threads: int, optional, default: 0 Number of parallel threads for training. If num_threads=0, all CPU cores will be utilized. If seed is not None, num_threads=1 to remove randomness from parallelization. @@ -358,15 +358,15 @@ def fit(self, train_set, val_set=None): """ Recommender.fit(self, train_set, val_set) - self.ui_mat = self.train_set.matrix.copy() + self.ui_mat = train_set.matrix.copy() self.mean_arr = np.zeros(self.ui_mat.shape[0]) - if self.train_set.min_rating != self.train_set.max_rating: # explicit feedback + if self.min_rating != self.max_rating: # explicit feedback self.ui_mat, self.mean_arr = _mean_centered(self.ui_mat) if self.mean_centered: weight_mat = self.ui_mat.copy() else: - weight_mat = self.train_set.matrix.copy() + weight_mat = train_set.matrix.copy() if self.similarity == "pearson": # centered by columns weight_mat, _ = _mean_centered(weight_mat.T.tocsr()) @@ -374,9 +374,9 @@ def fit(self, train_set, val_set=None): # re-weighting if self.weighting == "idf": - weight_mat.data *= np.sqrt(_idf_weight(self.train_set.matrix)) + weight_mat.data *= np.sqrt(_idf_weight(train_set.matrix)) elif self.weighting == "bm25": - weight_mat.data *= np.sqrt(_bm25_weight(self.train_set.matrix)) + weight_mat.data *= np.sqrt(_bm25_weight(train_set.matrix)) weight_mat = weight_mat.T.tocsr() self.sim_mat = compute_similarity( @@ -403,12 +403,12 @@ def score(self, user_idx, item_idx=None): res : A scalar or a Numpy array Relative scores that the user gives to the item or to all known items """ - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - if item_idx is not None and self.train_set.is_unk_item(item_idx): + if item_idx is not None and not self.knows_item(item_idx): raise ScoreException( "Can't make score prediction for (item_id=%d)" % item_idx ) @@ -425,7 +425,7 @@ def score(self, user_idx, item_idx=None): ) return self.mean_arr[user_idx] + weighted_avg - weighted_avg = np.zeros(self.train_set.num_items) + weighted_avg = np.zeros(self.num_items) compute_score( False, self.ui_mat[user_idx].A.ravel(), diff --git a/cornac/models/lightgcn/lightgcn.py b/cornac/models/lightgcn/lightgcn.py index eedd72b42..6fcdb7f5c 100644 --- a/cornac/models/lightgcn/lightgcn.py +++ b/cornac/models/lightgcn/lightgcn.py @@ -9,7 +9,7 @@ ITEM_KEY = "item" -def construct_graph(data_set): +def construct_graph(data_set, total_users, total_items): """ Generates graph given a cornac data set @@ -24,7 +24,7 @@ def construct_graph(data_set): (USER_KEY, "user_item", ITEM_KEY): (user_indices, item_indices), (ITEM_KEY, "item_user", USER_KEY): (item_indices, user_indices), } - num_dict = {USER_KEY: data_set.total_users, ITEM_KEY: data_set.total_items} + num_dict = {USER_KEY: total_users, ITEM_KEY: total_items} return dgl.heterograph(data_dict, num_nodes_dict=num_dict) diff --git a/cornac/models/lightgcn/recom_lightgcn.py b/cornac/models/lightgcn/recom_lightgcn.py index 635fb67a3..01fc26120 100644 --- a/cornac/models/lightgcn/recom_lightgcn.py +++ b/cornac/models/lightgcn/recom_lightgcn.py @@ -129,7 +129,7 @@ def fit(self, train_set, val_set=None): if torch.cuda.is_available(): torch.cuda.manual_seed_all(self.seed) - graph = construct_graph(train_set).to(self.device) + graph = construct_graph(train_set, self.total_users, self.total_items).to(self.device) model = Model( graph, self.emb_size, @@ -186,21 +186,29 @@ def fit(self, train_set, val_set=None): self.V = i_embs.cpu().detach().numpy() if self.early_stopping is not None and self.early_stop( - **self.early_stopping + train_set, val_set, **self.early_stopping ): break - def monitor_value(self): + def monitor_value(self, train_set, val_set): """Calculating monitored value used for early stopping on validation set (`val_set`). This function will be called by `early_stop()` function. + Parameters + ---------- + train_set: :obj:`cornac.data.Dataset`, required + User-Item preference data as well as additional modalities. + + val_set: :obj:`cornac.data.Dataset`, optional, default: None + User-Item preference data for model selection purposes (e.g., early stopping). + Returns ------- res : float Monitored value on validation set. Return `None` if `val_set` is `None`. """ - if self.val_set is None: + if val_set is None: return None from ...metrics import Recall @@ -209,8 +217,8 @@ def monitor_value(self): recall_20 = ranking_eval( model=self, metrics=[Recall(k=20)], - train_set=self.train_set, - test_set=self.val_set + train_set=train_set, + test_set=val_set, )[0][0] return recall_20 # Section 4.1.2 in the paper, same strategy as NGCF. @@ -234,16 +242,14 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) known_item_scores = self.V.dot(self.U[user_idx, :]) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) diff --git a/cornac/models/lrppm/recom_lrppm.pyx b/cornac/models/lrppm/recom_lrppm.pyx index 1cb580af0..493f11a86 100644 --- a/cornac/models/lrppm/recom_lrppm.pyx +++ b/cornac/models/lrppm/recom_lrppm.pyx @@ -183,9 +183,10 @@ class LRPPM(Recommender): self.UA = self.init_params.get("UA", None) self.IA = self.init_params.get("IA", None) - def _init(self): - n_users, n_items = self.train_set.num_users, self.train_set.num_items - n_aspects, n_opinions = self.train_set.sentiment.num_aspects, self.train_set.sentiment.num_opinions + def _init(self, train_set): + n_users, n_items = train_set.num_users, train_set.num_items + n_aspects, n_opinions = train_set.sentiment.num_aspects, train_set.sentiment.num_opinions + self.num_aspects = n_aspects if self.U is None: U_shape = (n_users, self.n_factors) @@ -210,7 +211,7 @@ class LRPPM(Recommender): if self.verbose: print("Building data started!") - sentiment = self.train_set.sentiment + sentiment = data_set.sentiment (u_indices, i_indices, r_values) = data_set.uir_tuple keys = np.array([get_key(u, i) for u, i in zip(u_indices, i_indices)], dtype=np.intp) cdef IntFloatDict rating_dict = IntFloatDict(keys, np.array(r_values, dtype=np.float64)) @@ -218,7 +219,7 @@ class LRPPM(Recommender): item_aspect_quality = {} user_item_aspect = {} for uid, sentiment_tup_ids_by_item in sentiment.user_sentiment.items(): - if self.train_set.is_unk_user(uid): + if not self.knows_user(uid): continue for iid, tup_idx in sentiment_tup_ids_by_item.items(): for aid, oid, polarity in sentiment.sentiment[tup_idx]: @@ -272,7 +273,8 @@ class LRPPM(Recommender): """ Recommender.fit(self, train_set, val_set) - self._init() + self._init(train_set) + ( rating_dict, user_item_aspect, @@ -291,14 +293,14 @@ class LRPPM(Recommender): X_iids.append(iid) X_aids.append(aid) ui_aspect_cnt = user_item_num_aspects[(uid, iid)] - ui_neg_aspect_cnt = self.train_set.sentiment.num_aspects - ui_aspect_cnt + ui_neg_aspect_cnt = train_set.sentiment.num_aspects - ui_aspect_cnt X_l_ui.append(1.0 / (ui_aspect_cnt * ui_neg_aspect_cnt)) X_uids = np.array(X_uids, dtype=np.int32) X_iids = np.array(X_iids, dtype=np.int32) X_aids = np.array(X_aids, dtype=np.int32) X_l_ui = np.array(X_l_ui, dtype=np.float32) - (u_indices, i_indices, r_values) = self.train_set.uir_tuple + (u_indices, i_indices, r_values) = train_set.uir_tuple cdef: int n_threads = self.n_threads @@ -376,9 +378,9 @@ class LRPPM(Recommender): """ cdef: long s, i_index, j_index, correct = 0, skipped = 0 - long n_users = self.train_set.num_users - long n_items = self.train_set.num_items - long n_aspects = self.train_set.sentiment.num_aspects + long n_users = self.num_users + long n_items = self.num_items + long n_aspects = self.num_aspects long n_factors = self.n_factors int num_samples = self.n_samples int num_ranking_samples = self.n_ranking_samples @@ -498,7 +500,7 @@ class LRPPM(Recommender): """ if i_idx is None: - if self.train_set.is_unk_user(u_idx): + if not self.knows_user(u_idx): raise ScoreException( "Can't make score prediction for (user_id=%d" & u_idx ) @@ -506,8 +508,7 @@ class LRPPM(Recommender): item_scores = self.I.dot(self.U[u_idx]) return item_scores else: - if (self.train_set.is_unk_user(u_idx) - or self.train_set.is_unk_item(i_idx)): + if not (self.knows_user(u_idx) and self.knows_item(i_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (u_idx, i_idx) @@ -517,8 +518,8 @@ class LRPPM(Recommender): def rank(self, user_idx, item_indices=None): if self.alpha > 0 and self.num_top_aspects > 0: - n_items = self.train_set.num_items - num_top_aspects = min(self.num_top_aspects, self.train_set.sentiment.num_aspects) + n_items = self.num_items + num_top_aspects = min(self.num_top_aspects, self.num_aspects) item_aspect_scores = self.UA.dot(self.U[user_idx]) + self.I.dot(self.IA.T) + np.expand_dims(self.I.dot(self.U[user_idx]), axis=1) top_aspect_ids = (-item_aspect_scores).argsort(axis=1)[:, :num_top_aspects] iids = np.repeat(range(n_items), num_top_aspects).reshape(n_items, num_top_aspects) @@ -530,17 +531,17 @@ class LRPPM(Recommender): # check if the returned scores also cover unknown items # if not, all unknown items will be given the MIN score - if len(known_item_scores) == self.train_set.total_items: + if len(known_item_scores) == self.total_items: all_item_scores = known_item_scores else: - all_item_scores = np.ones(self.train_set.total_items) * np.min( + all_item_scores = np.ones(self.total_items) * np.min( known_item_scores ) - all_item_scores[: self.train_set.num_items] = known_item_scores + all_item_scores[: self.num_items] = known_item_scores # rank items based on their scores if item_indices is None: - item_scores = all_item_scores[: self.train_set.num_items] + item_scores = all_item_scores[: self.num_items] item_rank = item_scores.argsort()[::-1] else: item_scores = all_item_scores[item_indices] diff --git a/cornac/models/mcf/recom_mcf.py b/cornac/models/mcf/recom_mcf.py index af61d2790..84926b104 100644 --- a/cornac/models/mcf/recom_mcf.py +++ b/cornac/models/mcf/recom_mcf.py @@ -128,21 +128,15 @@ def fit(self, train_set, val_set=None): (rat_uid, rat_iid, rat_val) = train_set.uir_tuple # item-item affinity network - map_iid = train_set.item_indices + train_item_indices = set(train_set.uir_tuple[1]) (net_iid, net_jid, net_val) = train_set.item_graph.get_train_triplet( - map_iid, map_iid + train_item_indices, train_item_indices ) - if [self.train_set.min_rating, self.train_set.max_rating] != [0, 1]: - if self.train_set.min_rating == self.train_set.max_rating: - rat_val = scale(rat_val, 0.0, 1.0, 0.0, self.train_set.max_rating) + if [self.min_rating, self.max_rating] != [0, 1]: + if self.min_rating == self.max_rating: + rat_val = scale(rat_val, 0.0, 1.0, 0.0, self.max_rating) else: - rat_val = scale( - rat_val, - 0.0, - 1.0, - self.train_set.min_rating, - self.train_set.max_rating, - ) + rat_val = scale(rat_val, 0.0, 1.0, self.min_rating, self.max_rating) if [min(net_val), max(net_val)] != [0, 1]: if min(net_val) == max(net_val): @@ -211,37 +205,24 @@ def score(self, user_idx, item_idx=None): ------- res : A scalar or a Numpy array Relative scores that the user gives to the item or to all known items - """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - known_item_scores = self.V.dot(self.U[user_idx, :]) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) ) - user_pred = self.V[item_idx, :].dot(self.U[user_idx, :]) - user_pred = sigmoid(user_pred) - if self.train_set.min_rating == self.train_set.max_rating: - user_pred = scale(user_pred, 0.0, self.train_set.max_rating, 0.0, 1.0) + if self.min_rating == self.max_rating: + user_pred = scale(user_pred, 0.0, self.max_rating, 0.0, 1.0) else: - user_pred = scale( - user_pred, - self.train_set.min_rating, - self.train_set.max_rating, - 0.0, - 1.0, - ) - + user_pred = scale(user_pred, self.min_rating, self.max_rating, 0.0, 1.0) return user_pred diff --git a/cornac/models/mf/recom_mf.pyx b/cornac/models/mf/recom_mf.pyx index 719ceeb3b..6fc9c316e 100644 --- a/cornac/models/mf/recom_mf.pyx +++ b/cornac/models/mf/recom_mf.pyx @@ -119,20 +119,18 @@ class MF(Recommender): self.i_factors = self.init_params.get('V', None) self.u_biases = self.init_params.get('Bu', None) self.i_biases = self.init_params.get('Bi', None) - self.global_mean = 0.0 def _init(self): rng = get_rng(self.seed) - n_users, n_items = self.train_set.num_users, self.train_set.num_items if self.u_factors is None: - self.u_factors = normal([n_users, self.k], std=0.01, random_state=rng) + self.u_factors = normal([self.num_users, self.k], std=0.01, random_state=rng) if self.i_factors is None: - self.i_factors = normal([n_items, self.k], std=0.01, random_state=rng) + self.i_factors = normal([self.num_items, self.k], std=0.01, random_state=rng) - self.u_biases = zeros(n_users) if self.u_biases is None else self.u_biases - self.i_biases = zeros(n_items) if self.i_biases is None else self.i_biases - self.global_mean = self.train_set.global_mean if self.use_bias else 0.0 + self.u_biases = zeros(self.num_users) if self.u_biases is None else self.u_biases + self.i_biases = zeros(self.num_items) if self.i_biases is None else self.i_biases + self.global_mean = self.global_mean if self.use_bias else 0.0 def fit(self, train_set, val_set=None): """Fit the model to observations. @@ -165,11 +163,10 @@ class MF(Recommender): @cython.wraparound(False) def _fit_sgd(self, integral[:] rid, integral[:] cid, floating[:] val, floating[:, :] U, floating[:, :] V, floating[:] Bu, floating[:] Bi): - """Fit the model parameters (U, V, Bu, Bi) with SGD - """ + """Fit the model parameters (U, V, Bu, Bi) with SGD""" cdef: - long num_users = self.train_set.num_users - long num_items = self.train_set.num_items + long num_users = self.num_users + long num_items = self.num_items long num_ratings = val.shape[0] int num_factors = self.k int max_iter = self.max_iter @@ -252,26 +249,23 @@ class MF(Recommender): Relative scores that the user gives to the item or to all known items """ - unk_user = self.train_set.is_unk_user(user_idx) - if item_idx is None: known_item_scores = np.add(self.i_biases, self.global_mean) - if not unk_user: + if self.knows_user(user_idx): known_item_scores = np.add(known_item_scores, self.u_biases[user_idx]) fast_dot(self.u_factors[user_idx], self.i_factors, known_item_scores) return known_item_scores else: - unk_item = self.train_set.is_unk_item(item_idx) if self.use_bias: item_score = self.global_mean - if not unk_user: + if self.knows_user(user_idx): item_score += self.u_biases[user_idx] - if not unk_item: + if self.knows_item(item_idx): item_score += self.i_biases[item_idx] - if not unk_user and not unk_item: + if self.knows_user(user_idx) and self.knows_item(item_idx): item_score += np.dot(self.u_factors[user_idx], self.i_factors[item_idx]) else: - if unk_user or unk_item: + if not self.knows_user(user_idx) or self.knows_item(item_idx): raise ScoreException("Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx)) item_score = np.dot(self.u_factors[user_idx], self.i_factors[item_idx]) return item_score diff --git a/cornac/models/mmmf/recom_mmmf.pyx b/cornac/models/mmmf/recom_mmmf.pyx index 4cb54a2c1..cbdc5037e 100644 --- a/cornac/models/mmmf/recom_mmmf.pyx +++ b/cornac/models/mmmf/recom_mmmf.pyx @@ -108,7 +108,7 @@ class MMMF(BPR): """ cdef: long num_samples = len(user_ids), s, i_index, j_index, correct = 0, skipped = 0 - long num_items = self.train_set.num_items + long num_items = self.num_items integral f, i_id, j_id, thread_id floating z, score, temp diff --git a/cornac/models/most_pop/recom_most_pop.py b/cornac/models/most_pop/recom_most_pop.py index 7a450caf6..8215b29f5 100644 --- a/cornac/models/most_pop/recom_most_pop.py +++ b/cornac/models/most_pop/recom_most_pop.py @@ -72,7 +72,7 @@ def score(self, user_idx, item_idx=None): if item_idx is None: return self.item_pop else: - if self.train_set.is_unk_item(item_idx): + if not self.knows_item(item_idx): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) diff --git a/cornac/models/mter/recom_mter.pyx b/cornac/models/mter/recom_mter.pyx index 0601f369c..e99e28e6c 100644 --- a/cornac/models/mter/recom_mter.pyx +++ b/cornac/models/mter/recom_mter.pyx @@ -195,9 +195,10 @@ class MTER(Recommender): self.A = self.init_params.get("A", None) self.O = self.init_params.get("O", None) - def _init(self): - n_users, n_items = self.train_set.num_users, self.train_set.num_items - n_aspects, n_opinions = self.train_set.sentiment.num_aspects, self.train_set.sentiment.num_opinions + def _init(self, train_set): + n_users, n_items = train_set.num_users, train_set.num_items + n_aspects, n_opinions = train_set.sentiment.num_aspects, train_set.sentiment.num_opinions + self.num_aspects, self.num_opinions = n_aspects, n_opinions if self.G1 is None: G1_shape = (self.n_user_factors, self.n_item_factors, self.n_aspect_factors) @@ -228,19 +229,19 @@ class MTER(Recommender): if self.verbose: print("Building data started!") - sentiment = self.train_set.sentiment + sentiment = data_set.sentiment (u_indices, i_indices, r_values) = data_set.uir_tuple keys = np.array([get_key(u, i) for u, i in zip(u_indices, i_indices)], dtype=np.intp) cdef IntFloatDict rating_dict = IntFloatDict(keys, np.array(r_values, dtype=np.float64)) rating_matrix = sp.csr_matrix( (r_values, (u_indices, i_indices)), - shape=(self.train_set.num_users, self.train_set.num_items), + shape=(self.num_users, self.num_items), ) user_item_aspect = {} user_aspect_opinion = {} item_aspect_opinion = {} for u_idx, sentiment_tup_ids_by_item in sentiment.user_sentiment.items(): - if self.train_set.is_unk_user(u_idx): + if not self.knows_user(u_idx): continue for i_idx, tup_idx in sentiment_tup_ids_by_item.items(): user_item_aspect[ @@ -317,7 +318,7 @@ class MTER(Recommender): """ Recommender.fit(self, train_set, val_set) - self._init() + self._init(train_set) if not self.trainable: return self @@ -366,7 +367,7 @@ class MTER(Recommender): YI_oids = np.array(YI_oids, dtype=np.int32) user_counts = np.ediff1d(rating_matrix.indptr).astype(np.int32) - user_ids = np.repeat(np.arange(self.train_set.num_users), user_counts).astype(np.int32) + user_ids = np.repeat(np.arange(self.num_users), user_counts).astype(np.int32) neg_item_ids = np.arange(train_set.num_items, dtype=np.int32) cdef: @@ -475,10 +476,10 @@ class MTER(Recommender): """ cdef: long s, i_index, j_index, correct = 0, skipped = 0 - long n_users = self.train_set.num_users - long n_items = self.train_set.num_items - long n_aspects = self.train_set.sentiment.num_aspects - long n_opinions = self.train_set.sentiment.num_opinions + long n_users = self.num_users + long n_items = self.num_items + long n_aspects = self.num_aspects + long n_opinions = self.num_opinions long n_user_factors = self.n_user_factors long n_item_factors = self.n_item_factors long n_aspect_factors = self.n_aspect_factors @@ -692,7 +693,7 @@ class MTER(Recommender): """ if i_idx is None: - if self.train_set.is_unk_user(u_idx): + if not self.knows_user(u_idx): raise ScoreException( "Can't make score prediction for (user_id=%d" & u_idx ) @@ -705,8 +706,7 @@ class MTER(Recommender): item_scores = np.einsum("MNc,c->MN", tensor_value2, self.A[-1]).flatten() return item_scores else: - if (self.train_set.is_unk_user(u_idx) - or self.train_set.is_unk_item(i_idx)): + if not (self.knows_user(u_idx) and self.knows_item(i_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (u_idx, i_idx) diff --git a/cornac/models/narre/recom_narre.py b/cornac/models/narre/recom_narre.py index 77db92d49..af7943c06 100644 --- a/cornac/models/narre/recom_narre.py +++ b/cornac/models/narre/recom_narre.py @@ -21,7 +21,8 @@ from ...exception import ScoreException -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + class NARRE(Recommender): """Neural Attentional Rating Regression with Review-level Explanations @@ -108,9 +109,9 @@ def __init__( max_num_review=32, batch_size=64, max_iter=10, - optimizer='adam', + optimizer="adam", learning_rate=0.001, - model_selection='last', # last or best + model_selection="last", # last or best user_based=True, trainable=True, verbose=True, @@ -132,8 +133,12 @@ def __init__( self.max_iter = max_iter self.optimizer = optimizer self.learning_rate = learning_rate - if model_selection not in ['best', 'last']: - raise ValueError("model_selection is either 'best' or 'last' but {}".format(model_selection)) + if model_selection not in ["best", "last"]: + raise ValueError( + "model_selection is either 'best' or 'last' but {}".format( + model_selection + ) + ) self.model_selection = model_selection self.user_based = user_based # Init params if provided @@ -160,11 +165,12 @@ def fit(self, train_set, val_set=None): if self.trainable: if not hasattr(self, "model"): from .narre import NARREModel + self.model = NARREModel( - self.train_set.num_users, - self.train_set.num_items, - self.train_set.review_text.vocab, - self.train_set.global_mean, + train_set.num_users, + train_set.num_items, + train_set.review_text.vocab, + train_set.global_mean, n_factors=self.n_factors, embedding_size=self.embedding_size, id_embedding_size=self.id_embedding_size, @@ -174,75 +180,146 @@ def fit(self, train_set, val_set=None): dropout_rate=self.dropout_rate, max_text_length=self.max_text_length, max_num_review=self.max_num_review, - pretrained_word_embeddings=self.init_params.get('pretrained_word_embeddings'), + pretrained_word_embeddings=self.init_params.get( + "pretrained_word_embeddings" + ), verbose=self.verbose, seed=self.seed, ) - self._fit() + self._fit_tf(train_set, val_set) return self - def _fit(self): + def _fit_tf(self, train_set, val_set): import tensorflow as tf from tensorflow import keras from .narre import get_data from ...eval_methods.base_method import rating_eval from ...metrics import MSE + loss = keras.losses.MeanSquaredError() - if not hasattr(self, '_optimizer'): - if self.optimizer == 'rmsprop': - self._optimizer = keras.optimizers.RMSprop(learning_rate=self.learning_rate) - elif self.optimizer == 'adam': - self._optimizer = keras.optimizers.Adam(learning_rate=self.learning_rate) + if not hasattr(self, "_optimizer"): + if self.optimizer == "rmsprop": + self._optimizer = keras.optimizers.RMSprop( + learning_rate=self.learning_rate + ) + elif self.optimizer == "adam": + self._optimizer = keras.optimizers.Adam( + learning_rate=self.learning_rate + ) else: - raise ValueError("optimizer is either 'rmsprop' or 'adam' but {}".format(self.optimizer)) + raise ValueError( + "optimizer is either 'rmsprop' or 'adam' but {}".format( + self.optimizer + ) + ) train_loss = keras.metrics.Mean(name="loss") - val_loss = float('inf') - best_val_loss = float('inf') + val_loss = float("inf") + best_val_loss = float("inf") self.best_epoch = None - loop = trange(self.max_iter, disable=not self.verbose, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') + loop = trange( + self.max_iter, + disable=not self.verbose, + bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}", + ) for i_epoch, _ in enumerate(loop): train_loss.reset_states() - for i, (batch_users, batch_items, batch_ratings) in enumerate(self.train_set.uir_iter(self.batch_size, shuffle=True)): - user_reviews, user_iid_reviews, user_num_reviews = get_data(batch_users, self.train_set, self.max_text_length, by='user', max_num_review=self.max_num_review) - item_reviews, item_uid_reviews, item_num_reviews = get_data(batch_items, self.train_set, self.max_text_length, by='item', max_num_review=self.max_num_review) + for i, (batch_users, batch_items, batch_ratings) in enumerate( + train_set.uir_iter(self.batch_size, shuffle=True) + ): + user_reviews, user_iid_reviews, user_num_reviews = get_data( + batch_users, + train_set, + self.max_text_length, + by="user", + max_num_review=self.max_num_review, + ) + item_reviews, item_uid_reviews, item_num_reviews = get_data( + batch_items, + train_set, + self.max_text_length, + by="item", + max_num_review=self.max_num_review, + ) with tf.GradientTape() as tape: predictions = self.model.graph( - [batch_users, batch_items, user_reviews, user_iid_reviews, user_num_reviews, item_reviews, item_uid_reviews, item_num_reviews], + [ + batch_users, + batch_items, + user_reviews, + user_iid_reviews, + user_num_reviews, + item_reviews, + item_uid_reviews, + item_num_reviews, + ], training=True, ) _loss = loss(batch_ratings, predictions) gradients = tape.gradient(_loss, self.model.graph.trainable_variables) - self._optimizer.apply_gradients(zip(gradients, self.model.graph.trainable_variables)) + self._optimizer.apply_gradients( + zip(gradients, self.model.graph.trainable_variables) + ) train_loss(_loss) if i % 10 == 0: - loop.set_postfix(loss=train_loss.result().numpy(), val_loss=val_loss, best_val_loss=best_val_loss, best_epoch=self.best_epoch) - current_weights = self.model.get_weights(self.train_set, self.batch_size) - if self.val_set is not None: - self.X, self.Y, self.W1, self.user_embedding, self.item_embedding, self.bu, self.bi, self.mu = current_weights + loop.set_postfix( + loss=train_loss.result().numpy(), + val_loss=val_loss, + best_val_loss=best_val_loss, + best_epoch=self.best_epoch, + ) + current_weights = self.model.get_weights(train_set, self.batch_size) + if val_set is not None: + ( + self.X, + self.Y, + self.W1, + self.user_embedding, + self.item_embedding, + self.bu, + self.bi, + self.mu, + ) = current_weights [current_val_mse], _ = rating_eval( model=self, metrics=[MSE()], - test_set=self.val_set, - user_based=self.user_based + test_set=val_set, + user_based=self.user_based, ) val_loss = current_val_mse if best_val_loss > val_loss: best_val_loss = val_loss self.best_epoch = i_epoch + 1 best_weights = current_weights - loop.set_postfix(loss=train_loss.result().numpy(), val_loss=val_loss, best_val_loss=best_val_loss, best_epoch=self.best_epoch) + loop.set_postfix( + loss=train_loss.result().numpy(), + val_loss=val_loss, + best_val_loss=best_val_loss, + best_epoch=self.best_epoch, + ) self.losses["train_losses"].append(train_loss.result().numpy()) self.losses["val_losses"].append(val_loss) loop.close() # save weights for predictions - self.X, self.Y, self.W1, self.user_embedding, self.item_embedding, self.bu, self.bi, self.mu = best_weights if self.val_set is not None and self.model_selection == 'best' else current_weights + ( + self.X, + self.Y, + self.W1, + self.user_embedding, + self.item_embedding, + self.bu, + self.bi, + self.mu, + ) = ( + best_weights + if val_set is not None and self.model_selection == "best" + else current_weights + ) if self.verbose: print("Learning completed!") - def save(self, save_dir=None): """Save a recommender model to the filesystem. @@ -263,7 +340,7 @@ def save(self, save_dir=None): self._optimizer = _optimizer self.model.graph = graph self.model.graph.save(model_file.replace(".pkl", ".cpt")) - with open(model_file.replace(".pkl", ".opt"), 'wb') as f: + with open(model_file.replace(".pkl", ".opt"), "wb") as f: pickle.dump(self._optimizer.get_weights(), f) return model_file @@ -278,9 +355,9 @@ def load(model_path, trainable=False): provided, the latest model will be loaded. trainable: boolean, optional, default: False - Set it to True if you would like to finetune the model. By default, + Set it to True if you would like to finetune the model. By default, the model parameters are assumed to be fixed after being loaded. - + Returns ------- self : object @@ -288,17 +365,24 @@ def load(model_path, trainable=False): import tensorflow as tf from tensorflow import keras import absl.logging + absl.logging.set_verbosity(absl.logging.ERROR) model = Recommender.load(model_path, trainable) - model.model.graph = keras.models.load_model(model.load_from.replace(".pkl", ".cpt"), compile=False) - if model.optimizer == 'rmsprop': - model._optimizer = keras.optimizers.RMSprop(learning_rate=model.learning_rate) + model.model.graph = keras.models.load_model( + model.load_from.replace(".pkl", ".cpt"), compile=False + ) + if model.optimizer == "rmsprop": + model._optimizer = keras.optimizers.RMSprop( + learning_rate=model.learning_rate + ) else: model._optimizer = keras.optimizers.Adam(learning_rate=model.learning_rate) zero_grads = [tf.zeros_like(w) for w in model.model.graph.trainable_variables] - model._optimizer.apply_gradients(zip(zero_grads, model.model.graph.trainable_variables)) - with open(model.load_from.replace(".pkl", ".opt"), 'rb') as f: + model._optimizer.apply_gradients( + zip(zero_grads, model.model.graph.trainable_variables) + ) + with open(model.load_from.replace(".pkl", ".opt"), "rb") as f: optimizer_weights = pickle.load(f) model._optimizer.set_weights(optimizer_weights) @@ -322,21 +406,25 @@ def score(self, user_idx, item_idx=None): Relative scores that the user gives to the item or to all known items """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - h0 = (self.user_embedding[user_idx] + self.X[user_idx]) * (self.item_embedding + self.Y) + h0 = (self.user_embedding[user_idx] + self.X[user_idx]) * ( + self.item_embedding + self.Y + ) known_item_scores = h0.dot(self.W1) + self.bu[user_idx] + self.bi + self.mu return known_item_scores.ravel() else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) ) - h0 = (self.user_embedding[user_idx] + self.X[user_idx]) * (self.item_embedding[item_idx] + self.Y[item_idx]) - known_item_score = h0.dot(self.W1) + self.bu[user_idx] + self.bi[item_idx] + self.mu + h0 = (self.user_embedding[user_idx] + self.X[user_idx]) * ( + self.item_embedding[item_idx] + self.Y[item_idx] + ) + known_item_score = ( + h0.dot(self.W1) + self.bu[user_idx] + self.bi[item_idx] + self.mu + ) return known_item_score diff --git a/cornac/models/ncf/recom_ncf_base.py b/cornac/models/ncf/recom_ncf_base.py index 4f47877ea..ed389b5e2 100644 --- a/cornac/models/ncf/recom_ncf_base.py +++ b/cornac/models/ncf/recom_ncf_base.py @@ -126,13 +126,13 @@ def fit(self, train_set, val_set=None): Recommender.fit(self, train_set, val_set) if self.trainable: - self.num_users = self.train_set.num_users - self.num_items = self.train_set.num_items + self.num_users = self.num_users + self.num_items = self.num_items if self.backend == "tensorflow": - self._fit_tf() + self._fit_tf(train_set, val_set) elif self.backend == "pytorch": - self._fit_pt() + self._fit_pt(train_set, val_set) else: raise ValueError(f"{self.backend} is not supported") @@ -159,7 +159,7 @@ def _get_feed_dict(self, batch_users, batch_items, batch_ratings): self.labels: batch_ratings.reshape(-1, 1), } - def _fit_tf(self): + def _fit_tf(self, train_set, val_set): if not hasattr(self, "graph"): self._build_graph_tf() @@ -168,7 +168,7 @@ def _fit_tf(self): count = 0 sum_loss = 0 for i, (batch_users, batch_items, batch_ratings) in enumerate( - self.train_set.uir_iter( + train_set.uir_iter( self.batch_size, shuffle=True, binary=True, num_zeros=self.num_neg ) ): @@ -184,7 +184,7 @@ def _fit_tf(self): loop.set_postfix(loss=(sum_loss / count)) if self.early_stopping is not None and self.early_stop( - **self.early_stopping + train_set, val_set, **self.early_stopping ): break loop.close() @@ -198,7 +198,7 @@ def _score_tf(self, user_idx, item_idx): def _build_model_pt(self): raise NotImplementedError() - def _fit_pt(self): + def _fit_pt(self, train_set, val_set): import torch import torch.nn as nn from .backend_pt import optimizer_dict @@ -225,7 +225,7 @@ def _fit_pt(self): count = 0 sum_loss = 0 for batch_id, (batch_users, batch_items, batch_ratings) in enumerate( - self.train_set.uir_iter( + train_set.uir_iter( self.batch_size, shuffle=True, binary=True, num_zeros=self.num_neg ) ): @@ -247,6 +247,12 @@ def _fit_pt(self): if batch_id % 10 == 0: loop.set_postfix(loss=(sum_loss / count)) + if self.early_stopping is not None and self.early_stop( + train_set, val_set, **self.early_stopping + ): + break + loop.close() + def _score_pt(self, user_idx, item_idx): raise NotImplementedError() @@ -303,17 +309,25 @@ def load(model_path, trainable=False): return model - def monitor_value(self): + def monitor_value(self, train_set, val_set): """Calculating monitored value used for early stopping on validation set (`val_set`). This function will be called by `early_stop()` function. + Parameters + ---------- + train_set: :obj:`cornac.data.Dataset`, required + User-Item preference data as well as additional modalities. + + val_set: :obj:`cornac.data.Dataset`, optional, default: None + User-Item preference data for model selection purposes (e.g., early stopping). + Returns ------- res : float Monitored value on validation set. Return `None` if `val_set` is `None`. """ - if self.val_set is None: + if val_set is None: return None from ...metrics import NDCG @@ -322,8 +336,8 @@ def monitor_value(self): ndcg_100 = ranking_eval( model=self, metrics=[NDCG(k=100)], - train_set=self.train_set, - test_set=self.val_set, + train_set=train_set, + test_set=val_set, )[0][0] return ndcg_100 @@ -345,12 +359,12 @@ def score(self, user_idx, item_idx=None): res : A scalar or a Numpy array Relative scores that the user gives to the item or to all known items """ - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - if item_idx is not None and self.train_set.is_unk_item(item_idx): + if item_idx is not None and not self.knows_item(item_idx): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) diff --git a/cornac/models/ngcf/ngcf.py b/cornac/models/ngcf/ngcf.py index 64e473bbc..dd184c55e 100644 --- a/cornac/models/ngcf/ngcf.py +++ b/cornac/models/ngcf/ngcf.py @@ -11,7 +11,7 @@ ITEM_KEY = "item" -def construct_graph(data_set): +def construct_graph(data_set, total_users, total_items): """ Generates graph given a cornac data set @@ -23,8 +23,8 @@ def construct_graph(data_set): user_indices, item_indices, _ = data_set.uir_tuple # construct graph from the train data and add self-loops - user_selfs = [i for i in range(data_set.total_users)] - item_selfs = [i for i in range(data_set.total_items)] + user_selfs = [i for i in range(total_users)] + item_selfs = [i for i in range(total_items)] data_dict = { (USER_KEY, "user_self", USER_KEY): (user_selfs, user_selfs), @@ -32,7 +32,7 @@ def construct_graph(data_set): (USER_KEY, "user_item", ITEM_KEY): (user_indices, item_indices), (ITEM_KEY, "item_user", USER_KEY): (item_indices, user_indices), } - num_dict = {USER_KEY: data_set.total_users, ITEM_KEY: data_set.total_items} + num_dict = {USER_KEY: total_users, ITEM_KEY: total_items} return dgl.heterograph(data_dict, num_nodes_dict=num_dict) diff --git a/cornac/models/ngcf/recom_ngcf.py b/cornac/models/ngcf/recom_ngcf.py index daa7b3109..719373c40 100644 --- a/cornac/models/ngcf/recom_ngcf.py +++ b/cornac/models/ngcf/recom_ngcf.py @@ -35,9 +35,9 @@ class NGCF(Recommender): Size of the output of convolution layers. dropout_rates: list, default: [0.1, 0.1, 0.1] - Dropout rate for each of the convolution layers. + Dropout rate for each of the convolution layers. - Number of values should be the same as 'layer_sizes' - + num_epochs: int, default: 1000 Maximum number of iterations or the number of epochs. @@ -133,7 +133,7 @@ def fit(self, train_set, val_set=None): if torch.cuda.is_available(): torch.cuda.manual_seed_all(self.seed) - graph = construct_graph(train_set).to(self.device) + graph = construct_graph(train_set, self.total_users, self.total_items).to(self.device) model = Model( graph, self.emb_size, @@ -191,21 +191,29 @@ def fit(self, train_set, val_set=None): self.V = i_embs.cpu().detach().numpy() if self.early_stopping is not None and self.early_stop( - **self.early_stopping + train_set, val_set, **self.early_stopping ): break - def monitor_value(self): + def monitor_value(self, train_set, val_set): """Calculating monitored value used for early stopping on validation set (`val_set`). This function will be called by `early_stop()` function. + Parameters + ---------- + train_set: :obj:`cornac.data.Dataset`, required + User-Item preference data as well as additional modalities. + + val_set: :obj:`cornac.data.Dataset`, optional, default: None + User-Item preference data for model selection purposes (e.g., early stopping). + Returns ------- res : float Monitored value on validation set. Return `None` if `val_set` is `None`. """ - if self.val_set is None: + if val_set is None: return None from ...metrics import Recall @@ -214,9 +222,9 @@ def monitor_value(self): recall_20 = ranking_eval( model=self, metrics=[Recall(k=20)], - train_set=self.train_set, - test_set=self.val_set, - verbose=True + train_set=train_set, + test_set=val_set, + verbose=True, )[0][0] return recall_20 # Section 4.2.3 in the paper @@ -240,16 +248,14 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) known_item_scores = self.V.dot(self.U[user_idx, :]) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) diff --git a/cornac/models/nmf/recom_nmf.pyx b/cornac/models/nmf/recom_nmf.pyx index c8a6b37ab..dde9c06cc 100644 --- a/cornac/models/nmf/recom_nmf.pyx +++ b/cornac/models/nmf/recom_nmf.pyx @@ -132,7 +132,7 @@ class NMF(Recommender): def _init(self): rng = get_rng(self.seed) - n_users, n_items = self.train_set.num_users, self.train_set.num_items + n_users, n_items = self.num_users, self.num_items if self.u_factors is None: self.u_factors = uniform((n_users, self.k), random_state=rng) @@ -141,7 +141,7 @@ class NMF(Recommender): self.u_biases = zeros(n_users) if self.u_biases is None else self.u_biases self.i_biases = zeros(n_items) if self.i_biases is None else self.i_biases - self.global_mean = self.train_set.global_mean if self.use_bias else 0.0 + self.global_mean = self.global_mean if self.use_bias else 0.0 def fit(self, train_set, val_set=None): """Fit the model to observations. @@ -163,7 +163,7 @@ class NMF(Recommender): self._init() if self.trainable: - n_users, n_items = self.train_set.num_users, self.train_set.num_items + n_users, n_items = self.num_users, self.num_items X = train_set.matrix # csr_matrix user_counts = np.ediff1d(X.indptr) user_ids = np.repeat(np.arange(n_users), user_counts).astype(X.indices.dtype) @@ -186,8 +186,8 @@ class NMF(Recommender): """Fit the model parameters (U, V, Bu, Bi) """ cdef: - long num_users = self.train_set.num_users - long num_items = self.train_set.num_items + long num_users = self.num_users + long num_items = self.num_items long num_ratings = val.shape[0] int num_factors = self.k int max_iter = self.max_iter @@ -284,26 +284,23 @@ class NMF(Recommender): Relative scores that the user gives to the item or to all known items """ - unk_user = self.train_set.is_unk_user(user_idx) - if item_idx is None: known_item_scores = np.add(self.i_biases, self.global_mean) - if not unk_user: + if self.knows_user(user_idx): known_item_scores = np.add(known_item_scores, self.u_biases[user_idx]) fast_dot(self.u_factors[user_idx], self.i_factors, known_item_scores) return known_item_scores else: - unk_item = self.train_set.is_unk_item(item_idx) if self.use_bias: item_score = self.global_mean - if not unk_user: + if self.knows_user(user_idx): item_score += self.u_biases[user_idx] - if not unk_item: + if self.knows_item(item_idx): item_score += self.i_biases[item_idx] - if not unk_user and not unk_item: + if self.knows_user(user_idx) and self.knows_item(item_idx): item_score += np.dot(self.u_factors[user_idx], self.i_factors[item_idx]) else: - if unk_user or unk_item: + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException("Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx)) item_score = np.dot(self.u_factors[user_idx], self.i_factors[item_idx]) return item_score diff --git a/cornac/models/online_ibpr/recom_online_ibpr.py b/cornac/models/online_ibpr/recom_online_ibpr.py index 585c8b3f9..ad665f342 100644 --- a/cornac/models/online_ibpr/recom_online_ibpr.py +++ b/cornac/models/online_ibpr/recom_online_ibpr.py @@ -146,21 +146,17 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - known_item_scores = self.V.dot(self.U[user_idx, :]) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) ) - user_pred = self.V[item_idx, :].dot(self.U[user_idx, :]) return user_pred diff --git a/cornac/models/pcrl/pcrl.py b/cornac/models/pcrl/pcrl.py index 48b262798..331c7e675 100644 --- a/cornac/models/pcrl/pcrl.py +++ b/cornac/models/pcrl/pcrl.py @@ -32,18 +32,12 @@ def __init__( w_determinist=True, init_params=None, ): - - self.train_set = train_set - self.cf_data = sp.csc_matrix( - self.train_set.matrix - ) # user-item interaction (CF data) - self.aux_data = self.train_set.item_graph.matrix[ - : self.train_set.num_items, : self.train_set.num_items + self.cf_data = train_set.csc_matrix # user-item interaction (CF data) + self.aux_data = train_set.item_graph.matrix[ + : train_set.num_items, : train_set.num_items ] # item auxiliary information (items'context in the original paper) self.k = k # the number of user and item latent factors - self.z_dims = ( - z_dims - ) # the dimension of the second hidden layer (we consider a 2-layers PCRL) + self.z_dims = z_dims # the dimension of the second hidden layer (we consider a 2-layers PCRL) self.c_dim = self.aux_data.shape[ 1 ] # the dimension of the auxiliary information matrix @@ -61,16 +55,12 @@ def __init__( self.Lr = sp.csc_matrix( (self.aux_data.shape[0], self.k) ) # Variational Rate parameters of the item factors (Beta in the paper) - self.Gs = ( - None - ) # Variational Shapre parameters of the user factors (Theta in the paper) + self.Gs = None # Variational Shapre parameters of the user factors (Theta in the paper) self.Gr = ( - None - ) # Variational Rate parameters of the user factors (Theta in the paper) + None # Variational Rate parameters of the user factors (Theta in the paper) + ) self.L = len(z_dims) # The number of deterministic hidden layers "z" - self.w_determinist = ( - w_determinist - ) # If true then deterministic wheights are used for the generator network + self.w_determinist = w_determinist # If true then deterministic wheights are used for the generator network self.sess = tf.Session() # Tensorflow session # Inference netwok parameters self.inference_params = [] @@ -113,7 +103,7 @@ def log_q(self, z, alpha, beta): # Log density of the standard normal N(0, 1) def log_t(self, epsilon): - return -0.5 * tf.log(2 * np.pi) - 0.5 * epsilon ** 2 + return -0.5 * tf.log(2 * np.pi) - 0.5 * epsilon**2 # Marsaglia and Tsang transformation def G(self, epsilon, alpha, beta): @@ -155,7 +145,6 @@ def shape_augmentation(self, alpha, B): # Collaborative filtering part of pcrl (Poisson Factorization) def pf_(self, X, k, max_iter=1, init_params=None): - # data preparation X = sp.csc_matrix(X, dtype=np.float64) M = X.copy() @@ -273,7 +262,6 @@ def inference_net(self, C, reuse=None): # The generative network (or decoder) def generative_net(self, Z, reuse=None): - # with tf.variable_scope("generative",reuse=reuse): if self.w_determinist: h2 = tf.nn.relu(tf.matmul(Z, self.generator_params[0])) @@ -315,7 +303,6 @@ def generative_net(self, Z, reuse=None): # The loss function def loss(self, C, X_g, X_, alpha, beta, z, E, Zik, Tk): - const_term = C * tf.log(1e-10 + X_) - X_ const_term = tf.reduce_sum(const_term, 1) @@ -351,7 +338,7 @@ def loss(self, C, X_g, X_, alpha, beta, z, E, Zik, Tk): ) # fitting PCRL to observed data - def learn(self): + def learn(self, train_set): # placeholders C = tf.placeholder(tf.float32, shape=[None, self.c_dim], name="C") X_ = tf.placeholder(tf.float32, shape=[None, self.c_dim], name="X_") @@ -391,7 +378,7 @@ def learn(self): ) for epoch in range(self.n_epoch): - for idx in self.train_set.item_iter(self.batch_size, shuffle=False): + for idx in train_set.item_iter(self.batch_size, shuffle=False): batch_C = self.aux_data[idx].A EE = self.sess.run(E_, feed_dict={C: batch_C}) z_c = self.sess.run(X_g, feed_dict={C: batch_C, E: EE}) @@ -404,7 +391,7 @@ def learn(self): } _, l = self.sess.run([train, loss], feed_dict=feed_dict) del (EE, z_c) - for idx in self.train_set.item_iter(2 * self.batch_size, shuffle=False): + for idx in train_set.item_iter(2 * self.batch_size, shuffle=False): batch_C = self.aux_data[idx].A self.Ls[idx], self.Lr[idx] = self.sess.run( [alpha, beta], feed_dict={C: batch_C} diff --git a/cornac/models/pcrl/recom_pcrl.py b/cornac/models/pcrl/recom_pcrl.py index fe91d6a6e..a6eef2ebe 100644 --- a/cornac/models/pcrl/recom_pcrl.py +++ b/cornac/models/pcrl/recom_pcrl.py @@ -92,13 +92,10 @@ def __init__( w_determinist=True, init_params=None, ): - Recommender.__init__(self, name=name, trainable=trainable, verbose=verbose) self.k = k - self.z_dims = ( - z_dims - ) # the dimension of the second hidden layer (we consider a 2-layers PCRL) + self.z_dims = z_dims # the dimension of the second hidden layer (we consider a 2-layers PCRL) self.max_iter = max_iter self.batch_size = batch_size self.learning_rate = learning_rate @@ -155,7 +152,7 @@ def fit(self, train_set, val_set=None): B=1, w_determinist=self.w_determinist, init_params=init_params, - ).learn() + ).learn(train_set) self.Theta = np.array(pcrl_.Gs) / np.array(pcrl_.Gr) self.Beta = np.array(pcrl_.Ls) / np.array(pcrl_.Lr) diff --git a/cornac/models/pmf/recom_pmf.py b/cornac/models/pmf/recom_pmf.py index 0154656f9..31d95e7c5 100644 --- a/cornac/models/pmf/recom_pmf.py +++ b/cornac/models/pmf/recom_pmf.py @@ -98,7 +98,7 @@ def __init__( self.ll = np.full(max_iter, 0) self.eps = 0.000000001 - + # Init params if provided self.init_params = {} if init_params is None else init_params self.U = self.init_params.get("U", None) # matrix of user factors @@ -128,20 +128,14 @@ def fit(self, train_set, val_set=None): (uid, iid, rat) = train_set.uir_tuple rat = np.array(rat, dtype="float32") if self.variant == "non_linear": # need to map the ratings to [0,1] - if [self.train_set.min_rating, self.train_set.max_rating] != [0, 1]: - rat = scale( - rat, - 0.0, - 1.0, - self.train_set.min_rating, - self.train_set.max_rating, - ) + if [self.min_rating, self.max_rating] != [0, 1]: + rat = scale(rat, 0.0, 1.0, self.min_rating, self.max_rating) uid = np.array(uid, dtype="int32") iid = np.array(iid, dtype="int32") if self.verbose: print("Learning...") - + # use pre-trained params if exists, otherwise from constructor init_params = {"U": self.U, "V": self.V} @@ -151,8 +145,8 @@ def fit(self, train_set, val_set=None): iid, rat, k=self.k, - n_users=train_set.num_users, - n_items=train_set.num_items, + n_users=self.num_users, + n_items=self.num_items, n_ratings=len(rat), n_epochs=self.max_iter, lambda_reg=self.lambda_reg, @@ -168,8 +162,8 @@ def fit(self, train_set, val_set=None): iid, rat, k=self.k, - n_users=train_set.num_users, - n_items=train_set.num_items, + n_users=self.num_users, + n_items=self.num_items, n_ratings=len(rat), n_epochs=self.max_iter, lambda_reg=self.lambda_reg, @@ -212,7 +206,7 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) @@ -220,9 +214,7 @@ def score(self, user_idx, item_idx=None): known_item_scores = self.V.dot(self.U[user_idx, :]) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not self.knows_user(user_idx) or not self.knows_item(item_idx): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) @@ -232,12 +224,6 @@ def score(self, user_idx, item_idx=None): if self.variant == "non_linear": user_pred = sigmoid(user_pred) - user_pred = scale( - user_pred, - self.train_set.min_rating, - self.train_set.max_rating, - 0.0, - 1.0, - ) + user_pred = scale(user_pred, self.min_rating, self.max_rating, 0.0, 1.0) return user_pred diff --git a/cornac/models/recommender.py b/cornac/models/recommender.py index 06b94ad90..564e9eec9 100644 --- a/cornac/models/recommender.py +++ b/cornac/models/recommender.py @@ -27,26 +27,93 @@ class Recommender: - """Generic class for a recommender model. All recommendation models should inherit from this class + """Generic class for a recommender model. All recommendation models should inherit from this class. Parameters ---------------- name: str, required - The name of the recommender model + Name of the recommender model. trainable: boolean, optional, default: True - When False, the model is not trainable + When False, the model is not trainable. + verbose: boolean, optional, default: False + When True, running logs are displayed. + + Attributes + ---------- + num_users: int + Number of users in training data. + + num_items: int + Number of items in training data. + + total_users: int + Number of users in training, validation, and test data. + In other words, this includes unknown/unseen users. + + total_items: int + Number of items in training, validation, and test data. + In other words, this includes unknown/unseen items. + + uid_map: int + Global mapping of user ID-index. + + iid_map: int + Global mapping of item ID-index. + + max_rating: float + Maximum value among the rating observations. + + min_rating: float + Minimum value among the rating observations. + + global_mean: float + Average value over the rating observations. """ def __init__(self, name, trainable=True, verbose=False): self.name = name self.trainable = trainable self.verbose = verbose - self.train_set = None - self.val_set = None - # attributes to be ignored when being saved - self.ignored_attrs = ["train_set", "val_set"] + + self.ignored_attrs = [] # attributes to be ignored when saving model + + # useful information getting from train_set for prediction + self.num_users = None + self.num_items = None + self.uid_map = None + self.iid_map = None + self.max_rating = None + self.min_rating = None + self.global_mean = None + + self.__user_ids = None + self.__item_ids = None + + @property + def total_users(self): + """Total number of users including users in test and validation if exists""" + return len(self.uid_map) if self.uid_map is not None else self.num_users + + @property + def total_items(self): + """Total number of items including users in test and validation if exists""" + return len(self.iid_map) if self.iid_map is not None else self.num_items + + @property + def user_ids(self): + """Return the list of raw user IDs""" + if self.__user_ids is None: + self.__user_ids = list(self.uid_map.keys()) + return self.__user_ids + + @property + def item_ids(self): + """Return the list of raw item IDs""" + if self.__item_ids is None: + self.__item_ids = list(self.iid_map.keys()) + return self.__item_ids def reset_info(self): self.best_value = -np.Inf @@ -117,11 +184,9 @@ def save(self, save_dir=None): model_file = os.path.join(model_dir, "{}.pkl".format(timestamp)) saved_model = copy.deepcopy(self) - pickle.dump( saved_model, open(model_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL ) - if self.verbose: print("{} model is saved to {}".format(self.name, model_file)) @@ -153,7 +218,6 @@ def load(model_path, trainable=False): model = pickle.load(open(model_file, "rb")) model.trainable = trainable model.load_from = model_file # for further loading - return model def fit(self, train_set, val_set=None): @@ -172,10 +236,51 @@ def fit(self, train_set, val_set=None): self : object """ self.reset_info() - self.train_set = train_set.reset() - self.val_set = None if val_set is None else val_set.reset() + train_set.reset() + if val_set is not None: + val_set.reset() + + # get some useful information for prediction + self.num_users = train_set.num_users + self.num_items = train_set.num_items + self.uid_map = train_set.uid_map + self.iid_map = train_set.iid_map + self.min_rating = train_set.min_rating + self.max_rating = train_set.max_rating + self.global_mean = train_set.global_mean + return self + def knows_user(self, user_idx): + """Return whether the model knows user by its index + + Parameters + ---------- + user_idx: int, required + The index of the user (not the original user ID). + + Returns + ------- + res : bool + True if model knows the user from traning data, False otherwise. + """ + return user_idx >= 0 and user_idx < self.num_users + + def knows_item(self, item_idx): + """Return whether the model knows item by its index + + Parameters + ---------- + item_idx: int, required + The index of the item (not the original item ID). + + Returns + ------- + res : bool + True if model knows the item from traning data, False otherwise. + """ + return item_idx >= 0 and item_idx < self.num_items + def transform(self, test_set): """Transform test set into cached results accelerating the score function. This function is supposed to be called in the `cornac.eval_methods.BaseMethod` @@ -211,7 +316,7 @@ def score(self, user_idx, item_idx=None): def default_score(self): """Overwrite this function if your algorithm has special treatment for cold-start problem""" - return self.train_set.global_mean + return self.global_mean def rate(self, user_idx, item_idx, clipping=True): """Give a rating score between pair of user and item @@ -238,11 +343,7 @@ def rate(self, user_idx, item_idx, clipping=True): rating_pred = self.default_score() if clipping: - rating_pred = clip( - values=rating_pred, - lower_bound=self.train_set.min_rating, - upper_bound=self.train_set.max_rating, - ) + rating_pred = clip(rating_pred, self.min_rating, self.max_rating) return rating_pred @@ -260,54 +361,115 @@ def rank(self, user_idx, item_indices=None): Returns ------- - (item_rank, item_scores): tuple - `item_rank` contains item indices being ranked by their scores. - `item_scores` contains scores of items corresponding to their indices in the `item_indices` input. + (ranked_items, item_scores): tuple + `ranked_items` contains item indices being ranked by their scores. + `item_scores` contains scores of items corresponding to index in `item_indices` input. + """ # obtain item scores from the model try: known_item_scores = self.score(user_idx) except ScoreException: - known_item_scores = ( - np.ones(self.train_set.total_items) * self.default_score() - ) + known_item_scores = np.ones(self.total_items) * self.default_score() # check if the returned scores also cover unknown items # if not, all unknown items will be given the MIN score - if len(known_item_scores) == self.train_set.total_items: + if len(known_item_scores) == self.total_items: all_item_scores = known_item_scores else: - all_item_scores = np.ones(self.train_set.total_items) * np.min( - known_item_scores - ) - all_item_scores[: self.train_set.num_items] = known_item_scores + all_item_scores = np.ones(self.total_items) * np.min(known_item_scores) + all_item_scores[: self.num_items] = known_item_scores # rank items based on their scores if item_indices is None: - item_scores = all_item_scores[: self.train_set.num_items] - item_rank = item_scores.argsort()[::-1] + item_scores = all_item_scores[: self.num_items] + ranked_items = item_scores.argsort()[::-1] else: item_scores = all_item_scores[item_indices] - item_rank = np.array(item_indices)[item_scores.argsort()[::-1]] + ranked_items = np.array(item_indices)[item_scores.argsort()[::-1]] + + return ranked_items, item_scores + + def recommend(self, user_id, k=-1, remove_seen=False, train_set=None): + """Generate top-K item recommendations for a given user. Key difference between + this function and rank() function is that rank() function works with mapped + user/item index while this function works with original user/item ID. This helps + hide the abstraction of ID-index mapping, and make model usage and deployment cleaner. + + Parameters + ---------- + user_id: str, required + The original ID of the user. + + k: int, optional, default=-1 + Cut-off length for recommendations, k=-1 will return ranked list of all items. + + remove_seen: bool, optional, default: False + Remove seen/known items during training and validation from output recommendations. + + train_set: :obj:`cornac.data.Dataset`, optional, default: None + Training dataset needs to be provided in order to remove seen items. + + Returns + ------- + recommendations: list + Recommended items in the form of their original IDs. + """ + user_idx = self.uid_map.get(user_id, -1) + if user_idx == -1: + raise ValueError(f"{user_id} is unknown to the model.") + + if k < -1 or k > self.total_items: + raise ValueError( + f"k={k} is invalid, there are {self.total_users} users in total." + ) - return item_rank, item_scores + item_indices = np.arange(self.total_items) + if remove_seen: + seen_mask = np.zeros(len(item_indices), dtype="bool") + if train_set is None: + raise ValueError("train_set must be provided to remove seen items.") + if user_idx < train_set.csr_matrix.shape[0]: + seen_mask[train_set.csr_matrix.getrow(user_idx).indices] = True + item_indices = item_indices[~seen_mask] - def monitor_value(self): + item_rank, _ = self.rank(user_idx, item_indices) + if k != -1: + item_rank = item_rank[:k] + + recommendations = [self.item_ids[i] for i in item_rank] + return recommendations + + def monitor_value(self, train_set, val_set): """Calculating monitored value used for early stopping on validation set (`val_set`). This function will be called by `early_stop()` function. Note: `val_set` could be `None` thus it needs to be checked before usage. + Parameters + ---------- + train_set: :obj:`cornac.data.Dataset`, required + User-Item preference data as well as additional modalities. + + val_set: :obj:`cornac.data.Dataset`, optional, default: None + User-Item preference data for model selection purposes (e.g., early stopping). + Returns ------- :raise NotImplementedError """ raise NotImplementedError() - def early_stop(self, min_delta=0.0, patience=0): + def early_stop(self, train_set, val_set, min_delta=0.0, patience=0): """Check if training should be stopped when validation loss has stopped improving. Parameters ---------- + train_set: :obj:`cornac.data.Dataset`, required + User-Item preference data as well as additional modalities. + + val_set: :obj:`cornac.data.Dataset`, optional, default: None + User-Item preference data for model selection purposes (e.g., early stopping). + min_delta: float, optional, default: 0. The minimum increase in monitored value on validation set to be considered as improvement, i.e. an increment of less than `min_delta` will count as no improvement. @@ -322,7 +484,7 @@ def early_stop(self, min_delta=0.0, patience=0): otherwise return `False`. """ self.current_epoch += 1 - current_value = self.monitor_value() + current_value = self.monitor_value(train_set, val_set) if current_value is None: return False diff --git a/cornac/models/sbpr/recom_sbpr.pyx b/cornac/models/sbpr/recom_sbpr.pyx index 551560454..6f538b14e 100644 --- a/cornac/models/sbpr/recom_sbpr.pyx +++ b/cornac/models/sbpr/recom_sbpr.pyx @@ -116,13 +116,14 @@ class SBPR(BPR): self.lambda_v = lambda_v self.lambda_b = lambda_b - def _prepare_social_data(self): - X = self.train_set.matrix # csr_matrix - n_users, n_items = self.train_set.num_users, self.train_set.num_items + def _prepare_social_data(self, train_set): + X = train_set.matrix # csr_matrix + n_users, n_items = train_set.num_users, train_set.num_items # construct social feedback in the sparse format - (rid, cid, val) = self.train_set.user_graph.get_train_triplet( - self.train_set.user_indices, self.train_set.user_indices + train_user_indices = set(train_set.uir_tuple[0]) + (rid, cid, val) = train_set.user_graph.get_train_triplet( + train_user_indices, train_user_indices ) Y = csr_matrix((val, (rid, cid)), shape=(n_users, n_users)) social_item_ids = [] @@ -197,7 +198,7 @@ class SBPR(BPR): """ cdef: long num_samples = len(user_ids) - long num_items = self.train_set.num_items + long num_items = self.num_items long s, i_index, k_index, skipped = 0 int f, u_id, i_id, j_id, k_id, n_social_items, thread_id floating u_temp, k_rand diff --git a/cornac/models/skm/recom_skmeans.py b/cornac/models/skm/recom_skmeans.py index b60967215..139020d30 100644 --- a/cornac/models/skm/recom_skmeans.py +++ b/cornac/models/skm/recom_skmeans.py @@ -61,15 +61,15 @@ class SKMeans(Recommender): """ def __init__( - self, - k=5, - max_iter=100, - name="Skmeans", - trainable=True, - tol=1e-6, - verbose=True, - seed=None, - init_par=None, + self, + k=5, + max_iter=100, + name="Skmeans", + trainable=True, + tol=1e-6, + verbose=True, + seed=None, + init_par=None, ): Recommender.__init__(self, name=name, trainable=trainable, verbose=verbose) self.k = k @@ -97,8 +97,7 @@ def fit(self, train_set, val_set=None): """ Recommender.fit(self, train_set, val_set) - X = self.train_set.matrix - X = sp.csr_matrix(X) + X = train_set.matrix # CSR matrix # Skmeans requires rows of X to have a unit L2 norm. We therefore need to make a copy of X as we should not modify the latter. X1 = X.copy() @@ -124,7 +123,7 @@ def fit(self, train_set, val_set=None): print("%s is trained already (trainable = False)" % (self.name)) self.user_center_sim = ( - X1 * self.centroids.T + X1 * self.centroids.T ) # user-centroid cosine similarity matrix del X1 @@ -149,7 +148,7 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) @@ -158,13 +157,11 @@ def score(self, user_idx, item_idx=None): self.user_center_sim[user_idx, :].T ) known_item_scores = known_item_scores.sum(0).A1 / ( - self.user_center_sim[user_idx, :].sum() + 1e-20 + self.user_center_sim[user_idx, :].sum() + 1e-20 ) # weighted average of cluster centroids return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) @@ -175,7 +172,7 @@ def score(self, user_idx, item_idx=None): ) # transform user_pred to a flatten array user_pred = user_pred.sum(0).A1 / ( - self.user_center_sim[user_idx, :].sum() + 1e-20 + self.user_center_sim[user_idx, :].sum() + 1e-20 ) # weighted average of cluster centroids return user_pred diff --git a/cornac/models/sorec/recom_sorec.py b/cornac/models/sorec/recom_sorec.py index 00ba7117e..bf8a6d7a7 100644 --- a/cornac/models/sorec/recom_sorec.py +++ b/cornac/models/sorec/recom_sorec.py @@ -148,13 +148,15 @@ def fit(self, train_set, val_set=None): (rat_uid, rat_iid, rat_val) = train_set.uir_tuple # user social network - map_uid = train_set.user_indices + train_user_indices = set(train_set.uir_tuple[0]) (net_uid, net_jid, net_val) = train_set.user_graph.get_train_triplet( - map_uid, map_uid + train_user_indices, train_user_indices ) if self.weight_link: - degree = train_set.user_graph.get_node_degree(map_uid, map_uid) + degree = train_set.user_graph.get_node_degree( + train_user_indices, train_user_indices + ) weighted_net_val = [] for u, j, val in zip(net_uid, net_jid, net_val): u_out = degree[int(u)][1] @@ -163,17 +165,11 @@ def fit(self, train_set, val_set=None): weighted_net_val.append(val_weighted) net_val = weighted_net_val - if [self.train_set.min_rating, self.train_set.max_rating] != [0, 1]: - if self.train_set.min_rating == self.train_set.max_rating: - rat_val = scale(rat_val, 0.0, 1.0, 0.0, self.train_set.max_rating) + if [self.min_rating, self.max_rating] != [0, 1]: + if self.min_rating == self.max_rating: + rat_val = scale(rat_val, 0.0, 1.0, 0.0, self.max_rating) else: - rat_val = scale( - rat_val, - 0.0, - 1.0, - self.train_set.min_rating, - self.train_set.max_rating, - ) + rat_val = scale(rat_val, 0.0, 1.0, self.min_rating, self.max_rating) rat_val = np.array(rat_val, dtype="float32") rat_uid = np.array(rat_uid, dtype="int32") @@ -214,7 +210,7 @@ def fit(self, train_set, val_set=None): if self.verbose: print("Learning completed") - + elif self.verbose: print("%s is trained already (trainable = False)" % self.name) @@ -235,33 +231,22 @@ def score(self, user_idx, item_idx=None): Relative scores that the user gives to the item or to all known items """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - known_item_scores = self.V.dot(self.U[user_idx, :]) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) ) - user_pred = self.V[item_idx, :].dot(self.U[user_idx, :]) user_pred = sigmoid(user_pred) - if self.train_set.min_rating == self.train_set.max_rating: - user_pred = scale(user_pred, 0.0, self.train_set.max_rating, 0.0, 1.0) + if self.min_rating == self.max_rating: + user_pred = scale(user_pred, 0.0, self.max_rating, 0.0, 1.0) else: - user_pred = scale( - user_pred, - self.train_set.min_rating, - self.train_set.max_rating, - 0.0, - 1.0, - ) - + user_pred = scale(user_pred, self.min_rating, self.max_rating, 0.0, 1.0) return user_pred diff --git a/cornac/models/trirank/recom_trirank.py b/cornac/models/trirank/recom_trirank.py index 2b87cf65b..c951d6ad2 100644 --- a/cornac/models/trirank/recom_trirank.py +++ b/cornac/models/trirank/recom_trirank.py @@ -130,16 +130,14 @@ def __init__( self.a = self.init_params.get("a", None) self.u = self.init_params.get("u", None) - def _init(self): + def _init(self, train_set): # Initialize user, item and aspect rank. if self.p is None: - self.p = uniform(self.train_set.num_items, random_state=self.rng) + self.p = uniform(train_set.num_items, random_state=self.rng) if self.a is None: - self.a = uniform( - self.train_set.sentiment.num_aspects, random_state=self.rng - ) + self.a = uniform(train_set.sentiment.num_aspects, random_state=self.rng) if self.u is None: - self.u = uniform(self.train_set.num_users, random_state=self.rng) + self.u = uniform(train_set.num_users, random_state=self.rng) def _symmetrical_normalization(self, matrix: csr_matrix): row = [] @@ -157,6 +155,8 @@ def _symmetrical_normalization(self, matrix: csr_matrix): def _create_matrices(self, train_set): from time import time + self.r_mat = train_set.csr_matrix + start_time = time() if self.verbose: print("Building matrices started!") @@ -222,7 +222,7 @@ def fit(self, train_set, val_set=None): self : object """ Recommender.fit(self, train_set, val_set) - self._init() + self._init(train_set) if not self.trainable: return self @@ -233,11 +233,11 @@ def fit(self, train_set, val_set=None): def _online_recommendation(self, user): # Algorithm 1: Online recommendation line 5 - p_0 = self.train_set.csr_matrix[[user]] + p_0 = self.r_mat[[user]] p_0.data.fill(1) p_0 = p_0.toarray().squeeze() a_0 = self.Y[user].toarray().squeeze() - u_0 = np.zeros(self.train_set.csr_matrix.shape[0]) + u_0 = np.zeros(self.r_mat.shape[0]) u_0[user] = 1 # Algorithm 1: Online training line 6 @@ -309,21 +309,19 @@ def score(self, u_idx, i_idx=None): Relative scores that the user gives to the item or to all known items """ - if self.train_set.is_unk_user(u_idx): + if not self.knows_user(u_idx): raise ScoreException("Can't make score prediction for (user_id=%d" & u_idx) - if i_idx is not None and self.train_set.is_unk_item(i_idx): + if i_idx is not None and not self.knows_item(i_idx): raise ScoreException("Can't make score prediction for (item_id=%d" & i_idx) item_scores, *_ = self._online_recommendation(u_idx) # Set already rated items to zero. - item_scores[self.train_set.csr_matrix[u_idx].indices] = 0 + item_scores[self.r_mat[u_idx].indices] = 0 # Scale to match rating scale. item_scores = ( - item_scores - * (self.train_set.max_rating - self.train_set.min_rating) - / max(item_scores) - + self.train_set.min_rating + item_scores * (self.max_rating - self.min_rating) / max(item_scores) + + self.min_rating ) if i_idx is None: diff --git a/cornac/models/vaecf/recom_vaecf.py b/cornac/models/vaecf/recom_vaecf.py index 1e2f14c41..efc30b1c3 100644 --- a/cornac/models/vaecf/recom_vaecf.py +++ b/cornac/models/vaecf/recom_vaecf.py @@ -137,8 +137,10 @@ def fit(self, train_set, val_set=None): torch.manual_seed(self.seed) torch.cuda.manual_seed(self.seed) + self.r_mat = train_set.matrix + if not hasattr(self, "vae"): - data_dim = train_set.matrix.shape[1] + data_dim = self.r_mat.shape[1] self.vae = VAE( self.k, [data_dim] + self.autoencoder_structure, @@ -148,7 +150,7 @@ def fit(self, train_set, val_set=None): learn( self.vae, - self.train_set, + train_set, n_epochs=self.n_epochs, batch_size=self.batch_size, learn_rate=self.learning_rate, @@ -183,12 +185,12 @@ def score(self, user_idx, item_idx=None): import torch if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - x_u = self.train_set.matrix[user_idx].copy() + x_u = self.r_mat[user_idx].copy() x_u.data = np.ones(len(x_u.data)) z_u, _ = self.vae.encode( torch.tensor(x_u.A, dtype=torch.float32, device=self.device) @@ -197,15 +199,13 @@ def score(self, user_idx, item_idx=None): return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) ) - x_u = self.train_set.matrix[user_idx].copy() + x_u = self.r_mat[user_idx].copy() x_u.data = np.ones(len(x_u.data)) z_u, _ = self.vae.encode( torch.tensor(x_u.A, dtype=torch.float32, device=self.device) diff --git a/cornac/models/vbpr/recom_vbpr.py b/cornac/models/vbpr/recom_vbpr.py index 19c38d405..26c0718d5 100644 --- a/cornac/models/vbpr/recom_vbpr.py +++ b/cornac/models/vbpr/recom_vbpr.py @@ -154,20 +154,20 @@ def fit(self, train_set, val_set=None): raise CornacException("item_image modality is required but None.") # Item visual feature from CNN - train_features = train_set.item_image.features[: self.train_set.total_items] + train_features = train_set.item_image.features[: self.total_items] train_features = train_features.astype(np.float32) self._init( - n_users=train_set.total_users, - n_items=train_set.total_items, + n_users=self.total_users, + n_items=self.total_items, features=train_features, ) if self.trainable: - self._fit_torch(train_features) + self._fit_torch(train_set, train_features) return self - def _fit_torch(self, train_features): + def _fit_torch(self, train_set, train_features): import torch def _l2_loss(*tensors): @@ -213,11 +213,11 @@ def _inner(a, b): sum_loss = 0.0 count = 0 progress_bar = tqdm( - total=self.train_set.num_batches(self.batch_size), + total=train_set.num_batches(self.batch_size), desc="Epoch {}/{}".format(epoch, self.n_epochs), disable=not self.verbose, ) - for batch_u, batch_i, batch_j in self.train_set.uij_iter( + for batch_u, batch_i, batch_j in train_set.uij_iter( self.batch_size, shuffle=True ): gamma_u = Gu[batch_u] diff --git a/cornac/models/vmf/recom_vmf.py b/cornac/models/vmf/recom_vmf.py index a0b50503a..aca572f91 100644 --- a/cornac/models/vmf/recom_vmf.py +++ b/cornac/models/vmf/recom_vmf.py @@ -149,7 +149,7 @@ def fit(self, train_set, val_set=None): if self.trainable: # Item visual cnn-features - item_features = train_set.item_image.features[: self.train_set.num_items] + item_features = train_set.item_image.features[: train_set.num_items] if self.verbose: print("Learning...") @@ -157,7 +157,7 @@ def fit(self, train_set, val_set=None): from .vmf import vmf res = vmf( - self.train_set, + train_set, item_features, k=self.k, d=self.d, @@ -208,7 +208,7 @@ def score(self, user_idx, item_idx=None): """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) @@ -221,9 +221,7 @@ def score(self, user_idx, item_idx=None): # fast_dot(self.P[user_id], self.Q, known_item_scores) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) @@ -233,12 +231,6 @@ def score(self, user_idx, item_idx=None): ].dot(self.P[user_idx, :]) user_pred = sigmoid(user_pred) - user_pred = scale( - user_pred, - self.train_set.min_rating, - self.train_set.max_rating, - 0.0, - 1.0, - ) + user_pred = scale(user_pred, self.min_rating, self.max_rating, 0.0, 1.0) return user_pred diff --git a/cornac/models/wmf/recom_wmf.py b/cornac/models/wmf/recom_wmf.py index 9b9b390aa..4b16465f3 100644 --- a/cornac/models/wmf/recom_wmf.py +++ b/cornac/models/wmf/recom_wmf.py @@ -119,12 +119,10 @@ def __init__( def _init(self): rng = get_rng(self.seed) - n_users, n_items = self.train_set.num_users, self.train_set.num_items - if self.U is None: - self.U = xavier_uniform((n_users, self.k), rng) + self.U = xavier_uniform((self.num_users, self.k), rng) if self.V is None: - self.V = xavier_uniform((n_items, self.k), rng) + self.V = xavier_uniform((self.num_items, self.k), rng) def fit(self, train_set, val_set=None): """Fit the model to observations. @@ -146,11 +144,11 @@ def fit(self, train_set, val_set=None): self._init() if self.trainable: - self._fit_cf() + self._fit_cf(train_set) return self - def _fit_cf(self,): + def _fit_cf(self, train_set): import tensorflow.compat.v1 as tf from .wmf import Model @@ -158,16 +156,15 @@ def _fit_cf(self,): os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) - R = self.train_set.csc_matrix # csc for efficient slicing over items - n_users, n_items, = self.train_set.num_users, self.train_set.num_items + R = train_set.csc_matrix # csc for efficient slicing over items # Build model graph = tf.Graph() with graph.as_default(): tf.set_random_seed(self.seed) model = Model( - n_users=n_users, - n_items=n_items, + n_users=self.num_users, + n_items=self.num_items, k=self.k, lambda_u=self.lambda_u, lambda_v=self.lambda_v, @@ -184,11 +181,10 @@ def _fit_cf(self,): loop = trange(self.max_iter, disable=not self.verbose) for _ in loop: - sum_loss = 0 count = 0 for i, batch_ids in enumerate( - self.train_set.item_iter(self.batch_size, shuffle=True) + train_set.item_iter(self.batch_size, shuffle=True) ): batch_R = R[:, batch_ids] batch_C = np.ones(batch_R.shape) * self.b @@ -232,17 +228,14 @@ def score(self, user_idx, item_idx=None): Relative scores that the user gives to the item or to all known items """ if item_idx is None: - if self.train_set.is_unk_user(user_idx): + if not self.knows_user(user_idx): raise ScoreException( "Can't make score prediction for (user_id=%d)" % user_idx ) - known_item_scores = self.V.dot(self.U[user_idx, :]) return known_item_scores else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): + if not (self.knows_user(user_idx) and self.knows_item(item_idx)): raise ScoreException( "Can't make score prediction for (user_id=%d, item_id=%d)" % (user_idx, item_idx) diff --git a/examples/amr_clothing.py b/examples/amr_clothing.py index 431509610..c5304095b 100644 --- a/examples/amr_clothing.py +++ b/examples/amr_clothing.py @@ -51,7 +51,6 @@ lambda_w=1, lambda_b=0.01, lambda_e=0.0, - lmd=1.0, use_gpu=True, ) diff --git a/tests/cornac/data/test_dataset.py b/tests/cornac/data/test_dataset.py index e0d49bdef..e7a308a27 100644 --- a/tests/cornac/data/test_dataset.py +++ b/tests/cornac/data/test_dataset.py @@ -24,10 +24,9 @@ class TestDataset(unittest.TestCase): - def setUp(self): - self.triplet_data = Reader().read('./tests/data.txt') - self.uirt_data = Reader().read('./tests/data.txt', fmt='UIRT') + self.triplet_data = Reader().read("./tests/data.txt") + self.uirt_data = Reader().read("./tests/data.txt", fmt="UIRT") def test_init(self): train_set = Dataset.from_uir(self.triplet_data) @@ -41,22 +40,18 @@ def test_init(self): self.assertEqual(train_set.num_users, 10) self.assertEqual(train_set.num_items, 10) - self.assertFalse(train_set.is_unk_user(7)) - self.assertTrue(train_set.is_unk_user(13)) - - self.assertFalse(train_set.is_unk_item(3)) - self.assertTrue(train_set.is_unk_item(16)) - - self.assertEqual(train_set.uid_map['768'], 1) - self.assertEqual(train_set.iid_map['195'], 7) + self.assertEqual(train_set.uid_map["768"], 1) + self.assertEqual(train_set.iid_map["195"], 7) - self.assertSequenceEqual(list(train_set.user_indices), range(10)) - self.assertListEqual(list(train_set.user_ids), - ['76', '768', '642', '930', '329', '633', '716', '871', '543', '754']) + self.assertSetEqual( + set(train_set.user_ids), + set(["76", "768", "642", "930", "329", "633", "716", "871", "543", "754"]), + ) - self.assertSequenceEqual(list(train_set.item_indices), range(10)) - self.assertListEqual(list(train_set.item_ids), - ['93', '257', '795', '709', '705', '226', '478', '195', '737', '282']) + self.assertSetEqual( + set(train_set.item_ids), + set(["93", "257", "795", "709", "705", "226", "478", "195", "737", "282"]), + ) def test_from_uirt(self): train_set = Dataset.from_uirt(self.uirt_data) @@ -72,14 +67,23 @@ def test_exclude_unknowns_empty_error(self): def test_idx_iter(self): train_set = Dataset.from_uir(self.triplet_data) - ids = [batch_ids for batch_ids in train_set.idx_iter( - idx_range=10, batch_size=1, shuffle=False)] + ids = [ + batch_ids + for batch_ids in train_set.idx_iter( + idx_range=10, batch_size=1, shuffle=False + ) + ] npt.assert_array_equal(ids, np.arange(10).reshape(10, 1)) - ids = [batch_ids for batch_ids in train_set.idx_iter( - idx_range=10, batch_size=1, shuffle=True)] - npt.assert_raises(AssertionError, npt.assert_array_equal, - ids, np.arange(10).reshape(10, 1)) + ids = [ + batch_ids + for batch_ids in train_set.idx_iter( + idx_range=10, batch_size=1, shuffle=True + ) + ] + npt.assert_raises( + AssertionError, npt.assert_array_equal, ids, np.arange(10).reshape(10, 1) + ) def test_uir_iter(self): train_set = Dataset.from_uir(self.triplet_data) @@ -93,12 +97,15 @@ def test_uir_iter(self): ratings = [batch_ratings for _, _, batch_ratings in train_set.uir_iter()] self.assertListEqual(ratings, [4, 4, 4, 4, 3, 4, 4, 5, 3, 4]) - ratings = [batch_ratings for _, _, - batch_ratings in train_set.uir_iter(binary=True)] + ratings = [ + batch_ratings for _, _, batch_ratings in train_set.uir_iter(binary=True) + ] self.assertListEqual(ratings, [1] * 10) - ratings = [batch_ratings for _, _, - batch_ratings in train_set.uir_iter(batch_size=5, num_zeros=1)] + ratings = [ + batch_ratings + for _, _, batch_ratings in train_set.uir_iter(batch_size=5, num_zeros=1) + ] self.assertListEqual(ratings[0].tolist(), [4, 4, 4, 4, 3, 0, 0, 0, 0, 0]) self.assertListEqual(ratings[1].tolist(), [4, 4, 5, 3, 4, 0, 0, 0, 0, 0]) @@ -112,16 +119,20 @@ def test_uij_iter(self): self.assertSequenceEqual(pos_items, range(10)) neg_items = [batch_neg_items for _, _, batch_neg_items in train_set.uij_iter()] - self.assertRaises(AssertionError, self.assertSequenceEqual, - neg_items, range(10)) - - neg_items = [batch_neg_items for _, _, - batch_neg_items in train_set.uij_iter(neg_sampling='popularity')] - self.assertRaises(AssertionError, self.assertSequenceEqual, - neg_items, range(10)) + self.assertRaises( + AssertionError, self.assertSequenceEqual, neg_items, range(10) + ) + + neg_items = [ + batch_neg_items + for _, _, batch_neg_items in train_set.uij_iter(neg_sampling="popularity") + ] + self.assertRaises( + AssertionError, self.assertSequenceEqual, neg_items, range(10) + ) try: - for _ in train_set.uij_iter(neg_sampling='bla'): + for _ in train_set.uij_iter(neg_sampling="bla"): continue except ValueError: assert True @@ -129,20 +140,28 @@ def test_uij_iter(self): def test_user_iter(self): train_set = Dataset.from_uir(self.triplet_data) - npt.assert_array_equal(np.arange(10).reshape(10, 1), - [u for u in train_set.user_iter()]) - self.assertRaises(AssertionError, npt.assert_array_equal, - np.arange(10).reshape(10, 1), - [u for u in train_set.user_iter(shuffle=True)]) + npt.assert_array_equal( + np.arange(10).reshape(10, 1), [u for u in train_set.user_iter()] + ) + self.assertRaises( + AssertionError, + npt.assert_array_equal, + np.arange(10).reshape(10, 1), + [u for u in train_set.user_iter(shuffle=True)], + ) def test_item_iter(self): train_set = Dataset.from_uir(self.triplet_data) - npt.assert_array_equal(np.arange(10).reshape(10, 1), - [i for i in train_set.item_iter()]) - self.assertRaises(AssertionError, npt.assert_array_equal, - np.arange(10).reshape(10, 1), - [i for i in train_set.item_iter(shuffle=True)]) + npt.assert_array_equal( + np.arange(10).reshape(10, 1), [i for i in train_set.item_iter()] + ) + self.assertRaises( + AssertionError, + npt.assert_array_equal, + np.arange(10).reshape(10, 1), + [i for i in train_set.item_iter(shuffle=True)], + ) def test_uir_tuple(self): train_set = Dataset.from_uir(self.triplet_data) @@ -184,12 +203,12 @@ def test_chrono_user_data(self): zero_data = [] for idx in range(len(self.triplet_data)): u = self.triplet_data[idx][0] - i = self.triplet_data[-1-idx][1] - zero_data.append((u, i, 1., 0)) + i = self.triplet_data[-1 - idx][1] + zero_data.append((u, i, 1.0, 0)) train_set = Dataset.from_uirt(self.uirt_data + zero_data) self.assertEqual(len(train_set.chrono_user_data), 10) - self.assertListEqual(train_set.chrono_user_data[0][1], [1., 4.]) + self.assertListEqual(train_set.chrono_user_data[0][1], [1.0, 4.0]) self.assertListEqual(train_set.chrono_user_data[0][2], [0, 882606572]) try: @@ -201,18 +220,19 @@ def test_chrono_item_data(self): zero_data = [] for idx in range(len(self.triplet_data)): u = self.triplet_data[idx][0] - i = self.triplet_data[-1-idx][1] - zero_data.append((u, i, 1., 0)) + i = self.triplet_data[-1 - idx][1] + zero_data.append((u, i, 1.0, 0)) train_set = Dataset.from_uirt(self.uirt_data + zero_data) self.assertEqual(len(train_set.chrono_item_data), 10) - self.assertListEqual(train_set.chrono_item_data[0][1], [1., 4.]) + self.assertListEqual(train_set.chrono_item_data[0][1], [1.0, 4.0]) self.assertListEqual(train_set.chrono_item_data[0][2], [0, 882606572]) - + try: Dataset.from_uir(self.triplet_data).chrono_item_data except ValueError: assert True -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/cornac/models/test_recommender.py b/tests/cornac/models/test_recommender.py new file mode 100644 index 000000000..a59e46c42 --- /dev/null +++ b/tests/cornac/models/test_recommender.py @@ -0,0 +1,55 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import unittest + +from cornac.data import Reader, Dataset +from cornac.models import MF + + +class TestRecommender(unittest.TestCase): + def setUp(self): + self.data = Reader().read("./tests/data.txt") + + def test_knows_x(self): + mf = MF(1, 1, seed=123) + dataset = Dataset.from_uir(self.data) + mf.fit(dataset) + + self.assertTrue(mf.knows_user(7)) + self.assertFalse(mf.knows_item(13)) + + self.assertTrue(mf.knows_item(3)) + self.assertFalse(mf.knows_item(16)) + + def test_recommend(self): + mf = MF(1, 1, seed=123) + dataset = Dataset.from_uir(self.data) + mf.fit(dataset) + self.assertFalse( + all( + [ + a == b + for a, b in zip( + mf.recommend("76", k=3, remove_seen=False), + mf.recommend("76", k=3, remove_seen=True, train_set=dataset), + ) + ] + ) + ) + + +if __name__ == "__main__": + unittest.main()