-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataset.h
227 lines (184 loc) · 6.03 KB
/
dataset.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
#ifndef DSL_DATASET_H
#define DSL_DATASET_H
// {{SMILE_PUBLIC_HEADER}}
#include <vector>
#include <string>
#include <cmath>
#include <cassert>
#include <cstdio>
class DSL_network;
#define DSL_MISSING_INT (-1)
#define DSL_MISSING_FLOAT ((float)sqrt(-1.0))
struct DSL_datasetVarInfo
{
DSL_datasetVarInfo() : discrete(true), missingInt(DSL_MISSING_INT) {}
bool discrete;
std::string id;
int missingInt;
float missingFloat;
std::vector<std::string> stateNames;
};
struct DSL_datasetParseParams
{
DSL_datasetParseParams() :
missingValueToken("*"),
missingInt(DSL_MISSING_INT),
missingFloat(DSL_MISSING_FLOAT),
columnIdsPresent(true) {}
std::string missingValueToken;
int missingInt;
float missingFloat;
bool columnIdsPresent;
};
struct DSL_datasetWriteParams
{
DSL_datasetWriteParams() :
missingValueToken("*"),
columnIdsPresent(true),
useStateIndices(false),
separator('\t'),
floatFormat("%g") {}
std::string missingValueToken;
bool columnIdsPresent;
bool useStateIndices;
char separator;
std::string floatFormat;
};
struct DSL_datasetMatch
{
DSL_datasetMatch() : node(-1), slice(0), column(-1) {}
int node;
int slice;
int column;
};
#ifdef NDEBUG
#define DS_VALIDATE_IDX(var, rec)
#else
#define DS_VALIDATE_IDX(var, rec) ValidateIdx(var, rec)
#endif
class DSL_dataset
{
public:
DSL_dataset();
DSL_dataset(const DSL_dataset &src);
DSL_dataset& operator=(const DSL_dataset &src);
~DSL_dataset() { FreeData(); }
int ReadFile(const std::string &filename, const DSL_datasetParseParams *params = NULL, std::string *errOut = NULL);
int WriteFile(const std::string &filename, const DSL_datasetWriteParams *params = NULL, std::string *errOut = NULL) const;
// MatchNetwork may change the integer indices in the dataset to ensure
// the correct fit to outcome ids in the network - hence the method isn't const
int MatchNetwork(const DSL_network &net, std::vector<DSL_datasetMatch> &matching, std::string &errMsg);
void PrepareStructure(DSL_network &net);
int AddIntVar(const std::string id = std::string(), int missingValue = DSL_MISSING_INT);
int AddFloatVar(const std::string id = std::string(), float missingValue = DSL_MISSING_FLOAT);
int RemoveVar(int var);
void AddEmptyRecord();
void SetNumberOfRecords(int numRecords);
int RemoveRecord(int rec);
int FindVariable(const std::string &id) const;
int GetNumberOfVariables() const { return int(metadata.size()); }
int GetNumberOfRecords() const
{
return filter == NULL ? GetRealRowCount() : int(filter->size());
}
int ApplyFilter(const std::vector<int> &filter);
void ClearFilter();
int GetInt(int var, int rec) const
{
DS_VALIDATE_IDX(var, rec);
assert(IsDiscrete(var));
return (*static_cast<const std::vector<int> *>(data[var]))[RealIdx(rec)];
}
float GetFloat(int var, int rec) const
{
DS_VALIDATE_IDX(var, rec);
assert(!IsDiscrete(var));
return (*static_cast<const std::vector<float> *>(data[var]))[RealIdx(rec)];
}
void SetInt(int var, int rec, int value)
{
DS_VALIDATE_IDX(var, rec);
assert(IsDiscrete(var));
(*static_cast<std::vector<int> *>(data[var]))[RealIdx(rec)] = value;
}
void SetFloat(int var, int rec, float value)
{
DS_VALIDATE_IDX(var, rec);
assert(!IsDiscrete(var));
(*static_cast<std::vector<float> *>(data[var]))[RealIdx(rec)] = value;
}
void SetMissing(int var, int rec)
{
if (IsDiscrete(var))
{
SetInt(var, rec, metadata[var].missingInt);
}
else
{
SetFloat(var, rec, metadata[var].missingFloat);
}
}
int GetMissingInt(int var) const
{
assert(IsDiscrete(var));
return metadata[var].missingInt;
}
float GetMissingFloat(int var) const
{
assert(!IsDiscrete(var));
return metadata[var].missingFloat;
}
bool IsMissing(int var, int rec) const
{
const DSL_datasetVarInfo &m = metadata[var];
return m.discrete ?
GetInt(var, rec) == m.missingInt :
IsMissingFloat(var, rec, m.missingFloat);
}
bool IsDiscrete(int v) const { return metadata[v].discrete; }
const std::vector<int> & GetIntData(int var) const
{
assert(IsDiscrete(var));
return *static_cast<const std::vector<int> *>(data[var]);
}
const std::vector<float> & GetFloatData(int var) const
{
assert(!IsDiscrete(var));
return *static_cast<const std::vector<float> *>(data[var]);
}
const DSL_datasetVarInfo& GetVariableInfo(int var) const { return metadata[var]; }
const std::string& GetId(int var) const { return metadata[var].id; }
int SetId(int v, const std::string &newId);
const std::vector<std::string>& GetStateNames(int var) const
{
assert(metadata[var].discrete);
return metadata[var].stateNames;
}
int SetStateNames(int var, const std::vector<std::string> &stateNames);
int GetMinMaxInt(int var, int &minval, int &maxval) const;
int GetMinMaxFloat(int var, float &minval, float &maxval) const;
int Get(int var, int rec, int &val) const;
int Get(int var, int rec, float &val) const;
bool HasMissingData(int var) const;
bool IsConstant(int var) const;
enum DiscretizeAlgorithm { Hierarchical, UniformWidth, UniformCount };
int Discretize(int var, DiscretizeAlgorithm alg, int intervals, const std::string &statePrefix, std::vector<double> &edges);
int Discretize(int var, DiscretizeAlgorithm alg, int intervals, const std::string &statePrefix);
private:
bool IsMissingFloat(int var, int rec, float missingFloat) const;
void FreeData();
void FreeColData(int var);
void ValidateIdx(int var, int rec) const;
int RealIdx(int rec) const
{
return filter == NULL ? rec : (*filter)[rec];
}
int AddVarHelper(const std::string &id, bool discrete, int missingInt, float missingFloat);
int GetRealRowCount() const;
int DoWriteFile(FILE *fout, const DSL_datasetWriteParams ¶ms, std::string &errOut) const;
std::vector<DSL_datasetVarInfo> metadata;
std::vector<void *> data;
std::vector<int>* filter;
friend class DSL_datasetParser;
};
#endif