rework key derivation parsing

This commit is contained in:
scgbckbone 2025-06-18 14:42:24 +02:00
parent cd344cb646
commit d4e9549f24

View File

@ -186,43 +186,44 @@ class KeyDerivationInfo:
return [self.indexes[-2]]
@classmethod
def from_string(cls, s):
fail_msg = "Cannot use hardened sub derivation path"
if not s:
return cls()
res = []
mp = 0
mpi = None
for idx, i in enumerate(s.split("/")):
start_i = i.find("<")
if start_i != -1:
end_i = s.find(">")
assert end_i
inner = s[start_i+1:end_i]
assert ";" in inner
inner_split = inner.split(";")
assert len(inner_split) == 2, "wrong multipath"
res.append([int(i) for i in inner_split])
mp += 1
mpi = idx
else:
if i == WILDCARD:
res.append(WILDCARD)
else:
assert "'" not in i, fail_msg
assert "h" not in i, fail_msg
res.append(int(i))
def parse(cls, s):
err = "Malformed key derivation"
multi_i = None
idxs = []
while True:
got, char = read_until(s, b"<,)/")
# only one <x;y> allowed in subderivation
assert mp <= 1, "too many multipaths (%d)" % mp
if char == b"<":
assert multi_i is None, "too many multipaths"
ext_num, char = read_until(s, b";")
assert char, err
int_num, char = read_until(s, b">")
assert char, err
if res == [0, WILDCARD]:
multi_i = len(idxs)
idxs.append([int(ext_num.decode()), int(int_num.decode())])
elif got == b"*":
# every derivation has to end with wildcard (only ranged keys allowed)
idxs.append(WILDCARD)
break
elif char == b"/" and got:
assert (b"'" not in got) and (b"h" not in got), "Cannot use hardened sub derivation path"
idxs.append(int(got.decode()))
if idxs == [0, WILDCARD]:
obj = cls()
else:
assert len(res) == 2, "Key derivation too long"
assert res[-1] == WILDCARD, "All keys must be ranged"
obj = cls(res)
obj.multi_path_index = mpi
assert idxs[-1] == WILDCARD, "All keys must be ranged"
if multi_i is not None:
assert len(idxs[multi_i]) == 2, "wrong multipath"
obj = cls(idxs)
obj.multi_path_index = multi_i
return obj
def to_string(self, external=True, internal=True):
@ -292,23 +293,14 @@ class Key:
k, char = read_until(s, b",)/")
der = b""
if char == b"/":
der, char = read_until(s, b"<,)")
if char == b"<":
der += b"<"
branch, char = read_until(s, b">")
if char is None:
raise ValueError("Failed reading the key, missing >")
der += branch + b">"
rest, char = read_until(s, b",)")
der += rest
der = KeyDerivationInfo.parse(s)
if char is not None:
s.seek(-1, 1)
# parse key
node, chain_type = cls.parse_key(k)
der = KeyDerivationInfo.from_string(der.decode())
if origin is None:
origin = KeyOriginInfo(ustruct.pack('<I', swab32(node.my_fp())), [])
return cls(node, origin, der, chain_type=chain_type)
return cls(node, origin, der or KeyDerivationInfo(), chain_type=chain_type)
@classmethod
def parse_key(cls, key_str):
@ -432,21 +424,13 @@ class Unspend(Key):
char = s.read(1)
if char != b"/":
raise ValueError("ranged unspend required")
der, char = read_until(s, b"<,)")
if char == b"<":
der += b"<"
branch, char = read_until(s, b">")
if char is None:
raise ValueError("Failed reading the key, missing >")
der += branch + b">"
rest, char = read_until(s, b",)")
der += rest
der = KeyDerivationInfo.parse(s)
if char is not None:
s.seek(-1, 1)
node = ngu.hdnode.HDNode().from_chaincode_pubkey(chain_code,
PROVABLY_UNSPENDABLE)
der = KeyDerivationInfo.from_string(der.decode())
return cls(node, None, der, chain_type=None)
def to_string(self, external=True, internal=True, subderiv=True):