diff --git a/boa/contracts/vyper/vyper_contract.py b/boa/contracts/vyper/vyper_contract.py index 8f4a1fdc..4bf5e9db 100644 --- a/boa/contracts/vyper/vyper_contract.py +++ b/boa/contracts/vyper/vyper_contract.py @@ -100,8 +100,8 @@ def stomp(self, address: Any, data_section=None) -> "VyperContract": address = Address(address) ret = self.deploy(override_address=address, skip_initcode=True) - vm = ret.env.vm - old_bytecode = vm.state.get_code(address.canonical_address) + vm = ret.env.evm + old_bytecode = vm.get_code(address) new_bytecode = self.compiler_data.bytecode_runtime immutables_size = self.compiler_data.global_ctx.immutable_section_bytes @@ -109,7 +109,7 @@ def stomp(self, address: Any, data_section=None) -> "VyperContract": data_section = old_bytecode[-immutables_size:] new_bytecode += data_section - vm.state.set_code(address.canonical_address, new_bytecode) + vm.set_code(address, new_bytecode) ret.env.register_contract(address, ret) ret._set_bytecode(new_bytecode) return ret diff --git a/tests/unitary/contracts/vyper/test_vyper_contract.py b/tests/unitary/contracts/vyper/test_vyper_contract.py index 9e5e498d..a7b4c36b 100644 --- a/tests/unitary/contracts/vyper/test_vyper_contract.py +++ b/tests/unitary/contracts/vyper/test_vyper_contract.py @@ -99,3 +99,53 @@ def foo() -> bool: assert c.contract_name == "" assert c.filename == "" + +def test_stomp(): + code1 = """ +VAR: immutable(uint256) + +@deploy +def __init__(): + VAR = 12345 + +@external +def foo() -> uint256: + return VAR + +@external +def bar() -> bool: + return True + """ + code2 = """ +VAR: immutable(uint256) + +@deploy +def __init__(): + VAR = 12345 + +@external +def foo() -> uint256: + return VAR + +@external +def bar() -> bool: + return False + """ + + deployer = boa.loads_partial(code1) + + c = deployer.deploy() + + assert c.foo() == 12345 + assert c.bar() is True + + deployer2 = boa.loads_partial(code2) + + c2 = deployer2.stomp(c.address) + + assert c2.foo() == 12345 + assert c2.bar() is False + + # the bytecode at the original contract has been stomped :scream: + assert c.foo() == 12345 + assert c.bar() is False