fixes + small speed up for multi/sortedmulti scripts

This commit is contained in:
scgbckbone 2025-06-28 11:04:35 +02:00
parent f7cb241e9c
commit 0b20ef5360
3 changed files with 35 additions and 43 deletions

View File

@ -396,9 +396,8 @@ class Key:
def key_bytes(self):
kb = self.node.pubkey()
if self.taproot:
if len(kb) == 33:
kb = kb[1:]
assert len(kb) == 32
# xonly
kb = kb[1:]
return kb
def extended_public_key(self):

View File

@ -236,20 +236,22 @@ class Descriptor:
return self._keys
def derive(self, idx=None, change=False):
# derive keys first
derived_keys = OrderedDict()
for i, k in enumerate(self.keys):
if not i and self.is_taproot:
# internal key is always at index 0 in self.keys
# ik is derived few lines later
continue
dk = k.derive(idx, change=change)
dk.taproot=self.is_taproot
derived_keys[k] = dk
if self.is_taproot:
# derive keys first
# duplicate keys can be may be found in different leaves
# use map to derive each key just once
derived_keys = OrderedDict()
ikd = None
for i, k in enumerate(self.keys):
dk = k.derive(idx, change=change)
dk.taproot = self.is_taproot
derived_keys[k] = dk
if not i:
# internal key is always at index 0 in self.keys
ikd = dk
return type(self)(
self.key.derive(idx, change=change),
ikd,
tapscript=self.tapscript.derive(idx, derived_keys, change=change),
addr_fmt=self.addr_fmt,
keys=list(derived_keys.values()),
@ -257,9 +259,8 @@ class Descriptor:
if self.miniscript:
return type(self)(
None,
self.miniscript.derive(idx, derived_keys, change=change),
self.miniscript.derive(idx, change=change),
addr_fmt=self.addr_fmt,
keys=list(derived_keys.values())
)
# single-sig

View File

@ -662,9 +662,8 @@ class KeyHash(Key):
return super().parse_key(k, *args, **kwargs)
def serialize(self, *args, **kwargs):
if self.taproot:
return ngu.hash.hash160(self.node.pubkey()[1:33])
return ngu.hash.hash160(self.node.pubkey())
start = 1 if self.taproot else 0
return ngu.hash.hash160(self.node.pubkey()[start:33])
def __len__(self):
return 21 # <20:pkh>
@ -1354,8 +1353,14 @@ class Multi(Miniscript):
N_MAX = 20
def inner_compile(self):
# scr = [arg.compile() for arg in self.args[1:]]
# optimization - it is all keys with known length (xonly keys not allowed here)
scr = [b'\x21' + arg.key_bytes() for arg in self.args[1:]]
if self.NAME == "sortedmulti":
scr.sort()
return (
b"".join([arg.compile() for arg in self.args])
self.args[0].compile()
+ b"".join(scr)
+ Number(len(self.args) - 1).compile()
+ b"\xae"
)
@ -1381,13 +1386,6 @@ class Sortedmulti(Multi):
# <k> <key1> ... <keyn> <n> CHECKMULTISIG
NAME = "sortedmulti"
def inner_compile(self):
return (
self.args[0].compile()
+ b"".join(sorted([arg.compile() for arg in self.args[1:]]))
+ Number(len(self.args) - 1).compile()
+ b"\xae"
)
class Multi_a(Multi):
# <key1> CHECKSIG <key> CHECKSIGADD ... <keyn> CHECKSIGADD EQUALVERIFY
@ -1398,12 +1396,19 @@ class Multi_a(Multi):
def inner_compile(self):
from opcodes import OP_CHECKSIGADD, OP_NUMEQUAL, OP_CHECKSIG
script = b""
for i, key in enumerate(self.args[1:]):
script += key.compile()
# scr = [arg.compile() for arg in self.args[1:]]
# optimization - it is all keys with known length (only xonly keys allowed here)
scr = [b"\x20" + arg.key_bytes() for arg in self.args[1:]]
if self.NAME == "sortedmulti_a":
scr.sort()
for i, key in enumerate(scr):
script += key
if i == 0:
script += bytes([OP_CHECKSIG])
else:
script += bytes([OP_CHECKSIGADD])
script += self.args[0].compile() # M (threshold)
script += bytes([OP_NUMEQUAL])
return script
@ -1417,19 +1422,6 @@ class Sortedmulti_a(Multi_a):
# <key1> CHECKSIG <key> CHECKSIGADD ... <keyn> CHECKSIGADD EQUALVERIFY
NAME = "sortedmulti_a"
def inner_compile(self):
from opcodes import OP_CHECKSIGADD, OP_NUMEQUAL, OP_CHECKSIG
script = b""
for i, key in enumerate(sorted([arg.compile() for arg in self.args[1:]])):
script += key
if i == 0:
script += bytes([OP_CHECKSIG])
else:
script += bytes([OP_CHECKSIGADD])
script += self.args[0].compile() # M (threshold)
script += bytes([OP_NUMEQUAL])
return script
class Pk(OneArg):
# <key> CHECKSIG