From 42eca98cf7f8a7c7e30812e1116977d09d8e0857 Mon Sep 17 00:00:00 2001 From: Daniel Schiavini Date: Fri, 18 Oct 2024 10:45:31 +0200 Subject: [PATCH] feat: implement function injection instead of eval --- boa/contracts/vvm/vvm_contract.py | 62 +++++++++++-------------- tests/unitary/contracts/vvm/test_vvm.py | 30 ++++++++++-- 2 files changed, 54 insertions(+), 38 deletions(-) diff --git a/boa/contracts/vvm/vvm_contract.py b/boa/contracts/vvm/vvm_contract.py index 738805b0..6a136dea 100644 --- a/boa/contracts/vvm/vvm_contract.py +++ b/boa/contracts/vvm/vvm_contract.py @@ -146,17 +146,19 @@ def bytecode(self): def bytecode_runtime(self): return to_bytes(self.compiler_output["bytecode_runtime"]) - def eval(self, code, return_type=None): + def inject_function(self, fn_source_code, force=False): """ - Evaluate a vyper statement in the context of this contract. - Note that the return_type is necessary to correctly decode the result. - WARNING: This is different from the vyper eval() function, which is able - to automatically detect the return type. - :param code: A vyper statement. - :param return_type: The return type of the statement evaluation. + Inject a function into this VVM Contract without affecting the + contract's source code. useful for testing private functionality. + :param fn_source_code: The source code of the function to inject. + :param force: If True, the function will be injected even if it already exists. :returns: The result of the statement evaluation. """ - return VVMEval(code, self, return_type)() + fn = VVMInjectedFunction(fn_source_code, self) + if hasattr(self, fn.name) and not force: + raise ValueError(f"Function {fn.name} already exists on contract.") + setattr(self, fn.name, fn) + fn.contract = self @cached_property def _storage(self): @@ -204,12 +206,16 @@ class _VVMInternal(ABIFunction): @cached_property def _override_bytecode(self) -> bytes: + return to_bytes(self._compiler_output["bytecode_runtime"]) + + @cached_property + def _compiler_output(self): assert isinstance(self.contract, VVMContract) # help mypy source = "\n".join((self.contract.source_code, self.source_code)) compiled = cached_vvm.compile_source( source, vyper_version=self.contract.vyper_version ) - return to_bytes(compiled[""]["bytecode_runtime"]) + return compiled[""] @property def source_code(self) -> str: @@ -319,40 +325,26 @@ def __boa_private_{self.name}__({args_signature}) -> {self.return_type[0]}: """ -class VVMEval(_VVMInternal): +class VVMInjectedFunction(_VVMInternal): """ - A Vyper eval statement which can be used to evaluate vyper statements - via vvm-compiled contracts. This implementation has some drawbacks: - - It is very slow, as it requires the complete contract to be recompiled. - - It does not detect the return type, as it is currently not possible. - - It will temporarily change the bytecode at the contract's address. + A Vyper function that is injected into a VVM contract. + It will temporarily change the bytecode at the contract's address. """ - def __init__(self, code: str, contract: VVMContract, return_type: str = None): - abi = { - "anonymous": False, - "inputs": [], - "outputs": ([{"name": "eval", "type": return_type}] if return_type else []), - "name": "__boa_debug__", - "type": "function", - } - super().__init__(abi, contract.contract_name) + def __init__(self, code: str, contract: VVMContract): self.contract = contract self.code = code + abi = [i for i in self._compiler_output["abi"] if i not in contract.abi] + if len(abi) != 1: + err = "Expected exactly one new ABI entry after injecting function. " + err += f"Found {abi}." + raise ValueError(err) + + super().__init__(abi[0], contract.contract_name) @cached_property def source_code(self): - debug_body = self.code - return_sig = "" - if self.return_type: - return_sig = f"-> ({', '.join(self.return_type)})" - debug_body = f"return {self.code}" - return f""" -@external -@payable -def __boa_debug__() {return_sig}: - {debug_body} -""" + return self.code def _get_storage_variable_types(spec: dict) -> tuple[list[dict], str]: diff --git a/tests/unitary/contracts/vvm/test_vvm.py b/tests/unitary/contracts/vvm/test_vvm.py index 350f4521..ab29ea79 100644 --- a/tests/unitary/contracts/vvm/test_vvm.py +++ b/tests/unitary/contracts/vvm/test_vvm.py @@ -1,3 +1,5 @@ +import pytest + import boa mock_3_10_path = "tests/unitary/contracts/vvm/mock_3_10.vy" @@ -52,13 +54,35 @@ def test_vvm_internal(): assert contract._storage.hash_map.get(address, 0) == 69 -def test_vvm_eval(): +def test_vvm_inject_fn(): contract = boa.loads(mock_3_10_code, 43) - assert contract.eval("self.bar", "uint256") == 43 - assert contract.eval("self.bar = 44") is None + contract.inject_function( + """ +@external +def set_bar(bar: uint256): + self.bar = bar +""" + ) + assert contract.bar() == 43 + assert contract.set_bar(44) is None assert contract.bar() == 44 +def test_vvm_inject_fn_exists(): + contract = boa.loads(mock_3_10_code, 43) + code = """ +@external +def bytecode(): + assert False, "Function injected" +""" + with pytest.raises(ValueError) as e: + contract.inject_function(code) + assert "Function bytecode already exists" in str(e.value) + contract.inject_function(code, force=True) + with boa.reverts("Function injected"): + contract.bytecode() + + def test_forward_args_on_deploy(): with open(mock_3_10_path) as f: code = f.read()