diff --git a/shared/desc_utils.py b/shared/desc_utils.py index d5680376..22e01ac0 100644 --- a/shared/desc_utils.py +++ b/shared/desc_utils.py @@ -186,43 +186,44 @@ class KeyDerivationInfo: return [self.indexes[-2]] @classmethod - def from_string(cls, s): - fail_msg = "Cannot use hardened sub derivation path" - if not s: - return cls() - res = [] - mp = 0 - mpi = None - for idx, i in enumerate(s.split("/")): - start_i = i.find("<") - if start_i != -1: - end_i = s.find(">") - assert end_i - inner = s[start_i+1:end_i] - assert ";" in inner - inner_split = inner.split(";") - assert len(inner_split) == 2, "wrong multipath" - res.append([int(i) for i in inner_split]) - mp += 1 - mpi = idx - else: - if i == WILDCARD: - res.append(WILDCARD) - else: - assert "'" not in i, fail_msg - assert "h" not in i, fail_msg - res.append(int(i)) + def parse(cls, s): + err = "Malformed key derivation" + multi_i = None + idxs = [] + while True: + got, char = read_until(s, b"<,)/") - # only one allowed in subderivation - assert mp <= 1, "too many multipaths (%d)" % mp + if char == b"<": + assert multi_i is None, "too many multipaths" + ext_num, char = read_until(s, b";") + assert char, err + int_num, char = read_until(s, b">") + assert char, err - if res == [0, WILDCARD]: + multi_i = len(idxs) + idxs.append([int(ext_num.decode()), int(int_num.decode())]) + + + elif got == b"*": + # every derivation has to end with wildcard (only ranged keys allowed) + idxs.append(WILDCARD) + break + + elif char == b"/" and got: + assert (b"'" not in got) and (b"h" not in got), "Cannot use hardened sub derivation path" + idxs.append(int(got.decode())) + + + if idxs == [0, WILDCARD]: obj = cls() else: - assert len(res) == 2, "Key derivation too long" - assert res[-1] == WILDCARD, "All keys must be ranged" - obj = cls(res) - obj.multi_path_index = mpi + assert idxs[-1] == WILDCARD, "All keys must be ranged" + if multi_i is not None: + assert len(idxs[multi_i]) == 2, "wrong multipath" + + obj = cls(idxs) + obj.multi_path_index = multi_i + return obj def to_string(self, external=True, internal=True): @@ -292,23 +293,14 @@ class Key: k, char = read_until(s, b",)/") der = b"" if char == b"/": - der, char = read_until(s, b"<,)") - if char == b"<": - der += b"<" - branch, char = read_until(s, b">") - if char is None: - raise ValueError("Failed reading the key, missing >") - der += branch + b">" - rest, char = read_until(s, b",)") - der += rest + der = KeyDerivationInfo.parse(s) if char is not None: s.seek(-1, 1) # parse key node, chain_type = cls.parse_key(k) - der = KeyDerivationInfo.from_string(der.decode()) if origin is None: origin = KeyOriginInfo(ustruct.pack('") - if char is None: - raise ValueError("Failed reading the key, missing >") - der += branch + b">" - rest, char = read_until(s, b",)") - der += rest + + der = KeyDerivationInfo.parse(s) if char is not None: s.seek(-1, 1) node = ngu.hdnode.HDNode().from_chaincode_pubkey(chain_code, PROVABLY_UNSPENDABLE) - der = KeyDerivationInfo.from_string(der.decode()) return cls(node, None, der, chain_type=None) def to_string(self, external=True, internal=True, subderiv=True):