-
Notifications
You must be signed in to change notification settings - Fork 10
/
test.cpp
201 lines (180 loc) · 6.25 KB
/
test.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
#include <algorithm>
#include <sqlite3.h>
#include "globals.h"
#include "load.h"
#include "predictor.h"
#include "optimizers.h"
#include <omp.h>
#include <cstdlib>
#include <cstdio>
using namespace std;
extern const int MAX_USERS; // users in the entire training set
extern const int MAX_MOVIES; // movies in the entire training set (+1)
static int method = 0; // 0 for sgd, 1 for reg gradient descent, 2 for conjugate gradient
static bool warm_start = 0; // set to 1 to load features.bin, previous solution
static int sample_size = 1000000;
static const int PROBE_SIZE = 1408395;
static const int NON_PROBE_SIZE = 99072112;
Settings parse_args(int argc, char **argv);
void setup(int& num_ratings, int& num_cv_ratings, Data *& ratings, Data *& cv_ratings);
Data *sample(Data *ratings, int sample_size, int num_ratings);
double cost(Predictor& p, Data *ratings, int num_ratings);
void log(Settings& s, double total_time, double train_cost, double cv_cost);
int comparator(const void *ii, const void *jj) {
Data i = *(Data *)ii, j = *(Data *)jj;
if (i.user > j.user) return 1;
else if (i.user == j.user && i.movie > j.movie) return 1;
return -1;
}
int comparator2(const void *ii, const void *jj) {
Data i = *(Data *)ii, j = *(Data *)jj;
if (i.movie > j.movie) return 1;
else if (i.movie == j.movie && i.user > j.user) return 1;
return -1;
}
int main(int argc, char **argv) {
int num_ratings, num_cv_ratings, user, movie;
Data *ratings, *cv_ratings;
//sample_size = 500;
sample_size = NON_PROBE_SIZE;
setup(num_ratings, num_cv_ratings, ratings, cv_ratings);
const int nu = 480189;
const int nm = 17770;
int *u_rating_ptrs = new int[nu];
int *m_rating_ptrs = new int[nm+1];
movie = ratings[0].movie;
m_rating_ptrs[1] = 0;
for (int r=1; r<sample_size; r++) {
if (ratings[r].movie != movie) {
movie = ratings[r].movie;
m_rating_ptrs[movie] = r;
}
}
cout << "movie rating counts" << endl;
Data *start, *end;
for (int m=1; m<11; m++) {
start = &ratings[m_rating_ptrs[m]];
end = &ratings[m_rating_ptrs[m+1]];
printf("%2d %4lu\n", m, end-start);
}
cout << "sorting" << endl;
time_t t1,t2; time(&t1);
qsort(ratings, sample_size, sizeof(Data), comparator);
time(&t2);
cout << "total time: " << difftime(t2,t1) << "s" << endl;
user = ratings[0].user;
u_rating_ptrs[0] = 0;
for (int r=1; r<sample_size; r++) {
if (ratings[r].user != user) {
user = ratings[r].user;
u_rating_ptrs[user] = r;
}
}
cout << "user rating counts" << endl;
for (int u=0; u<10; u++) {
start = &ratings[u_rating_ptrs[u]];
end = &ratings[u_rating_ptrs[u+1]];
printf("%2d %4lu\n", u, end-start);
}
return 0;
}
void setup(int& num_ratings, int& num_cv_ratings, Data *& ratings, Data *& cv_ratings) {
ratings = new Data[NON_PROBE_SIZE];
num_ratings = load_binary(ratings, "cpp/train.bin");
if (sample_size < NON_PROBE_SIZE)
ratings = sample(ratings, sample_size, num_ratings);
cv_ratings = new Data[PROBE_SIZE];
num_cv_ratings = load_binary(cv_ratings, "cpp/cv.bin");
}
// essentially this takes sample_size random ratings and shoves them to the back
// of the array, and returns a pointer to the beginning of this shuffled part
Data *sample(Data *ratings, int sample_size, int num_ratings) {
int i, r;
for (i=0; i<sample_size; i++) {
r = rand()%num_ratings;
if (r != num_ratings-i-1)
swap(ratings[r], ratings[num_ratings-i-1]);
}
cout << "sampled " << sample_size << " ratings" << endl;
return &ratings[num_ratings - sample_size];
}
Settings parse_args(int argc, char **argv) {
Settings s;
if (argc > 3) {
cout << "usage: ./funk [-i] [x K]" << endl;
exit(1);
}
if (argc == 3) {
s.num_features = 50;
sample_size = NON_PROBE_SIZE;
method = 2;
s.max_epochs = 0;
s.K = atof(argv[2]);
cout << "using K = " << s.K << endl;
}
if (argc == 2) { // interactive mode
cout << "enter number of features: ";
cin >> s.num_features;
cout << "enter sample size (0 to use all data): ";
cin >> sample_size;
if (sample_size <= 0)
sample_size = NON_PROBE_SIZE;
cout << "enter 0 for sgd, 1 for reg grad desc, 2 for bfgs: ";
cin >> method;
if (method == 0)
s.lrate = .001;
else if (method == 1)
s.lrate = .0005;
else if (method != 2) {
cout << "only 0 or 1 or 2, silly" << endl;
exit(1);
}
if (method == 0 || method == 1) {
cout << "enter learning rate (0 for " << s.lrate << "): ";
double temp; cin >> temp;
if (temp > 0)
s.lrate = temp;
}
cout << "max iterations: ";
cin >> s.max_epochs;
cout << "regularizataion parameter (.015): ";
cin >> s.K;
cout << "minimum improvement (.001): ";
cin >> s.min_improvement;
}
return s;
}
void log(Settings& s, double total_time, double train_cost, double cv_cost) {
sqlite3 *db;
char *zErrMsg = 0;
int rc;
string sql;
rc = sqlite3_open_v2("log.db", &db, SQLITE_OPEN_READWRITE, NULL);
if (rc) {
fprintf(stderr, "Can't open database: %s\n", sqlite3_errmsg(db));
sqlite3_close(db);
exit(1);
}
sql = "INSERT INTO log "
"(datetime, method, num_features, cv_cost, time, train_cost, learning_rate, "
"regularizer, sample_size, warm_start) "
"VALUES (datetime('now'), ";
char temp[100];
string method_string;
if (method == 0) method_string = "sgd";
else if (method == 1) method_string = "gd";
else if (method == 2) {
method_string = "cg";
s.lrate = 0;
}
sprintf(temp, "'%s', %d, %f, %f, %f, %f, %f, %d, %d)",
method_string.c_str(), s.num_features, cv_cost, total_time, train_cost, s.lrate,
s.K, sample_size, warm_start);
sql += temp;
rc = sqlite3_exec(db, sql.c_str(), NULL, 0, &zErrMsg);
if (rc!=SQLITE_OK) {
fprintf(stderr, "SQL error: %s\n", zErrMsg);
sqlite3_free(zErrMsg);
}
sqlite3_close(db);
}