From 3ad15e6a91d491c714bad79cc50267812f12fa2a Mon Sep 17 00:00:00 2001 From: Hakan Date: Fri, 10 Jun 2022 13:52:41 +0200 Subject: [PATCH 1/6] =?UTF-8?q?EASE=E1=B4=BF=20Implementation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 + cornac/models/__init__.py | 1 + cornac/models/ease/__init__.py | 1 + cornac/models/ease/recom_ease.py | 139 +++++++++++++++++++++++++++++++ docs/source/models.rst | 5 ++ examples/ease_movielens.py | 61 ++++++++++++++ 6 files changed, 209 insertions(+) create mode 100644 cornac/models/ease/__init__.py create mode 100644 cornac/models/ease/recom_ease.py create mode 100644 examples/ease_movielens.py diff --git a/README.md b/README.md index 3a41e90da..119089777 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,7 @@ The recommender models supported by Cornac are listed below. Why don't you join | | [Causal Inference for Visual Debiasing in Visually-Aware Recommendation (CausalRec)](cornac/models/causalrec), [paper](https://arxiv.org/abs/2107.02390) | [requirements.txt](cornac/models/causalrec/requirements.txt) | [causalrec_clothing.py](examples/causalrec_clothing.py) | | [Explainable Recommendation with Comparative Constraints on Product Aspects (ComparER)](cornac/models/comparer), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441754) | N/A | [PreferredAI/ComparER](https://github.com/PreferredAI/ComparER) | 2020 | [Adversarial Training Towards Robust Multimedia Recommender System (AMR)](cornac/models/amr), [paper](https://ieeexplore.ieee.org/document/8618394) | [requirements.txt](cornac/models/amr/requirements.txt) | [amr_clothing.py](examples/amr_clothing.py) +| 2019 | [Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ)](cornac/models/ease), [paper](https://arxiv.org/pdf/1905.03375.pdf) | N/A | [ease_movielens.py](examples/ease_movielens.py) | 2018 | [Collaborative Context Poisson Factorization (C2PF)](cornac/models/c2pf), [paper](https://www.ijcai.org/proceedings/2018/0370.pdf) | N/A | [c2pf_exp.py](examples/c2pf_example.py) | | [Multi-Task Explainable Recommendation (MTER)](cornac/models/mter), [paper](https://arxiv.org/pdf/1806.03568.pdf) | N/A | [mter_exp.py](examples/mter_example.py) | | [Neural Attention Rating Regression with Review-level Explanations (NARRE)](cornac/models/narre), [paper](http://www.thuir.cn/group/~YQLiu/publications/WWW2018_CC.pdf) | [requirements.txt](cornac/models/narre/requirements.txt) | [narre_example.py](examples/narre_example.py) @@ -153,6 +154,7 @@ The recommender models supported by Cornac are listed below. Why don't you join | | [User K-Nearest-Neighbors (UserKNN)](cornac/models/knn), [paper](https://arxiv.org/pdf/1301.7363.pdf) | N/A | [knn_movielens.py](examples/knn_movielens.py) | | [Weighted Matrix Factorization (WMF)](cornac/models/wmf), [paper](http://yifanhu.net/PUB/cf.pdf) | [requirements.txt](cornac/models/wmf/requirements.txt) | [wmf_exp.py](examples/wmf_example.py) + ## Support Your contributions at any level of the library are welcome. If you intend to contribute, please: diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index 85dcf2751..4a84333a3 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -32,6 +32,7 @@ from .cvae import CVAE from .cvaecf import CVAECF from .efm import EFM +from .ease import EASE from .global_avg import GlobalAvg from .hft import HFT from .hpf import HPF diff --git a/cornac/models/ease/__init__.py b/cornac/models/ease/__init__.py new file mode 100644 index 000000000..759f1898e --- /dev/null +++ b/cornac/models/ease/__init__.py @@ -0,0 +1 @@ +from .recom_ease import EASE \ No newline at end of file diff --git a/cornac/models/ease/recom_ease.py b/cornac/models/ease/recom_ease.py new file mode 100644 index 000000000..464a936ce --- /dev/null +++ b/cornac/models/ease/recom_ease.py @@ -0,0 +1,139 @@ +import numpy as np + +from cornac.models.recommender import Recommender +from cornac.exception import ScoreException + +class EASE(Recommender): + """Embarrassingly Shallow Autoencoders for Sparse Data. + + Parameters + ---------- + name: string, optional, default: 'EASEᴿ' + The name of the recommender model. + + lamb: float, optional, default: 500 + L2-norm regularization-parameter λ ∈ R+. + + posB: boolean, optional, default: False + Remove Negative Weights + + trainable: boolean, optional, default: True + When False, the model is not trained and Cornac assumes that the model is already \ + trained. + + verbose: boolean, optional, default: False + When True, some running logs are displayed. + + seed: int, optional, default: None + Random seed for parameters initialization. + + init_par: numpy 1d array, optional, default: None + The initial object parition, 1d array contaning the cluster label (int type starting from 0) \ + of each object (user). If par = None, then skmeans is initialized randomly. + + References + ---------- + * Steck, H. (2019, May). "Embarrassingly shallow autoencoders for sparse data." \ + In The World Wide Web Conference (pp. 3251-3257). + """ + + def __init__( + self, + name="EASEᴿ", + lamb=500, + posB=True, + trainable=True, + verbose=True, + seed=None, + B=None, + U=None, + ): + Recommender.__init__(self, name=name, trainable=trainable, verbose=verbose) + self.lamb = lamb + self.posB = posB + self.verbose = verbose + self.seed = seed + self.B = B + self.U = U + + def fit(self, train_set, val_set=None): + """Fit the model to observations. + + Parameters + ---------- + train_set: :obj:`cornac.data.Dataset`, required + User-Item preference data as well as additional modalities. + + val_set: :obj:`cornac.data.Dataset`, optional, default: None + User-Item preference data for model selection purposes (e.g., early stopping). + + Returns + ------- + self : object + """ + Recommender.fit(self, train_set, val_set) + + # A rating matrix + self.U = self.train_set.matrix + + # Grahm matrix is X^t X, compute dot product + G = self.U.T.dot(self.U).toarray() + + diag_indices = np.diag_indices(G.shape[0]) + + G[diag_indices] = G.diagonal() + self.lamb + + P = np.linalg.inv(G) + + B = P / (-np.diag(P)) + + B[diag_indices] = 0.0 + + # if self.posB remove -ve values + if self.posB: + B[B<0]=0 + + # save B for predictions + self.B=B + + return self + + + def score(self, user_idx, item_idx=None): + """Predict the scores/ratings of a user for an item. + + Parameters + ---------- + user_idx: int, required + The index of the user for whom to perform score prediction. + + item_idx: int, optional, default: None + The index of the item for which to perform score prediction. + If None, scores for all known items will be returned. + + Returns + ------- + res : A scalar or a Numpy array + Relative scores that the user gives to the item or to all known items + + """ + if item_idx is None: + if self.train_set.is_unk_user(user_idx): + raise ScoreException( + "Can't make score prediction for (user_id=%d)" % user_idx + ) + + known_item_scores = self.U[user_idx, :].dot(self.B) + return known_item_scores + else: + if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( + item_idx + ): + raise ScoreException( + "Can't make score prediction for (user_id=%d, item_id=%d)" + % (user_idx, item_idx) + ) + + user_pred = self.B[item_idx, :].dot(self.U[user_idx, :]) + + return user_pred \ No newline at end of file diff --git a/docs/source/models.rst b/docs/source/models.rst index eae175be1..c436392b3 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -232,3 +232,8 @@ Weighted Matrix Factorization (WMF) -------------------------------------------------- .. automodule:: cornac.models.wmf.recom_wmf :members: + +Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ) +-------------------------------------------------- +.. automodule:: cornac.models.ease.recom_ease + :members: \ No newline at end of file diff --git a/examples/ease_movielens.py b/examples/ease_movielens.py new file mode 100644 index 000000000..ddd78c3c8 --- /dev/null +++ b/examples/ease_movielens.py @@ -0,0 +1,61 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Example (EASEᴿ) Embarrassingly Shallow Autoencoders for Sparse Data on MovieLens data""" + +import cornac +from cornac.datasets import movielens +from cornac.eval_methods import RatioSplit + + +# Load user-item feedback +data = movielens.load_feedback(variant="1M") + +# Instantiate an evaluation method to split data into train and test sets. +ratio_split = RatioSplit( + data=data, + test_size=0.2, + exclude_unknowns=True, + verbose=True, + seed=123, + rating_threshold=0.8, +) + +ease_original = cornac.models.EASE( + lamb=500, + name="EASEᴿ (B>0)", + posB=True +) + +ease_all = cornac.models.EASE( + lamb=500, + name="EASEᴿ (B>-∞)", + posB=False +) + + +# Instantiate evaluation measures +rec_20 = cornac.metrics.Recall(k=20) +rec_50 = cornac.metrics.Recall(k=50) +ndcg_100 = cornac.metrics.NDCG(k=100) + + +# Put everything together into an experiment and run it +cornac.Experiment( + eval_method=ratio_split, + models=[ease_original, ease_all], + metrics=[rec_20, rec_50, ndcg_100], + user_based=True, #If `False`, results will be averaged over the number of ratings. + save_dir=None +).run() \ No newline at end of file From f05ee227d00af9e0ca33c1a6e54d7aa6e669e5fa Mon Sep 17 00:00:00 2001 From: Hakan Date: Tue, 14 Jun 2022 19:06:09 +0200 Subject: [PATCH 2/6] sorting corrected --- docs/source/models.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index c436392b3..eb9665b89 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -33,6 +33,11 @@ Adversarial Training Towards Robust Multimedia Recommender System (AMR) .. automodule:: cornac.models.amr.recom_amr :members: +Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ) +-------------------------------------------------- +.. automodule:: cornac.models.ease.recom_ease + :members: + Collaborative Context Poisson Factorization (C2PF) ---------------------------------------------------- .. automodule:: cornac.models.c2pf.recom_c2pf @@ -231,9 +236,4 @@ User K-Nearest-Neighbors (UserKNN) Weighted Matrix Factorization (WMF) -------------------------------------------------- .. automodule:: cornac.models.wmf.recom_wmf - :members: - -Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ) --------------------------------------------------- -.. automodule:: cornac.models.ease.recom_ease :members: \ No newline at end of file From dacd5f90d9205178afb58ba2373453587cc5babd Mon Sep 17 00:00:00 2001 From: Hakan Date: Tue, 14 Jun 2022 19:10:06 +0200 Subject: [PATCH 3/6] corrections --- cornac/models/ease/recom_ease.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/cornac/models/ease/recom_ease.py b/cornac/models/ease/recom_ease.py index 464a936ce..e154dd328 100644 --- a/cornac/models/ease/recom_ease.py +++ b/cornac/models/ease/recom_ease.py @@ -27,10 +27,6 @@ class EASE(Recommender): seed: int, optional, default: None Random seed for parameters initialization. - init_par: numpy 1d array, optional, default: None - The initial object parition, 1d array contaning the cluster label (int type starting from 0) \ - of each object (user). If par = None, then skmeans is initialized randomly. - References ---------- * Steck, H. (2019, May). "Embarrassingly shallow autoencoders for sparse data." \ @@ -76,7 +72,7 @@ def fit(self, train_set, val_set=None): # A rating matrix self.U = self.train_set.matrix - # Grahm matrix is X^t X, compute dot product + # Gram matrix is X^t X, compute dot product G = self.U.T.dot(self.U).toarray() diag_indices = np.diag_indices(G.shape[0]) From 7066a779c22a1322b0542145e3df1656748c87f4 Mon Sep 17 00:00:00 2001 From: Quoc-Tuan Truong Date: Thu, 16 Jun 2022 11:46:50 +0800 Subject: [PATCH 4/6] Update __init__.py --- cornac/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index 4a84333a3..63e7315e7 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -31,8 +31,8 @@ from .ctr import CTR from .cvae import CVAE from .cvaecf import CVAECF -from .efm import EFM from .ease import EASE +from .efm import EFM from .global_avg import GlobalAvg from .hft import HFT from .hpf import HPF From 5c980515087e8a3920feca81c147ada41f7c2c16 Mon Sep 17 00:00:00 2001 From: Quoc-Tuan Truong Date: Fri, 17 Jun 2022 18:03:26 +0800 Subject: [PATCH 5/6] Update README.md --- examples/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/README.md b/examples/README.md index 7e84804bc..24ff2c01b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -68,6 +68,8 @@ [bpr_netflix.py](bpr_netflix.py) - Example to run Bayesian Personalized Ranking (BPR) with Netflix dataset. +[ease_movielens.py](ease_movielens.py) - Embarrassingly Shallow Autoencoders (EASEᴿ) with MovieLens 1M dataset. + [fm_example.py](fm_example.py) - Example to run Factorization Machines (FM) with MovieLens 100K dataset. [hpf_movielens.py](hpf_movielens.py) - (Hierarchical) Poisson Factorization vs BPR on MovieLens data. From 70fbb96362475e1c8f828b1497af4d891caad679 Mon Sep 17 00:00:00 2001 From: Hakan Date: Tue, 16 Aug 2022 11:36:14 +0200 Subject: [PATCH 6/6] z-scoREC Implementation --- README.md | 1 + cornac/models/__init__.py | 1 + cornac/models/zscorec/__init__.py | 1 + cornac/models/zscorec/recom_zscorec.py | 141 +++++++++++++++++++++++++ docs/source/models.rst | 5 + examples/README.md | 2 + examples/zscorec_movielens.py | 59 +++++++++++ 7 files changed, 210 insertions(+) create mode 100644 cornac/models/zscorec/__init__.py create mode 100644 cornac/models/zscorec/recom_zscorec.py create mode 100644 examples/zscorec_movielens.py diff --git a/README.md b/README.md index 119089777..d72f6b8ab 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,7 @@ The recommender models supported by Cornac are listed below. Why don't you join | Year | Model and paper | Additional dependencies | Examples | | :---: | --- | :---: | :---: | +| 2022 | [ImposeSVD: Incrementing PureSVD For Top-N Recommendations for Cold-Start Problems and Sparse Datasets (z-scoREC)](cornac/models/zscorec), [paper]( https://doi.org/10.1093/comjnl/bxac106) | N/A | [zscorec_netflix.py](examples/zscorec_netflix.py) | 2021 | [Bilateral Variational Autoencoder for Collaborative Filtering (BiVAECF)](cornac/models/bivaecf), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441759) | [requirements.txt](cornac/models/bivaecf/requirements.txt) | [PreferredAI/bi-vae](https://github.com/PreferredAI/bi-vae) | | [Causal Inference for Visual Debiasing in Visually-Aware Recommendation (CausalRec)](cornac/models/causalrec), [paper](https://arxiv.org/abs/2107.02390) | [requirements.txt](cornac/models/causalrec/requirements.txt) | [causalrec_clothing.py](examples/causalrec_clothing.py) | | [Explainable Recommendation with Comparative Constraints on Product Aspects (ComparER)](cornac/models/comparer), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441754) | N/A | [PreferredAI/ComparER](https://github.com/PreferredAI/ComparER) diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index 63e7315e7..07d5bc265 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -60,6 +60,7 @@ from .vbpr import VBPR from .vmf import VMF from .wmf import WMF +from .zscorec import zscoREC try: from .fm import FM diff --git a/cornac/models/zscorec/__init__.py b/cornac/models/zscorec/__init__.py new file mode 100644 index 000000000..4f758966e --- /dev/null +++ b/cornac/models/zscorec/__init__.py @@ -0,0 +1 @@ +from .recom_zscorec import zscoREC \ No newline at end of file diff --git a/cornac/models/zscorec/recom_zscorec.py b/cornac/models/zscorec/recom_zscorec.py new file mode 100644 index 000000000..d98d7a325 --- /dev/null +++ b/cornac/models/zscorec/recom_zscorec.py @@ -0,0 +1,141 @@ +import numpy as np + +from cornac.models.recommender import Recommender +from cornac.exception import ScoreException + +import scipy.sparse as sps +import scipy.stats as stats +import numpy as np + +class zscoREC(Recommender): + """ImposeSVD: Incrementing PureSVD For Top-N Recommendations for Cold-Start Problems and Sparse Datasets. + + Parameters + ---------- + name: string, optional, default: 'zscoREC' + The name of the recommender model. + + lamb: float, optional, default: .2 + Shifting parameter λ ∈ R + + posZ: boolean, optional, default: False + Remove Negative Weights + + trainable: boolean, optional, default: True + When False, the model is not trained and Cornac assumes that the model is already \ + trained. + + verbose: boolean, optional, default: False + When True, some running logs are displayed. + + seed: int, optional, default: None + Random seed for parameters initialization. + + References + ---------- + * Hakan Yilmazer, Selma Ayşe Özel, ImposeSVD: Incrementing PureSVD For Top-N Recommendations for Cold-Start Problems and Sparse Datasets, + The Computer Journal, 2022;, bxac106, https://doi.org/10.1093/comjnl/bxac106. + """ + + def __init__( + self, + name="zscoREC", + lamb=.2, + posZ=True, + trainable=True, + verbose=True, + seed=None, + Z=None, + U=None, + ): + Recommender.__init__(self, name=name, trainable=trainable, verbose=verbose) + self.lamb = lamb + self.posZ = posZ + self.verbose = verbose + self.seed = seed + self.Z = Z + self.U = U + + def fit(self, train_set, val_set=None): + """Fit the model to observations. + + Parameters + ---------- + train_set: :obj:`cornac.data.Dataset`, required + User-Item preference data as well as additional modalities. + + val_set: :obj:`cornac.data.Dataset`, optional, default: None + User-Item preference data for model selection purposes (e.g., early stopping). + + Returns + ------- + self : object + """ + Recommender.fit(self, train_set, val_set) + + # A rating matrix + self.U = self.train_set.matrix + + # Gram matrix is X^t X, Eq.(2) in paper + G = sps.csr_matrix(self.U.T.dot(self.U), dtype="uint64") + + # Shifting Gram matrix Eq.(7) in paper + """ In a structure view, the shifting operation on the second + matrix (which is the same as the matrix for the Gram-matrix + estimation) is performed as row-based degree shifting on + Gram-matrix. + """ + W = G - np.tile(G.diagonal()*self.lam, (G.shape[0],1)).transpose() + + # Column-based z-score normalization + # fit each item's column (could be parallel for big streams) + Z = stats.mstats.zscore(np.array(W), axis=0, ddof=1, nan_policy='omit') + + # if self.posZ remove -ve values + if self.posZ: + Z[Z<0]=0 + + # save Z for predictions + self.Z = Z + + return self + + + def score(self, user_idx, item_idx=None): + """Predict the scores/ratings of a user for an item. + + Parameters + ---------- + user_idx: int, required + The index of the user for whom to perform score prediction. + + item_idx: int, optional, default: None + The index of the item for which to perform score prediction. + If None, scores for all known items will be returned. + + Returns + ------- + res : A scalar or a Numpy array + Relative scores that the user gives to the item or to all known items + + """ + if item_idx is None: + if self.train_set.is_unk_user(user_idx): + raise ScoreException( + "Can't make score prediction for (user_id=%d)" % user_idx + ) + + known_item_scores = self.U[user_idx, :].dot(self.Z) + return known_item_scores + else: + if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( + item_idx + ): + raise ScoreException( + "Can't make score prediction for (user_id=%d, item_id=%d)" + % (user_idx, item_idx) + ) + + user_pred = self.Z[item_idx, :].dot(self.U[user_idx, :]) + + return user_pred \ No newline at end of file diff --git a/docs/source/models.rst b/docs/source/models.rst index eb9665b89..2561c3252 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -10,6 +10,11 @@ Recommender (Generic Class) .. automodule:: cornac.models.recommender :members: +ImposeSVD: Incrementing PureSVD For Top-N Recommendations for Cold-Start Problems and Sparse Datasets (z-scoREC) +---------------------------------------------------- +.. automodule:: cornac.models.zscorec.recom_zscorec + :members: + Bilateral VAE for Collaborative Filtering (BiVAECF) ---------------------------------------------------- .. automodule:: cornac.models.bivaecf.recom_bivaecf diff --git a/examples/README.md b/examples/README.md index 24ff2c01b..b192ed994 100644 --- a/examples/README.md +++ b/examples/README.md @@ -93,3 +93,5 @@ [vaecf_citeulike.py](vaecf_citeulike.py) - Variational Autoencoder for Collaborative Filtering (VAECF) with CiteULike dataset. [wmf_example.py](wmf_example.py) - Weighted Matrix Factorization with CiteULike dataset. + +[zscorec_movielens.py](zscorec_movielens.py) - z-scoREC model with MovieLens 1M dataset. diff --git a/examples/zscorec_movielens.py b/examples/zscorec_movielens.py new file mode 100644 index 000000000..ea71d85ab --- /dev/null +++ b/examples/zscorec_movielens.py @@ -0,0 +1,59 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Example (z-scoREC) ImposeSVD: Incrementing PureSVD For Top-N Recommendations for Cold-Start Problems and Sparse Datasets""" + +import cornac +from cornac.datasets import movielens +from cornac.eval_methods import RatioSplit + + +# Load user-item feedback +data = movielens.load_feedback(variant="1M") + +# Instantiate an evaluation method to split data into train and test sets. +ratio_split = RatioSplit( + data=data, + test_size=0.2, + exclude_unknowns=True, + verbose=True, + seed=123, + rating_threshold=0.8, +) + +zscorec_original = cornac.models.zscoREC( + lam=.2, + name="zscoREC (Z>0)", + posZ=True +) + +zscorec_all = cornac.models.zscoREC( + lam=.2, + name="zscoREC (Z>-∞)", + posZ=False +) + +# Instantiate evaluation measures +rec_20 = cornac.metrics.Recall(k=20) +rec_50 = cornac.metrics.Recall(k=50) +ndcg_100 = cornac.metrics.NDCG(k=100) + +# Put everything together into an experiment and run it +cornac.Experiment( + eval_method=ratio_split, + models=[zscorec_original, zscorec_all], + metrics=[rec_20, rec_50, ndcg_100], + user_based=True, #If `False`, results will be averaged over the number of ratings. + save_dir=None +).run() \ No newline at end of file