-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.cpp
128 lines (110 loc) · 4.11 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#include "GraphNNClassifier.hpp"
#include "GraphNN.hpp"
#include "GraphNNLearner.hpp"
#include "GraphNNLearnableParameter.hpp"
#include "FileGraphGenerator.hpp"
#include "DatasetManager.hpp"
#include <Eigen/Core>
#include <random>
#include <iostream>
#include <filesystem>
namespace fs = std::filesystem;
void RunAllTrainer(const GraphNNClassifier &classifier,
const DatasetManager<GraphNN> &trainManager, const DatasetManager<GraphNN> &validManager,
int iteration, int batchSize);
GraphNNLearnableParameter InitializedParameter();
int main(int argc, char* argv[]) {
if (argc != 5) {
std::cout << "usage: program iteration batch_size dataset_path output_path" << std::endl;
return -1;
}
int iteration = std::stoi(argv[1]);
int batchSize = std::stoi(argv[2]);
fs::path datasetPath(argv[3]);
fs::path resultPath(argv[4]);
GraphNNClassifier classifier(2);
std::shared_ptr<DatasetManager<Data>> trainManager;
std::shared_ptr<DatasetManager<Data>> validManager;
try {
trainManager.reset(
new DatasetManager<Data>(
DatasetLoader::LoadGraphAndLabel((datasetPath / "train/train/").string(), 1200, 8)));
validManager.reset(
new DatasetManager<Data>(
DatasetLoader::LoadGraphAndLabel((datasetPath / "train/valid/").string(), 800, 8, 1200)));
} catch (std::string message) {
std::cout << message << std::endl;
exit(-1);
}
// RunAllTrainer(classifier, *trainManager, *validManager, iteration, batchSize);
// train using adam
std::shared_ptr<GraphNNLearner> learner(new GraphNNAdam(classifier));
GraphNNLearnableParameter parameter = InitializedParameter();
for (int i = 0; i < iteration; i++)
{
// train
DatasetBatchManager trainBatchManager(*trainManager, batchSize);
(*learner)(trainBatchManager, parameter);
std::cout << "iteration:" << i << std::endl;
}
// test
std::ofstream test((resultPath / "results/test.txt").string());
if (test.fail()) {
std::cout << "Failed to open : " << "test" << std::endl;
exit(-1);
}
std::shared_ptr<DatasetManager<GraphNN>> testManager;
try {
testManager.reset(
new DatasetManager<GraphNN>(
DatasetLoader::LoadGraph((datasetPath / "test/").string(), 1200, 8)));
} catch (std::string message) {
std::cout << message << std::endl;
exit(-1);
}
for (auto itr = testManager->begin(); itr != testManager->end(); ++itr) {
test << classifier.Classify(*itr, parameter) << std::endl;
}
return 0;
}
void RunAllTrainer(
const GraphNNClassifier &classifier,
const DatasetManager<Data> &trainManager,
const DatasetManager<Data> &validManager,
int iteration, int batchSize) {
std::vector<std::pair<GraphNNLearner*, std::string>> learner = {
std::make_pair(
new GraphNNAdam(classifier),
"../results/train_and_valid_adam.csv")};
for (auto itr : learner) {
GraphNNLearnableParameter parameter = InitializedParameter();
std::ofstream result(itr.second);
if (result.fail()) {
std::cout << "Failed to open : " << "result" << std::endl;
exit(-1);
}
for (int i = 0; i < iteration; i++) {
// train
DatasetBatchManager<Data> trainBatchManager(trainManager, batchSize);
// valid
result << classifier.AverageLoss(trainManager, parameter) << ","
<< classifier.AverageLoss(validManager, parameter) << std::endl;
(*itr.first)(trainBatchManager, parameter);
}
}
}
GraphNNLearnableParameter InitializedParameter() {
// init parameters
std::random_device seed_gen;
std::default_random_engine engine(seed_gen());
std::normal_distribution<> dist(0.0, 0.4);
Eigen::VectorXd initA(8);
Eigen::MatrixXd initW(8,8);
for (int i = 0; i < 8; i++) {
initA(i) = dist(engine);
for (int j = 0; j < 8; j++) {
initW(i,j) = dist(engine);
}
}
return GraphNNLearnableParameter(initW, 0.0, initA);
}