From f58c33cde50f6deaaeefff1136b18f92ef747c6b Mon Sep 17 00:00:00 2001 From: Patrick Collins <54278053+PatrickAlphaC@users.noreply.github.com> Date: Fri, 25 Oct 2024 13:16:46 -0400 Subject: [PATCH] fix: pass contract_name to VyperContract (#338) add contract_name to VyperContract ctor. allows setting contract_name at deploy time. --------- Co-authored-by: Charles Cooper --- boa/contracts/vvm/vvm_contract.py | 3 +- boa/contracts/vyper/vyper_contract.py | 13 +++++--- boa/interpret.py | 18 +++++++---- dev-requirements.txt | 1 + .../network/anvil/test_network_env.py | 32 ++++++++++++++++++- .../network/sepolia/test_sepolia_env.py | 4 ++- .../contracts/vyper/test_vyper_contract.py | 27 ++++++++++++++++ tests/unitary/utils/test_cache.py | 8 ++--- 8 files changed, 87 insertions(+), 19 deletions(-) diff --git a/boa/contracts/vvm/vvm_contract.py b/boa/contracts/vvm/vvm_contract.py index 6eedc2b8..87038a2d 100644 --- a/boa/contracts/vvm/vvm_contract.py +++ b/boa/contracts/vvm/vvm_contract.py @@ -54,7 +54,7 @@ def constructor(self): return ABIFunction(t, contract_name=self.filename) return None - def deploy(self, *args, env=None, **kwargs): + def deploy(self, *args, contract_name=None, env=None, **kwargs): encoded_args = b"" if self.constructor is not None: encoded_args = self.constructor.prepare_calldata(*args) @@ -66,6 +66,7 @@ def deploy(self, *args, env=None, **kwargs): address, _ = env.deploy_code(bytecode=self.bytecode + encoded_args, **kwargs) + # TODO: pass thru contract_name return self.at(address) @cached_property diff --git a/boa/contracts/vyper/vyper_contract.py b/boa/contracts/vyper/vyper_contract.py index d7027a25..4bf5e9db 100644 --- a/boa/contracts/vyper/vyper_contract.py +++ b/boa/contracts/vyper/vyper_contract.py @@ -144,10 +144,13 @@ class _BaseVyperContract(_BaseEVMContract): def __init__( self, compiler_data: CompilerData, + contract_name: Optional[str] = None, env: Optional[Env] = None, filename: Optional[str] = None, ): - contract_name = Path(compiler_data.contract_path).stem + if contract_name is None: + contract_name = Path(compiler_data.contract_path).stem + super().__init__(contract_name, env, filename) self.compiler_data = compiler_data @@ -185,12 +188,11 @@ def __init__( env=None, override_address=None, blueprint_preamble=None, + contract_name=None, filename=None, gas=None, ): - # note slight code duplication with VyperContract ctor, - # maybe use common base class? - super().__init__(compiler_data, env, filename) + super().__init__(compiler_data, contract_name, env, filename) deploy_bytecode = generate_blueprint_bytecode( compiler_data.bytecode, blueprint_preamble @@ -516,10 +518,11 @@ def __init__( # whether to skip constructor skip_initcode=False, created_from: Address = None, + contract_name=None, filename: str = None, gas=None, ): - super().__init__(compiler_data, env, filename) + super().__init__(compiler_data, contract_name, env, filename) self.created_from = created_from self._computation = None diff --git a/boa/interpret.py b/boa/interpret.py index 17136d3c..8870b168 100644 --- a/boa/interpret.py +++ b/boa/interpret.py @@ -129,11 +129,15 @@ def get_module_fingerprint( def compiler_data( - source_code: str, contract_name: str, filename: str | Path, deployer=None, **kwargs + source_code: str, + contract_name: str | None, + filename: str | Path, + deployer=None, + **kwargs, ) -> CompilerData: global _disk_cache, _search_path - path = Path(contract_name) + path = Path(filename) resolved_path = Path(filename).resolve(strict=False) file_input = FileInput( @@ -164,7 +168,7 @@ def get_compiler_data(): assert isinstance(deployer, type) or deployer is None deployer_id = repr(deployer) # a unique str identifying the deployer class - cache_key = str((contract_name, fingerprint, kwargs, deployer_id)) + cache_key = str((contract_name, filename, fingerprint, kwargs, deployer_id)) return _disk_cache.caching_lookup(cache_key, get_compiler_data) @@ -188,9 +192,9 @@ def loads( ): d = loads_partial(source_code, name, filename=filename, compiler_args=compiler_args) if as_blueprint: - return d.deploy_as_blueprint(**kwargs) + return d.deploy_as_blueprint(contract_name=name, **kwargs) else: - return d.deploy(*args, **kwargs) + return d.deploy(*args, contract_name=name, **kwargs) def load_abi(filename: str, *args, name: str = None, **kwargs) -> ABIContractFactory: @@ -239,8 +243,8 @@ def loads_partial( dedent: bool = True, compiler_args: dict = None, ) -> VyperDeployer: - name = name or "VyperContract" - filename = filename or "" + if filename is None: + filename = "" if dedent: source_code = textwrap.dedent(source_code) diff --git a/dev-requirements.txt b/dev-requirements.txt index 93d94bad..eeb2522a 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -12,6 +12,7 @@ pytest pytest-xdist pytest-cov sphinx-rtd-theme +requests-cache # jupyter jupyter_server diff --git a/tests/integration/network/anvil/test_network_env.py b/tests/integration/network/anvil/test_network_env.py index 0bfbd397..18f18b46 100644 --- a/tests/integration/network/anvil/test_network_env.py +++ b/tests/integration/network/anvil/test_network_env.py @@ -73,9 +73,38 @@ def test_failed_transaction(): # XXX: probably want to test deployment revert behavior -def test_deployment_db(): +def test_deployment_db_overriden_contract_name(): with set_deployments_db(DeploymentsDB(":memory:")) as db: arg = 5 + contract_name = "test_deployment" + + # contract is written to deployments db + contract = boa.loads(code, arg, contract_name=contract_name) + + # test get_deployments() + deployment = next(db.get_deployments()) + + initcode = contract.compiler_data.bytecode + arg.to_bytes(32, "big") + + # sanity check all the fields + assert deployment.contract_address == contract.address + assert deployment.contract_name == contract.contract_name + assert deployment.contract_name == contract_name + assert deployment.deployer == boa.env.eoa + assert deployment.rpc == boa.env._rpc.name + assert deployment.source_code == contract.deployer.solc_json + assert deployment.abi == contract.abi + + # some sanity checks on tx_dict and rx_dict fields + assert to_bytes(deployment.tx_dict["data"]) == initcode + assert deployment.tx_dict["chainId"] == hex(boa.env.get_chain_id()) + assert Address(deployment.receipt_dict["contractAddress"]) == contract.address + + +def test_deployment_db_no_overriden_name(): + with set_deployments_db(DeploymentsDB(":memory:")) as db: + arg = 5 + non_contract_name = "test_deployment" # contract is written to deployments db contract = boa.loads(code, arg) @@ -88,6 +117,7 @@ def test_deployment_db(): # sanity check all the fields assert deployment.contract_address == contract.address assert deployment.contract_name == contract.contract_name + assert deployment.contract_name != non_contract_name assert deployment.deployer == boa.env.eoa assert deployment.rpc == boa.env._rpc.name assert deployment.source_code == contract.deployer.solc_json diff --git a/tests/integration/network/sepolia/test_sepolia_env.py b/tests/integration/network/sepolia/test_sepolia_env.py index 0d00e631..7801f99c 100644 --- a/tests/integration/network/sepolia/test_sepolia_env.py +++ b/tests/integration/network/sepolia/test_sepolia_env.py @@ -75,9 +75,10 @@ def test_raise_exception(simple_contract, amount): def test_deployment_db(): with set_deployments_db(DeploymentsDB(":memory:")) as db: arg = 5 + contract_name = "test_deployment" # contract is written to deployments db - contract = boa.loads(code, arg) + contract = boa.loads(code, arg, contract_name=contract_name) # test get_deployments() deployment = next(db.get_deployments()) @@ -87,6 +88,7 @@ def test_deployment_db(): # sanity check all the fields assert deployment.contract_address == contract.address assert deployment.contract_name == contract.contract_name + assert deployment.contract_name == contract_name assert deployment.deployer == boa.env.eoa assert deployment.rpc == boa.env._rpc.name assert deployment.source_code == contract.deployer.solc_json diff --git a/tests/unitary/contracts/vyper/test_vyper_contract.py b/tests/unitary/contracts/vyper/test_vyper_contract.py index 70370082..4fd8e70a 100644 --- a/tests/unitary/contracts/vyper/test_vyper_contract.py +++ b/tests/unitary/contracts/vyper/test_vyper_contract.py @@ -74,6 +74,33 @@ def foo() -> bool: c.foo() +def test_contract_name(): + code = """ +@external +def foo() -> bool: + return True + """ + c = boa.loads(code, name="return_one", filename="return_one.vy") + + assert c.contract_name == "return_one" + assert c.filename == "return_one.vy" + + c = boa.loads(code, filename="a/b/return_one.vy") + + assert c.contract_name == "return_one" + assert c.filename == "a/b/return_one.vy" + + c = boa.loads(code, filename=None, name="dummy_name") + + assert c.contract_name == "dummy_name" + assert c.filename == "" + + c = boa.loads(code, filename=None, name=None) + + assert c.contract_name == "" + assert c.filename == "" + + def test_stomp(): code1 = """ VAR: immutable(uint256) diff --git a/tests/unitary/utils/test_cache.py b/tests/unitary/utils/test_cache.py index 64659821..ad58eec5 100644 --- a/tests/unitary/utils/test_cache.py +++ b/tests/unitary/utils/test_cache.py @@ -22,12 +22,12 @@ def test_cache_contract_name(): x: constant(int128) = 1000 """ assert _disk_cache is not None - test1 = compiler_data(code, "test1", __file__, VyperDeployer) - test2 = compiler_data(code, "test2", __file__, VyperDeployer) - test3 = compiler_data(code, "test1", __file__, VyperDeployer) + test1 = compiler_data(code, "test1", "test1.vy", VyperDeployer) + test2 = compiler_data(code, "test2", "test2.vy", VyperDeployer) + test3 = compiler_data(code, "test1", "test1.vy", VyperDeployer) assert _to_dict(test1) == _to_dict(test3), "Should hit the cache" assert _to_dict(test1) != _to_dict(test2), "Should be different objects" - assert str(test2.contract_path) == "test2" + assert str(test2.contract_path) == "test2.vy" def test_cache_vvm():