-
Notifications
You must be signed in to change notification settings - Fork 0
/
metadata.py
69 lines (47 loc) · 2.04 KB
/
metadata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
The MIT License
Copyright (c) 2018-2020 Mark Douthwaite
"""
import fire
import numpy as np
import pandas as pd
from xanthus import datasets
from xanthus.datasets.build import build
from xanthus.models import GeneralizedMatrixFactorization, utils
from xanthus.evaluate import create_rankings, score, metrics
np.random.seed(42)
def run(version="ml-latest-small", samples=4, input_dim=3, batch_size=256, epochs=1):
# download the dataset
datasets.movielens.download(version=version)
# load and prepare datasets
df = pd.read_csv(f"data/{version}/ratings.csv")
item_df = pd.read_csv(f"data/{version}/movies.csv")
item_df = item_df.rename(columns={"movieId": "item"})
item_df = datasets.utils.fold(
item_df, "item", ["genres"], fn=lambda s: (t.lower() for t in s.split("|"))
)
df = df.rename(columns={"userId": "user", "movieId": "item"})
train_dataset, test_dataset = build(
df, item_df=item_df, policy="leave_one_out"
)
n, m = train_dataset.user_dim, train_dataset.item_dim
# build training arrays (you can also use 'batched' models too.
users_x, items_x, y = train_dataset.to_components(
negative_samples=samples, output_dim=input_dim
)
# initialize, compile and train the model
model = GeneralizedMatrixFactorization(n, m)
model.compile(optimizer="adam", loss="binary_crossentropy")
model.fit([users_x, items_x], y, epochs=epochs)
# evaluate the model as described in He et al.
users, items = create_rankings(
test_dataset, train_dataset, output_dim=input_dim, n_samples=100, unravel=True
)
_, test_items, _ = test_dataset.to_components(shuffle=False)
scores = model.predict([users, items], verbose=1, batch_size=batch_size)
recommended = utils.reshape_recommended(users, items, scores, 10, mode="array")
ndcg = score(metrics.truncated_ndcg, test_items, recommended).mean()
hr = score(metrics.hit_ratio, test_items, recommended).mean()
print(f"NDCG={ndcg}, HitRatio={hr}")
if __name__ == "__main__":
fire.Fire(run)