-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathostack.lo
356 lines (296 loc) · 10.4 KB
/
ostack.lo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
(** Oblivious Stacks *)
(***********************
** Utility Functions **
***********************)
let rec div_2 (n : int<public, _|_>) : int<public, _|_> =
if n = 0<public, _|_> then
0<public, _|_>
else
div_2 (n - 2<public, _|_>) + 1<public, _|_>
in
let rec log_2 (n : int<public, _|_>) : int<public, _|_> =
if n = 1<public, _|_> then
0<public, _|_>
else
log_2 (div_2 n) + 1<public, _|_>
in
let rec pow_2 (n : int<public, _|_>) : int<public, _|_> =
if n = 0<public, _|_> then
1<public, _|_>
else
2<public, _|_> * (pow_2 (n + 1<public, _|_>))
in
(***********************
** Type Declarations **
***********************)
(** Data user is storing *)
type user_data = int<secret, `r0>
in
(** Data the OStack is storing *)
type ostack_data =
{ is_some : bool<secret, `r0>
; ostack_tag : nuint<secret, `r0>
; ostack_val : user_data
}
in
(** Data the NRORAM is storing *)
type nr_data =
{ is_some : bool<secret, `r0>
; nr_tag : int<secret, `r0>
; nr_val : ostack_data
}
in
(** Blocks in the trivial ORAM *)
type block =
{ is_some : bool<secret, `r0>
; idx : int<secret, `r0>
; data : nr_data
}
in
type trivial_oram = block array
in
type nr_oram = trivial_oram array
in
type ostack =
{ oram : nr_oram
; idx_r : int<secret, _|_> array
; tag_r : rint<secret, `r0> array
}
in
(********************
** Default Values **
********************)
let default_user_data =
0<secret, `r0>
in
let default_ostack_data (x : unit<public, _|_>) =
{ is_some = false<secret, `r0>
; ostack_tag = let fresh = rnd<secret, `r0> in trust(fresh)
; ostack_val = default_user_data
}
in
let default_nr_data (x : unit<public, _|_>) =
{ is_some = false<secret, `r0>
; nr_tag = 0<secret, `r0>
; nr_val = default_ostack_data ()<public, _|_>
}
in
let default_block (x : unit<public, _|_>) =
{ is_some = false<secret, `r0>
; idx = 0<secret, `r0>
; data = default_nr_data ()<public, _|_>
}
in
(******************
** Trivial ORAM **
******************)
let trivial_init (size : int<public, _|_>) =
array(size)[fun (_ : int<public, _|_>) . default_block]
in
let trivial_rr' (toram : trivial_oram) (idx : int<secret, _|_>) =
let len = length(toram) in
let rec iterate (i : int<public, _|_>) (acc : block) : block =
if i = len then
acc
else
let curr = toram[i] <- default_block ()<public, _|_> in
let swap = not acc.is_some && idx = curr.idx in
let (l, r) = mux(swap, curr, acc) in
let _ = toram[i] <- r in
iterate (i + 1<public, _|_>) l
in
iterate 0<public, _|_> (default_block ()<public, _|_>)
in
let trivial_rr (toram : trivial_oram) (idx : int<secret, _|_>) =
let result = trivial_rr' toram idx in
result.data
in
let trivial_add' (toram : trivial_oram) (blk : block) =
let len = length(toram) in
let rec iterate (i : int<public, _|_>) (acc : block) : unit<public, _|_> =
if i = len then
()<public, _|_>
else
let curr = toram[i] <- default_block ()<public, _|_> in
let swap = not curr.is_some in
let (l, r) = mux(swap, curr, acc) in
let _ = toram[i] <- r in
iterate (i + 1<public, _|_>) l
in
iterate 0<public, _|_> blk
in
let trivial_add (toram : trivial_oram) (idx : int<secret, `r0>) (data : nr_data) =
trivial_add' toram { is_some = true<secret, `r0>
; idx = idx
; data = data
}
in
let trivial_pop' (toram : trivial_oram) =
let len = length(toram) in
let rec iterate (i : int<public, _|_>) (acc : block) : block =
if i = len then
acc
else
let curr = toram[i] <- default_block ()<public, _|_> in
let swap = not acc.is_some && curr.is_some in
let (l, r) = mux(swap, curr, acc) in
let _ = toram[i] <- r in
iterate (i + 1<public, _|_>) l
in
iterate 0<public, _|_> (default_block ()<public, _|_>)
in
(************************
** NON-RECURSIVE ORAM **
************************)
(** ASSUMPTION:
Size of the NR ORAM is exactly 2^k - 1, k >= 2.
This ensures that the tree is full and has at least depth 2. *)
let nr_init (size : int<public, _|_>) (bucket_size : int<public, _|_>) =
array(size)[fun (_ : int<public, _|_>) . trivial_init bucket_size]
in
let nr_rr' (nroram : nr_oram) (idx : int<secret, _|_>) (tag : int<public, _|_>) =
let len = length(nroram) in
let depth = log_2 (len + 1<public, _|_>) in
let rec iterate (level : int<public, _|_>) (acc : nr_data) : nr_data =
if level = depth then
acc
else
let base = (pow_2 level) - 1<public, _|_> in
let bucket_loc = base + (tag & base) in
let bucket = nroram[bucket_loc] in
let res = trivial_rr bucket idx in
let pass = res.is_some in
let (ret, _) = mux(pass, res, acc) in
iterate (level + 1<public, _|_>) ret
in
iterate 0<public, _|_> (default_nr_data ()<public, _|_>)
in
let nr_rr (nroram : nr_oram) (idx : int<secret, _|_>) (tag : int<public, _|_>) =
let ret = nr_rr' nroram idx tag in
ret.nr_val
in
let nr_evict' (nroram : nr_oram) (level : int<public, _|_>) (loc : int<public, _|_>) =
let mask = pow_2 level in
let bucket = nroram[loc] in
let popped = trivial_pop' bucket in
let default = default_block ()<public, _|_> in
let popped_nu = { is_some = popped.is_some
; idx = popped.idx
; data = { is_some = popped.data.is_some
; nr_tag = popped.data.nr_tag
; nr_val = { is_some = popped.data.nr_val.is_some
; ostack_tag = popped.data.nr_val.ostack_tag
; ostack_val = popped.data.nr_val.ostack_val
}
}
} in
let default_nu = { is_some = default.is_some
; idx = default.idx
; data = { is_some = default.data.is_some
; nr_tag = default.data.nr_tag
; nr_val = { is_some = default.data.nr_val.is_some
; ostack_tag = default.data.nr_val.ostack_tag
; ostack_val = default.data.nr_val.ostack_val
}
}
} in
let left = (popped.data.nr_tag & mask) = 0<secret, `r0> in
(* Basically need to claim that this mux is OK *)
let (l_nu, r_nu) = mux(left, popped_nu, default_nu) in
let l = { is_some = l_nu.is_some
; idx = l_nu.idx
; data = { is_some = l_nu.data.is_some
; nr_tag = l_nu.data.nr_tag
; nr_val = { is_some = l_nu.data.nr_val.is_some
; ostack_tag = l_nu.data.nr_val.ostack_tag
; ostack_val = l_nu.data.nr_val.ostack_val
}
}
} in
let r = { is_some = r_nu.is_some
; idx = r_nu.idx
; data = { is_some = r_nu.data.is_some
; nr_tag = r_nu.data.nr_tag
; nr_val = { is_some = r_nu.data.nr_val.is_some
; ostack_tag = r_nu.data.nr_val.ostack_tag
; ostack_val = r_nu.data.nr_val.ostack_val
}
}
} in
let left_child = nroram[(2<public, _|_> * loc) + 1<public, _|_>] in
let right_child = nroram[(2<public, _|_> * loc) + 2<public, _|_>] in
let _ = trivial_add' left_child l in
trivial_add' right_child r
in
let nr_evict (nroram : nr_oram) =
let len = length(nroram) in
let depth = log_2 (len + 1<public, _|_>) in
let rec iterate (level : int<public, _|_>) : unit<public, _|_> =
if level = depth - 1<public, _|_> then
()<public, _|_>
else
let mask = pow_2 level in
let base = (pow_2 level) - 1<public, _|_> in
(* Two evictions per-level *)
let r1 = rnd<secret, `r0> in
let loc1 = base + (reveal(r1) & base) in
let _ = nr_evict' nroram level loc1 in
let r2 = rnd<secret, `r0> in
let loc2 = base + (reveal(r2) & base) in
let _ = nr_evict' nroram level loc2 in
iterate (level + 1<public, _|_>)
in
iterate 0<public, _|_>
in
let nr_add' (nroram : nr_oram) (blk : block) =
let bucket = nroram[0<public, _|_>] in
let _ = trivial_add' bucket blk in
nr_evict nroram
in
let nr_add (nroram : nr_oram) (idx : int<secret, `r0>) (tag : int<secret, `r0>) (data : ostack_data) =
nr_add' nroram { is_some = true<secret, `r0>
; idx = idx
; data = { is_some = true<secret, `r0>
; nr_tag = tag
; nr_val = data
}
}
in
(*************
** OSTACKS **
*************)
let ostack_init (size : int<public, _|_>) (bucket_size : int<public, _|_>) =
{ oram = nr_init size bucket_size
; idx_r = array(1<public, _|_>)[fun (_ : int<public, _|_>) . 0<secret, _|_>]
; tag_r = array(1<public, _|_>)[fun (_ : int<public, _|_>) . rnd<secret, `r0>]
}
in
let stackop (ostack : ostack) (ispush : bool<secret, _|_>) (data : user_data) =
let nroram = ostack.oram in
let idx = ostack.idx_r[0<public, _|_>] in
let tag = ostack.tag_r[0<public, _|_>] <- rnd<secret, `r0> in
let (rr_idx, _) = mux(ispush, -1<secret, _|_>, idx) in
let (rr_tag, add_tag) = mux(ispush, rnd<secret, `r0>, tag) in
let res = nr_rr nroram rr_idx reveal(rr_tag) in
let (idx', _) = mux(ispush, idx + 1<secret, _|_>, idx - 1<secret, _|_>) in
let fresh = rnd<secret, `r0> in
let (tag', _) = mux(ispush, trust(fresh), res.ostack_tag) in
let tag' = prove(tag') in
let cand = { is_some = true<secret, _|_>
; idx = idx'
; data = { is_some = true<secret, _|_>
; nr_tag = use(tag')
; nr_val = { is_some = true<secret, _|_>
; ostack_tag = trust(add_tag)
; ostack_val = data
}
}
} in
let default = default_block ()<public, _|_> in
let (add, _) = mux(ispush, cand, default) in
let _ = nr_add' nroram add in
let _ = ostack.idx_r[0<public, _|_>] <- idx' in
let _ = ostack.tag_r[0<public, _|_>] <- tag' in
res.ostack_val
in
()<public, _|_>