diff --git a/chia/data_layer/data_layer_util.py b/chia/data_layer/data_layer_util.py index cb61b279c611..c0ec5c3b239b 100644 --- a/chia/data_layer/data_layer_util.py +++ b/chia/data_layer/data_layer_util.py @@ -1,7 +1,7 @@ from __future__ import annotations import dataclasses -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import IntEnum from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union @@ -135,6 +135,7 @@ class CommitState(IntEnum): Node = Union["TerminalNode", "InternalNode"] +@final @dataclass(frozen=True) class TerminalNode: hash: bytes32 @@ -142,7 +143,13 @@ class TerminalNode: key: bytes value: bytes - atom: None = field(init=False, default=None) + @classmethod + def from_key_value(cls, key: bytes, value: bytes) -> TerminalNode: + return cls( + hash=leaf_hash(key=key, value=value), + key=key, + value=value, + ) @property def pair(self) -> Tuple[CLVMStorage, CLVMStorage]: @@ -232,6 +239,7 @@ def valid(self) -> bool: return True +@final @dataclass(frozen=True) class InternalNode: hash: bytes32 @@ -239,8 +247,18 @@ class InternalNode: left_hash: bytes32 right_hash: bytes32 - pair: Optional[Tuple[Node, Node]] = None - atom: None = None + left: Optional[Node] = None + right: Optional[Node] = None + + @classmethod + def from_child_nodes(cls, left: Node, right: Node) -> InternalNode: + return cls( + hash=internal_hash(left_hash=left.hash, right_hash=right.hash), + left_hash=left.hash, + right_hash=right.hash, + left=left, + right=right, + ) @classmethod def from_row(cls, row: aiosqlite.Row) -> InternalNode: diff --git a/chia/data_layer/data_store.py b/chia/data_layer/data_store.py index 1d12593eb1c9..874b08c58fbb 100644 --- a/chia/data_layer/data_store.py +++ b/chia/data_layer/data_store.py @@ -1390,7 +1390,7 @@ async def get_node(self, node_hash: bytes32) -> Node: node = row_to_node(row=row) return node - async def get_tree_as_program(self, tree_id: bytes32) -> Program: + async def get_tree_as_nodes(self, tree_id: bytes32) -> Node: async with self.db_wrapper.reader() as reader: root = await self.get_tree_root(tree_id=tree_id) # TODO: consider actual proper behavior @@ -1414,13 +1414,12 @@ async def get_tree_as_program(self, tree_id: bytes32) -> Program: hash_to_node: Dict[bytes32, Node] = {} for node in reversed(nodes): if isinstance(node, InternalNode): - node = replace(node, pair=(hash_to_node[node.left_hash], hash_to_node[node.right_hash])) + node = replace(node, left=hash_to_node[node.left_hash], right=hash_to_node[node.right_hash]) hash_to_node[node.hash] = node root_node = hash_to_node[root_node.hash] - program = Program.to(root_node) - return program + return root_node async def get_proof_of_inclusion_by_hash( self, diff --git a/tests/core/data_layer/test_data_store.py b/tests/core/data_layer/test_data_store.py index 79b67535e548..1a7d4b02bbc6 100644 --- a/tests/core/data_layer/test_data_store.py +++ b/tests/core/data_layer/test_data_store.py @@ -253,7 +253,7 @@ async def test_build_a_tree( example = await create_example(data_store, tree_id) await _debug_dump(db=data_store.db_wrapper, description="final") - actual = await data_store.get_tree_as_program(tree_id=tree_id) + actual = await data_store.get_tree_as_nodes(tree_id=tree_id) # print("actual ", actual.as_python()) # print("expected", example.expected.as_python()) assert actual == example.expected @@ -774,7 +774,7 @@ async def test_delete_from_left_both_terminal(data_store: DataStore, tree_id: by ) await data_store.delete(key=b"\x04", tree_id=tree_id, hint_keys_values=hint_keys_values, status=Status.COMMITTED) - result = await data_store.get_tree_as_program(tree_id=tree_id) + result = await data_store.get_tree_as_nodes(tree_id=tree_id) assert result == expected @@ -812,7 +812,7 @@ async def test_delete_from_left_other_not_terminal(data_store: DataStore, tree_i await data_store.delete(key=b"\x04", tree_id=tree_id, hint_keys_values=hint_keys_values, status=Status.COMMITTED) await data_store.delete(key=b"\x05", tree_id=tree_id, hint_keys_values=hint_keys_values, status=Status.COMMITTED) - result = await data_store.get_tree_as_program(tree_id=tree_id) + result = await data_store.get_tree_as_nodes(tree_id=tree_id) assert result == expected @@ -852,7 +852,7 @@ async def test_delete_from_right_both_terminal(data_store: DataStore, tree_id: b ) await data_store.delete(key=b"\x03", tree_id=tree_id, hint_keys_values=hint_keys_values, status=Status.COMMITTED) - result = await data_store.get_tree_as_program(tree_id=tree_id) + result = await data_store.get_tree_as_nodes(tree_id=tree_id) assert result == expected @@ -890,7 +890,7 @@ async def test_delete_from_right_other_not_terminal(data_store: DataStore, tree_ await data_store.delete(key=b"\x03", tree_id=tree_id, hint_keys_values=hint_keys_values, status=Status.COMMITTED) await data_store.delete(key=b"\x02", tree_id=tree_id, hint_keys_values=hint_keys_values, status=Status.COMMITTED) - result = await data_store.get_tree_as_program(tree_id=tree_id) + result = await data_store.get_tree_as_nodes(tree_id=tree_id) assert result == expected diff --git a/tests/core/data_layer/util.py b/tests/core/data_layer/util.py index 89da59a0130c..e32dffafa3f2 100644 --- a/tests/core/data_layer/util.py +++ b/tests/core/data_layer/util.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from typing import IO, TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Union, overload -from chia.data_layer.data_layer_util import NodeType, Side, Status +from chia.data_layer.data_layer_util import InternalNode, Node, NodeType, Side, Status, TerminalNode from chia.data_layer.data_store import DataStore from chia.types.blockchain_format.program import Program from chia.types.blockchain_format.sized_bytes import bytes32 @@ -47,21 +47,19 @@ async def general_insert( @dataclass(frozen=True) class Example: - expected: Program + expected: Node terminal_nodes: List[bytes32] async def add_0123_example(data_store: DataStore, tree_id: bytes32) -> Example: - expected = Program.to( - ( - ( - (b"\x00", b"\x10\x00"), - (b"\x01", b"\x11\x01"), - ), - ( - (b"\x02", b"\x12\x02"), - (b"\x03", b"\x13\x03"), - ), + expected = InternalNode.from_child_nodes( + left=InternalNode.from_child_nodes( + left=TerminalNode.from_key_value(key=b"\x00", value=b"\x10\x00"), + right=TerminalNode.from_key_value(key=b"\x01", value=b"\x11\x01"), + ), + right=InternalNode.from_child_nodes( + left=TerminalNode.from_key_value(key=b"\x02", value=b"\x12\x02"), + right=TerminalNode.from_key_value(key=b"\x03", value=b"\x13\x03"), ), ) @@ -76,27 +74,25 @@ async def add_0123_example(data_store: DataStore, tree_id: bytes32) -> Example: async def add_01234567_example(data_store: DataStore, tree_id: bytes32) -> Example: - expected = Program.to( - ( - ( - ( - (b"\x00", b"\x10\x00"), - (b"\x01", b"\x11\x01"), - ), - ( - (b"\x02", b"\x12\x02"), - (b"\x03", b"\x13\x03"), - ), + expected = InternalNode.from_child_nodes( + left=InternalNode.from_child_nodes( + InternalNode.from_child_nodes( + left=TerminalNode.from_key_value(key=b"\x00", value=b"\x10\x00"), + right=TerminalNode.from_key_value(key=b"\x01", value=b"\x11\x01"), + ), + InternalNode.from_child_nodes( + left=TerminalNode.from_key_value(key=b"\x02", value=b"\x12\x02"), + right=TerminalNode.from_key_value(key=b"\x03", value=b"\x13\x03"), + ), + ), + right=InternalNode.from_child_nodes( + InternalNode.from_child_nodes( + left=TerminalNode.from_key_value(key=b"\x04", value=b"\x14\x04"), + right=TerminalNode.from_key_value(key=b"\x05", value=b"\x15\x05"), ), - ( - ( - (b"\x04", b"\x14\x04"), - (b"\x05", b"\x15\x05"), - ), - ( - (b"\x06", b"\x16\x06"), - (b"\x07", b"\x17\x07"), - ), + InternalNode.from_child_nodes( + left=TerminalNode.from_key_value(key=b"\x06", value=b"\x16\x06"), + right=TerminalNode.from_key_value(key=b"\x07", value=b"\x17\x07"), ), ), )