Skip to content

Commit

Permalink
fix pydantic arg generation for nested dictionaries
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Dec 11, 2024
1 parent 4be3e70 commit 9e5b76a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 16 deletions.
7 changes: 2 additions & 5 deletions cairo/tests/fixtures/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,8 @@ def _factory(entrypoint, *args, **kwargs):
output_ptr = runner.segments.add()
stack.append(output_ptr)
else:
stack.append(
gen_arg(
python_type, kwargs[arg_name] if arg_name in kwargs else args[i]
)
)
arg_value = kwargs[arg_name] if arg_name in kwargs else args[i]
stack.append(gen_arg(python_type, arg_value))

return_fp = runner.execution_base + 2
end = runner.program_base + len(runner.program.data)
Expand Down
4 changes: 1 addition & 3 deletions cairo/tests/src/test_state.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,7 @@ func test__add_transfer_should_return_false_when_overflowing_recipient_balance{
alloc_locals;
let state = State.init();
let (code) = alloc();
tempvar code_hash = new Uint256(
Constants.EMPTY_CODE_HASH_LOW, Constants.EMPTY_CODE_HASH_HIGH
);
tempvar code_hash = new Uint256(Constants.EMPTY_CODE_HASH_LOW, Constants.EMPTY_CODE_HASH_HIGH);

// Sender
tempvar sender = 0x10001;
Expand Down
17 changes: 10 additions & 7 deletions cairo/tests/utils/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def gen_arg_pydantic(
To be removed once all models are removed in favor of eels types.
"""
if isinstance(arg, Dict):
base = segments.add()
assert base.segment_index not in dict_manager.trackers
dict_ptr = segments.add()
assert dict_ptr.segment_index not in dict_manager.trackers

data = {
k: gen_arg_pydantic(dict_manager, segments, v, apply_modulo_to_args)
Expand All @@ -51,13 +51,16 @@ def gen_arg_pydantic(
if isinstance(arg, defaultdict):
data = defaultdict(arg.default_factory, data)

dict_manager.trackers[base.segment_index] = DictTracker(
data=data, current_ptr=base
# This is required for tests where we read data from DictAccess segments while no dict method has been used.
# Equivalent to doing an initial dict_read of all keys.
initial_data = flatten([(k, v, v) for k, v in data.items()])
segments.load_data(dict_ptr, initial_data)
current_ptr = dict_ptr + len(initial_data)
dict_manager.trackers[dict_ptr.segment_index] = DictTracker(
data=data, current_ptr=current_ptr
)

# In case of a dict, it's assumed that the struct **always** have consecutive dict_start, dict_ptr
# fields.
return base, base
return dict_ptr, current_ptr

if isinstance(arg, Iterable):
base = segments.add()
Expand Down
5 changes: 4 additions & 1 deletion cairo/tests/utils/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
return kwargs

if python_cls in (U256, Hash32, Bytes32):
value = kwargs["value"]["low"] + kwargs["value"]["high"] * 2**128
# The inner uint256 has been serialized in `serialize_uint256` already.
value = kwargs["value"]
if python_cls == U256:
return U256(value)
return python_cls(value.to_bytes(32, "little"))
Expand Down Expand Up @@ -243,6 +244,8 @@ def serialize_scope(self, scope, scope_ptr):
return self.serialize_block_kakarot(scope_ptr)
if scope.path == ("src", "model", "model", "Option"):
return self.serialize_option(scope_ptr)
if scope.path == ("starkware", "cairo", "common", "uint256", "Uint256"):
return self.serialize_uint256(scope_ptr)
try:
return self.serialize_type(scope.path, scope_ptr)
except MissingIdentifierError:
Expand Down

0 comments on commit 9e5b76a

Please sign in to comment.