add bip32 derivation fallback when retreiving signing nodes for high-index inputs

This commit is contained in:
Craig Raw 2026-03-11 17:31:13 +02:00
parent 6cbb144326
commit f5a52f9eae
2 changed files with 91 additions and 2 deletions

View File

@ -100,6 +100,11 @@ public class FinalizingPSBTWallet extends Wallet {
return signedInputNodes;
}
@Override
public Map<PSBTInput, WalletNode> getSigningNodes(PSBT psbt, boolean useDerivationFallback) {
return signedInputNodes;
}
@Override
public ECKey getPubKey(WalletNode node) {
return signedNodeKeys.get(node).get(0);

View File

@ -1523,7 +1523,7 @@ public class Wallet extends Persistable implements Comparable<Wallet> {
}
public boolean canSignAllInputs(PSBT psbt) {
return isValid() && getSigningNodes(psbt).size() == psbt.getPsbtInputs().size();
return isValid() && getSigningNodes(psbt, false).size() == psbt.getPsbtInputs().size();
}
/**
@ -1533,6 +1533,10 @@ public class Wallet extends Persistable implements Comparable<Wallet> {
* @return A map if the PSBT inputs and the nodes that can sign them
*/
public Map<PSBTInput, WalletNode> getSigningNodes(PSBT psbt) {
return getSigningNodes(psbt, true);
}
public Map<PSBTInput, WalletNode> getSigningNodes(PSBT psbt, boolean useDerivationFallback) {
Map<PSBTInput, WalletNode> signingNodes = new LinkedHashMap<>();
Map<Script, WalletNode> walletOutputScripts = getWalletOutputScripts();
@ -1542,6 +1546,12 @@ public class Wallet extends Persistable implements Comparable<Wallet> {
if(utxo != null) {
Script scriptPubKey = utxo.getScript();
WalletNode signingNode = walletOutputScripts.get(scriptPubKey);
// BIP32-derivation fallback for inputs beyond the wallet's derived address range
if(signingNode == null && useDerivationFallback) {
signingNode = getSigningNodeFromDerivation(psbtInput, scriptPubKey);
}
if(signingNode != null) {
signingNodes.put(psbtInput, signingNode);
}
@ -1551,6 +1561,80 @@ public class Wallet extends Persistable implements Comparable<Wallet> {
return signingNodes;
}
private WalletNode getSigningNodeFromDerivation(PSBTInput psbtInput, Script scriptPubKey) {
Map<ECKey, KeyDerivation> derivedPublicKeys = psbtInput.getDerivedPublicKeys();
Map<ECKey, Map<KeyDerivation, List<Sha256Hash>>> tapDerivedPublicKeys = psbtInput.getTapDerivedPublicKeys();
for(Map.Entry<ECKey, KeyDerivation> entry : derivedPublicKeys.entrySet()) {
WalletNode node = matchDerivation(entry.getValue(), scriptPubKey);
if(node != null) {
return node;
}
}
for(Map.Entry<ECKey, Map<KeyDerivation, List<Sha256Hash>>> entry : tapDerivedPublicKeys.entrySet()) {
for(KeyDerivation keyDerivation : entry.getValue().keySet()) {
WalletNode node = matchDerivation(keyDerivation, scriptPubKey);
if(node != null) {
return node;
}
}
}
return null;
}
private WalletNode matchDerivation(KeyDerivation keyDerivation, Script scriptPubKey) {
if(policyType == PolicyType.CUSTOM) {
return null;
}
for(Keystore keystore : getKeystores()) {
ECKey derivedKey = keystore.getPubKeyForDerivation(keyDerivation);
if(derivedKey == null) {
continue;
}
List<ChildNumber> fullPath = keyDerivation.getDerivation();
List<ChildNumber> keystorePath = keystore.getKeyDerivation().getDerivation();
List<ChildNumber> remaining;
if(fullPath.size() > keystorePath.size() && fullPath.subList(0, keystorePath.size()).equals(keystorePath)) {
remaining = fullPath.subList(keystorePath.size(), fullPath.size());
} else {
remaining = fullPath;
}
if(remaining.size() != 2) {
continue;
}
KeyPurpose keyPurpose = KeyPurpose.fromChildNumber(remaining.get(0));
int addressIndex = remaining.get(1).num();
if(keyPurpose == null || !getWalletKeyPurposes().contains(keyPurpose)) {
continue;
}
WalletNode purposeNode = getNode(keyPurpose);
WalletNode targetNode = null;
for(WalletNode child : purposeNode.getChildren()) {
if(child.getIndex() == addressIndex) {
targetNode = child;
break;
}
}
if(targetNode == null) {
targetNode = new WalletNode(this, keyPurpose, addressIndex);
}
Script expectedScript = getOutputScript(targetNode);
if(expectedScript.equals(scriptPubKey)) {
return targetNode;
}
}
return null;
}
public Collection<Keystore> getSigningKeystores(PSBT psbt) {
Set<Keystore> signingKeystores = new LinkedHashSet<>();
@ -1582,7 +1666,7 @@ public class Wallet extends Persistable implements Comparable<Wallet> {
WalletNode purposeNode = copy.getNode(keyPurpose);
purposeNode.fillToIndex(purposeNode.getChildren().size() + SEARCH_LOOKAHEAD);
}
Map<PSBTInput, WalletNode> copySigningNodes = copy.getSigningNodes(psbt);
Map<PSBTInput, WalletNode> copySigningNodes = copy.getSigningNodes(psbt, false);
boolean found = false;
int gapLimit = getGapLimit();
for(KeyPurpose keyPurpose : KeyPurpose.DEFAULT_PURPOSES) {