From b6c61cb3cba49afe786f5a3a4d7c67fa99a9306a Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 24 Oct 2024 21:00:44 -0400 Subject: [PATCH] fix contract_path/contract_name API --- boa/contracts/vvm/vvm_contract.py | 3 ++- boa/interpret.py | 22 ++++++++----------- .../contracts/vyper/test_vyper_contract.py | 2 +- tests/unitary/utils/test_cache.py | 8 +++---- 4 files changed, 16 insertions(+), 19 deletions(-) diff --git a/boa/contracts/vvm/vvm_contract.py b/boa/contracts/vvm/vvm_contract.py index 6eedc2b8..87038a2d 100644 --- a/boa/contracts/vvm/vvm_contract.py +++ b/boa/contracts/vvm/vvm_contract.py @@ -54,7 +54,7 @@ def constructor(self): return ABIFunction(t, contract_name=self.filename) return None - def deploy(self, *args, env=None, **kwargs): + def deploy(self, *args, contract_name=None, env=None, **kwargs): encoded_args = b"" if self.constructor is not None: encoded_args = self.constructor.prepare_calldata(*args) @@ -66,6 +66,7 @@ def deploy(self, *args, env=None, **kwargs): address, _ = env.deploy_code(bytecode=self.bytecode + encoded_args, **kwargs) + # TODO: pass thru contract_name return self.at(address) @cached_property diff --git a/boa/interpret.py b/boa/interpret.py index dd57133c..8870b168 100644 --- a/boa/interpret.py +++ b/boa/interpret.py @@ -129,11 +129,15 @@ def get_module_fingerprint( def compiler_data( - source_code: str, contract_name: str, filename: str | Path, deployer=None, **kwargs + source_code: str, + contract_name: str | None, + filename: str | Path, + deployer=None, + **kwargs, ) -> CompilerData: global _disk_cache, _search_path - path = Path(contract_name) + path = Path(filename) resolved_path = Path(filename).resolve(strict=False) file_input = FileInput( @@ -164,7 +168,7 @@ def get_compiler_data(): assert isinstance(deployer, type) or deployer is None deployer_id = repr(deployer) # a unique str identifying the deployer class - cache_key = str((contract_name, fingerprint, kwargs, deployer_id)) + cache_key = str((contract_name, filename, fingerprint, kwargs, deployer_id)) return _disk_cache.caching_lookup(cache_key, get_compiler_data) @@ -188,9 +192,9 @@ def loads( ): d = loads_partial(source_code, name, filename=filename, compiler_args=compiler_args) if as_blueprint: - return d.deploy_as_blueprint(**kwargs) + return d.deploy_as_blueprint(contract_name=name, **kwargs) else: - return d.deploy(*args, **kwargs) + return d.deploy(*args, contract_name=name, **kwargs) def load_abi(filename: str, *args, name: str = None, **kwargs) -> ABIContractFactory: @@ -242,14 +246,6 @@ def loads_partial( if filename is None: filename = "" - if name is None: - if isinstance(filename, Path) or ( - isinstance(filename, str) and filename != "" - ): - name = Path(filename).stem - else: - name = "VyperContract" - if dedent: source_code = textwrap.dedent(source_code) diff --git a/tests/unitary/contracts/vyper/test_vyper_contract.py b/tests/unitary/contracts/vyper/test_vyper_contract.py index 6734377d..9e5e498d 100644 --- a/tests/unitary/contracts/vyper/test_vyper_contract.py +++ b/tests/unitary/contracts/vyper/test_vyper_contract.py @@ -97,5 +97,5 @@ def foo() -> bool: c = boa.loads(code, filename=None, name=None) - assert c.contract_name == "VyperContract" + assert c.contract_name == "" assert c.filename == "" diff --git a/tests/unitary/utils/test_cache.py b/tests/unitary/utils/test_cache.py index 64659821..ad58eec5 100644 --- a/tests/unitary/utils/test_cache.py +++ b/tests/unitary/utils/test_cache.py @@ -22,12 +22,12 @@ def test_cache_contract_name(): x: constant(int128) = 1000 """ assert _disk_cache is not None - test1 = compiler_data(code, "test1", __file__, VyperDeployer) - test2 = compiler_data(code, "test2", __file__, VyperDeployer) - test3 = compiler_data(code, "test1", __file__, VyperDeployer) + test1 = compiler_data(code, "test1", "test1.vy", VyperDeployer) + test2 = compiler_data(code, "test2", "test2.vy", VyperDeployer) + test3 = compiler_data(code, "test1", "test1.vy", VyperDeployer) assert _to_dict(test1) == _to_dict(test3), "Should hit the cache" assert _to_dict(test1) != _to_dict(test2), "Should be different objects" - assert str(test2.contract_path) == "test2" + assert str(test2.contract_path) == "test2.vy" def test_cache_vvm():