Skip to content

Commit

Permalink
SFT-1708: taproot address verification is slow but working on device
Browse files Browse the repository at this point in the history
  • Loading branch information
mjg-foundation committed Oct 14, 2023
1 parent f192540 commit 0fc1ca7
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 113 deletions.
10 changes: 0 additions & 10 deletions extmod/foundation-rust/include/foundation.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,6 @@ void foundation_secp256k1_schnorr_sign(const uint8_t (*data)[32],
const uint8_t (*secret_key)[32],
uint8_t (*signature)[64]);

/**
* Adds a tweak to an x-only public key
* - `x_only_pubkey` is the public key
* - `tweak` is the tweak value
* - `tweaked_pubkey` is the result of the tweak
*/
void foundation_secp256k1_add_tweak(const uint8_t (*x_only_pubkey)[32],
const uint8_t (*tweak)[32],
uint8_t (*tweaked_pubkey)[32]);

/**
* Receive a Uniform Resource part.
*
Expand Down
20 changes: 1 addition & 19 deletions extmod/foundation-rust/src/secp256k1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use once_cell::sync::Lazy;
use secp256k1::{
ffi::types::AlignedType, AllPreallocated, KeyPair, Message, Secp256k1, constants::SCHNORR_PUBLIC_KEY_SIZE, XOnlyPublicKey, Scalar
ffi::types::AlignedType, AllPreallocated, KeyPair, Message, Secp256k1,
};

/// cbindgen:ignore
Expand Down Expand Up @@ -37,24 +37,6 @@ pub extern "C" fn secp256k1_sign_schnorr(
signature.copy_from_slice(sig.as_ref());
}

/// Adds a tweak to an x-only public key
///
/// - `x_only_pubkey` is the public key
/// - `tweak` is the tweak value
/// - `tweaked_pubkey` is the result of the tweak
#[export_name = "foundation_secp256k1_add_tweak"]
pub extern "C" fn secp256k1_add_tweak(
x_only_pubkey: &[u8; SCHNORR_PUBLIC_KEY_SIZE],
tweak: &[u8; SCHNORR_PUBLIC_KEY_SIZE],
tweaked_pubkey: &mut [u8; SCHNORR_PUBLIC_KEY_SIZE],
) {
let pk_struct = XOnlyPublicKey::from_slice(x_only_pubkey).unwrap();
let tweak_struct = Scalar::from_be_bytes(*tweak).unwrap();
let (tweaked_pubkey_struct, _) = pk_struct.add_tweak(&PRE_ALLOCATED_CTX, &tweak_struct).unwrap();
let tweaked_pubkey_bytes = tweaked_pubkey_struct.serialize();
tweaked_pubkey.copy_from_slice(&tweaked_pubkey_bytes)
}

#[cfg(target_arch = "arm")]
fn rng() -> crate::rand::PassportRng {
crate::rand::PassportRng
Expand Down
28 changes: 0 additions & 28 deletions extmod/foundation/modfoundation-secp56k1.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,36 +33,8 @@ STATIC mp_obj_t mod_foundation_secp256k1_sign_schnorr(mp_obj_t data_obj,
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_foundation_secp256k1_sign_schnorr_obj,
mod_foundation_secp256k1_sign_schnorr);

/// def add_tweak(x_only_public_key, tweak) -> tweaked_public_key:
/// """
// """
STATIC mp_obj_t mod_foundation_secp256k1_add_tweak(mp_obj_t x_only_public_key_obj, mp_obj_t tweak_obj)
{
mp_buffer_info_t x_only_public_key;
mp_buffer_info_t tweak;
uint8_t tweaked_public_key[32];

mp_get_buffer_raise(x_only_public_key_obj, &x_only_public_key, MP_BUFFER_READ);
mp_get_buffer_raise(tweak_obj, &tweak, MP_BUFFER_READ);

if (x_only_public_key.len != 32) {
mp_raise_msg(&mp_type_ValueError, MP_ERROR_TEXT("x_only_public_key should be 32 bytes"));
}

if (tweak.len != 32) {
mp_raise_msg(&mp_type_ValueError, MP_ERROR_TEXT("tweak should be 32 bytes"));
}

foundation_secp256k1_add_tweak(x_only_public_key.buf, tweak.buf, &tweaked_public_key);

return mp_obj_new_bytes(tweaked_public_key, sizeof(tweaked_public_key));
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_foundation_secp256k1_add_tweak_obj,
mod_foundation_secp256k1_add_tweak);

STATIC const mp_rom_map_elem_t mod_foundation_secp256k1_globals_table[] = {
{ MP_ROM_QSTR(MP_QSTR_schnorr_sign), MP_ROM_PTR(&mod_foundation_secp256k1_sign_schnorr_obj) },
{ MP_ROM_QSTR(MP_QSTR_add_tweak), MP_ROM_PTR(&mod_foundation_secp256k1_add_tweak_obj) },
};
STATIC MP_DEFINE_CONST_DICT(mod_foundation_secp256k1_globals, mod_foundation_secp256k1_globals_table);

Expand Down
1 change: 1 addition & 0 deletions ports/stm32/boards/Passport/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
'sflash.py',
'stash.py',
'stat.py',
'taproot.py',
't9.py',
'utils.py',
'version.py',
Expand Down
14 changes: 5 additions & 9 deletions ports/stm32/boards/Passport/modules/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from serializations import hash160, ser_compact_size
from ucollections import namedtuple
from opcodes import OP_CHECKMULTISIG
from utils import hash_tap_tweak
from foundation import secp256k1
from taproot import output_script

# See SLIP 132 <https://github.com/satoshilabs/slips/blob/master/slip-0132.md>
# for background on these version bytes. Not to be confused with SLIP-32 which involves Bech32.
Expand Down Expand Up @@ -119,16 +118,13 @@ def address(cls, node, addr_fmt):
if addr_fmt & AFC_BECH32:
# bech32 encoded segwit p2pkh
return tcc.codecs.bech32_encode(cls.bech32_hrp, 0, raw)
elif addr_fmt & AFC_BECH32M:

elif addr_fmt & AFC_BECH32M:
pubkey = node.public_key()
internal_key = pubkey[1::]
# internal_key = int.from_bytes(pubkey[1::], "big")
tweak = hash_tap_tweak(internal_key)
# tweak = int.from_bytes(hash_tap_tweak(internal_key), "big")
# TODO: expose point_add and point_multiply
# output_key = internal_key + trezorcrypto.secp256k1.multiply(tweak, G)
output_key = secp256k1.add_tweak(internal_key, tweak)
print("internal_key: {}".format(b2a_hex(internal_key)))
output_key = output_script(internal_key, None)[2::]
print("output_key: {}".format(b2a_hex(output_key)))
return tcc.codecs.bech32_encode(cls.bech32_hrp, 1, output_key)

# see bip-141, "P2WPKH nested in BIP16 P2SH" section
Expand Down
8 changes: 6 additions & 2 deletions ports/stm32/boards/Passport/modules/menus.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,17 @@ def multisig_menu():
from pages import MultisigPolicySettingPage, ErrorPage
from flows import ImportMultisigWalletFromMicroSDFlow, ImportMultisigWalletFromQRFlow
from utils import escape_text
from common import settings

xfp = settings.get('xfp')
multisigs = MultisigWallet.get_by_xfp(xfp)

if not MultisigWallet.exists():
if len(multisigs) == 0:
items = [{'icon': 'ICON_TWO_KEYS', 'label': '(None setup yet)', 'page': ErrorPage,
'args': {'text': "You haven't imported any multisig wallets yet."}}]
else:
items = []
for ms in MultisigWallet.get_all():
for ms in multisigs:
nice_name = '%d/%d: %s' % (ms.M, ms.N, escape_text(ms.name))
items.append({
'icon': 'ICON_TWO_KEYS',
Expand Down
163 changes: 163 additions & 0 deletions ports/stm32/boards/Passport/modules/taproot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# TODO: add BSD-2-Clause license for BIP340 and BSD-3-Clase for BIP341

import ubinascii
import utime

# Set DEBUG to True to get a detailed debug output including
# intermediate values during key generation, signing, and
# verification. This is implemented via calls to the
# debug_print_vars() function.
#
# If you want to print values on an individual basis, use
# the pretty() function, e.g., print(pretty(foo)).
DEBUG = False

p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
SECP256K1_ORDER = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141

# Points are tuples of X and Y coordinates and the point at infinity is
# represented by the None keyword.
G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798,
0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8)

# Point = Tuple[int, int]

# This implementation can be sped up by storing the midstate after hashing
# tag_hash instead of rehashing it all the time.
# def tagged_hash(tag, msg) -> bytes:
# tag_hash = uhashlib.sha256(tag.encode()).digest()
# return uhashlib.sha256(tag_hash + tag_hash + msg).digest()


# TODO: optimize for TapTweak
def hash_tap_tweak(data):
from serializations import sha256
from public_constants import TAP_TWEAK_SHA256
from ubinascii import unhexlify as a2b_hex

tag_hash = a2b_hex(TAP_TWEAK_SHA256)
return sha256(tag_hash + tag_hash + data)


def is_infinite(P):
return P is None


def x(P):
assert not is_infinite(P)
return P[0]


def y(P):
assert not is_infinite(P)
return P[1]


def point_add(P1, P2):
if P1 is None:
return P2
if P2 is None:
return P1
if (x(P1) == x(P2)) and (y(P1) != y(P2)):
return None
if P1 == P2:
lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p
else:
lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p
x3 = (lam * lam - x(P1) - x(P2)) % p
return (x3, (lam * (x(P1) - x3) - y(P1)) % p)


def point_mul(P, n):
R = None
for i in range(256):
if (n >> i) & 1:
R = point_add(R, P)
P = point_add(P, P)
return R


def bytes_from_int(x):
return x.to_bytes(32, "big")


def bytes_from_point(P):
return bytes_from_int(x(P))


def xor_bytes(b0, b1):
return bytes(x ^ y for (x, y) in zip(b0, b1))


def lift_x(x):
if x >= p:
return None
y_sq = (pow(x, 3, p) + 7) % p
y = pow(y_sq, (p + 1) // 4, p)
if pow(y, 2, p) != y_sq:
return None
return (x, y if y & 1 == 0 else p - y)


def int_from_bytes(b):
return int.from_bytes(b, "big")


def has_even_y(P):
assert not is_infinite(P)
return y(P) % 2 == 0


def tweak_internal_key(internal_key, h):
print("1 {}".format(utime.ticks_ms()))
t = int_from_bytes(hash_tap_tweak(internal_key + h))
print("2 {}".format(utime.ticks_ms()))
if t >= SECP256K1_ORDER:
raise ValueError
print("3 {}".format(utime.ticks_ms()))
P = lift_x(int_from_bytes(internal_key))
print("4 {}".format(utime.ticks_ms()))
if P is None:
raise ValueError
print("5 {}".format(utime.ticks_ms()))
Q = point_add(P, point_mul(G, t))
print("6 {}".format(utime.ticks_ms()))
return 0 if has_even_y(Q) else 1, bytes_from_int(x(Q))


# def taproot_tweak_seckey(seckey0, h):
# seckey0 = int_from_bytes(seckey0)
# P = point_mul(G, seckey0)
# seckey = seckey0 if has_even_y(P) else SECP256K1_ORDER - seckey0
# t = int_from_bytes(tagged_hash("TapTweak", bytes_from_int(x(P)) + h))
# if t >= SECP256K1_ORDER:
# raise ValueError
# return bytes_from_int((seckey + t) % SECP256K1_ORDER)
#
#
# def taproot_tree_helper(script_tree):
# if isinstance(script_tree, tuple):
# leaf_version, script = script_tree
# h = tagged_hash("TapLeaf", bytes([leaf_version]) + ser_script(script))
# return ([((leaf_version, script), bytes())], h)
# left, left_h = taproot_tree_helper(script_tree[0])
# right, right_h = taproot_tree_helper(script_tree[1])
# ret = [(l, c + right_h) for l, c in left] + [(l, c + left_h) for l, c in right]
# if right_h < left_h:
# left_h, right_h = right_h, left_h
# return (ret, tagged_hash("TapBranch", left_h + right_h))
#
#
def output_script(internal_pubkey, script_tree):
"""Given a internal public key and a tree of scripts, compute the output script.
script_tree is either:
- a (leaf_version, script) tuple (leaf_version is 0xc0 for [[bip-0342.mediawiki|BIP342]] scripts)
- a list of two elements, each with the same structure as script_tree itself
- None
"""
if script_tree is None:
h = bytes()
else:
_, h = taproot_tree_helper(script_tree)
_, output_pubkey = tweak_internal_key(internal_pubkey, h)
return bytes([0x51, 0x20]) + output_pubkey
Original file line number Diff line number Diff line change
Expand Up @@ -25,41 +25,41 @@ async def search_for_address_task(
from errors import Error
from uasyncio import sleep_ms

try:
with stash.SensitiveValues() as sv:
if multisig_wallet:
# NOTE: Can't easily reverse order here, so this is slightly less efficient
for (curr_idx, paths, curr_address, script) in multisig_wallet.yield_addresses(
start_address_idx,
max_to_check,
change_idx=1 if is_change else 0):
# print('Multisig: curr_idx={}: paths={} curr_address = {}'.format(curr_idx, paths, curr_address))
# try:
with stash.SensitiveValues() as sv:
if multisig_wallet:
# NOTE: Can't easily reverse order here, so this is slightly less efficient
for (curr_idx, paths, curr_address, script) in multisig_wallet.yield_addresses(
start_address_idx,
max_to_check,
change_idx=1 if is_change else 0):
# print('Multisig: curr_idx={}: paths={} curr_address = {}'.format(curr_idx, paths, curr_address))

if curr_address == address:
# NOTE: Paths are the full paths of the addresses of each signer
await on_done(curr_idx, paths, None)
return
await sleep_ms(1)
if curr_address == address:
# NOTE: Paths are the full paths of the addresses of each signer
await on_done(curr_idx, paths, None)
return
await sleep_ms(1)

else:
r = range(start_address_idx, start_address_idx + max_to_check)
if reverse:
r = reversed(r)
else:
r = range(start_address_idx, start_address_idx + max_to_check)
if reverse:
r = reversed(r)

for curr_idx in r:
addr_path = '{}/{}/{}'.format(path, is_change, curr_idx) # Zero for non-change address
# print('Singlesig: addr_path={}'.format(addr_path))
node = sv.derive_path(addr_path)
curr_address = sv.chain.address(node, addr_type)
# print(' curr_idx={}: path={} addr_type={} curr_address = {}'.format(curr_idx, addr_path,
# addr_type, curr_address))
if curr_address == address:
await on_done(curr_idx, addr_path, None)
return
await sleep_ms(1)
for curr_idx in r:
addr_path = '{}/{}/{}'.format(path, is_change, curr_idx) # Zero for non-change address
# print('Singlesig: addr_path={}'.format(addr_path))
node = sv.derive_path(addr_path)
curr_address = sv.chain.address(node, addr_type)
# print(' curr_idx={}: path={} addr_type={} curr_address = {}'.format(curr_idx, addr_path,
# addr_type, curr_address))
if curr_address == address:
await on_done(curr_idx, addr_path, None)
return
await sleep_ms(1)

await on_done(-1, None, Error.ADDRESS_NOT_FOUND)
except Exception as e:
# print('EXCEPTION: e={}'.format(e))
# Any address handling exceptions result in no address found
await on_done(-1, None, Error.ADDRESS_NOT_FOUND)
await on_done(-1, None, Error.ADDRESS_NOT_FOUND)
# except Exception as e:
# # print('EXCEPTION: e={}'.format(e))
# # Any address handling exceptions result in no address found
# await on_done(-1, None, Error.ADDRESS_NOT_FOUND)
8 changes: 0 additions & 8 deletions ports/stm32/boards/Passport/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,12 +1442,4 @@ def escape_text(text):
return text.replace("#", "##")


def hash_tap_tweak(data):
from serializations import sha256
from public_constants import TAP_TWEAK_SHA256

tag_hash = a2b_hex(TAP_TWEAK_SHA256)
return sha256(tag_hash + tag_hash + data)


# EOF
Loading

0 comments on commit 0fc1ca7

Please sign in to comment.