-
Notifications
You must be signed in to change notification settings - Fork 3
/
QuartetCounterLookup.hpp
318 lines (291 loc) · 13.4 KB
/
QuartetCounterLookup.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
#pragma once
#include "genesis/genesis.hpp"
#include <vector>
#include <cassert>
#include <algorithm>
#include <memory>
#include "TreeInformation.hpp"
#include "quartet_lookup_table.hpp"
#include <unordered_map>
#include <cstdint>
using namespace genesis;
using namespace tree;
using namespace utils;
using namespace std;
#define CO(a,b,c,d) (a) * n_cube + (b) * n_square + (c) * n + (d)
/**
* Let n be the number of taxa in the reference tree.
* Count occurrences of quartet topologies in the set of evaluation trees using a O(n^4) lookup table with O(1) lookup cost.
*/
template<typename CINT>
class QuartetCounterLookup {
public:
QuartetCounterLookup(const Tree &refTree, const std::string &evalTreesPath, size_t m, bool savemem);
~QuartetCounterLookup() = default;
std::tuple<CINT, CINT, CINT> countQuartetOccurrences(size_t aIdx, size_t bIdx, size_t cIdx, size_t dIdx) const;
private:
void countQuartets(const std::string &evalTreesPath, size_t m,
const std::unordered_map<std::string, size_t> &taxonToReferenceID);
void updateQuartets(const Tree &tree, size_t nodeIdx, const std::vector<int> &eulerTourLeaves,
const std::vector<int> &linkToEulerLeafIndex);
void updateQuartetsThreeLinks(size_t link1, size_t link2, size_t link3, const Tree &tree,
const std::vector<int> &eulerTourLeaves, const std::vector<int> &linkToEulerLeafIndex);
void updateQuartetsThreeClades(size_t startLeafIndexS1, size_t endLeafIndexS1, size_t startLeafIndexS2,
size_t endLeafIndexS2, size_t startLeafIndexS3, size_t endLeafIndexS3,
const std::vector<int> &eulerTourLeaves);
std::pair<size_t, size_t> subtreeLeafIndices(size_t linkIdx, const Tree &tree,
const std::vector<int> &linkToEulerLeafIndex);
CINT lookupQuartetCount(size_t aIdx, size_t bIdx, size_t cIdx, size_t dIdx) const;
std::vector<CINT> lookupTableFast; /**> larger O(n^4) lookup table storing the count of each quartet topology */
QuartetLookupTable<CINT> lookupTable; /**> smaller O(n^4) lookup table storing the count of each quartet topology */
size_t n; /**> number of taxa in the reference tree */
size_t n_square; /**> n*n */
size_t n_cube; /**> n*n*n */
std::vector<size_t> refIdToLookupID;
bool savemem; /**> trade speed for less memory or not */
};
/**
* Update the quartet topology counts for quartets {a.b.c.d} where a,b \in S_1, c \in S_2, and d \in S_3.
* @param startLeafIndexS1 the first index in eulerTourLeaves that corresponds to a leaf in subtree S_1
* @param endLeafIndexS1 the last index in eulerTourLeaves that corresponds to a leaf in subtree S_1
* @param startLeafIndexS2 the first index in eulerTourLeaves that corresponds to a leaf in subtree S_2
* @param endLeafIndexS2 the last index in eulerTourLeaves that corresponds to a leaf in subtree S_2
* @param startLeafIndexS3 the first index in eulerTourLeaves that corresponds to a leaf in subtree S_3
* @param endLeafIndexS3 the last index in eulerTourLeaves that corresponds to a leaf in subtree S_3
* @param eulerTourLeaves the leaves' IDs of the tree traversed in an euler tour order
*/
template<typename CINT>
void QuartetCounterLookup<CINT>::updateQuartetsThreeClades(size_t startLeafIndexS1, size_t endLeafIndexS1,
size_t startLeafIndexS2, size_t endLeafIndexS2, size_t startLeafIndexS3, size_t endLeafIndexS3,
const std::vector<int> &eulerTourLeaves) {
size_t aLeafIndex = startLeafIndexS1;
size_t bLeafIndex = startLeafIndexS2;
size_t cLeafIndex = startLeafIndexS3;
while (aLeafIndex != endLeafIndexS1) {
size_t a = eulerTourLeaves[aLeafIndex];
size_t a2LeafIndex = (aLeafIndex + 1) % eulerTourLeaves.size();
while (a2LeafIndex != endLeafIndexS1) {
size_t a2 = eulerTourLeaves[a2LeafIndex];
while (bLeafIndex != endLeafIndexS2) {
size_t b = eulerTourLeaves[bLeafIndex];
while (cLeafIndex != endLeafIndexS3) {
size_t c = eulerTourLeaves[cLeafIndex];
if (savemem) {
auto& tuple = lookupTable.get_tuple(a, a2, b, c);
size_t tupleIdx = lookupTable.tuple_index(a, a2, b, c);
//#pragma omp atomic
tuple[tupleIdx]++;
} else {
//#pragma omp atomic
lookupTableFast[CO(a, a2, b, c)]++;
}
cLeafIndex = (cLeafIndex + 1) % eulerTourLeaves.size();
}
bLeafIndex = (bLeafIndex + 1) % eulerTourLeaves.size();
cLeafIndex = startLeafIndexS3;
}
a2LeafIndex = (a2LeafIndex + 1) % eulerTourLeaves.size();
bLeafIndex = startLeafIndexS2;
cLeafIndex = startLeafIndexS3;
}
aLeafIndex = (aLeafIndex + 1) % eulerTourLeaves.size();
bLeafIndex = startLeafIndexS2;
cLeafIndex = startLeafIndexS3;
}
}
/**
* Return a pair <start, end> representing the leaf indices in the Euler tour within the subtree induced by the genesis TreeLink with ID linkIdx.
* The leaf indices are between [start,end), this means they include the start index but not the end index.
* @param linkIdx the ID of the TreeLink from genesis
* @param tree the tree
* @param linkToEulerLeafIndex Mapping of each link in the tree to indices in the euler tour;
* needed for determining first and last index of leaves belonging to a subtree.
*/
template<typename CINT>
std::pair<size_t, size_t> QuartetCounterLookup<CINT>::subtreeLeafIndices(size_t linkIdx, const Tree &tree,
const std::vector<int> &linkToEulerLeafIndex) {
size_t outerLinkIdx = tree.link_at(linkIdx).outer().index();
return {linkToEulerLeafIndex[linkIdx] % linkToEulerLeafIndex.size(), linkToEulerLeafIndex[outerLinkIdx] % linkToEulerLeafIndex.size()};
}
/**
* Given the genesis links to the tree subtrees induced by an inner node, update the quartet topology counts of all quartets
* {a,b,c,d} for which a and b are in the same subtree, c is in another subtree, and d is in the remaining subtree.
* @param link1 link ID to the first subtree
* @param link2 link ID to the second subtree
* @param link3 link ID to the third subtree
* @param tree the evaluation tree
* @param eulerTourLeaves the leaves' IDs of the tree traversed in an euler tour order
* @param linkToEulerLeafIndex Mapping of each link in the tree to indices in the euler tour;
* needed for determining first and last index of leaves belonging to a subtree.
*/
template<typename CINT>
void QuartetCounterLookup<CINT>::updateQuartetsThreeLinks(size_t link1, size_t link2, size_t link3, const Tree &tree,
const std::vector<int> &eulerTourLeaves, const std::vector<int> &linkToEulerLeafIndex) {
std::pair<size_t, size_t> subtree1 = subtreeLeafIndices(link1, tree, linkToEulerLeafIndex);
std::pair<size_t, size_t> subtree2 = subtreeLeafIndices(link2, tree, linkToEulerLeafIndex);
std::pair<size_t, size_t> subtree3 = subtreeLeafIndices(link3, tree, linkToEulerLeafIndex);
size_t startLeafIndexS1 = subtree1.first % eulerTourLeaves.size();
size_t endLeafIndexS1 = subtree1.second % eulerTourLeaves.size();
size_t startLeafIndexS2 = subtree2.first % eulerTourLeaves.size();
size_t endLeafIndexS2 = subtree2.second % eulerTourLeaves.size();
size_t startLeafIndexS3 = subtree3.first % eulerTourLeaves.size();
size_t endLeafIndexS3 = subtree3.second % eulerTourLeaves.size();
updateQuartetsThreeClades(startLeafIndexS1, endLeafIndexS1, startLeafIndexS2, endLeafIndexS2, startLeafIndexS3,
endLeafIndexS3, eulerTourLeaves);
updateQuartetsThreeClades(startLeafIndexS2, endLeafIndexS2, startLeafIndexS1, endLeafIndexS1, startLeafIndexS3,
endLeafIndexS3, eulerTourLeaves);
updateQuartetsThreeClades(startLeafIndexS3, endLeafIndexS3, startLeafIndexS1, endLeafIndexS1, startLeafIndexS2,
endLeafIndexS2, eulerTourLeaves);
}
/**
* An inner node in a bifurcating tree induces three subtrees S_1, S_2, and S_3.
* Given an evaluation tree and an inner node, update the quartet topology counts of all quartets
* {a,b,c,d} for which a and b are in the same subtree, c is in another subtree, and d is in the remaining subtree.
* @param tree the evaluation tree
* @param nodeIdx ID of an inner node in the evaluation tree
* @param eulerTourLeaves the leaves' IDs of the tree traversed in an euler tour order
* @param linkToEulerLeafIndex Mapping of each link in the tree to indices in the euler tour;
* needed for determining first and last index of leaves belonging to a subtree.
*/
template<typename CINT>
void QuartetCounterLookup<CINT>::updateQuartets(const Tree &tree, size_t nodeIdx,
const std::vector<int> &eulerTourLeaves, const std::vector<int> &linkToEulerLeafIndex) {
// get taxa from subtree clades at nodeIdx
std::vector<size_t> subtreeLinkIndices;
const TreeLink* actLinkPtr = &tree.node_at(nodeIdx).link();
subtreeLinkIndices.push_back(actLinkPtr->index());
while (subtreeLinkIndices[0] != actLinkPtr->next().index()) {
actLinkPtr = &actLinkPtr->next();
subtreeLinkIndices.push_back(actLinkPtr->index());
}
for (size_t i = 0; i < subtreeLinkIndices.size(); ++i) {
for (size_t j = i + 1; j < subtreeLinkIndices.size(); ++j) {
for (size_t k = j + 1; k < subtreeLinkIndices.size(); ++k) {
size_t link1 = subtreeLinkIndices[i];
size_t link2 = subtreeLinkIndices[j];
size_t link3 = subtreeLinkIndices[k];
updateQuartetsThreeLinks(link1, link2, link3, tree, eulerTourLeaves, linkToEulerLeafIndex);
}
}
}
}
/**
* Fill the lookup table by counting quartet topologies in the set of evaluation trees.
* @param evalTreesPath path to the file containing the set of evaluation trees
* @param m number of evaluation trees
* @param taxonToReferenceID mapping of taxon names to leaf ID in reference tree
*/
template<typename CINT>
void QuartetCounterLookup<CINT>::countQuartets(const std::string &evalTreesPath, size_t m,
const std::unordered_map<std::string, size_t> &taxonToReferenceID) {
unsigned int progress = 1;
float onePercent = (float) m / 100;
utils::InputStream instream(utils::make_unique<utils::FileInputSource>(evalTreesPath));
auto itTree = NewickInputIterator(instream, DefaultTreeNewickReader());
size_t i = 0;
while (itTree) { // iterate over the set of evaluation trees
Tree const& tree = *itTree;
size_t nEval = tree.node_count();
// do an euler tour through the tree
std::vector<int> eulerTourLeaves; // directly containing the mapped IDs from the reference
std::vector<int> linkToEulerLeafIndex;
linkToEulerLeafIndex.resize(tree.link_count());
for (auto it : eulertour(tree)) {
if (it.node().is_leaf()) {
size_t leafIdx = it.node().index();
eulerTourLeaves.push_back(
refIdToLookupID[taxonToReferenceID.at(tree.node_at(leafIdx).data<DefaultNodeData>().name)]);
}
linkToEulerLeafIndex[it.link().index()] = eulerTourLeaves.size();
}
#pragma omp parallel for schedule(dynamic)
for (size_t j = 0; j < nEval; ++j) {
if (!tree.node_at(j).is_leaf()) {
updateQuartets(tree, j, eulerTourLeaves, linkToEulerLeafIndex);
}
}
if (i > progress * onePercent) {
std::cout << "Counting quartets... " << progress << "%" << std::endl;
progress++;
}
++itTree;
++i;
}
}
/**
* @param refTree the reference tree
* @param evalTreesPath path to the file containing the set of evaluation trees
* @param m the number of evaluation trees
*/
template<typename CINT>
QuartetCounterLookup<CINT>::QuartetCounterLookup(Tree const &refTree, const std::string &evalTreesPath, size_t m,
bool savemem) :
savemem(savemem) {
std::unordered_map<std::string, size_t> taxonToReferenceID;
refIdToLookupID.resize(refTree.node_count());
n = 0;
for (auto it : eulertour(refTree)) {
if (it.node().is_leaf()) {
taxonToReferenceID[it.node().data<DefaultNodeData>().name] = it.node().index();
refIdToLookupID[it.node().index()] = n;
n++;
}
}
n_square = n * n;
n_cube = n_square * n;
// initialize the lookup table.
if (savemem) {
lookupTable.init(n);
} else {
lookupTableFast.resize(n * n * n * n);
}
countQuartets(evalTreesPath, m, taxonToReferenceID);
if (savemem) {
std::cout << "lookup table size in bytes: " << lookupTable.size() << "\n";
} else {
std::cout << "lookup table size in bytes: " << lookupTableFast.size() * sizeof(CINT) << "\n";
}
}
/**
* Returns the count of the quartet topology ab|cd in the evaluation trees... only needed for the fast option
* @param aIdx ID of taxon a
* @param bIdx ID of taxon b
* @param cIdx ID of taxon c
* @param dIdx ID of taxon d
*/
template<typename CINT>
CINT QuartetCounterLookup<CINT>::lookupQuartetCount(size_t aIdx, size_t bIdx, size_t cIdx, size_t dIdx) const {
aIdx = refIdToLookupID[aIdx];
bIdx = refIdToLookupID[bIdx];
cIdx = refIdToLookupID[cIdx];
dIdx = refIdToLookupID[dIdx];
return lookupTableFast[CO(aIdx, bIdx, cIdx, dIdx)] + lookupTableFast[CO(aIdx, bIdx, dIdx, cIdx)]
+ lookupTableFast[CO(bIdx, aIdx, cIdx, dIdx)] + lookupTableFast[CO(bIdx, aIdx, dIdx, cIdx)];
}
/**
* Returns the counts of the quartet topologies ab|cd, ac|bd, and ad|bc in the evaluation trees
* @param aIdx ID of taxon a
* @param bIdx ID of taxon b
* @param cIdx ID of taxon c
* @param dIdx ID of taxon d
*/
template<typename CINT>
std::tuple<CINT, CINT, CINT> QuartetCounterLookup<CINT>::countQuartetOccurrences(size_t aIdx, size_t bIdx, size_t cIdx,
size_t dIdx) const {
if (savemem) {
size_t a = refIdToLookupID[aIdx];
size_t b = refIdToLookupID[bIdx];
size_t c = refIdToLookupID[cIdx];
size_t d = refIdToLookupID[dIdx];
const auto& tuple = lookupTable.get_tuple(a, b, c, d);
CINT abCD = tuple[lookupTable.tuple_index(a, b, c, d)];
CINT acBD = tuple[lookupTable.tuple_index(a, c, b, d)];
CINT adBC = tuple[lookupTable.tuple_index(a, d, b, c)];
return std::tuple<CINT, CINT, CINT>(abCD, acBD, adBC);
} else {
CINT abCD = lookupQuartetCount(aIdx, bIdx, cIdx, dIdx);
CINT acBD = lookupQuartetCount(aIdx, cIdx, bIdx, dIdx);
CINT adBC = lookupQuartetCount(aIdx, dIdx, bIdx, cIdx);
return std::tuple<CINT, CINT, CINT>(abCD, acBD, adBC);
}
}