Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First crack at serialization with backrefs. #119

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions clvm/SExp.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def nullp(self):
def as_int(self):
return int_from_bytes(self.atom)

def as_bin(self):
def as_bin(self, *, allow_backrefs=False):
f = io.BytesIO()
sexp_to_stream(self, f)
sexp_to_stream(self, f, allow_backrefs=allow_backrefs)
return f.getvalue()

@classmethod
Expand Down
98 changes: 98 additions & 0 deletions clvm/object_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import hashlib


class ObjectCache:
"""
`ObjectCache` provides a way to calculate and cache values for each node
in a clvm object tree. It can be used to calculate the sha256 tree hash
for an object and save the hash for all the child objects for building
usage tables, for example.

It also allows a function that's defined recursively on a clvm tree to
have a non-recursive implementation (as it keeps a stack of uncached
objects locally).
"""

def __init__(self, f):
"""
`f`: Callable[ObjectCache, CLVMObject] -> Union[None, T]

The function `f` is expected to calculate its T value recursively based
on the T values for the left and right child for a pair. For an atom, the
function f must calculate the T value directly.

If a pair is passed and one of the children does not have its T value cached
in `ObjectCache` yet, return `None` and f will be called with each child in turn.
Don't recurse in f; that's part of the point of this function.
"""
self.f = f
self.lookup = dict()

def get(self, obj):
obj_id = id(obj)
if obj_id not in self.lookup:
obj_list = [obj]
while obj_list:
node = obj_list.pop()
node_id = id(node)
if node_id not in self.lookup:
v = self.f(self, node)
if v is None:
richardkiss marked this conversation as resolved.
Show resolved Hide resolved
if node.pair is None:
raise ValueError("f returned None for atom", node)
obj_list.append(node)
obj_list.append(node.pair[0])
obj_list.append(node.pair[1])
else:
self.lookup[node_id] = (v, node)
return self.lookup[obj_id][0]

def contains(self, obj):
return id(obj) in self.lookup


def treehash(cache, obj):
"""
This function can be fed to `ObjectCache` to calculate the sha256 tree
hash for all objects in a tree.
"""
if obj.pair:
left, right = obj.pair

# ensure both `left` and `right` have cached values
if cache.contains(left) and cache.contains(right):
left_hash = cache.get(left)
right_hash = cache.get(right)
return hashlib.sha256(b"\2" + left_hash + right_hash).digest()
return None
return hashlib.sha256(b"\1" + obj.atom).digest()


def serialized_length(cache, obj):
"""
This function can be fed to `ObjectCache` to calculate the serialized
length for all objects in a tree.
"""
if obj.pair:
left, right = obj.pair

# ensure both `left` and `right` have cached values
if cache.contains(left) and cache.contains(right):
left_length = cache.get(left)
right_length = cache.get(right)
return 1 + left_length + right_length
return None
lb = len(obj.atom)
if lb == 0 or (lb == 1 and obj.atom[0] < 128):
return 1
if lb < 0x40:
return 1 + lb
if lb < 0x2000:
return 2 + lb
if lb < 0x100000:
return 3 + lb
if lb < 0x8000000:
return 4 + lb
if lb < 0x400000000:
return 5 + lb
raise ValueError("atom of size %d too long" % lb)
178 changes: 178 additions & 0 deletions clvm/read_cache_lookup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from collections import Counter
from typing import Optional, List, Set, Tuple

import hashlib


LEFT = 0
RIGHT = 1


class ReadCacheLookup:
"""
When deserializing a clvm object, a stack of deserialized child objects
is created, which can be used with back-references. A `ReadCacheLookup` keeps
track of the state of this stack and all child objects under each root
node in the stack so that we can quickly determine if a relevant
back-reference is available.

In other words, if we've already serialized an object with tree hash T,
and we encounter another object with that tree hash, we don't re-serialize
it, but rather include a back-reference to it. This data structure lets
us quickly determine which back-reference has the shortest path.

Note that there is a counter. This is because the stack contains some
child objects that are transient, and no longer appear in the stack
at later times in the parsing. We don't want to waste time looking for
these objects that no longer exist, so we reference-count them.

All hashes correspond to sha256 tree hashes.
"""

def __init__(self):
"""
Create a new `ReadCacheLookup` object with just the null terminator
(ie. an empty list of objects).
"""
self.root_hash = hashlib.sha256(b"\1").digest()
self.read_stack = []
self.count = Counter()
self.parent_paths_for_child = {}

def push(self, obj_hash: bytes) -> None:
"""
This function is used to note that an object with the given hash has just
been pushed to the read stack, and update the lookups as appropriate.
"""
# we add two new entries: the new root of the tree, and this object (by id)
# new_root: (obj_hash, old_root)
new_root_hash = hashlib.sha256(b"\2" + obj_hash + self.root_hash).digest()

self.read_stack.append((obj_hash, self.root_hash))

self.count.update([obj_hash, new_root_hash])

new_parent_to_old_root = (new_root_hash, LEFT)
self.parent_paths_for_child.setdefault(obj_hash, list()).append(
new_parent_to_old_root
)

new_parent_to_id = (new_root_hash, RIGHT)
self.parent_paths_for_child.setdefault(self.root_hash, list()).append(
new_parent_to_id
)
self.root_hash = new_root_hash

def pop(self) -> Tuple[bytes, bytes]:
"""
This function is used to note that the top object has just been popped
from the read stack. Return the 2-tuple of the child hashes.
"""
item = self.read_stack.pop()
self.count[item[0]] -= 1
self.count[self.root_hash] -= 1
self.root_hash = item[1]
return item

def pop2_and_cons(self) -> None:
"""
This function is used to note that a "pop-and-cons" operation has just
happened. We remove two objects, cons them together, and push the cons,
updating the internal look-ups as necessary.
"""
# we remove two items: the right side of each left/right pair
right = self.pop()
left = self.pop()

self.count.update([left[0], right[0]])

new_root_hash = hashlib.sha256(b"\2" + left[0] + right[0]).digest()

self.parent_paths_for_child.setdefault(left[0], list()).append(
(new_root_hash, LEFT)
)
self.parent_paths_for_child.setdefault(right[0], list()).append(
(new_root_hash, RIGHT)
)
self.push(new_root_hash)

def find_paths(self, obj_hash: bytes, serialized_length: int) -> Set[bytes]:
"""
This function looks for a path from the root to a child node with a given hash
by using the read cache.
"""
valid_paths = set()
if serialized_length < 3:
return valid_paths

seen_ids = set()

max_bytes_for_path_encoding = serialized_length - 2
# 1 byte for 0xfe, 1 min byte for savings

max_path_length = max_bytes_for_path_encoding * 8 - 1
seen_ids.add(obj_hash)

partial_paths = [(obj_hash, [])]

while partial_paths:
new_seen_ids = set(seen_ids)
new_partial_paths = []
for (node, path) in partial_paths:
if node == self.root_hash:
valid_paths.add(reversed_path_to_bytes(path))
continue

parent_paths = self.parent_paths_for_child.get(node)

if parent_paths:
for (parent, direction) in parent_paths:
if self.count[parent] > 0 and parent not in seen_ids:
new_path = list(path)
new_path.append(direction)
if len(new_path) > max_path_length:
return set()
new_partial_paths.append((parent, new_path))
new_seen_ids.add(parent)
partial_paths = new_partial_paths
if valid_paths:
return valid_paths
seen_ids = frozenset(new_seen_ids)
return valid_paths

def find_path(self, obj_hash: bytes, serialized_length: int) -> Optional[bytes]:
r = self.find_paths(obj_hash, serialized_length)
return min(r) if len(r) > 0 else None


def reversed_path_to_bytes(path: List[int]) -> bytes:
"""
Convert a list of 0/1 (for left/right) values to a path expected by clvm.

Reverse the list; convert to a binary number; prepend a 1; break into bytes.

[] => bytes([0b1])
[0] => bytes([0b10])
[1] => bytes([0b11])
[0, 0] => bytes([0b100])
[0, 1] => bytes([0b101])
[1, 0] => bytes([0b110])
[1, 1] => bytes([0b111])
[0, 0, 1] => bytes([0b1001])
[1, 1, 1, 1, 0, 0, 0, 0, 1] => bytes([0b11, 0b11100001])
"""
richardkiss marked this conversation as resolved.
Show resolved Hide resolved

byte_count = (len(path) + 1 + 7) >> 3
v = bytearray(byte_count)
index = byte_count - 1
mask = 1
for p in reversed(path):
if p:
v[index] |= mask
if mask == 0x80:
index -= 1
mask = 1
else:
mask <<= 1
v[index] |= mask
return bytes(v)
Loading