558 lines
17 KiB
Python
558 lines
17 KiB
Python
# (c) Copyright 2023 by Coinkite Inc. This file is covered by license found in COPYING-CC.
|
|
#
|
|
# Copyright (c) 2020 Stepan Snigirev MIT License embit/arguments.py
|
|
#
|
|
import ngu, chains, ustruct
|
|
from io import BytesIO
|
|
from public_constants import AF_P2SH, AF_P2WSH_P2SH, AF_P2WSH, AF_CLASSIC, AF_P2TR
|
|
from binascii import unhexlify as a2b_hex
|
|
from binascii import hexlify as b2a_hex
|
|
from utils import keypath_to_str, str_to_keypath, swab32, xfp2str
|
|
from serializations import ser_compact_size
|
|
|
|
|
|
WILDCARD = "*"
|
|
PROVABLY_UNSPENDABLE = b'\x02P\x92\x9bt\xc1\xa0IT\xb7\x8bK`5\xe9z^\x07\x8aZ\x0f(\xec\x96\xd5G\xbf\xee\x9a\xce\x80:\xc0'
|
|
|
|
INPUT_CHARSET = "0123456789()[],'/*abcdefgh@:$%{}IJKLMNOPQRSTUVWXYZ&+-.;<=>?!^_|~ijklmnopqrstuvwxyzABCDEFGH`#\"\\ "
|
|
CHECKSUM_CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
|
|
|
|
|
|
def polymod(c, val):
|
|
c0 = c >> 35
|
|
c = ((c & 0x7ffffffff) << 5) ^ val
|
|
if (c0 & 1):
|
|
c ^= 0xf5dee51989
|
|
if (c0 & 2):
|
|
c ^= 0xa9fdca3312
|
|
if (c0 & 4):
|
|
c ^= 0x1bab10e32d
|
|
if (c0 & 8):
|
|
c ^= 0x3706b1677a
|
|
if (c0 & 16):
|
|
c ^= 0x644d626ffd
|
|
|
|
return c
|
|
|
|
def descriptor_checksum(desc):
|
|
c = 1
|
|
cls = 0
|
|
clscount = 0
|
|
for ch in desc:
|
|
pos = INPUT_CHARSET.find(ch)
|
|
if pos == -1:
|
|
raise ValueError(ch)
|
|
|
|
c = polymod(c, pos & 31)
|
|
cls = cls * 3 + (pos >> 5)
|
|
clscount += 1
|
|
if clscount == 3:
|
|
c = polymod(c, cls)
|
|
cls = 0
|
|
clscount = 0
|
|
|
|
if clscount > 0:
|
|
c = polymod(c, cls)
|
|
for j in range(0, 8):
|
|
c = polymod(c, 0)
|
|
c ^= 1
|
|
|
|
rv = ''
|
|
for j in range(0, 8):
|
|
rv += CHECKSUM_CHARSET[(c >> (5 * (7 - j))) & 31]
|
|
|
|
return rv
|
|
|
|
def append_checksum(desc):
|
|
return desc + "#" + descriptor_checksum(desc)
|
|
|
|
|
|
def parse_desc_str(string):
|
|
"""Remove comments, empty lines and strip line. Produce single line string"""
|
|
res = ""
|
|
for l in string.split("\n"):
|
|
strip_l = l.strip()
|
|
if not strip_l:
|
|
continue
|
|
if strip_l.startswith("#"):
|
|
continue
|
|
res += strip_l
|
|
return res
|
|
|
|
|
|
def multisig_descriptor_template(xpub, path, xfp, addr_fmt):
|
|
key_exp = "[%s%s]%s/0/*" % (xfp.lower(), path.replace("m", ''), xpub)
|
|
if addr_fmt == AF_P2WSH_P2SH:
|
|
descriptor_template = "sh(wsh(sortedmulti(M,%s,...)))"
|
|
elif addr_fmt == AF_P2WSH:
|
|
descriptor_template = "wsh(sortedmulti(M,%s,...))"
|
|
elif addr_fmt == AF_P2SH:
|
|
descriptor_template = "sh(sortedmulti(M,%s,...))"
|
|
elif addr_fmt == AF_P2TR:
|
|
# provably unspendable BIP-0341
|
|
descriptor_template = "tr(" + b2a_hex(PROVABLY_UNSPENDABLE[1:]).decode() + ",sortedmulti_a(M,%s,...))"
|
|
else:
|
|
return None
|
|
descriptor_template = descriptor_template % key_exp
|
|
return descriptor_template
|
|
|
|
|
|
def read_until(s, chars=b",)(#"):
|
|
# TODO potential infinite loop
|
|
# what is the longest possible element? (proly some raw( but that is unsupported)
|
|
#
|
|
res = b""
|
|
chunk = b""
|
|
char = None
|
|
while True:
|
|
chunk = s.read(1)
|
|
if len(chunk) == 0:
|
|
return res, None
|
|
if chunk in chars:
|
|
return res, chunk
|
|
res += chunk
|
|
return res, None
|
|
|
|
|
|
class KeyOriginInfo:
|
|
def __init__(self, fingerprint: bytes, derivation: list):
|
|
self.fingerprint = fingerprint
|
|
self.derivation = derivation
|
|
self.cc_fp = swab32(int(b2a_hex(self.fingerprint).decode(), 16))
|
|
|
|
def __eq__(self, other):
|
|
return self.psbt_derivation() == other.psbt_derivation()
|
|
|
|
def __hash__(self):
|
|
return hash(tuple(self.psbt_derivation()))
|
|
|
|
def str_derivation(self):
|
|
return keypath_to_str(self.derivation, prefix='m/', skip=0)
|
|
|
|
def psbt_derivation(self):
|
|
res = [self.cc_fp]
|
|
for i in self.derivation:
|
|
res.append(i)
|
|
return res
|
|
|
|
@classmethod
|
|
def from_string(cls, s: str):
|
|
arr = s.split("/")
|
|
xfp = a2b_hex(arr[0])
|
|
assert len(xfp) == 4
|
|
arr[0] = "m"
|
|
path = "/".join(arr)
|
|
derivation = str_to_keypath(xfp, path)[1:] # ignoring xfp here, already stored
|
|
return cls(xfp, derivation)
|
|
|
|
def __str__(self):
|
|
return "%s/%s" % (b2a_hex(self.fingerprint).decode(),
|
|
keypath_to_str(self.derivation, prefix='', skip=0).replace("'", "h"))
|
|
|
|
|
|
class KeyDerivationInfo:
|
|
|
|
def __init__(self, indexes=None):
|
|
self.indexes = indexes
|
|
if self.indexes is None:
|
|
self.indexes = [[0, 1], WILDCARD]
|
|
self.multi_path_index = 0
|
|
else:
|
|
self.multi_path_index = None
|
|
|
|
@property
|
|
def is_int_ext(self):
|
|
if self.multi_path_index is not None:
|
|
return True
|
|
return False
|
|
|
|
@property
|
|
def is_external(self):
|
|
if self.is_int_ext:
|
|
return True
|
|
elif self.indexes[-2] % 2 == 0:
|
|
return True
|
|
|
|
return False
|
|
|
|
@property
|
|
def branches(self):
|
|
if self.is_int_ext:
|
|
return self.indexes[self.multi_path_index]
|
|
else:
|
|
return [self.indexes[-2]]
|
|
|
|
@classmethod
|
|
def from_string(cls, s):
|
|
fail_msg = "Cannot use hardened sub derivation path"
|
|
if not s:
|
|
return cls()
|
|
res = []
|
|
mp = 0
|
|
mpi = None
|
|
for idx, i in enumerate(s.split("/")):
|
|
start_i = i.find("<")
|
|
if start_i != -1:
|
|
end_i = s.find(">")
|
|
assert end_i
|
|
inner = s[start_i+1:end_i]
|
|
assert ";" in inner
|
|
inner_split = inner.split(";")
|
|
assert len(inner_split) == 2, "wrong multipath"
|
|
res.append([int(i) for i in inner_split])
|
|
mp += 1
|
|
mpi = idx
|
|
else:
|
|
if i == WILDCARD:
|
|
res.append(WILDCARD)
|
|
else:
|
|
assert "'" not in i, fail_msg
|
|
assert "h" not in i, fail_msg
|
|
res.append(int(i))
|
|
|
|
# only one <x;y> allowed in subderivation
|
|
assert mp <= 1, "too many multipaths (%d)" % mp
|
|
|
|
if res == [0, WILDCARD]:
|
|
obj = cls()
|
|
else:
|
|
assert len(res) == 2, "Key derivation too long"
|
|
assert res[-1] == WILDCARD, "All keys must be ranged"
|
|
obj = cls(res)
|
|
obj.multi_path_index = mpi
|
|
return obj
|
|
|
|
def to_string(self, external=True, internal=True):
|
|
res = []
|
|
for i in self.indexes:
|
|
if isinstance(i, list):
|
|
if internal is True and external is False:
|
|
i = str(i[1])
|
|
elif internal is False and external is True:
|
|
i = str(i[0])
|
|
else:
|
|
i = "<%d;%d>" % (i[0], i[1])
|
|
else:
|
|
i = str(i)
|
|
res.append(i)
|
|
return "/".join(res)
|
|
|
|
def to_int_list(self, branch_idx, idx):
|
|
assert branch_idx in self.indexes[0]
|
|
return [branch_idx, idx]
|
|
|
|
|
|
class Key:
|
|
def __init__(self, node, origin, derivation=None, taproot=False, chain_type=None):
|
|
self.origin = origin
|
|
self.node = node
|
|
self.derivation = derivation
|
|
self.taproot = taproot
|
|
self.chain_type = chain_type
|
|
|
|
def __eq__(self, other):
|
|
return self.origin == other.origin \
|
|
and self.derivation.indexes == other.derivation.indexes
|
|
|
|
def __hash__(self):
|
|
return hash(self.to_string())
|
|
|
|
def __len__(self):
|
|
return 34 - int(self.taproot) # <33:sec> or <32:xonly>
|
|
|
|
@property
|
|
def fingerprint(self):
|
|
return self.origin.fingerprint
|
|
|
|
def serialize(self):
|
|
return self.key_bytes()
|
|
|
|
def compile(self):
|
|
d = self.serialize()
|
|
return ser_compact_size(len(d)) + d
|
|
|
|
@classmethod
|
|
def parse(cls, s):
|
|
first = s.read(1)
|
|
origin = None
|
|
if first == b"u":
|
|
s.seek(-1, 1)
|
|
return Unspend.parse(s)
|
|
|
|
if first == b"[":
|
|
prefix, char = read_until(s, b"]")
|
|
if char != b"]":
|
|
raise ValueError("Invalid key - missing ] in key origin info")
|
|
origin = KeyOriginInfo.from_string(prefix.decode())
|
|
else:
|
|
s.seek(-1, 1)
|
|
k, char = read_until(s, b",)/")
|
|
der = b""
|
|
if char == b"/":
|
|
der, char = read_until(s, b"<,)")
|
|
if char == b"<":
|
|
der += b"<"
|
|
branch, char = read_until(s, b">")
|
|
if char is None:
|
|
raise ValueError("Failed reading the key, missing >")
|
|
der += branch + b">"
|
|
rest, char = read_until(s, b",)")
|
|
der += rest
|
|
if char is not None:
|
|
s.seek(-1, 1)
|
|
# parse key
|
|
node, chain_type = cls.parse_key(k)
|
|
der = KeyDerivationInfo.from_string(der.decode())
|
|
return cls(node, origin, der, chain_type=chain_type)
|
|
|
|
@classmethod
|
|
def parse_key(cls, key_str):
|
|
chain_type = None
|
|
if key_str[1:4].lower() == b"pub":
|
|
# extended key
|
|
# or xpub or tpub as we use descriptors (SLIP-132 NOT allowed)
|
|
hint = key_str[0:1].lower()
|
|
if hint == b"x":
|
|
chain_type = "BTC"
|
|
else:
|
|
assert hint == b"t", "no slip"
|
|
chain_type = "XTN"
|
|
node = ngu.hdnode.HDNode()
|
|
node.deserialize(key_str)
|
|
else:
|
|
# only unspendable keys can be bare pubkeys - for now
|
|
H = PROVABLY_UNSPENDABLE[1:]
|
|
if b"r=" in key_str:
|
|
_, r = key_str.split(b"=")
|
|
if r == b"@":
|
|
# pick a fresh integer r in the range 0...n-1 uniformly at random and use H + rG
|
|
kp = ngu.secp256k1.keypair()
|
|
else:
|
|
# H + rG where r is provided from user
|
|
r = a2b_hex(r)
|
|
assert len(r) == 32, "r != 32"
|
|
kp = ngu.secp256k1.keypair(r)
|
|
|
|
H_xo = ngu.secp256k1.xonly_pubkey(H)
|
|
|
|
node = H_xo.tweak_add(kp.xonly_pubkey().to_bytes()).to_bytes()
|
|
|
|
elif a2b_hex(key_str) == H:
|
|
node = H
|
|
else:
|
|
node = a2b_hex(key_str)
|
|
|
|
assert len(node) == 32, "invalid pk %d %s" % (len(node), node)
|
|
|
|
return node, chain_type
|
|
|
|
def derive(self, idx=None, change=False):
|
|
if isinstance(self.node, bytes):
|
|
return self
|
|
if isinstance(idx, list):
|
|
for i in idx:
|
|
mp_i = self.derivation.multi_path_index or 0
|
|
if i in self.derivation.indexes[mp_i]:
|
|
idx = i
|
|
break
|
|
else:
|
|
assert False
|
|
|
|
elif idx is None:
|
|
# derive according to key subderivation if any
|
|
if self.derivation is None:
|
|
idx = 1 if change else 0
|
|
else:
|
|
if self.derivation.multi_path_index is not None:
|
|
ext, inter = self.derivation.indexes[self.derivation.multi_path_index]
|
|
idx = inter if change else ext
|
|
|
|
new_node = self.node.copy()
|
|
new_node.derive(idx, False)
|
|
if self.origin:
|
|
origin = KeyOriginInfo(self.origin.fingerprint, self.origin.derivation + [idx])
|
|
else:
|
|
fp = ustruct.pack('<I', swab32(self.node.my_fp()))
|
|
origin = KeyOriginInfo(fp, [idx])
|
|
|
|
derivation = KeyDerivationInfo(self.derivation.indexes[1:])
|
|
return type(self)(new_node, origin, derivation, taproot=self.taproot)
|
|
|
|
@classmethod
|
|
def read_from(cls, s, taproot=False):
|
|
return cls.parse(s)
|
|
|
|
@classmethod
|
|
def from_cc_data(cls, xfp, deriv, xpub):
|
|
koi = KeyOriginInfo.from_string("%s/%s" % (xfp2str(xfp), deriv.replace("m/", "")))
|
|
node = ngu.hdnode.HDNode()
|
|
node.deserialize(xpub)
|
|
return cls(node, koi, KeyDerivationInfo())
|
|
|
|
def to_cc_data(self):
|
|
ch = chains.current_chain()
|
|
return (self.origin.cc_fp,
|
|
self.origin.str_derivation(),
|
|
ch.serialize_public(self.node, AF_CLASSIC))
|
|
|
|
@property
|
|
def is_provably_unspendable(self):
|
|
if isinstance(self.node, bytes):
|
|
return True
|
|
if PROVABLY_UNSPENDABLE == self.node.pubkey():
|
|
return True
|
|
return False
|
|
|
|
@property
|
|
def prefix(self):
|
|
if self.origin:
|
|
return "[%s]" % self.origin
|
|
return ""
|
|
|
|
def key_bytes(self):
|
|
kb = self.node
|
|
if not isinstance(kb, bytes):
|
|
kb = self.node.pubkey()
|
|
if self.taproot:
|
|
if len(kb) == 33:
|
|
kb = kb[1:]
|
|
assert len(kb) == 32
|
|
return kb
|
|
|
|
def extended_public_key(self):
|
|
return chains.current_chain().serialize_public(self.node)
|
|
|
|
def to_string(self, external=True, internal=True, subderiv=True):
|
|
key = self.prefix
|
|
if isinstance(self.node, ngu.hdnode.HDNode):
|
|
key += self.extended_public_key()
|
|
if self.derivation and subderiv:
|
|
key += "/" + self.derivation.to_string(external, internal)
|
|
else:
|
|
key += b2a_hex(self.node).decode()
|
|
|
|
return key
|
|
|
|
@classmethod
|
|
def from_string(cls, s):
|
|
s = BytesIO(s.encode())
|
|
return cls.parse(s)
|
|
|
|
|
|
class Unspend(Key):
|
|
def __init__(self, node, origin=None, derivation=None, taproot=True, chain_type=None):
|
|
super().__init__(node, origin, derivation, taproot, chain_type)
|
|
assert self.taproot
|
|
|
|
def __eq__(self, other):
|
|
return self.node.chain_code() == other.node.chain_code() \
|
|
and self.node.pubkey() == other.node.pubkey() \
|
|
and self.derivation.indexes == other.derivation.indexes
|
|
|
|
@classmethod
|
|
def parse(cls, s):
|
|
assert s.read(8) == b"unspend("
|
|
chain_code, c = read_until(s, b")")
|
|
chain_code = a2b_hex(chain_code)
|
|
assert len(chain_code) == 32, "chain code length"
|
|
assert c
|
|
char = s.read(1)
|
|
if char != b"/":
|
|
raise ValueError("ranged unspend required")
|
|
der, char = read_until(s, b"<,)")
|
|
if char == b"<":
|
|
der += b"<"
|
|
branch, char = read_until(s, b">")
|
|
if char is None:
|
|
raise ValueError("Failed reading the key, missing >")
|
|
der += branch + b">"
|
|
rest, char = read_until(s, b",)")
|
|
der += rest
|
|
if char is not None:
|
|
s.seek(-1, 1)
|
|
|
|
node = ngu.hdnode.HDNode().from_chaincode_pubkey(chain_code,
|
|
PROVABLY_UNSPENDABLE)
|
|
der = KeyDerivationInfo.from_string(der.decode())
|
|
return cls(node, None, der, chain_type=None)
|
|
|
|
def to_string(self, external=True, internal=True, subderiv=True):
|
|
res = "unspend(%s)" % b2a_hex(self.node.chain_code()).decode()
|
|
if self.derivation and subderiv:
|
|
res += "/" + self.derivation.to_string(external, internal)
|
|
|
|
return res
|
|
|
|
@property
|
|
def is_provably_unspendable(self):
|
|
return True
|
|
|
|
|
|
def fill_policy(policy, keys, external=True, internal=True):
|
|
orig_keys = []
|
|
for k in keys:
|
|
if not isinstance(k, str):
|
|
k_orig = k.to_string(external, internal, subderiv=False)
|
|
else:
|
|
_idx = k.find("]") # end of key origin info - no more / expected besides subderivation
|
|
assert _idx != -1
|
|
ek = k[_idx+1:].split("/")[0]
|
|
k_orig = k[:_idx+1] + ek
|
|
|
|
if k_orig not in orig_keys:
|
|
orig_keys.append(k_orig)
|
|
|
|
for i in range(len(orig_keys) - 1, -1, -1):
|
|
k = orig_keys[i]
|
|
ph = "@%d" % i
|
|
ph_len = len(ph)
|
|
while True:
|
|
ix = policy.find(ph)
|
|
if ix == -1:
|
|
break
|
|
|
|
assert policy[ix+ph_len] == "/"
|
|
# subderivation is part of the policy
|
|
x = ix + ph_len
|
|
substr = policy[x:x+26] # 26 is the longest possible subderivation allowed "/<2147483647;2147483646>/*"
|
|
mp_start = substr.find("<")
|
|
assert mp_start != -1
|
|
mp_end = substr.find(">")
|
|
mp = substr[mp_start:mp_end + 1]
|
|
_ext, _int = mp[1:-1].split(";")
|
|
if external and not internal:
|
|
sub = _ext
|
|
elif internal and not external:
|
|
sub = _int
|
|
else:
|
|
sub = None
|
|
if sub is not None:
|
|
policy = policy[:x + mp_start] + sub + policy[x + mp_end + 1:]
|
|
|
|
x = policy[ix:ix + ph_len]
|
|
assert x == ph
|
|
policy = policy[:ix] + k + policy[ix + ph_len:]
|
|
|
|
return policy
|
|
|
|
|
|
def taproot_tree_helper(scripts):
|
|
from miniscript import Miniscript
|
|
|
|
if isinstance(scripts, Miniscript):
|
|
script = scripts.compile()
|
|
assert isinstance(script, bytes)
|
|
h = ngu.secp256k1.tagged_sha256(b"TapLeaf", chains.tapscript_serialize(script))
|
|
return [(chains.TAPROOT_LEAF_TAPSCRIPT, script, bytes())], h
|
|
if len(scripts) == 1:
|
|
return taproot_tree_helper(scripts[0])
|
|
|
|
split_pos = len(scripts) // 2
|
|
left, left_h = taproot_tree_helper(scripts[0:split_pos])
|
|
right, right_h = taproot_tree_helper(scripts[split_pos:])
|
|
left = [(version, script, control + right_h) for version, script, control in left]
|
|
right = [(version, script, control + left_h) for version, script, control in right]
|
|
if right_h < left_h:
|
|
right_h, left_h = left_h, right_h
|
|
h = ngu.secp256k1.tagged_sha256(b"TapBranch", left_h + right_h)
|
|
return left + right, h |