From 9e5b76a2fac4065326bfb9b49817d2d41e877df5 Mon Sep 17 00:00:00 2001 From: enitrat Date: Wed, 11 Dec 2024 13:15:22 +0700 Subject: [PATCH] fix pydantic arg generation for nested dictionaries --- cairo/tests/fixtures/runner.py | 7 ++----- cairo/tests/src/test_state.cairo | 4 +--- cairo/tests/utils/hints.py | 17 ++++++++++------- cairo/tests/utils/serde.py | 5 ++++- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/cairo/tests/fixtures/runner.py b/cairo/tests/fixtures/runner.py index 120e4e75..6e7f9f33 100644 --- a/cairo/tests/fixtures/runner.py +++ b/cairo/tests/fixtures/runner.py @@ -126,11 +126,8 @@ def _factory(entrypoint, *args, **kwargs): output_ptr = runner.segments.add() stack.append(output_ptr) else: - stack.append( - gen_arg( - python_type, kwargs[arg_name] if arg_name in kwargs else args[i] - ) - ) + arg_value = kwargs[arg_name] if arg_name in kwargs else args[i] + stack.append(gen_arg(python_type, arg_value)) return_fp = runner.execution_base + 2 end = runner.program_base + len(runner.program.data) diff --git a/cairo/tests/src/test_state.cairo b/cairo/tests/src/test_state.cairo index b9334716..af3ee5ad 100644 --- a/cairo/tests/src/test_state.cairo +++ b/cairo/tests/src/test_state.cairo @@ -273,9 +273,7 @@ func test__add_transfer_should_return_false_when_overflowing_recipient_balance{ alloc_locals; let state = State.init(); let (code) = alloc(); - tempvar code_hash = new Uint256( - Constants.EMPTY_CODE_HASH_LOW, Constants.EMPTY_CODE_HASH_HIGH - ); + tempvar code_hash = new Uint256(Constants.EMPTY_CODE_HASH_LOW, Constants.EMPTY_CODE_HASH_HIGH); // Sender tempvar sender = 0x10001; diff --git a/cairo/tests/utils/hints.py b/cairo/tests/utils/hints.py index 72a426e8..509a2845 100644 --- a/cairo/tests/utils/hints.py +++ b/cairo/tests/utils/hints.py @@ -41,8 +41,8 @@ def gen_arg_pydantic( To be removed once all models are removed in favor of eels types. """ if isinstance(arg, Dict): - base = segments.add() - assert base.segment_index not in dict_manager.trackers + dict_ptr = segments.add() + assert dict_ptr.segment_index not in dict_manager.trackers data = { k: gen_arg_pydantic(dict_manager, segments, v, apply_modulo_to_args) @@ -51,13 +51,16 @@ def gen_arg_pydantic( if isinstance(arg, defaultdict): data = defaultdict(arg.default_factory, data) - dict_manager.trackers[base.segment_index] = DictTracker( - data=data, current_ptr=base + # This is required for tests where we read data from DictAccess segments while no dict method has been used. + # Equivalent to doing an initial dict_read of all keys. + initial_data = flatten([(k, v, v) for k, v in data.items()]) + segments.load_data(dict_ptr, initial_data) + current_ptr = dict_ptr + len(initial_data) + dict_manager.trackers[dict_ptr.segment_index] = DictTracker( + data=data, current_ptr=current_ptr ) - # In case of a dict, it's assumed that the struct **always** have consecutive dict_start, dict_ptr - # fields. - return base, base + return dict_ptr, current_ptr if isinstance(arg, Iterable): base = segments.add() diff --git a/cairo/tests/utils/serde.py b/cairo/tests/utils/serde.py index d987b708..3a4cde86 100644 --- a/cairo/tests/utils/serde.py +++ b/cairo/tests/utils/serde.py @@ -196,7 +196,8 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any: return kwargs if python_cls in (U256, Hash32, Bytes32): - value = kwargs["value"]["low"] + kwargs["value"]["high"] * 2**128 + # The inner uint256 has been serialized in `serialize_uint256` already. + value = kwargs["value"] if python_cls == U256: return U256(value) return python_cls(value.to_bytes(32, "little")) @@ -243,6 +244,8 @@ def serialize_scope(self, scope, scope_ptr): return self.serialize_block_kakarot(scope_ptr) if scope.path == ("src", "model", "model", "Option"): return self.serialize_option(scope_ptr) + if scope.path == ("starkware", "cairo", "common", "uint256", "Uint256"): + return self.serialize_uint256(scope_ptr) try: return self.serialize_type(scope.path, scope_ptr) except MissingIdentifierError: