621 lines
19 KiB
Python
621 lines
19 KiB
Python
# (c) Copyright 2020 by Stepan Snigirev, see <https://github.com/diybitcoinhardware/embit/blob/master/LICENSE>
|
|
#
|
|
# Changes (c) Copyright 2023 by Coinkite Inc. This file is covered by license found in COPYING-CC.
|
|
#
|
|
import ngu, chains, ustruct, stash
|
|
from io import BytesIO
|
|
from public_constants import MAX_PATH_DEPTH
|
|
from binascii import unhexlify as a2b_hex
|
|
from binascii import hexlify as b2a_hex
|
|
from utils import keypath_to_str, str_to_keypath, swab32, xfp2str
|
|
from serializations import ser_compact_size
|
|
|
|
|
|
WILDCARD = "*"
|
|
PROVABLY_UNSPENDABLE = b'\x02P\x92\x9bt\xc1\xa0IT\xb7\x8bK`5\xe9z^\x07\x8aZ\x0f(\xec\x96\xd5G\xbf\xee\x9a\xce\x80:\xc0'
|
|
|
|
# sha256(b"MuSig2MuSig2MuSig2")
|
|
MUSIG_CHAIN_CODE = b'\x86\x80\x87\xca\x02\xa6\xf9t\xc4Y\x89$\xc3kWv-2\xcbEqqg\xe3\x00b,qg\xe3\x89e'
|
|
|
|
INPUT_CHARSET = "0123456789()[],'/*abcdefgh@:$%{}IJKLMNOPQRSTUVWXYZ&+-.;<=>?!^_|~ijklmnopqrstuvwxyzABCDEFGH`#\"\\ "
|
|
CHECKSUM_CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
|
|
|
|
|
|
def polymod(c, val):
|
|
c0 = c >> 35
|
|
c = ((c & 0x7ffffffff) << 5) ^ val
|
|
if (c0 & 1):
|
|
c ^= 0xf5dee51989
|
|
if (c0 & 2):
|
|
c ^= 0xa9fdca3312
|
|
if (c0 & 4):
|
|
c ^= 0x1bab10e32d
|
|
if (c0 & 8):
|
|
c ^= 0x3706b1677a
|
|
if (c0 & 16):
|
|
c ^= 0x644d626ffd
|
|
|
|
return c
|
|
|
|
def descriptor_checksum(desc):
|
|
c = 1
|
|
cls = 0
|
|
clscount = 0
|
|
for ch in desc:
|
|
pos = INPUT_CHARSET.find(ch)
|
|
if pos == -1:
|
|
raise ValueError(ch)
|
|
|
|
c = polymod(c, pos & 31)
|
|
cls = cls * 3 + (pos >> 5)
|
|
clscount += 1
|
|
if clscount == 3:
|
|
c = polymod(c, cls)
|
|
cls = 0
|
|
clscount = 0
|
|
|
|
if clscount > 0:
|
|
c = polymod(c, cls)
|
|
for j in range(0, 8):
|
|
c = polymod(c, 0)
|
|
c ^= 1
|
|
|
|
rv = ''
|
|
for j in range(0, 8):
|
|
rv += CHECKSUM_CHARSET[(c >> (5 * (7 - j))) & 31]
|
|
|
|
return rv
|
|
|
|
def append_checksum(desc):
|
|
return desc + "#" + descriptor_checksum(desc)
|
|
|
|
|
|
def parse_desc_str(string):
|
|
"""Remove comments, empty lines and strip line. Produce single line string"""
|
|
res = ""
|
|
for l in string.split("\n"):
|
|
strip_l = l.strip()
|
|
if not strip_l:
|
|
continue
|
|
if strip_l.startswith("#"):
|
|
continue
|
|
res += strip_l
|
|
return res
|
|
|
|
|
|
def read_until(s, chars=b",)(#"):
|
|
res = b""
|
|
while True:
|
|
chunk = s.read(1)
|
|
if len(chunk) == 0:
|
|
return res, None
|
|
if chunk in chars:
|
|
return res, chunk
|
|
res += chunk
|
|
|
|
|
|
def musig_synthetic_node(agg_pk_bytes):
|
|
assert len(agg_pk_bytes) == 33 # need non-xonly pubkey
|
|
node = ngu.hdnode.HDNode()
|
|
node.from_chaincode_pubkey(MUSIG_CHAIN_CODE, agg_pk_bytes)
|
|
return node
|
|
|
|
|
|
class KeyOriginInfo:
|
|
def __init__(self, fingerprint: bytes, derivation: list, cc_fp=None):
|
|
self.fingerprint = fingerprint
|
|
self.derivation = derivation
|
|
self._cc_fp = cc_fp
|
|
|
|
def __eq__(self, other):
|
|
return self.psbt_derivation() == other.psbt_derivation()
|
|
|
|
def __hash__(self):
|
|
return hash(tuple(self.psbt_derivation()))
|
|
|
|
@property
|
|
def cc_fp(self):
|
|
if self._cc_fp is None:
|
|
self._cc_fp = ustruct.unpack('<I', self.fingerprint)[0]
|
|
return self._cc_fp
|
|
|
|
def str_derivation(self):
|
|
return keypath_to_str(self.derivation, prefix='m/', skip=0)
|
|
|
|
def psbt_derivation(self):
|
|
res = [self.cc_fp]
|
|
for i in self.derivation:
|
|
res.append(i)
|
|
return res
|
|
|
|
@classmethod
|
|
def from_string(cls, s: str):
|
|
arr = s.split("/")
|
|
xfp = a2b_hex(arr[0])
|
|
assert len(xfp) == 4
|
|
arr[0] = "m"
|
|
path = "/".join(arr)
|
|
derivation = str_to_keypath(xfp, path)[1:] # ignoring xfp here, already stored
|
|
assert len(derivation) <= MAX_PATH_DEPTH, "origin too deep"
|
|
return cls(xfp, derivation)
|
|
|
|
def __str__(self):
|
|
rv = "%s" % b2a_hex(self.fingerprint).decode()
|
|
if self.derivation:
|
|
rv += "/%s" % keypath_to_str(self.derivation, prefix='', skip=0)
|
|
return rv
|
|
|
|
|
|
class KeyDerivationInfo:
|
|
|
|
def __init__(self, indexes=None):
|
|
self.indexes = indexes
|
|
if self.indexes is None:
|
|
self.indexes = ((0, 1), WILDCARD)
|
|
self.multi_path_index = 0
|
|
else:
|
|
self.multi_path_index = None
|
|
|
|
def __hash__(self):
|
|
return hash(self.indexes)
|
|
|
|
@staticmethod
|
|
def not_hardened(x):
|
|
assert (b"'" not in x) and (b"h" not in x), "Cannot use hardened sub derivation path"
|
|
|
|
def get_ext_int(self):
|
|
return self.indexes[self.multi_path_index]
|
|
|
|
@classmethod
|
|
def parse(cls, s):
|
|
err = "Malformed key derivation"
|
|
multi_i = None
|
|
idxs = []
|
|
while True:
|
|
got, char = read_until(s, b"<,)/")
|
|
if char == b"<":
|
|
assert multi_i is None, "too many multipaths"
|
|
ext_num, char = read_until(s, b";")
|
|
assert char, err
|
|
cls.not_hardened(ext_num)
|
|
int_num, char = read_until(s, b">")
|
|
assert char, err
|
|
assert b";" not in int_num, "Solved cardinality > 2"
|
|
cls.not_hardened(int_num)
|
|
|
|
assert int_num != ext_num # cannot be the same
|
|
multi_i = len(idxs)
|
|
idxs.append((int(ext_num.decode()), int(int_num.decode())))
|
|
|
|
else:
|
|
# char in "/),"
|
|
if got == b"*":
|
|
# every derivation has to end with wildcard (only ranged keys allowed)
|
|
idxs.append(WILDCARD)
|
|
break
|
|
elif got:
|
|
cls.not_hardened(got)
|
|
idxs.append(int(got.decode()))
|
|
|
|
# comma and parenthesis not allowed in subderivation, marker of the end
|
|
if char in b",)": break
|
|
|
|
assert idxs[-1] == WILDCARD, "All keys must be ranged"
|
|
if idxs == [0, WILDCARD]:
|
|
# normalize and instead save as <0;1> as change derivation was not provided
|
|
obj = cls()
|
|
else:
|
|
|
|
assert multi_i is not None, "need multipath"
|
|
assert len(idxs[multi_i]) == 2, "wrong multipath"
|
|
|
|
obj = cls(tuple(idxs))
|
|
obj.multi_path_index = multi_i
|
|
|
|
return obj
|
|
|
|
def to_string(self, external=True, internal=True):
|
|
res = []
|
|
for i in self.indexes:
|
|
if isinstance(i, tuple):
|
|
if internal is True and external is False:
|
|
i = str(i[1])
|
|
elif internal is False and external is True:
|
|
i = str(i[0])
|
|
else:
|
|
i = "<%d;%d>" % (i[0], i[1])
|
|
else:
|
|
i = str(i)
|
|
res.append(i)
|
|
return "/".join(res)
|
|
|
|
def der_index(self, idx, change=False):
|
|
if isinstance(idx, list):
|
|
for i in idx:
|
|
mp_i = self.multi_path_index or 0
|
|
if i in self.indexes[mp_i]:
|
|
idx = i
|
|
break
|
|
else:
|
|
assert False
|
|
|
|
elif idx is None:
|
|
# derive according to key subderivation if any
|
|
if self is None:
|
|
idx = 1 if change else 0
|
|
else:
|
|
if self.multi_path_index is not None:
|
|
ext, inter = self.indexes[self.multi_path_index]
|
|
idx = inter if change else ext
|
|
|
|
return idx
|
|
|
|
|
|
class ExtendedKey:
|
|
def __init__(self, node, origin, derivation=None, taproot=False, chain_type=None):
|
|
self.origin = origin
|
|
self.node = node
|
|
self.derivation = derivation or KeyDerivationInfo()
|
|
self.taproot = taproot
|
|
self.chain_type = chain_type
|
|
|
|
def __eq__(self, other):
|
|
return hash(self) == hash(other)
|
|
|
|
def __hash__(self):
|
|
return hash(self.node.pubkey()) + hash(self.derivation)
|
|
|
|
def __len__(self):
|
|
return 34 - int(self.taproot) # <33:sec> or <32:xonly>
|
|
|
|
@property
|
|
def fingerprint(self):
|
|
return self.origin.fingerprint
|
|
|
|
def serialize(self):
|
|
return self.key_bytes()
|
|
|
|
def compile(self):
|
|
d = self.serialize()
|
|
return ser_compact_size(len(d)) + d
|
|
|
|
@classmethod
|
|
def parse_key(cls, key_str):
|
|
assert key_str[1:4].lower() == b"pub", "only extended pubkeys allowed"
|
|
# extended key
|
|
# or xpub or tpub as we use descriptors (SLIP-132 NOT allowed)
|
|
hint = key_str[0:1].lower()
|
|
if hint == b"x":
|
|
chain_type = "BTC"
|
|
elif hint == b"t":
|
|
chain_type = "XTN"
|
|
else:
|
|
# slip (ignore any implied address format)
|
|
chain_type = "BTC" if hint in b"yz" else "XTN"
|
|
|
|
node = ngu.hdnode.HDNode()
|
|
node.deserialize(key_str)
|
|
try:
|
|
assert node.privkey() is None, "no privkeys"
|
|
except ValueError:
|
|
# ValueError is thrown from libngu if key is public
|
|
pass
|
|
|
|
return node, chain_type
|
|
|
|
def validate(self, my_xfp, disable_checks=False):
|
|
assert self.chain_type == chains.current_key_chain().ctype, "wrong chain"
|
|
|
|
# xfp is always available, even if key was serialized without origin info
|
|
# upon parse root origin info is generated from key itself
|
|
xfp = self.origin.cc_fp
|
|
is_mine = (xfp == my_xfp)
|
|
|
|
# raises ValueError on invalid pubkey (should be in libngu)
|
|
# invalid public key not allowed even with disable checks
|
|
ngu.secp256k1.pubkey(self.node.pubkey())
|
|
|
|
if not disable_checks:
|
|
depth = self.node.depth()
|
|
# we now allow blinded keys that have depth X but derivation len is 0,
|
|
# where only fingerprint constitutes key origin
|
|
# only check if derivation length is greater than 0
|
|
if self.origin.derivation:
|
|
assert len(self.origin.derivation) == depth, \
|
|
"deriv len != xpub depth (xfp=%s)" % xfp2str(xfp)
|
|
if depth == 0:
|
|
# blinded keys allowed
|
|
# assert not self.node.parent_fp()
|
|
# assert self.node.child_number()[0] == 0
|
|
assert swab32(self.node.my_fp()) == xfp, "master xfp mismatch"
|
|
elif depth == 1:
|
|
target = swab32(self.node.parent_fp())
|
|
assert xfp == target, 'xfp depth=1 wrong'
|
|
|
|
if is_mine:
|
|
# it's supposed to be my key, so I should be able to generate pubkey
|
|
# - might indicate collision on xfp value between co-signers,
|
|
# and that's not supported
|
|
deriv = self.origin.str_derivation()
|
|
with stash.SensitiveValues() as sv:
|
|
chk_node = sv.derive_path(deriv)
|
|
assert self.node.pubkey() == chk_node.pubkey(), \
|
|
"[%s/%s] wrong pubkey" % (xfp2str(xfp), deriv[2:])
|
|
|
|
return is_mine
|
|
|
|
def derive(self, idx=None, change=False):
|
|
if self.derivation:
|
|
idx = self.derivation.der_index(idx, change)
|
|
else:
|
|
assert idx
|
|
|
|
new_node = self.node.copy()
|
|
new_node.derive(idx, False)
|
|
if self.origin:
|
|
origin = KeyOriginInfo(self.origin.fingerprint, self.origin.derivation + [idx],
|
|
self.origin.cc_fp)
|
|
else:
|
|
origin = KeyOriginInfo(self.origin.fingerprint, [idx], self.origin.cc_fp)
|
|
|
|
new_der = None
|
|
if self.derivation:
|
|
new_der = KeyDerivationInfo(self.derivation.indexes[1:])
|
|
|
|
return type(self)(new_node, origin, new_der, taproot=self.taproot)
|
|
|
|
@classmethod
|
|
def read_from(cls, s, taproot=False, musig=False):
|
|
first = s.read(1)
|
|
origin = None
|
|
|
|
if first == b"[":
|
|
prefix, char = read_until(s, b"]")
|
|
if char != b"]":
|
|
raise ValueError("Invalid key - missing ] in key origin info")
|
|
origin = KeyOriginInfo.from_string(prefix.decode())
|
|
else:
|
|
s.seek(-1, 1)
|
|
|
|
k, char = read_until(s, b",)/")
|
|
if musig and char not in b",)":
|
|
assert b"musig(" not in k, "nested musig not allowed"
|
|
assert char != b"/", "key derivation not allowed inside musig"
|
|
|
|
der = None
|
|
if char == b"/":
|
|
der = KeyDerivationInfo.parse(s)
|
|
if char is not None:
|
|
s.seek(-1, 1)
|
|
|
|
# parse key
|
|
node, chain_type = cls.parse_key(k)
|
|
if origin is None:
|
|
cc_fp = swab32(node.my_fp())
|
|
origin = KeyOriginInfo(ustruct.pack('<I', cc_fp), [], cc_fp)
|
|
return cls(node, origin, der, chain_type=chain_type, taproot=taproot)
|
|
|
|
@classmethod
|
|
def from_cc_data(cls, xfp, deriv, xpub):
|
|
xfp_str = xfp if isinstance(xfp, str) else xfp2str(xfp)
|
|
koi = KeyOriginInfo.from_string("%s/%s" % (xfp_str, deriv.replace("m/", "")))
|
|
node, chain_type = cls.parse_key(xpub.encode())
|
|
|
|
return cls(node, koi, KeyDerivationInfo(), chain_type=chain_type)
|
|
|
|
@classmethod
|
|
def from_cc_json(cls, vals, af_str):
|
|
key_exp = af_str + "_key_exp"
|
|
if key_exp in vals:
|
|
# new firmware, prefer key expression
|
|
return cls.from_string(vals[key_exp])
|
|
|
|
# TODO
|
|
node, _, _, _ = chains.slip132_deserialize(vals[af_str])
|
|
ek = chains.current_chain().serialize_public(node)
|
|
return cls.from_cc_data(vals["xfp"], vals["%s_deriv" % af_str], ek)
|
|
|
|
@classmethod
|
|
def from_psbt_xpub(cls, ek_bytes, xfp_path):
|
|
xfp, *path = xfp_path
|
|
koi = KeyOriginInfo(a2b_hex(xfp2str(xfp)), path)
|
|
# TODO this should be done by C code, no need to base58 encode/decode
|
|
# byte-serialized key should be decodable
|
|
ek = ngu.codecs.b58_encode(ek_bytes)
|
|
node, chain_type = cls.parse_key(ek.encode())
|
|
|
|
return cls(node, koi, KeyDerivationInfo(), chain_type=chain_type)
|
|
|
|
@property
|
|
def is_provably_unspendable(self):
|
|
if PROVABLY_UNSPENDABLE == self.node.pubkey():
|
|
return True
|
|
return False
|
|
|
|
@property
|
|
def prefix(self):
|
|
if self.origin and self.origin.derivation:
|
|
return "[%s]" % self.origin
|
|
# jut a bare [xfp]key - omit origin info (jut xfp)
|
|
# or no origin at all
|
|
return ""
|
|
|
|
def key_bytes(self):
|
|
kb = self.node.pubkey()
|
|
if self.taproot:
|
|
# xonly
|
|
kb = kb[1:]
|
|
return kb
|
|
|
|
def extended_public_key(self):
|
|
return chains.current_chain().serialize_public(self.node)
|
|
|
|
def to_string(self, external=True, internal=True):
|
|
key = self.prefix
|
|
key += self.extended_public_key()
|
|
if self.derivation and (external or internal):
|
|
key += "/" + self.derivation.to_string(external, internal)
|
|
|
|
return key
|
|
|
|
@classmethod
|
|
def from_string(cls, s):
|
|
s = BytesIO(s.encode())
|
|
return cls.read_from(s)
|
|
|
|
|
|
class MusigKey:
|
|
def __init__(self, keys, der=None, node=None):
|
|
self.keys = keys
|
|
self.derivation = der or KeyDerivationInfo()
|
|
self._node = node
|
|
|
|
def __len__(self):
|
|
return 33 # length + <32:xonly>
|
|
|
|
def __eq__(self, other):
|
|
return hash(self) == hash(other)
|
|
|
|
def __hash__(self):
|
|
return hash(self.node.pubkey()) + hash(self.derivation)
|
|
|
|
def serialize(self):
|
|
return self.key_bytes()
|
|
|
|
def compile(self):
|
|
d = self.serialize()
|
|
return ser_compact_size(len(d)) + d
|
|
|
|
@property
|
|
def node(self):
|
|
if self._node is None:
|
|
self._node = musig_synthetic_node(self.aggregate_pubkey().to_bytes())
|
|
return self._node
|
|
|
|
def validate(self, my_xfp, disable_checks=False):
|
|
has_mine = 0
|
|
for k in self.keys:
|
|
assert not k.is_provably_unspendable, "unspendable key inside musig"
|
|
if k.validate(my_xfp, disable_checks):
|
|
has_mine += 1
|
|
|
|
assert len(self.keys) == len(set(self.keys)), "musig keys not unique"
|
|
assert has_mine <= 1, "multiple own keys in musig"
|
|
return has_mine
|
|
|
|
def key_bytes(self):
|
|
return ngu.secp256k1.pubkey(self.node.pubkey()).to_xonly().to_bytes()
|
|
|
|
def aggregate_pubkey(self):
|
|
keyagg_cache = ngu.secp256k1.MusigKeyAggCache()
|
|
secp_pubkeys = [ngu.secp256k1.pubkey(k.node.pubkey()) for k in self.keys]
|
|
ngu.secp256k1.musig_pubkey_agg(secp_pubkeys, keyagg_cache)
|
|
return keyagg_cache.agg_pubkey()
|
|
|
|
def to_string(self, external=True, internal=True):
|
|
base = "musig(%s)" % (",".join([k.to_string(False, False) for k in self.keys]))
|
|
base += "/" + self.derivation.to_string(external, internal)
|
|
return base
|
|
|
|
def derive(self, idx=None, change=False):
|
|
idx = self.derivation.der_index(idx, change)
|
|
new_node = self.node.copy()
|
|
new_node.derive(idx, False)
|
|
|
|
return type(self)(self.keys, KeyDerivationInfo(self.derivation.indexes[1:]),
|
|
node=new_node)
|
|
|
|
@property
|
|
def is_provably_unspendable(self):
|
|
return False
|
|
|
|
@classmethod
|
|
def read_from(cls, s, taproot=True):
|
|
assert taproot, "musig in non-taproot context"
|
|
assert s.read(6) == b"musig(", "not musig()"
|
|
|
|
der = None
|
|
keys = []
|
|
while True:
|
|
k = ExtendedKey.read_from(s, taproot=taproot, musig=True)
|
|
k.der = None
|
|
k.taproot = taproot
|
|
# already verified that no der present in keys
|
|
k.derivation = None
|
|
keys.append(k)
|
|
c = s.read(1)
|
|
if c == b")":
|
|
sep = s.read(1)
|
|
if sep == b"/":
|
|
der = KeyDerivationInfo.parse(s)
|
|
|
|
s.seek(-1, 1)
|
|
break
|
|
|
|
assert c == b","
|
|
|
|
return cls(keys, der)
|
|
|
|
@classmethod
|
|
def from_string(cls, s):
|
|
s = BytesIO(s.encode())
|
|
return cls.read_from(s)
|
|
|
|
|
|
class KeyExpression:
|
|
@classmethod
|
|
def read_from(cls, s, taproot=False):
|
|
is_musig = (s.read(6) == b"musig(")
|
|
s.seek(-6, 1)
|
|
if is_musig:
|
|
return MusigKey.read_from(s, taproot=taproot)
|
|
else:
|
|
return ExtendedKey.read_from(s, taproot=taproot)
|
|
|
|
|
|
def bip388_wallet_policy_to_descriptor(desc_tmplt, keys_info):
|
|
for i in range(len(keys_info) - 1, -1, -1):
|
|
k_str = keys_info[i]
|
|
ph = "@%d" % i
|
|
desc_tmplt = desc_tmplt.replace(ph, k_str)
|
|
return desc_tmplt.replace("/**", "/<0;1>/*")
|
|
|
|
|
|
def bip388_validate_policy(desc_tmplt, keys_info):
|
|
s = BytesIO(desc_tmplt)
|
|
r = []
|
|
while True:
|
|
g1, char = read_until(s, b"@")
|
|
if not char:
|
|
# no more - done
|
|
break
|
|
|
|
# key derivation info required for policy
|
|
g2, char = read_until(s, b"/")
|
|
assert char, "key derivation missing"
|
|
if g1.endswith(b"musig("):
|
|
# key derivations not allowed inside musig
|
|
assert b"/" not in g2
|
|
assert g2[-1:] == b")"
|
|
|
|
for i, num in enumerate(g2[:-1].split(b",")):
|
|
if i:
|
|
# 0th element has @ already removed
|
|
assert num[0:1] == b"@"
|
|
num = num[1:]
|
|
|
|
num = int(num.decode())
|
|
if num not in r:
|
|
r.append(num)
|
|
|
|
else:
|
|
num = int(g2.decode())
|
|
if num not in r:
|
|
r.append(num)
|
|
|
|
assert s.read(1) in b"<*", "need multipath"
|
|
|
|
|
|
assert len(r) == len(keys_info), "Invalid policy"
|
|
assert r == list(range(len(r))), "Out of order"
|