Skip to content

Commit

Permalink
When possible, read in chunks when using readn
Browse files Browse the repository at this point in the history
LambdaMemory has a cache now, which makes for more
solver-friendly constraints.

However, this cache is not used when we are evaluating
a `readn` of more than one word. For example, when
calculating a mapping, you need to `keccak(val . slot)`.
This change makes it so that we can make use of the
cache for `val` and for `slot` -- instead of reading
them via array-ops (which are slow).
  • Loading branch information
robmcl4 committed May 11, 2024
1 parent dfffb30 commit c9ecd76
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 15 deletions.
42 changes: 27 additions & 15 deletions greed/memory/lambda_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,33 +234,45 @@ def readn(self, index, n):
AssertionError: if the length is 0
"""
assert is_concrete(n), "readn with symbolic length not implemented"
assert bv_unsigned_value(n) != 0, "invalid readn with length=0"
n_val = bv_unsigned_value(n)
assert n_val != 0, "invalid readn with length=0"

# check cache
if (
is_concrete(index)
and bv_unsigned_value(index) in self.cache[bv_unsigned_value(n)]
and bv_unsigned_value(index) in self.cache[n_val]
):
# print(f"cache hit {bv_unsigned_value(index)}: {self.cache[bv_unsigned_value(n)][bv_unsigned_value(index)]}")
return self.cache[bv_unsigned_value(n)][bv_unsigned_value(index)]
# print(f"cache hit {bv_unsigned_value(index)}: {self.cache[n_val][bv_unsigned_value(index)]}")
return self.cache[n_val][bv_unsigned_value(index)]

if bv_unsigned_value(n) == 1:
if n_val == 1:
return self[index]
else:
vv = []
if is_concrete(index):
tag = f"READN_{self.tag}_BASE{self._base.id}_{bv_unsigned_value(index)}_{bv_unsigned_value(n)}"
tag = f"READN_{self.tag}_BASE{self._base.id}_{bv_unsigned_value(index)}_{n_val}"

# Optimization: attempt to read in word-size chunks, which is more cache-friendly
index_val = bv_unsigned_value(index)
for idx in range(index_val, index_val + n_val, 32):
# If we can read a whole chunk, try the cache
if idx + 32 <= index_val + n_val and idx in self.cache[32]:
vv.append(self.cache[32][idx])
else:
# We cannot read a chunk, just do it the normal way
for pos in range(idx, min(idx + 32, index_val + n_val)):
vv.append(self[BVV(pos, 256)])
else:
tag = f"READN_{self.tag}_BASE{self._base.id}_sym{index.id}_{bv_unsigned_value(n)}"
tag = f"READN_{self.tag}_BASE{self._base.id}_sym{index.id}_{n_val}"
vv = list()
for i in range(n_val):
read_index = BV_Add(index, BVV(i, 256))
vv.append(self[read_index])

res = BVS(tag, bv_unsigned_value(n) * 8)

vv = list()
for i in range(bv_unsigned_value(n)):
read_index = BV_Add(index, BVV(i, 256))
vv.append(self[read_index])
res = BVS(tag, n_val * 8)

self.add_constraint(Equal(res, BV_Concat(vv)))
# print(f"actual readn {bv_unsigned_value(index)}:{bv_unsigned_value(n)} = {BV_Concat(vv)}")
# print(f"actual readn {bv_unsigned_value(index)}:{n_val} = {BV_Concat(vv)}")
return res

def writen(self, index, v, n):
Expand All @@ -271,7 +283,7 @@ def writen(self, index, v, n):

for i in range(bv_unsigned_value(n)):
m = BV_Extract((31 - i) * 8, (31 - i) * 8 + 7, v)
self.state.memory[BV_Add(index, BVV(i, 256))] = m
self[BV_Add(index, BVV(i, 256))] = m

# update cache
if is_concrete(index):
Expand Down
136 changes: 136 additions & 0 deletions tests/test_lambda_memory_read_from_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""
Tests LambdaMemory's ability to use cache to avoid doing array ops
"""

from greed.state_plugins import SimStateSolver
from greed.solver.shortcuts import *
from greed.memory import LambdaMemory

def test_basic_store_read():
"""
Tests basic operation: we can store a value and read (the exact same value) back
"""
mem = _get_dummy_memory()

to_write = BVV(0xDEADBEEFCAFE, 256)

mem.writen(
BVV(0, 256),
to_write,
BVV(32, 256),
)

read = mem.readn(
BVV(0, 256),
BVV(32, 256),
)

# We should have the same value that we wrote, because it was cached
assert read.id == to_write.id

def test_readn_reads_words_from_cache():
"""
Ensure that readn() reads and concats from cache when possible, avoiding
any array operations.
"""
mem = _get_dummy_memory()

# Write two values to memory
to_write_1 = BVV(0xDEADBEEFCAFE, 256)
to_write_2 = BVV(0xCAFEBABE, 256)

mem.writen(
BVV(0, 256),
to_write_1,
BVV(32, 256),
)
mem.writen(
BVV(32, 256),
to_write_2,
BVV(32, 256),
)

read = mem.readn(
BVV(0, 256),
BVV(64, 256),
)

# We should have a symbol
assert read.operator == 'bvs'

# The solver should not have any array operations -- dfs through the
# current assertions to ensure
queue = list(mem.state.solver.constraints)

while queue:
constraint = queue.pop(0)

if getattr(constraint, 'operator', None) == 'array':
raise AssertionError('Array operation found in solver')

queue.extend(getattr(constraint, 'children', []))

# The value should be the concatenation of the two values
value = mem.state.solver.eval(read)
expected = (0xDEADBEEFCAFE << 256) | 0xCAFEBABE
assert hex(value) == hex(expected)

def test_readn_uses_array_fallback():
"""
Test that when no cache is available, readn will fall back to using array
operations.
"""
mem = _get_dummy_memory()

# Write one value to memory
to_write = BVV(0xDEADBEEFCAFE, 256)
mem.writen(
BVV(0, 256),
to_write,
BVV(32, 256),
)

# Read unaligned
read = mem.readn(
BVV(10, 256),
BVV(32, 256),
)

# We should have a symbol
assert read.operator == 'bvs'

# There should be an array op in the solver
queue = list(mem.state.solver.constraints)
while queue:
constraint = queue.pop(0)

if getattr(constraint, 'operator', None) == 'array':
break

queue.extend(getattr(constraint, 'children', []))
else:
raise AssertionError('Array operation not found in solver')

# The value should be what we wrote, but shifted
value = mem.state.solver.eval(read)
expected = 0xDEADBEEFCAFE << (8 * 10)
assert hex(value) == hex(expected)

_n_mems = 0
def _get_dummy_memory() -> LambdaMemory:
global _n_mems

ret = LambdaMemory(
f'test_{_n_mems}',
value_sort=BVSort(8),
default=BVV(0, 8),
state=DummyState()
)

_n_mems += 1

return ret

class DummyState:
def __init__(self):
self.solver = SimStateSolver()

0 comments on commit c9ecd76

Please sign in to comment.