-
Notifications
You must be signed in to change notification settings - Fork 4
/
cu_hash_table.h
85 lines (70 loc) · 1.93 KB
/
cu_hash_table.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#ifndef MXNET_PERMUTOHEDRAL_HASH_TABLE_H_
#define MXNET_PERMUTOHEDRAL_HASH_TABLE_H_
#include <mshadow/base.h>
namespace mxnet {
namespace op {
namespace permutohedral {
#if defined(__CUDACC__)
template<int key_size>
class CuHashTable {
public:
CuHashTable(int32_t n_keys, int32_t *entries, int16_t *keys)
: n_keys_(n_keys), entries_(entries), keys_(keys) {
}
MSHADOW_FORCE_INLINE __device__ uint32_t hash(const int16_t *key) {
uint32_t h = 0;
for (int32_t i = 0; i < key_size; i++) {
h = (h + key[i])* 2531011;
}
h = h%(2*n_keys_);
return h;
}
MSHADOW_FORCE_INLINE __device__ int32_t insert(const int16_t *key, int32_t idx) {
uint32_t h = hash(key);
// write our key
for (int32_t i = 0; i < key_size; i++) {
keys_[idx*key_size+i] = key[i];
}
while (true) {
int32_t *e = entries_ + h;
// If the cell is empty (-1), write our key in it.
int32_t contents = atomicCAS(e, -1, idx);
if (contents == -1) {
// If it was empty, return.
return idx;
} else {
// The cell has a key in it, check if it matches
bool match = true;
for (int32_t i = 0; i < key_size && match; i++) {
match = (keys_[contents*key_size+i] == key[i]);
}
if (match) return contents;
}
// increment the bucket with wraparound
h++;
if (h == n_keys_*2) h = 0;
}
}
MSHADOW_FORCE_INLINE __device__ int32_t find(const int16_t *key) {
uint32_t h = hash(key);
while (true) {
int32_t contents = entries_[h];
if (contents == -1) return -1;
bool match = true;
for (int32_t i = 0; i < key_size && match; i++) {
match = (keys_[contents*key_size+i] == key[i]);
}
if (match) return contents;
h++;
if (h == n_keys_*2) h = 0;
}
}
int32_t n_keys_;
int32_t *entries_;
int16_t *keys_;
};
#endif
}
}
}
#endif // MXNET_PERMUTOHEDRAL_HASH_TABLE_H_