From c9ecd766e7695ecbbcea343deb4859c9fa66f48a Mon Sep 17 00:00:00 2001 From: Robert McLaughlin Date: Sat, 11 May 2024 01:14:41 +0000 Subject: [PATCH] When possible, read in chunks when using readn LambdaMemory has a cache now, which makes for more solver-friendly constraints. However, this cache is not used when we are evaluating a `readn` of more than one word. For example, when calculating a mapping, you need to `keccak(val . slot)`. This change makes it so that we can make use of the cache for `val` and for `slot` -- instead of reading them via array-ops (which are slow). --- greed/memory/lambda_memory.py | 42 +++--- tests/test_lambda_memory_read_from_cache.py | 136 ++++++++++++++++++++ 2 files changed, 163 insertions(+), 15 deletions(-) create mode 100644 tests/test_lambda_memory_read_from_cache.py diff --git a/greed/memory/lambda_memory.py b/greed/memory/lambda_memory.py index a6fbe39..aa9cbcd 100644 --- a/greed/memory/lambda_memory.py +++ b/greed/memory/lambda_memory.py @@ -234,33 +234,45 @@ def readn(self, index, n): AssertionError: if the length is 0 """ assert is_concrete(n), "readn with symbolic length not implemented" - assert bv_unsigned_value(n) != 0, "invalid readn with length=0" + n_val = bv_unsigned_value(n) + assert n_val != 0, "invalid readn with length=0" # check cache if ( is_concrete(index) - and bv_unsigned_value(index) in self.cache[bv_unsigned_value(n)] + and bv_unsigned_value(index) in self.cache[n_val] ): - # print(f"cache hit {bv_unsigned_value(index)}: {self.cache[bv_unsigned_value(n)][bv_unsigned_value(index)]}") - return self.cache[bv_unsigned_value(n)][bv_unsigned_value(index)] + # print(f"cache hit {bv_unsigned_value(index)}: {self.cache[n_val][bv_unsigned_value(index)]}") + return self.cache[n_val][bv_unsigned_value(index)] - if bv_unsigned_value(n) == 1: + if n_val == 1: return self[index] else: + vv = [] if is_concrete(index): - tag = f"READN_{self.tag}_BASE{self._base.id}_{bv_unsigned_value(index)}_{bv_unsigned_value(n)}" + tag = f"READN_{self.tag}_BASE{self._base.id}_{bv_unsigned_value(index)}_{n_val}" + + # Optimization: attempt to read in word-size chunks, which is more cache-friendly + index_val = bv_unsigned_value(index) + for idx in range(index_val, index_val + n_val, 32): + # If we can read a whole chunk, try the cache + if idx + 32 <= index_val + n_val and idx in self.cache[32]: + vv.append(self.cache[32][idx]) + else: + # We cannot read a chunk, just do it the normal way + for pos in range(idx, min(idx + 32, index_val + n_val)): + vv.append(self[BVV(pos, 256)]) else: - tag = f"READN_{self.tag}_BASE{self._base.id}_sym{index.id}_{bv_unsigned_value(n)}" + tag = f"READN_{self.tag}_BASE{self._base.id}_sym{index.id}_{n_val}" + vv = list() + for i in range(n_val): + read_index = BV_Add(index, BVV(i, 256)) + vv.append(self[read_index]) - res = BVS(tag, bv_unsigned_value(n) * 8) - - vv = list() - for i in range(bv_unsigned_value(n)): - read_index = BV_Add(index, BVV(i, 256)) - vv.append(self[read_index]) + res = BVS(tag, n_val * 8) self.add_constraint(Equal(res, BV_Concat(vv))) - # print(f"actual readn {bv_unsigned_value(index)}:{bv_unsigned_value(n)} = {BV_Concat(vv)}") + # print(f"actual readn {bv_unsigned_value(index)}:{n_val} = {BV_Concat(vv)}") return res def writen(self, index, v, n): @@ -271,7 +283,7 @@ def writen(self, index, v, n): for i in range(bv_unsigned_value(n)): m = BV_Extract((31 - i) * 8, (31 - i) * 8 + 7, v) - self.state.memory[BV_Add(index, BVV(i, 256))] = m + self[BV_Add(index, BVV(i, 256))] = m # update cache if is_concrete(index): diff --git a/tests/test_lambda_memory_read_from_cache.py b/tests/test_lambda_memory_read_from_cache.py new file mode 100644 index 0000000..b4a218d --- /dev/null +++ b/tests/test_lambda_memory_read_from_cache.py @@ -0,0 +1,136 @@ +""" +Tests LambdaMemory's ability to use cache to avoid doing array ops +""" + +from greed.state_plugins import SimStateSolver +from greed.solver.shortcuts import * +from greed.memory import LambdaMemory + +def test_basic_store_read(): + """ + Tests basic operation: we can store a value and read (the exact same value) back + """ + mem = _get_dummy_memory() + + to_write = BVV(0xDEADBEEFCAFE, 256) + + mem.writen( + BVV(0, 256), + to_write, + BVV(32, 256), + ) + + read = mem.readn( + BVV(0, 256), + BVV(32, 256), + ) + + # We should have the same value that we wrote, because it was cached + assert read.id == to_write.id + +def test_readn_reads_words_from_cache(): + """ + Ensure that readn() reads and concats from cache when possible, avoiding + any array operations. + """ + mem = _get_dummy_memory() + + # Write two values to memory + to_write_1 = BVV(0xDEADBEEFCAFE, 256) + to_write_2 = BVV(0xCAFEBABE, 256) + + mem.writen( + BVV(0, 256), + to_write_1, + BVV(32, 256), + ) + mem.writen( + BVV(32, 256), + to_write_2, + BVV(32, 256), + ) + + read = mem.readn( + BVV(0, 256), + BVV(64, 256), + ) + + # We should have a symbol + assert read.operator == 'bvs' + + # The solver should not have any array operations -- dfs through the + # current assertions to ensure + queue = list(mem.state.solver.constraints) + + while queue: + constraint = queue.pop(0) + + if getattr(constraint, 'operator', None) == 'array': + raise AssertionError('Array operation found in solver') + + queue.extend(getattr(constraint, 'children', [])) + + # The value should be the concatenation of the two values + value = mem.state.solver.eval(read) + expected = (0xDEADBEEFCAFE << 256) | 0xCAFEBABE + assert hex(value) == hex(expected) + +def test_readn_uses_array_fallback(): + """ + Test that when no cache is available, readn will fall back to using array + operations. + """ + mem = _get_dummy_memory() + + # Write one value to memory + to_write = BVV(0xDEADBEEFCAFE, 256) + mem.writen( + BVV(0, 256), + to_write, + BVV(32, 256), + ) + + # Read unaligned + read = mem.readn( + BVV(10, 256), + BVV(32, 256), + ) + + # We should have a symbol + assert read.operator == 'bvs' + + # There should be an array op in the solver + queue = list(mem.state.solver.constraints) + while queue: + constraint = queue.pop(0) + + if getattr(constraint, 'operator', None) == 'array': + break + + queue.extend(getattr(constraint, 'children', [])) + else: + raise AssertionError('Array operation not found in solver') + + # The value should be what we wrote, but shifted + value = mem.state.solver.eval(read) + expected = 0xDEADBEEFCAFE << (8 * 10) + assert hex(value) == hex(expected) + +_n_mems = 0 +def _get_dummy_memory() -> LambdaMemory: + global _n_mems + + ret = LambdaMemory( + f'test_{_n_mems}', + value_sort=BVSort(8), + default=BVV(0, 8), + state=DummyState() + ) + + _n_mems += 1 + + return ret + +class DummyState: + def __init__(self): + self.solver = SimStateSolver()