-
Notifications
You must be signed in to change notification settings - Fork 4
/
cu_hash_table.cu
75 lines (62 loc) · 1.92 KB
/
cu_hash_table.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// #include "cu_hash_table.h"
// namespace mxnet {
// namespace op {
// namespace permutohedral {
// template<int key_size>
// CuHashTable<key_size>::CuHashTable(int32_t n_keys, int32_t *entries, int16_t *keys)
// : n_keys_(n_keys), entries_(entries), keys_(keys) {
// }
// template<int key_size>
// MSHADOW_FORCE_INLINE __device__ int32_t CuHashTable<key_size>::hash(const int16_t *key) {
// int32_t h = 0;
// for (int32_t i = 0; i < key_size; i++) {
// h = (h + key[i])* 2531011;
// }
// h = h%(2*n_keys_);
// return h;
// }
// template<int key_size>
// MSHADOW_FORCE_INLINE __device__ int32_t CuHashTable<key_size>::insert(const int16_t *key, int32_t idx) {
// int32_t h = hash(key);
// // write our key
// for (int32_t i = 0; i < key_size; i++) {
// keys_[idx*key_size+i] = key[i];
// }
// while (true) {
// int32_t *e = entries_ + h;
// // If the cell is empty (-1), write our key in it.
// int32_t contents = atomicCAS(e, -1, idx);
// if (contents == -1) {
// // If it was empty, return.
// return idx;
// } else {
// // The cell has a key in it, check if it matches
// bool match = true;
// for (int32_t i = 0; i < key_size && match; i++) {
// match = (keys_[contents*key_size+i] == key[i]);
// }
// if (match) return contents;
// }
// // increment the bucket with wraparound
// h++;
// if (h == n_keys_*2) h = 0;
// }
// }
// template<int key_size>
// MSHADOW_FORCE_INLINE __device__ int32_t CuHashTable<key_size>::find(const int16_t *key) {
// int32_t h = hash(key);
// while (true) {
// int32_t contents = entries_[h];
// if (contents == -1) return -1;
// bool match = true;
// for (int32_t i = 0; i < key_size && match; i++) {
// match = (keys_[contents*key_size+i] == key[i]);
// }
// if (match) return contents;
// h++;
// if (h == n_keys_*2) h = 0;
// }
// }
// }
// }
// }