From 2c3cc2ee31de1ede631468e4b00a78f084feffa9 Mon Sep 17 00:00:00 2001 From: Salvatore Ingala <6681844+bigspider@users.noreply.github.com> Date: Wed, 19 Jun 2024 11:25:29 +0200 Subject: [PATCH] Generalize txmaker to most supported policies --- test_utils/taproot.py | 50 ++++ test_utils/txmaker.py | 225 ++++++++++------ test_utils/wallet_policy.py | 414 ++++++++++++++++++++++++++++++ tests_perf/test_perf_sign_psbt.py | 118 +++++++-- 4 files changed, 715 insertions(+), 92 deletions(-) create mode 100644 test_utils/taproot.py create mode 100644 test_utils/wallet_policy.py diff --git a/test_utils/taproot.py b/test_utils/taproot.py new file mode 100644 index 00000000..0ba25a2a --- /dev/null +++ b/test_utils/taproot.py @@ -0,0 +1,50 @@ +# from portions of BIP-0341 +# - https://github.com/bitcoin/bips/blob/b3701faef2bdb98a0d7ace4eedbeefa2da4c89ed/bip-0341.mediawiki +# Distributed under the BSD-3-Clause license + +# fmt: off + +# If you want to print values on an individual basis, use +# the pretty() function, e.g., print(pretty(foo)). +import hashlib +import struct + + +# This implementation can be sped up by storing the midstate after hashing +# tag_hash instead of rehashing it all the time. +def tagged_hash(tag: str, msg: bytes) -> bytes: + tag_hash = hashlib.sha256(tag.encode()).digest() + return hashlib.sha256(tag_hash + tag_hash + msg).digest() + + +def ser_compact_size(l): + r = b"" + if l < 253: + r = struct.pack("B", l) + elif l < 0x10000: + r = struct.pack(" bytes: return random_bytes(32) +def random_p2tr() -> bytes: + """Returns 32 random bytes. Not cryptographically secure.""" + privkey = random_bytes(32) + pubkey = bip0340.point_mul(bip0340.G, int.from_bytes(privkey, 'big')) + + return b'\x51\x20' + (pubkey[0]).to_bytes(32, 'big') + + def getScriptPubkeyFromWallet(wallet: WalletPolicy, change: bool, address_index: int) -> Script: descriptor_str = wallet.descriptor_template @@ -116,27 +136,109 @@ def createFakeWalletTransaction(n_inputs: int, n_outputs: int, output_amount: in return tx, selected_output_index, selected_output_change, selected_output_address_index -def createPsbt(wallet: WalletPolicy, input_amounts: List[int], output_amounts: List[int], output_is_change: List[bool], output_wallet: Optional[List[Optional[WalletPolicy]]] = None) -> PSBT: - if output_wallet is None: - output_wallet = [None] * len(output_amounts) +def get_placeholder_root_key(placeholder: KeyPlaceholder, keys_info: List[str]) -> Tuple[ExtendedKey, Optional[KeyOriginInfo]]: + if isinstance(placeholder, PlainKeyPlaceholder): + key_info = keys_info[placeholder.key_index] + key_origin_end_pos = key_info.find("]") + if key_origin_end_pos == -1: + xpub = key_info + root_key_origin = None + else: + xpub = key_info[key_origin_end_pos+1:] + root_key_origin = KeyOriginInfo.from_string( + key_info[1:key_origin_end_pos]) + root_pubkey = ExtendedKey.deserialize(xpub) + else: + raise ValueError("Unsupported placeholder type") + + return root_pubkey, root_key_origin - assert len(output_amounts) == len(output_is_change) - assert len(output_amounts) == len(output_wallet) - assert sum(output_amounts) <= sum(input_amounts) - # TODO: add support for wrapped segwit wallets +def fill_inout(wallet_policy: WalletPolicy, inout: Union[PartiallySignedInput, PartiallySignedOutput], is_change: bool, address_index: int): + desc_tmpl = DescriptorTemplate.from_string( + wallet_policy.descriptor_template) - if wallet.n_keys != 1: - raise NotImplementedError("Only 1-key wallets supported") - if wallet.version == WalletType.WALLET_POLICY_V1: - if wallet.descriptor_template not in ["pkh(@0)", "wpkh(@0)", "tr(@0)"]: - raise NotImplementedError("Unsupported policy type") - elif wallet.version == WalletType.WALLET_POLICY_V2: - if wallet.descriptor_template not in ["pkh(@0/**)", "wpkh(@0/**)", "tr(@0/**)"]: - raise NotImplementedError("Unsupported policy type") + if isinstance(desc_tmpl, TrDescriptorTemplate): + keypath_der_subpath = [ + desc_tmpl.key.num1 if not is_change else desc_tmpl.key.num2, + address_index + ] + + keypath_pubkey, _ = get_placeholder_root_key( + desc_tmpl.key, wallet_policy.keys_info) + + inout.tap_internal_key = keypath_pubkey.derive_pub_path( + keypath_der_subpath).pubkey[1:] + + if desc_tmpl.tree is not None: + inout.tap_merkle_root = desc_tmpl.get_taptree_hash( + wallet_policy.keys_info, is_change, address_index) + + for placeholder, tapleaf_desc in desc_tmpl.placeholders(): + root_pubkey, root_pubkey_origin = get_placeholder_root_key( + placeholder, wallet_policy.keys_info) + + placeholder_der_subpath = [ + placeholder.num1 if not is_change else placeholder.num2, + address_index + ] + + leaf_script = None + if tapleaf_desc is not None: + leaf_desc = derive_plain_descriptor( + tapleaf_desc, wallet_policy.keys_info, is_change, address_index) + s = BytesIO(leaf_desc.encode()) + desc: Miniscript = Miniscript.read_from(s, taproot=True) + leaf_script = desc.compile() + + derived_pubkey = root_pubkey.derive_pub_path( + placeholder_der_subpath) + + if root_pubkey_origin is not None: + derived_key_origin = KeyOriginInfo( + root_pubkey_origin.fingerprint, root_pubkey_origin.path + placeholder_der_subpath) + + leaf_hashes = [] + if leaf_script is not None: + # In BIP-388 compliant wallet policies, there will be only one tapleaf with a given key + leaf_hashes = [tapleaf_hash(leaf_script)] + + inout.tap_bip32_paths[derived_pubkey.pubkey[1:]] = ( + leaf_hashes, derived_key_origin) else: - raise ValueError( - f"Unknown wallet policy version: {wallet.version}") + if isinstance(desc_tmpl, WshDescriptorTemplate): + # add witnessScript + desc_str = derive_plain_descriptor( + wallet_policy.descriptor_template, wallet_policy.keys_info, is_change, address_index) + s = BytesIO(desc_str.encode()) + desc: Descriptor = Descriptor.read_from(s) + inout.witness_script = desc.witness_script().data + + for placeholder, _ in desc_tmpl.placeholders(): + root_pubkey, root_pubkey_origin = get_placeholder_root_key( + placeholder, wallet_policy.keys_info) + + placeholder_der_subpath = [ + placeholder.num1 if not is_change else placeholder.num2, + address_index + ] + + derived_pubkey = root_pubkey.derive_pub_path( + placeholder_der_subpath) + + if root_pubkey_origin is not None: + derived_key_origin = KeyOriginInfo( + root_pubkey_origin.fingerprint, root_pubkey_origin.path + placeholder_der_subpath) + + inout.hd_keypaths[derived_pubkey.pubkey] = derived_key_origin + + +def createPsbt(wallet_policy: WalletPolicy, input_amounts: List[int], output_amounts: List[int], output_is_change: List[bool]) -> PSBT: + assert output_is_change.count( + True) <= 1, "At most one change output is supported" + + assert len(output_amounts) == len(output_is_change) + assert sum(output_amounts) <= sum(input_amounts) vin: List[CTxIn] = [CTxIn() for _ in input_amounts] vout: List[CTxOut] = [CTxOut() for _ in output_amounts] @@ -150,7 +252,7 @@ def createPsbt(wallet: WalletPolicy, input_amounts: List[int], output_amounts: L n_inputs = randint(1, 10) n_outputs = randint(1, 10) prevout, idx, is_change, addr_idx = createFakeWalletTransaction( - n_inputs, n_outputs, prevout_amount, wallet) + n_inputs, n_outputs, prevout_amount, wallet_policy) prevouts.append(prevout) prevout_ns.append(idx) prevout_path_change.append(is_change) @@ -168,68 +270,43 @@ def createPsbt(wallet: WalletPolicy, input_amounts: List[int], output_amounts: L tx.vout = vout tx.wit = CTxWitness() + change_address_index = randint(0, 10_000) + for i, output_amount in enumerate(output_amounts): + tx.vout[i].nValue = output_amount + if output_is_change[i]: + script = getScriptPubkeyFromWallet( + wallet_policy, output_is_change[i], change_address_index) + + tx.vout[i].scriptPubKey = script.data + else: + # a random P2TR output + tx.vout[i].scriptPubKey = random_p2tr() + psbt.inputs = [PartiallySignedInput(0) for _ in input_amounts] psbt.outputs = [PartiallySignedOutput(0) for _ in output_amounts] - # simplification; good enough for the scripts we support now, but will need more work - is_legacy = wallet.descriptor_template.startswith("pkh(") - is_segwitv0 = wallet.descriptor_template.startswith( - "wpkh(") or wallet.descriptor_template.startswith("sh(wpkh(") - is_taproot = wallet.descriptor_template.startswith("tr(") - - key_origin = wallet.keys_info[0][1:wallet.keys_info[0].index("]")] + desc_tmpl = DescriptorTemplate.from_string( + wallet_policy.descriptor_template) - for i in range(len(input_amounts)): - if is_legacy or is_segwitv0: - # add non-witness UTXO - psbt.inputs[i].non_witness_utxo = prevouts[i] - if is_segwitv0 or is_taproot: + for input_index, input in enumerate(psbt.inputs): + if desc_tmpl.is_segwit(): # add witness UTXO - psbt.inputs[i].witness_utxo = prevouts[i].vout[prevout_ns[i]] - - path_str = f"m{key_origin[8:]}/{prevout_path_change[i]}/{prevout_path_addr_idx[i]}" - path = parse_path(path_str) - input_key: bytes = master_key.derive(path_str).key.sec() - - assert len(input_key) == 33 - - # add key and path info - if is_legacy or is_segwitv0: - psbt.inputs[i].hd_keypaths[input_key] = KeyOriginInfo( - master_key_fpr, path) - elif is_taproot: - internal_key = input_key[1:] - psbt.inputs[i].tap_bip32_paths[internal_key] = ( - {}, KeyOriginInfo(master_key_fpr, path)) - else: - raise RuntimeError("Unexpected state: unknown transaction type") + input.witness_utxo = prevouts[input_index].vout[prevout_ns[input_index]] - for i, output_amount in enumerate(output_amounts): - wallet_i = output_wallet[i] - if output_is_change[i] or wallet_i is None: - script = getScriptPubkeyFromWallet(wallet, output_is_change[i], i) - else: - script = getScriptPubkeyFromWallet(wallet_i, 0, i) + if desc_tmpl.is_legacy() or (desc_tmpl.is_segwit() and not desc_tmpl.is_taproot()): + # add non_witness_utxo for legacy or segwitv0 + input.non_witness_utxo = prevouts[input_index] - tx.vout[i].scriptPubKey = script.data - tx.vout[i].nValue = output_amount + is_change = bool(prevout_path_change[input_index]) + address_index = prevout_path_addr_idx[input_index] - if output_is_change[i]: - path_str = f"m{key_origin[8:]}/1/{i}" - path = parse_path(path_str) - output_key: bytes = master_key.derive(path_str).key.sec() - - # add key and path information for change output - if is_legacy or is_segwitv0: - psbt.outputs[i].hd_keypaths[output_key] = KeyOriginInfo( - master_key_fpr, path) - elif is_taproot: - internal_key = output_key[1:] - psbt.outputs[i].tap_bip32_paths[internal_key] = ( - {}, KeyOriginInfo(master_key_fpr, path)) - - psbt.outputs[i].tap_bip32_paths[internal_key] = ( - {}, KeyOriginInfo(master_key_fpr, path)) + fill_inout(wallet_policy, input, is_change, address_index) + + # only for the change output, we need to do the same + for output_index, output in enumerate(psbt.outputs): + if output_is_change[output_index]: + fill_inout(wallet_policy, output, is_change=True, + address_index=change_address_index) psbt.tx = tx diff --git a/test_utils/wallet_policy.py b/test_utils/wallet_policy.py new file mode 100644 index 00000000..07413faf --- /dev/null +++ b/test_utils/wallet_policy.py @@ -0,0 +1,414 @@ +# This is a partial implementation of BIP-0388: https://github.com/bitcoin/bips/blob/master/bip-0388.mediawiki +# It is used to manipulate wallet policies, but it has incomplete error checking and does not support all the +# possible types of descriptor templates from the BIP. +# Only to be used for testing purposes. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from io import BytesIO +import re +from typing import Iterator, List, Optional, Tuple, Type, Union + +from ledger_bitcoin.embit.descriptor.miniscript import Miniscript +from ledger_bitcoin.key import ExtendedKey + +from .taproot import ser_script, tagged_hash + + +def tapleaf_hash(script: Optional[bytes], leaf_version=b'\xC0') -> Optional[bytes]: + if script is None: + return None + return tagged_hash( + "TapLeaf", + leaf_version + ser_script(script) + ) + + +@dataclass +class PlainKeyPlaceholder: + key_index: int + num1: int + num2: int + + +# future extensions will have multiple subtypes (e.g.: MuSig2KeyPlaceholder) +KeyPlaceholder = PlainKeyPlaceholder + + +def parse_placeholder(placeholder_str: str) -> KeyPlaceholder: + """Parses a placeholder string to create a KeyPlaceholder object.""" + if placeholder_str.startswith('@'): + parts = placeholder_str.split('/') + key_index = int(parts[0].strip('@')) + + # Remove '<' from the start and '>' from the end + nums_part = parts[1][1:-1] + num1, num2 = map(int, nums_part.split(';')) + + return PlainKeyPlaceholder(key_index, num1, num2) + else: + raise ValueError("Invalid placeholder string") + + +def extract_placeholders(desc_tmpl: str) -> List[KeyPlaceholder]: + """Extracts and parses all placeholders in a descriptor template, from left to right.""" + + pattern = r'musig\((?:@\d+,)*(?:@\d+)\)/<\d+;\d+>/\*|@\d+/<\d+;\d+>/\*' + matches = [(match.group(), match.start()) + for match in re.finditer(pattern, desc_tmpl)] + sorted_matches = sorted(matches, key=lambda x: x[1]) + return [parse_placeholder(match[0]) for match in sorted_matches] + + +def derive_from_key_info(key_info: str, steps: List[int]) -> str: + start = key_info.find(']') + pk = ExtendedKey.deserialize(key_info[start + 1:]) + return pk.derive_pub_path(steps).to_string() + + +def derive_plain_descriptor(desc_tmpl: str, keys_info: List[str], is_change: bool, address_index: int): + """ + Given a wallet policy, and the change/address_index combination, computes the corresponding descriptor. + It replaces /** with /<0;1>/* + It also replaces each musig() key expression with the corresponding xpub. + The resulting descriptor can be used with descriptor libraries that do not support musig or wallet policies. + """ + + desc_tmpl = desc_tmpl.replace("/**", "/<0;1>/*") + desc_tmpl = desc_tmpl.replace("*", str(address_index)) + + # Replace each with M if is_change is False, otherwise with N + def replace_m_n(match: re.Match[str]): + m, n = match.groups() + return m if not is_change else n + + desc_tmpl = re.sub(r'<([^;]+);([^>]+)>', replace_m_n, desc_tmpl) + + # Replace @i/a/b with the i-th element in keys_info, deriving the key appropriately + # to get a plain xpub + def replace_key_index(match): + index, step1, step2 = [int(x) for x in match.group(1).split('/')] + return derive_from_key_info(keys_info[index], [step1, step2]) + + desc_tmpl = re.sub(r'@(\d+/\d+/\d+)', replace_key_index, desc_tmpl) + + return desc_tmpl + + +class Tree: + """ + Recursive structure that represents a taptree, or one of its subtrees. + It can either contain a single descriptor template (if it's a tapleaf), or exactly two child Trees. + """ + + def __init__(self, content: Union[str, Tuple['Tree', 'Tree']]): + if isinstance(content, str): + self.script = content + self.left, self.right = (None, None) + else: + self.script = None + self.left, self.right = content + + @property + def is_leaf(self) -> bool: + return self.script is not None + + def __str__(self): + if self.is_leaf: + return self.script + else: + return f'{{{str(self.left)},{str(self.right)}}}' + + def placeholders(self) -> Iterator[Tuple[KeyPlaceholder, str]]: + """ + Generates an iterator over the placeholders contained in the scripts of the tree's leaf nodes. + + Yields: + Iterator[Tuple[KeyPlaceholder, str]]: An iterator over tuples containing a KeyPlaceholder and its associated script. + """ + + if self.is_leaf: + assert self.script is not None + for placeholder in extract_placeholders(self.script): + yield (placeholder, self.script) + else: + assert self.left is not None and self.right is not None + for placeholder, script in self.left.placeholders(): + yield (placeholder, script) + for placeholder, script in self.right.placeholders(): + yield (placeholder, script) + + def get_taptree_hash(self, keys_info: List[str], is_change: bool, address_index: int) -> bytes: + if self.is_leaf: + assert self.script is not None + leaf_desc = derive_plain_descriptor( + self.script, keys_info, is_change, address_index) + + s = BytesIO(leaf_desc.encode()) + desc: Miniscript = Miniscript.read_from( + s, taproot=True) + + return tapleaf_hash(desc.compile()) + + else: + assert self.left is not None and self.right is not None + left_h = self.left.get_taptree_hash( + keys_info, is_change, address_index) + right_h = self.left.get_taptree_hash( + keys_info, is_change, address_index) + if left_h <= right_h: + return tagged_hash("TapBranch", left_h + right_h) + else: + return tagged_hash("TapBranch", right_h + left_h) + + +class GenericParser(ABC): + def __init__(self, input: str): + self.input = input + self.index = 0 + self.length = len(input) + + @abstractmethod + def parse(self): + pass + + def parse_keyplaceholder(self): + if self.peek() == '@': + self.consume('@') + key_index = self.parse_num() + self.consume('/<') + num1 = self.parse_num() + self.consume(';') + num2 = self.parse_num() + self.consume('>/*') + return PlainKeyPlaceholder(key_index, num1, num2) + else: + raise Exception("Syntax error in key placeholder") + + def parse_tree(self) -> Tree: + if self.peek() == '{': + self.consume('{') + tree1 = self.parse_tree() + self.consume(',') + tree2 = self.parse_tree() + self.consume('}') + return Tree((tree1, tree2)) + else: + return Tree(self.parse_script()) + + def parse_script(self) -> str: + start = self.index + nesting = 0 + while self.index < self.length and (nesting > 0 or self.input[self.index] not in ('}', ',', ')')): + if self.input[self.index] == '(': + nesting += 1 + elif self.input[self.index] == ')': + nesting -= 1 + + self.index += 1 + return self.input[start:self.index] + + def parse_key_indexes(self) -> List[int]: + nums = [] + self.consume('@') + nums.append(self.parse_num()) + while self.peek() == ',': + self.consume(',@') + nums.append(self.parse_num()) + return nums + + def parse_num(self) -> int: + start = self.index + while self.index < self.length and self.input[self.index].isdigit(): + self.index += 1 + return int(self.input[start:self.index]) + + def consume(self, char: str): + if self.input[self.index:self.index+len(char)] == char: + self.index += len(char) + else: + raise Exception( + f"Syntax error: Expected '{char}'; rest: {self.input[self.index:]}") + + def peek(self) -> Optional[str]: + return self.input[self.index] if self.index < self.length else None + + +class DescriptorTemplate(ABC): + """ + Represents a generic descriptor template. + This is a base class for all specific descriptor templates. + """ + + @abstractmethod + def __init__(self): + pass + + @classmethod + @abstractmethod + def from_string(cls, input_string: str) -> 'DescriptorTemplate': + pass + + @abstractmethod + def placeholders(self) -> Iterator[Tuple[KeyPlaceholder, Optional[str]]]: + pass + + @staticmethod + def get_descriptor_type(input_string: str) -> Type['DescriptorTemplate']: + if input_string.startswith('tr('): + return TrDescriptorTemplate + elif input_string.startswith('wsh('): + return WshDescriptorTemplate + elif input_string.startswith('wpkh('): + return WpkhDescriptorTemplate + elif input_string.startswith('pkh('): + return PkhDescriptorTemplate + else: + raise ValueError("Unknown descriptor type") + + @classmethod + def from_string(cls, input_string: str) -> 'DescriptorTemplate': + descriptor_type = cls.get_descriptor_type(input_string) + return descriptor_type.from_string(input_string) + + def is_legacy(self) -> bool: + # TODO: incomplete, missing legacy sh(...) descriptors + return isinstance(self, PkhDescriptorTemplate) + + def is_segwit(self) -> bool: + # TODO: incomplete, missing sh(wsh(...)) and sh(wpkh(...)) descriptors + return isinstance(self, (WshDescriptorTemplate, WpkhDescriptorTemplate, TrDescriptorTemplate)) + + def is_taproot(self) -> bool: + return isinstance(self, TrDescriptorTemplate) + + +class TrDescriptorTemplate(DescriptorTemplate): + """ + Represents a descriptor template for a tr(KEY) or a tr(KEY,TREE). + This is minimal implementation in order to enable iterating over the placeholders, + and compile the corresponding leaf scripts. + """ + + def __init__(self, key: KeyPlaceholder, tree=Optional[Tree]): + self.key: KeyPlaceholder = key + self.tree: Optional[Tree] = tree + + @classmethod + def from_string(cls, input_string): + parser = cls.Parser(input_string.replace("/**", "/<0;1>/*")) + return parser.parse() + + class Parser(GenericParser): + def parse(self) -> 'TrDescriptorTemplate': + self.consume('tr(') + key = self.parse_keyplaceholder() + tree = None + if self.peek() == ',': + self.consume(',') + tree = self.parse_tree() + self.consume(')') + return TrDescriptorTemplate(key, tree) + + def placeholders(self) -> Iterator[Tuple[KeyPlaceholder, Optional[str]]]: + """ + Generates an iterator over the placeholders contained in the template and its tree, also + yielding the corresponding leaf script descriptor (or None for the keypath placeholder). + + Yields: + Iterator[Tuple[KeyPlaceholder, Optional[str]]]: An iterator over tuples containing a KeyPlaceholder and an optional associated script. + """ + + yield (self.key, None) + + if self.tree is not None: + for placeholder, script in self.tree.placeholders(): + yield (placeholder, script) + + def get_taptree_hash(self, keys_info: List[str], is_change: bool, address_index: int) -> bytes: + if self.tree is None: + raise ValueError("There is no taptree") + return self.tree.get_taptree_hash(keys_info, is_change, address_index) + + +class WshDescriptorTemplate(DescriptorTemplate): + """ + Represents a wsh(SCRIPT) descriptor template. + This is minimal implementation in order to enable iterating over the placeholders, + and compile the corresponding leaf scripts. + """ + + def __init__(self, inner_script: str): + self.inner_script = inner_script + + @classmethod + def from_string(cls, input_string): + parser = cls.Parser(input_string.replace("/**", "/<0;1>/*")) + return parser.parse() + + class Parser(GenericParser): + def parse(self) -> 'WshDescriptorTemplate': + if self.input.startswith('wsh('): + self.consume('wsh(') + inner_script = self.parse_script() + self.consume(')') + return WshDescriptorTemplate(inner_script) + else: + raise Exception( + "Syntax error: Input does not start with 'tr('") + + def placeholders(self) -> Iterator[Tuple[KeyPlaceholder, Optional[str]]]: + for placeholder in extract_placeholders(self.inner_script): + yield (placeholder, None) + + +class WpkhDescriptorTemplate(DescriptorTemplate): + """ + Represents a wpkh(KEY) descriptor template. + This is minimal implementation in order to enable iterating over the placeholders, + and compile the corresponding leaf scripts. + """ + + def __init__(self, key: KeyPlaceholder): + self.key = key + + @classmethod + def from_string(cls, input_string): + parser = cls.Parser(input_string.replace("/**", "/<0;1>/*")) + return parser.parse() + + class Parser(GenericParser): + def parse(self) -> 'WpkhDescriptorTemplate': + self.consume('wpkh(') + key = self.parse_keyplaceholder() + self.consume(')') + return WpkhDescriptorTemplate(key) + + def placeholders(self) -> Iterator[Tuple[KeyPlaceholder, Optional[str]]]: + yield (self.key, None) + + +class PkhDescriptorTemplate(DescriptorTemplate): + """ + Represents a pkh(KEY) descriptor template. + This is minimal implementation in order to enable iterating over the placeholders, + and compile the corresponding leaf scripts. + """ + + def __init__(self, key: KeyPlaceholder): + self.key = key + + @classmethod + def from_string(cls, input_string): + parser = cls.Parser(input_string.replace("/**", "/<0;1>/*")) + return parser.parse() + + class Parser(GenericParser): + def parse(self) -> 'PkhDescriptorTemplate': + self.consume('pkh(') + key = self.parse_keyplaceholder() + self.consume(')') + return PkhDescriptorTemplate(key) + + def placeholders(self) -> Iterator[Tuple[KeyPlaceholder, Optional[str]]]: + yield (self.key, None) diff --git a/tests_perf/test_perf_sign_psbt.py b/tests_perf/test_perf_sign_psbt.py index 8cbb8212..5cae87e4 100644 --- a/tests_perf/test_perf_sign_psbt.py +++ b/tests_perf/test_perf_sign_psbt.py @@ -1,12 +1,14 @@ from pathlib import Path +from hashlib import sha256 +import hmac import pytest from ledger_bitcoin import WalletPolicy, Client from ledger_bitcoin.psbt import PSBT -from test_utils import txmaker +from test_utils import SpeculosGlobals, txmaker tests_root: Path = Path(__file__).parent @@ -33,11 +35,40 @@ def make_psbt(wallet_policy: WalletPolicy, n_inputs: int, n_outputs: int) -> PSB return psbt +def run_test(client: Client, wallet_policy: WalletPolicy, n_inputs: int, speculos_globals: SpeculosGlobals, benchmark): + + wallet_hmac = None + if wallet_policy.name != "": + wallet_hmac = hmac.new( + speculos_globals.wallet_registration_key, wallet_policy.id, sha256).digest() + + psbt = make_psbt(wallet_policy, n_inputs, 2) + + # the following code might count repetitions incorrectly for more than 10 keys + assert len(wallet_policy.keys_info) <= 10 + + n_internal_placeholders = 0 + for key_index, key_info in enumerate(wallet_policy.keys_info): + if key_info.startswith(f"[{speculos_globals.master_key_fingerprint.hex()}"): + # this is incorrect if more than 10 keys, as key indexes are more than one digit + n_internal_placeholders += wallet_policy.descriptor_template.count( + f"@{key_index}") + + assert n_internal_placeholders >= 1 + + def sign_tx(): + result = client.sign_psbt(psbt, wallet_policy, wallet_hmac) + + assert len(result) == n_inputs * n_internal_placeholders + + benchmark.pedantic(sign_tx, rounds=1) + + @pytest.mark.parametrize("n_inputs", [1, 3, 10]) -def test_perf_sign_psbt_singlesig_pkh(client: Client, n_inputs: int, benchmark): +def test_perf_sign_psbt_singlesig_pkh(client: Client, n_inputs: int, speculos_globals: SpeculosGlobals, benchmark): # PSBT for a legacy 2-output spend (1 change address) - wallet = WalletPolicy( + wallet_policy = WalletPolicy( "", "pkh(@0/**)", [ @@ -45,21 +76,14 @@ def test_perf_sign_psbt_singlesig_pkh(client: Client, n_inputs: int, benchmark): ], ) - psbt = make_psbt(wallet, n_inputs, 2) - - def sign_tx(): - result = client.sign_psbt(psbt, wallet, None) - - assert len(result) == n_inputs - - benchmark.pedantic(sign_tx, rounds=1) + run_test(client, wallet_policy, n_inputs, speculos_globals, benchmark) @pytest.mark.parametrize("n_inputs", [1, 3, 10]) -def test_perf_sign_psbt_singlesig_wpkh(client: Client, n_inputs: int, benchmark): +def test_perf_sign_psbt_singlesig_wpkh(client: Client, n_inputs: int, speculos_globals: SpeculosGlobals, benchmark): # PSBT for a segwit 2-output spend (1 change address) - wallet = WalletPolicy( + wallet_policy = WalletPolicy( "", "wpkh(@0/**)", [ @@ -67,11 +91,69 @@ def test_perf_sign_psbt_singlesig_wpkh(client: Client, n_inputs: int, benchmark) ], ) - psbt = make_psbt(wallet, n_inputs, 2) + run_test(client, wallet_policy, n_inputs, speculos_globals, benchmark) - def sign_tx(): - result = client.sign_psbt(psbt, wallet, None) - assert len(result) == n_inputs +@pytest.mark.parametrize("n_inputs", [1, 3, 10]) +def test_perf_sign_psbt_singlesig_tr(client: Client, n_inputs: int, speculos_globals: SpeculosGlobals, benchmark): + # PSBT for a taproot 2-output spend (1 change address) + + wallet_policy = WalletPolicy( + name="", + descriptor_template="tr(@0/**)", + keys_info=[ + f"[f5acc2fd/86'/1'/0']tpubDDKYE6BREvDsSWMazgHoyQWiJwYaDDYPbCFjYxN3HFXJP5fokeiK4hwK5tTLBNEDBwrDXn8cQ4v9b2xdW62Xr5yxoQdMu1v6c7UDXYVH27U", + ], + ) + + run_test(client, wallet_policy, n_inputs, speculos_globals, benchmark) - benchmark.pedantic(sign_tx, rounds=1) + +@pytest.mark.parametrize("n_inputs", [1, 3, 10]) +def test_perf_sign_psbt_multisig2of3_wsh(client: Client, n_inputs: int, speculos_globals: SpeculosGlobals, benchmark): + wallet_policy = WalletPolicy( + name="Cold storage", + descriptor_template="wsh(sortedmulti(2,@0/**,@1/**,@2/**))", + keys_info=[ + "[f5acc2fd/48'/1'/0'/2']tpubDFAqEGNyad35aBCKUAXbQGDjdVhNueno5ZZVEn3sQbW5ci457gLR7HyTmHBg93oourBssgUxuWz1jX5uhc1qaqFo9VsybY1J5FuedLfm4dK", + "tpubDE7NQymr4AFtewpAsWtnreyq9ghkzQBXpCZjWLFVRAvnbf7vya2eMTvT2fPapNqL8SuVvLQdbUbMfWLVDCZKnsEBqp6UK93QEzL8Ck23AwF", + "tpubDF4kujkh5dAhC1pFgBToZybXdvJFXXGX4BWdDxWqP7EUpG8gxkfMQeDjGPDnTr9e4NrkFmDM1ocav3Jz6x79CRZbxGr9dzFokJLuvDDnyRh" + ], + ) + + run_test(client, wallet_policy, n_inputs, speculos_globals, benchmark) + + +@pytest.mark.parametrize("n_inputs", [1, 3, 10]) +def test_perf_sign_psbt_multisig3of5_wsh(client: Client, n_inputs: int, speculos_globals: SpeculosGlobals, benchmark): + wallet_policy = WalletPolicy( + name="Cold storage", + descriptor_template="wsh(sortedmulti(3,@0/**,@1/**,@2/**,@3/**,@4/**))", + keys_info=[ + "[f5acc2fd/48'/1'/0'/2']tpubDFAqEGNyad35aBCKUAXbQGDjdVhNueno5ZZVEn3sQbW5ci457gLR7HyTmHBg93oourBssgUxuWz1jX5uhc1qaqFo9VsybY1J5FuedLfm4dK", + "tpubDE7NQymr4AFtewpAsWtnreyq9ghkzQBXpCZjWLFVRAvnbf7vya2eMTvT2fPapNqL8SuVvLQdbUbMfWLVDCZKnsEBqp6UK93QEzL8Ck23AwF", + "tpubDF4kujkh5dAhC1pFgBToZybXdvJFXXGX4BWdDxWqP7EUpG8gxkfMQeDjGPDnTr9e4NrkFmDM1ocav3Jz6x79CRZbxGr9dzFokJLuvDDnyRh", + "tpubDD3ULTdBbyuMMMs8BCsJKgZgEnZjjbsbtV6ig3xtkQnaSc1gu9kNhmDDEW49HoLzDNA4y2TMqRzj4BugrrtcpXkjoHSoMVhJwfZLUFmv6yn", + "tpubDDyh1VAY2sHfGHE59muC5PWa3tosSTm62sNTDSmZUsx9TbyBdoVkZibYZuDoqJ8dJ6v6eYZz6SE1d6sDv45NgJFB1oqCLGzyiQBGyjexc7V" + ], + ) + + run_test(client, wallet_policy, n_inputs, speculos_globals, benchmark) + + +@pytest.mark.parametrize("n_inputs", [1, 3, 10]) +def test_perf_sign_psbt_tapminiscript_2paths(client: Client, n_inputs: int, speculos_globals: SpeculosGlobals, benchmark): + # A taproot miniscript policy where the two placeholders (in different spending paths) are internal + # The app signs for both spending paths. + wallet_policy = WalletPolicy( + name="Cold storage", + descriptor_template="wsh(or_d(multi(4,@0/<0;1>/*,@1/<0;1>/*,@2/<0;1>/*,@3/<0;1>/*),and_v(v:thresh(3,pkh(@0/<2;3>/*),a:pkh(@1/<2;3>/*),a:pkh(@2/<2;3>/*),a:pkh(@3/<2;3>/*)),older(65535))))", + keys_info=[ + "[f5acc2fd/48'/1'/0'/2']tpubDFAqEGNyad35aBCKUAXbQGDjdVhNueno5ZZVEn3sQbW5ci457gLR7HyTmHBg93oourBssgUxuWz1jX5uhc1qaqFo9VsybY1J5FuedLfm4dK", + "tpubDE7NQymr4AFtewpAsWtnreyq9ghkzQBXpCZjWLFVRAvnbf7vya2eMTvT2fPapNqL8SuVvLQdbUbMfWLVDCZKnsEBqp6UK93QEzL8Ck23AwF", + "tpubDF4kujkh5dAhC1pFgBToZybXdvJFXXGX4BWdDxWqP7EUpG8gxkfMQeDjGPDnTr9e4NrkFmDM1ocav3Jz6x79CRZbxGr9dzFokJLuvDDnyRh", + "tpubDD3ULTdBbyuMMMs8BCsJKgZgEnZjjbsbtV6ig3xtkQnaSc1gu9kNhmDDEW49HoLzDNA4y2TMqRzj4BugrrtcpXkjoHSoMVhJwfZLUFmv6yn", + ], + ) + + run_test(client, wallet_policy, n_inputs, speculos_globals, benchmark)