Skip to content

Commit

Permalink
Generalize txmaker to most supported policies
Browse files Browse the repository at this point in the history
  • Loading branch information
bigspider committed Jun 19, 2024
1 parent 861d525 commit 4c4b1d2
Show file tree
Hide file tree
Showing 4 changed files with 715 additions and 92 deletions.
50 changes: 50 additions & 0 deletions test_utils/taproot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# from portions of BIP-0341
# - https://github.com/bitcoin/bips/blob/b3701faef2bdb98a0d7ace4eedbeefa2da4c89ed/bip-0341.mediawiki
# Distributed under the BSD-3-Clause license

# fmt: off

# If you want to print values on an individual basis, use
# the pretty() function, e.g., print(pretty(foo)).
import hashlib
import struct


# This implementation can be sped up by storing the midstate after hashing
# tag_hash instead of rehashing it all the time.
def tagged_hash(tag: str, msg: bytes) -> bytes:
tag_hash = hashlib.sha256(tag.encode()).digest()
return hashlib.sha256(tag_hash + tag_hash + msg).digest()


def ser_compact_size(l):
r = b""
if l < 253:
r = struct.pack("B", l)
elif l < 0x10000:
r = struct.pack("<BH", 253, l)
elif l < 0x100000000:
r = struct.pack("<BI", 254, l)
else:
r = struct.pack("<BQ", 255, l)
return r

def deser_compact_size(f):
nit = struct.unpack("<B", f.read(1))[0]
if nit == 253:
nit = struct.unpack("<H", f.read(2))[0]
elif nit == 254:
nit = struct.unpack("<I", f.read(4))[0]
elif nit == 255:
nit = struct.unpack("<Q", f.read(8))[0]
return nit

def deser_string(f):
nit = deser_compact_size(f)
return f.read(nit)

def ser_string(s):
return ser_compact_size(len(s)) + s

def ser_script(s):
return ser_string(s)
225 changes: 151 additions & 74 deletions test_utils/txmaker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
# This module contains a utility function to create test PSBTs spending from an arbitrary wallet policy.
# It creates transactions spending non-existing UTXOs, and fills in the PSBTs with enough information to
# satisfy the requirements of the Ledger bitcoin app.
# It does not guarantee BIP-174 compliant PSBTs, as some fields that are not required in the
# Ledger bitcoin app might not be filled in.


from io import BytesIO
from random import randint

from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Union
from bitcoin_client.ledger_bitcoin import WalletPolicy, WalletType
from bitcoin_client.ledger_bitcoin.key import KeyOriginInfo, parse_path, get_taproot_output_key
from bitcoin_client.ledger_bitcoin.key import ExtendedKey, KeyOriginInfo, parse_path, get_taproot_output_key
from bitcoin_client.ledger_bitcoin.psbt import PSBT, PartiallySignedInput, PartiallySignedOutput
from bitcoin_client.ledger_bitcoin.tx import CScriptWitness, CTransaction, CTxIn, CTxInWitness, CTxOut, COutPoint, CTxWitness, uint256_from_str

Expand All @@ -11,6 +19,10 @@
from embit.bip32 import HDKey
from embit.bip39 import mnemonic_to_seed

from ledger_bitcoin.embit.descriptor.miniscript import Miniscript
from test_utils import bip0340
from test_utils.wallet_policy import DescriptorTemplate, KeyPlaceholder, PlainKeyPlaceholder, TrDescriptorTemplate, WshDescriptorTemplate, derive_plain_descriptor, tapleaf_hash


SPECULOS_SEED = "glory promote mansion idle axis finger extra february uncover one trip resource lawn turtle enact monster seven myth punch hobby comfort wild raise skin"
master_key = HDKey.from_seed(mnemonic_to_seed(SPECULOS_SEED))
Expand Down Expand Up @@ -40,6 +52,14 @@ def random_txid() -> bytes:
return random_bytes(32)


def random_p2tr() -> bytes:
"""Returns 32 random bytes. Not cryptographically secure."""
privkey = random_bytes(32)
pubkey = bip0340.point_mul(bip0340.G, int.from_bytes(privkey, 'big'))

return b'\x51\x20' + (pubkey[0]).to_bytes(32, 'big')


def getScriptPubkeyFromWallet(wallet: WalletPolicy, change: bool, address_index: int) -> Script:
descriptor_str = wallet.descriptor_template

Expand Down Expand Up @@ -116,27 +136,109 @@ def createFakeWalletTransaction(n_inputs: int, n_outputs: int, output_amount: in
return tx, selected_output_index, selected_output_change, selected_output_address_index


def createPsbt(wallet: WalletPolicy, input_amounts: List[int], output_amounts: List[int], output_is_change: List[bool], output_wallet: Optional[List[Optional[WalletPolicy]]] = None) -> PSBT:
if output_wallet is None:
output_wallet = [None] * len(output_amounts)
def get_placeholder_root_key(placeholder: KeyPlaceholder, keys_info: List[str]) -> Tuple[ExtendedKey, Optional[KeyOriginInfo]]:
if isinstance(placeholder, PlainKeyPlaceholder):
key_info = keys_info[placeholder.key_index]
key_origin_end_pos = key_info.find("]")
if key_origin_end_pos == -1:
xpub = key_info
root_key_origin = None
else:
xpub = key_info[key_origin_end_pos+1:]
root_key_origin = KeyOriginInfo.from_string(
key_info[1:key_origin_end_pos])
root_pubkey = ExtendedKey.deserialize(xpub)
else:
raise ValueError("Unsupported placeholder type")

return root_pubkey, root_key_origin

assert len(output_amounts) == len(output_is_change)
assert len(output_amounts) == len(output_wallet)
assert sum(output_amounts) <= sum(input_amounts)

# TODO: add support for wrapped segwit wallets
def fill_inout(wallet_policy: WalletPolicy, inout: Union[PartiallySignedInput, PartiallySignedOutput], is_change: bool, address_index: int):
desc_tmpl = DescriptorTemplate.from_string(
wallet_policy.descriptor_template)

if wallet.n_keys != 1:
raise NotImplementedError("Only 1-key wallets supported")
if wallet.version == WalletType.WALLET_POLICY_V1:
if wallet.descriptor_template not in ["pkh(@0)", "wpkh(@0)", "tr(@0)"]:
raise NotImplementedError("Unsupported policy type")
elif wallet.version == WalletType.WALLET_POLICY_V2:
if wallet.descriptor_template not in ["pkh(@0/**)", "wpkh(@0/**)", "tr(@0/**)"]:
raise NotImplementedError("Unsupported policy type")
if isinstance(desc_tmpl, TrDescriptorTemplate):
keypath_der_subpath = [
desc_tmpl.key.num1 if not is_change else desc_tmpl.key.num2,
address_index
]

keypath_pubkey, _ = get_placeholder_root_key(
desc_tmpl.key, wallet_policy.keys_info)

inout.tap_internal_key = keypath_pubkey.derive_pub_path(
keypath_der_subpath).pubkey[1:]

if desc_tmpl.tree is not None:
inout.tap_merkle_root = desc_tmpl.get_taptree_hash(
wallet_policy.keys_info, is_change, address_index)

for placeholder, tapleaf_desc in desc_tmpl.placeholders():
root_pubkey, root_pubkey_origin = get_placeholder_root_key(
placeholder, wallet_policy.keys_info)

placeholder_der_subpath = [
placeholder.num1 if not is_change else placeholder.num2,
address_index
]

leaf_script = None
if tapleaf_desc is not None:
leaf_desc = derive_plain_descriptor(
tapleaf_desc, wallet_policy.keys_info, is_change, address_index)
s = BytesIO(leaf_desc.encode())
desc: Miniscript = Miniscript.read_from(s, taproot=True)
leaf_script = desc.compile()

derived_pubkey = root_pubkey.derive_pub_path(
placeholder_der_subpath)

if root_pubkey_origin is not None:
derived_key_origin = KeyOriginInfo(
root_pubkey_origin.fingerprint, root_pubkey_origin.path + placeholder_der_subpath)

leaf_hashes = []
if leaf_script is not None:
# In BIP-388 compliant wallet policies, there will be only one tapleaf with a given key
leaf_hashes = [tapleaf_hash(leaf_script)]

inout.tap_bip32_paths[derived_pubkey.pubkey[1:]] = (
leaf_hashes, derived_key_origin)
else:
raise ValueError(
f"Unknown wallet policy version: {wallet.version}")
if isinstance(desc_tmpl, WshDescriptorTemplate):
# add witnessScript
desc_str = derive_plain_descriptor(
wallet_policy.descriptor_template, wallet_policy.keys_info, is_change, address_index)
s = BytesIO(desc_str.encode())
desc: Descriptor = Descriptor.read_from(s)
inout.witness_script = desc.witness_script().data

for placeholder, _ in desc_tmpl.placeholders():
root_pubkey, root_pubkey_origin = get_placeholder_root_key(
placeholder, wallet_policy.keys_info)

placeholder_der_subpath = [
placeholder.num1 if not is_change else placeholder.num2,
address_index
]

derived_pubkey = root_pubkey.derive_pub_path(
placeholder_der_subpath)

if root_pubkey_origin is not None:
derived_key_origin = KeyOriginInfo(
root_pubkey_origin.fingerprint, root_pubkey_origin.path + placeholder_der_subpath)

inout.hd_keypaths[derived_pubkey.pubkey] = derived_key_origin


def createPsbt(wallet_policy: WalletPolicy, input_amounts: List[int], output_amounts: List[int], output_is_change: List[bool]) -> PSBT:
assert output_is_change.count(
True) <= 1, "At most one change output is supported"

assert len(output_amounts) == len(output_is_change)
assert sum(output_amounts) <= sum(input_amounts)

vin: List[CTxIn] = [CTxIn() for _ in input_amounts]
vout: List[CTxOut] = [CTxOut() for _ in output_amounts]
Expand All @@ -150,7 +252,7 @@ def createPsbt(wallet: WalletPolicy, input_amounts: List[int], output_amounts: L
n_inputs = randint(1, 10)
n_outputs = randint(1, 10)
prevout, idx, is_change, addr_idx = createFakeWalletTransaction(
n_inputs, n_outputs, prevout_amount, wallet)
n_inputs, n_outputs, prevout_amount, wallet_policy)
prevouts.append(prevout)
prevout_ns.append(idx)
prevout_path_change.append(is_change)
Expand All @@ -168,68 +270,43 @@ def createPsbt(wallet: WalletPolicy, input_amounts: List[int], output_amounts: L
tx.vout = vout
tx.wit = CTxWitness()

change_address_index = randint(0, 10_000)
for i, output_amount in enumerate(output_amounts):
tx.vout[i].nValue = output_amount
if output_is_change[i]:
script = getScriptPubkeyFromWallet(
wallet_policy, output_is_change[i], change_address_index)

tx.vout[i].scriptPubKey = script.data
else:
# a random P2TR output
tx.vout[i].scriptPubKey = random_p2tr()

psbt.inputs = [PartiallySignedInput(0) for _ in input_amounts]
psbt.outputs = [PartiallySignedOutput(0) for _ in output_amounts]

# simplification; good enough for the scripts we support now, but will need more work
is_legacy = wallet.descriptor_template.startswith("pkh(")
is_segwitv0 = wallet.descriptor_template.startswith(
"wpkh(") or wallet.descriptor_template.startswith("sh(wpkh(")
is_taproot = wallet.descriptor_template.startswith("tr(")

key_origin = wallet.keys_info[0][1:wallet.keys_info[0].index("]")]
desc_tmpl = DescriptorTemplate.from_string(
wallet_policy.descriptor_template)

for i in range(len(input_amounts)):
if is_legacy or is_segwitv0:
# add non-witness UTXO
psbt.inputs[i].non_witness_utxo = prevouts[i]
if is_segwitv0 or is_taproot:
for input_index, input in enumerate(psbt.inputs):
if desc_tmpl.is_segwit():
# add witness UTXO
psbt.inputs[i].witness_utxo = prevouts[i].vout[prevout_ns[i]]

path_str = f"m{key_origin[8:]}/{prevout_path_change[i]}/{prevout_path_addr_idx[i]}"
path = parse_path(path_str)
input_key: bytes = master_key.derive(path_str).key.sec()

assert len(input_key) == 33

# add key and path info
if is_legacy or is_segwitv0:
psbt.inputs[i].hd_keypaths[input_key] = KeyOriginInfo(
master_key_fpr, path)
elif is_taproot:
internal_key = input_key[1:]
psbt.inputs[i].tap_bip32_paths[internal_key] = (
{}, KeyOriginInfo(master_key_fpr, path))
else:
raise RuntimeError("Unexpected state: unknown transaction type")
input.witness_utxo = prevouts[input_index].vout[prevout_ns[input_index]]

for i, output_amount in enumerate(output_amounts):
wallet_i = output_wallet[i]
if output_is_change[i] or wallet_i is None:
script = getScriptPubkeyFromWallet(wallet, output_is_change[i], i)
else:
script = getScriptPubkeyFromWallet(wallet_i, 0, i)
if desc_tmpl.is_legacy() or (desc_tmpl.is_segwit() and not desc_tmpl.is_taproot()):
# add non_witness_utxo for legacy or segwitv0
input.non_witness_utxo = prevouts[input_index]

tx.vout[i].scriptPubKey = script.data
tx.vout[i].nValue = output_amount
is_change = bool(prevout_path_change[input_index])
address_index = prevout_path_addr_idx[input_index]

if output_is_change[i]:
path_str = f"m{key_origin[8:]}/1/{i}"
path = parse_path(path_str)
output_key: bytes = master_key.derive(path_str).key.sec()

# add key and path information for change output
if is_legacy or is_segwitv0:
psbt.outputs[i].hd_keypaths[output_key] = KeyOriginInfo(
master_key_fpr, path)
elif is_taproot:
internal_key = output_key[1:]
psbt.outputs[i].tap_bip32_paths[internal_key] = (
{}, KeyOriginInfo(master_key_fpr, path))

psbt.outputs[i].tap_bip32_paths[internal_key] = (
{}, KeyOriginInfo(master_key_fpr, path))
fill_inout(wallet_policy, input, is_change, address_index)

# only for the change output, we need to do the same
for output_index, output in enumerate(psbt.outputs):
if output_is_change[output_index]:
fill_inout(wallet_policy, output, is_change=True,
address_index=change_address_index)

psbt.tx = tx

Expand Down
Loading

0 comments on commit 4c4b1d2

Please sign in to comment.