-
Notifications
You must be signed in to change notification settings - Fork 0
/
oram.lo
305 lines (257 loc) · 8.83 KB
/
oram.lo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
(** Non-Recursive ORAM *)
(***********************
** Utility Functions **
***********************)
let rec div_2 (n : int<public, _|_>) : int<public, _|_> =
if n = 0<public, _|_> then
0<public, _|_>
else
div_2 (n - 2<public, _|_>) + 1<public, _|_>
in
let rec log_2 (n : int<public, _|_>) : int<public, _|_> =
if n = 1<public, _|_> then
0<public, _|_>
else
log_2 (div_2 n) + 1<public, _|_>
in
let rec pow_2 (n : int<public, _|_>) : int<public, _|_> =
if n = 0<public, _|_> then
1<public, _|_>
else
2<public, _|_> * (pow_2 (n + 1<public, _|_>))
in
(***********************
** Type Declarations **
***********************)
(** Data user is storing *)
type user_data = rint<secret, `r0 \/ `r1> * rint<secret, `r0 \/ `r1>
in
(** Data the NRORAM is storing *)
type nr_data =
{ is_some : bool<secret, `r0>
; nr_tag : int<secret, `r0>
; nr_val : user_data
}
in
(** Blocks in the trivial ORAM *)
type block =
{ is_some : bool<secret, `r0>
; idx : int<secret, `r0>
; data : nr_data
}
in
type trivial_oram = block array
in
type nr_oram = trivial_oram array
in
type oram = nr_oram array * trivial_oram
in
(********************
** Default Values **
********************)
let default_user_data =
(rnd<secret, `r0 \/ `r1>, rnd<secret, `r0 \/ `r1>)
in
let default_nr_data =
{ is_some = false<secret, `r0>
; nr_tag = 0<secret, `r0>
; nr_val = default_user_data
}
in
let default_block =
{ is_some = false<secret, `r0>
; idx = 0<secret, `r0>
; data = default_nr_data
}
in
(******************
** Trivial ORAM **
******************)
let trivial_init (size : int<public, _|_>) =
array(size)[fun (_ : int<public, _|_>) . default_block]
in
let trivial_rr' (toram : trivial_oram) (idx : int<secret, _|_>) =
let len = length(toram) in
let rec iterate (i : int<public, _|_>) (acc : block) : block =
if i = len then
acc
else
let curr = toram[i] <- default_block in
let swap = not acc.is_some && idx = curr.idx in
let (l, r) = mux(swap, curr, acc) in
let _ = toram[i] <- r in
iterate (i + 1<public, _|_>) l
in
iterate 0<public, _|_> default_block
in
let trivial_rr (toram : trivial_oram) (idx : int<secret, _|_>) =
let result = trivial_rr' toram idx in
result.data
in
let trivial_add' (toram : trivial_oram) (blk : block) =
let len = length(toram) in
let rec iterate (i : int<public, _|_>) (acc : block) : unit<public, _|_> =
if i = len then
()<public, _|_>
else
let curr = toram[i] <- default_block in
let swap = not curr.is_some in
let (l, r) = mux(swap, curr, acc) in
let _ = toram[i] <- r in
iterate (i + 1<public, _|_>) l
in
iterate 0<public, _|_> blk
in
let trivial_add (toram : trivial_oram) (idx : int<secret, `r0>) (data : nr_data) =
trivial_add' toram { is_some = true<secret, `r0>
; idx = idx
; data = data
}
in
let trivial_pop' (toram : trivial_oram) =
let len = length(toram) in
let rec iterate (i : int<public, _|_>) (acc : block) : block =
if i = len then
acc
else
let curr = toram[i] <- default_block in
let swap = not acc.is_some && curr.is_some in
let (l, r) = mux(swap, curr, acc) in
let _ = toram[i] <- r in
iterate (i + 1<public, _|_>) l
in
iterate 0<public, _|_> default_block
in
(************************
** NON-RECURSIVE ORAM **
************************)
(** ASSUMPTION:
Size of the NR ORAM is exactly 2^k - 1, k >= 2.
This ensures that the tree is full and has at least depth 2. *)
let nr_init (size : int<public, _|_>) (bucket_size : int<public, _|_>) =
array(size)[fun (_ : int<public, _|_>) . trivial_init bucket_size]
in
let nr_rr' (nroram : nr_oram) (idx : int<secret, _|_>) (tag : int<public, _|_>) =
let len = length(nroram) in
let depth = log_2 (len + 1<public, _|_>) in
let rec iterate (level : int<public, _|_>) (acc : nr_data) : nr_data =
if level = depth then
acc
else
let base = (pow_2 level) - 1<public, _|_> in
let bucket_loc = base + (tag & base) in
let bucket = nroram[bucket_loc] in
let res = trivial_rr bucket idx in
let pass = res.is_some in
let (ret, _) = mux(pass, res, acc) in
iterate (level + 1<public, _|_>) ret
in
iterate 0<public, _|_> default_nr_data
in
let nr_rr (nroram : nr_oram) (idx : int<secret, _|_>) (tag : int<public, _|_>) =
let ret = nr_rr' nroram idx tag in
ret.nr_val
in
let nr_evict' (nroram : nr_oram) (level : int<public, _|_>) (loc : int<public, _|_>) =
let mask = pow_2 level in
let bucket = nroram[loc] in
let popped = trivial_pop' bucket in
let left = (popped.data.nr_tag & mask) = 0<secret, `r0> in
(* Basically need to claim that this mux is OK *)
let (l, r) = mux(left, popped, default_block) in
let left_child = nroram[(2<public, _|_> * loc) + 1<public, _|_>] in
let right_child = nroram[(2<public, _|_> * loc) + 2<public, _|_>] in
let _ = trivial_add' left_child l in
trivial_add' right_child r
in
let nr_evict (nroram : nr_oram) =
let len = length(nroram) in
let depth = log_2 (len + 1<public, _|_>) in
let rec iterate (level : int<public, _|_>) : unit<public, _|_> =
if level = depth - 1<public, _|_> then
()<public, _|_>
else
let mask = pow_2 level in
let base = (pow_2 level) - 1<public, _|_> in
(* Two evictions per-level *)
let r1 = rnd<secret, `r1> in
let loc1 = base + (reveal(r1) & base) in
let _ = nr_evict' nroram level loc1 in
let r2 = rnd<secret, `r1> in
let loc2 = base + (reveal(r2) & base) in
let _ = nr_evict' nroram level loc2 in
iterate (level + 1<public, _|_>)
in
iterate 0<public, _|_>
in
let nr_add' (nroram : nr_oram) (blk : block) =
let bucket = nroram[0<public, _|_>] in
let _ = trivial_add' bucket blk in
nr_evict nroram
in
let nr_add (nroram : nr_oram) (idx : int<secret, `r0>) (tag : int<secret, `r0>) (data : user_data) =
nr_add' nroram { is_some = true<secret, `r0>
; idx = idx
; data = { is_some = true<secret, `r0>
; nr_tag = tag
; nr_val = data
}
}
in
(************************
** RECURSIVE ORAM **
************************)
let tree_add_h (oram : oram) (idx : int<secret, _|_>) (level : int<public, _|_>) (d : rint<secret, `r0 \/ `r1> * rint<secret, `r0 \/ `r1>) =
()<public, _|_>
in
let rec tree_rr_h (oram : oram) (idx : int<secret, _|_>) (level : int<public, _|_>) : rint<secret, `r0 \/ `r1> * rint<secret, `r0 \/ `r1> =
let (norams, troram) = oram in
let levels = length(norams) in
if level = levels then
let result = trivial_rr troram idx in
result.nr_val
else
let (r0, r1) = tree_rr_h oram (idx / 2<secret, _|_>) (level + 1<public, _|_>) in
let zero_m2 = idx % 2<secret, _|_> = 0<secret, _|_> in
let one_m2 = idx % 2<secret, _|_> = 1<secret, _|_> in
let fresh = rnd<secret, `r0 \/ `r1> in
let (r0', tag) = mux(zero_m2, fresh, r0) in
let (r1', tag) = mux(one_m2, tag, r1) in
let d' = (r0', r1') in
let _ = tree_add_h oram (idx / 2<secret, _|_>) (level + 1<public, _|_>) d' in
let curr_noram = norams[level] in
nr_rr curr_noram idx reveal(tag)
in
let tree_rr (oram : oram) (idx : int<secret, _|_>) =
tree_rr_h oram idx 0<public, _|_>
in
let rec add_rr_h (oram : oram) (idx : int<secret, _|_>) (level : int<public, _|_>) (d : rint<secret, `r0 \/ `r1> * rint<secret, `r0 \/ `r1>) : unit<public, _|_> =
let (norams, troram) = oram in
let levels = length(norams) in
if level = levels then
let tag = rnd<secret, `r0> in
let tag = use(tag) in
(* The use of + 0<secret, `r0> is to raise the region of idx from _|_ to `r0 *)
trivial_add troram (idx + 0<secret, `r0>) { is_some = true<secret, `r0>
; nr_tag = tag
; nr_val = d
}
else
let (r0, r1) = tree_rr_h oram (idx / 2<secret, _|_>) (level + 1<public, _|_>) in
let zero_m2 = idx % 2<secret, _|_> = 0<secret, _|_> in
let one_m2 = idx % 2<secret, _|_> = 1<secret, _|_> in
let fresh = rnd<secret, `r0 \/ `r1> in
let sec_fresh = use(fresh) in
let (r0', tag) = mux(zero_m2, fresh, r0) in
let (r1', tag) = mux(one_m2, tag, r1) in
let d' = (r0', r1') in
let _ = tree_add_h oram (idx / 2<secret, _|_>) (level + 1<public, _|_>) d' in
let curr_noram = norams[level] in
(* Uncomment the line below to see the type error mentioned in section 5.2 of the paper *)
(* let _ = nr_add curr_noram (idx + 0<secret, `r0>) sec_fresh d in *)
()<public, _|_>
in
let add_rr (oram : oram) (idx : int<secret, _|_>) (d : rint<secret, `r0 \/ `r1> * rint<secret, `r0 \/ `r1>) =
add_rr_h oram idx 0<public, _|_> d
in
()<public, _|_>