Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve address dealiasing #166

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 47 additions & 41 deletions boa/contracts/vyper/vyper_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,53 +366,47 @@ def __init__(self, contract, slot, typ):
self.slot = slot
self.typ = typ

def _decode(self, slot, typ, truncate_limit=None):
def _decode_storage(self, slot, typ, truncate_limit=None):
n = typ.memory_bytes_required
if truncate_limit is not None and n > truncate_limit:
return None # indicate failure to caller

fakemem = ByteAddressableStorage(self.accountdb, self.addr, slot)
return decode_vyper_object(fakemem, typ)

def _dealias(self, maybe_address):
try:
return self.contract.env.lookup_alias(maybe_address)
except KeyError: # not found, return the input
return maybe_address
return self.contract._decode(fakemem, typ)

def get(self, truncate_limit=None):
if isinstance(self.typ, HashMapT):
ret = {}
for k in self.contract.env.sstore_trace.get(self.addr, {}):
path = unwrap_storage_key(self.contract.env.sha3_trace, k)
if to_int(path[0]) != self.slot:
continue

path = path[1:] # drop the slot
path_t = []

ty = self.typ
for i, p in enumerate(path):
path[i] = decode_vyper_object(memoryview(p), ty.key_type)
path_t.append(ty.key_type)
ty = ty.value_type

val = self._decode(k, ty, truncate_limit)

# set val only if value is nonzero
if val:
# decode aliases as needed/possible
dealiased_path = []
for p, t in zip(path, path_t):
if isinstance(t, AddressT):
p = self._dealias(p)
dealiased_path.append(p)
setpath(ret, dealiased_path, val)
if not isinstance(self.typ, HashMapT):
return self._decode_storage(self.slot, self.typ, truncate_limit)

return ret
ret = {}
for k in self.contract.env.sstore_trace.get(self.addr, {}):
path = unwrap_storage_key(self.contract.env.sha3_trace, k)
if to_int(path[0]) != self.slot:
continue

path = path[1:] # drop the slot
path_t = []

ty = self.typ
for i, p in enumerate(path):
path[i] = self.contract._decode(memoryview(p), ty.key_type)
path_t.append(ty.key_type)
ty = ty.value_type

val = self._decode_storage(k, ty, truncate_limit)

# set val only if value is nonzero
if val:
# decode aliases as needed/possible
dealiased_path = []
for p, t in zip(path, path_t):
if isinstance(t, AddressT):
p = self.contract._dealias(p)
dealiased_path.append(p)
setpath(ret, dealiased_path, val)

return ret

else:
return self._decode(self.slot, self.typ, truncate_limit)


# data structure to represent the storage variables in a contract
Expand Down Expand Up @@ -446,7 +440,7 @@ def __init__(self, contract):
if v.is_immutable: # check that v
ofst = compiler_data.storage_layout["code_layout"][k]["offset"]
immutable_raw_bytes = data_section[ofst:]
value = decode_vyper_object(immutable_raw_bytes, v.typ)
value = contract._decode(immutable_raw_bytes, v.typ)
setattr(self, k, value)

def dump(self):
Expand Down Expand Up @@ -547,6 +541,18 @@ def __repr__(self):

return ret

def _dealias(self, maybe_address):
try:
return self.env.lookup_alias(maybe_address)
except KeyError: # not found, return the input
return maybe_address

def _decode(mem, typ):
ret = decode_vyper_object(mem, typ)
if isinstance(typ, AddressT):
return f"address \"{self._dealias(ret)}\""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks weird, can we add some explanation why it's needed?

return ret

@cached_property
def _immutables(self):
return ImmutablesModel(self)
Expand Down Expand Up @@ -582,14 +588,14 @@ def debug_frame(self, computation=None):
mem = computation._memory
frame_detail = FrameDetail(fn.name)

# ensure memory is initialized for `decode_vyper_object()`
# ensure memory is initialized for `self._decode()`
mem.extend(frame_info.frame_start, frame_info.frame_size)
for k, v in frame_info.frame_vars.items():
if v.location.name != "memory":
continue
ofst = v.pos
size = v.typ.memory_bytes_required
frame_detail[k] = decode_vyper_object(mem.read(ofst, size), v.typ)
frame_detail[k] = self._decode(mem.read(ofst, size), v.typ)

return frame_detail

Expand Down
Loading