Implement multisig input signing and fixes to signtx for keepkey

This commit is contained in:
Andrew Chow 2019-01-04 18:08:46 -05:00
parent 9d6949d1b0
commit 685759377d

View File

@ -8,6 +8,7 @@ from keepkeylib import messages_pb2, types_pb2 as proto
from keepkeylib.tx_api import TxApi
from ..base58 import get_xpub_fingerprint, decode, to_address, xpub_main_2_test, get_xpub_fingerprint_hex
from ..serializations import ser_uint256, uint256_from_str
from .. import bech32
import base64
import binascii
@ -17,6 +18,8 @@ import os
KEEPKEY_VENDOR_ID = 0x2B24
KEEPKEY_DEVICE_ID = 0x0001
py_enumerate = enumerate # Need to use the enumerate built-in but there's another function already named that
class TxAPIPSBT(TxApi):
def __init__(self, psbt):
@ -24,31 +27,77 @@ class TxAPIPSBT(TxApi):
self.psbt = psbt
def get_tx(self, txhash):
tx = None
for psbt_in in self.psbt.inputs:
if psbt_in.non_witness_utxo and psbt_in.non_witness_utxo.sha256 == uint256_from_str(binascii.unhexlify(txhash)[::-1]):
tx = psbt_in.non_witness_utxo
if not tx:
raise ValueError("TX {} not found in PSBT".format(txhash))
# Find index of the input
for i, input in py_enumerate(self.psbt.tx.vin):
if input.prevout.hash == uint256_from_str(binascii.unhexlify(txhash)[::-1]):
break
psbt_in = self.psbt.inputs[i]
t = proto.TransactionType()
t.version = tx.nVersion
t.lock_time = tx.nLockTime
if psbt_in.non_witness_utxo:
assert(psbt_in.non_witness_utxo.sha256 == uint256_from_str(binascii.unhexlify(txhash)[::-1]))
tx = psbt_in.non_witness_utxo
for vin in tx.vin:
i = t.inputs.add()
i.prev_hash = ser_uint256(vin.prevout.hash)[::-1]
i.prev_index = vin.prevout.n
i.script_sig = vin.scriptSig
i.sequence = vin.nSequence
t.version = tx.nVersion
t.lock_time = tx.nLockTime
for vout in tx.vout:
for vin in tx.vin:
i = t.inputs.add()
i.prev_hash = ser_uint256(vin.prevout.hash)[::-1]
i.prev_index = vin.prevout.n
i.script_sig = vin.scriptSig
i.sequence = vin.nSequence
for vout in tx.vout:
o = t.bin_outputs.add()
o.amount = vout.nValue
o.script_pubkey = vout.scriptPubKey
elif psbt_in.witness_utxo:
# HACK: the library looks up this info for all inputs. we just need to appease it for segwit stuff
t.version = 1
t.lock_time = 0
o = t.bin_outputs.add()
o.amount = vout.nValue
o.script_pubkey = vout.scriptPubKey
o.amount = psbt_in.witness_utxo.nValue
o.script_pubkey = psbt_in.witness_utxo.scriptPubKey
else:
raise ValueError('{} is not an input in this transaction'.format(txhash))
return t
# Only handles up to 15 of 15
def parse_multisig(script):
# Get m
m = script[0] - 80
if m < 1 or m > 15:
return (False, None)
# Get pubkeys and build HDNodePathType
pubkeys = []
offset = 1
while True:
pubkey_len = script[offset]
if pubkey_len != 33:
break
offset += 1
key = script[offset:offset + 33]
offset += 33
hd_node = proto.HDNodeType(depth=0, fingerprint=0, child_num=0, chain_code=b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', public_key=key)
pubkeys.append(proto.HDNodePathType(node=hd_node, address_n=[]))
# Check things at the end
n = script[offset] - 80
if n != len(pubkeys):
return (False, None)
offset += 1
op_cms = script[offset]
if op_cms != 174:
return (False, None)
# Build MultisigRedeemScriptType and return it
multisig = proto.MultisigRedeemScriptType(m=m, signatures=[b''] * n, pubkeys=pubkeys)
return (True, multisig)
# This class extends the HardwareWalletClient for Digital Bitbox specific things
class KeepkeyClient(HardwareWalletClient):
@ -85,93 +134,146 @@ class KeepkeyClient(HardwareWalletClient):
master_key = self.client.get_public_node([0])
master_fp = get_xpub_fingerprint(master_key.xpub)
# Prepare inputs
inputs = []
for psbt_in, txin in zip(tx.inputs, tx.tx.vin):
txinputtype = proto.TxInputType()
# Do multiple passes for multisig
passes = 1
p = 0
# Set the input stuff
txinputtype.prev_hash = ser_uint256(txin.prevout.hash)[::-1]
txinputtype.prev_index = txin.prevout.n
txinputtype.sequence = txin.nSequence
while p < passes:
# Prepare inputs
inputs = []
to_ignore = []
for input_num, (psbt_in, txin) in py_enumerate(list(zip(tx.inputs, tx.tx.vin))):
txinputtype = proto.TxInputType()
# Detrermine spend type
if psbt_in.non_witness_utxo:
txinputtype.script_type = 0
elif psbt_in.witness_utxo:
# Check if the output is p2sh
if psbt_in.witness_utxo.is_p2sh():
txinputtype.script_type = 3
# Set the input stuff
txinputtype.prev_hash = ser_uint256(txin.prevout.hash)[::-1]
txinputtype.prev_index = txin.prevout.n
txinputtype.sequence = txin.nSequence
# Detrermine spend type
scriptcode = b''
if psbt_in.non_witness_utxo:
utxo = psbt_in.non_witness_utxo.vout[txin.prevout.n]
txinputtype.script_type = proto.SPENDADDRESS
scriptcode = utxo.scriptPubKey
txinputtype.amount = psbt_in.non_witness_utxo.vout[txin.prevout.n].nValue
elif psbt_in.witness_utxo:
utxo = psbt_in.witness_utxo
# Check if the output is p2sh
if psbt_in.witness_utxo.is_p2sh():
txinputtype.script_type = proto.SPENDP2SHWITNESS
else:
txinputtype.script_type = proto.SPENDWITNESS
scriptcode = psbt_in.witness_utxo.scriptPubKey
txinputtype.amount = psbt_in.witness_utxo.nValue
# Set the script
if psbt_in.witness_script:
scriptcode = psbt_in.witness_script
elif psbt_in.redeem_script:
scriptcode = psbt_in.redeem_script
def ignore_input():
txinputtype.address_n.extend([0x80000000])
txinputtype.ClearField('multisig')
txinputtype.script_type = proto.SPENDWITNESS
inputs.append(txinputtype)
to_ignore.append(input_num)
# Check for multisig
is_ms, multisig = parse_multisig(scriptcode)
if is_ms:
# Add to txinputtype
txinputtype.multisig.CopyFrom(multisig)
if psbt_in.non_witness_utxo:
if utxo.is_p2sh:
txinputtype.script_type = proto.SPENDMULTISIG
else:
# Cannot sign bare multisig, ignore it
ignore_input()
continue
elif not is_ms and psbt_in.non_witness_utxo and not utxo.is_p2pkh:
# Cannot sign unknown spk, ignore it
ignore_input()
continue
elif not is_ms and psbt_in.witness_utxo and psbt_in.witness_script:
# Cannot sign unknown witness script, ignore it
ignore_input()
continue
# Find key to sign with
found = False
our_keys = 0
for key in psbt_in.hd_keypaths.keys():
keypath = psbt_in.hd_keypaths[key]
if keypath[0] == master_fp and key not in psbt_in.partial_sigs:
if not found:
txinputtype.address_n.extend(keypath[1:])
found = True
our_keys += 1
# Determine if we need to do more passes to sign everything
if our_keys > passes:
passes = our_keys
if not found:
# This input is not one of ours
ignore_input()
continue
# append to inputs
inputs.append(txinputtype)
# address version byte
if self.is_testnet:
p2pkh_version = b'\x6f'
p2sh_version = b'\xc4'
bech32_hrp = 'tb'
else:
p2pkh_version = b'\x00'
p2sh_version = b'\x05'
bech32_hrp = 'bc'
# prepare outputs
outputs = []
for out in tx.tx.vout:
txoutput = proto.TxOutputType()
txoutput.amount = out.nValue
txoutput.script_type = proto.PAYTOADDRESS
if out.is_p2pkh():
txoutput.address = to_address(out.scriptPubKey[3:23], p2pkh_version)
txoutput.script_type = 0
elif out.is_p2sh():
txoutput.address = to_address(out.scriptPubKey[2:22], p2sh_version)
txoutput.script_type = 1
else:
txinputtype.script_type = 4
wit, ver, prog = out.is_witness()
if wit:
txoutput.address = bech32.encode(bech32_hrp, ver, prog)
else:
raise TypeError("Output is not an address")
# Check for 1 key
if len(psbt_in.hd_keypaths) == 1:
# Is this key ours
pubkey = list(psbt_in.hd_keypaths.keys())[0]
fp = psbt_in.hd_keypaths[pubkey][0]
keypath = psbt_in.hd_keypaths[pubkey][1:]
if fp == master_fp:
# Set the keypath
txinputtype.address_n.extend(keypath)
# append to outputs
outputs.append(txoutput)
# Check for multisig (more than 1 key)
elif len(psbt_in.hd_keypaths) > 1:
raise TypeError("Cannot sign multisig yet")
# Sign the transaction
self.client.set_tx_api(TxAPIPSBT(tx))
if self.is_testnet:
signed_tx = self.client.sign_tx("Testnet", inputs, outputs, tx.tx.nVersion, tx.tx.nLockTime)
else:
raise TypeError("All inputs must have a key for this device")
signed_tx = self.client.sign_tx("Bitcoin", inputs, outputs, tx.tx.nVersion, tx.tx.nLockTime)
# Set the amount
if psbt_in.non_witness_utxo:
txinputtype.amount = psbt_in.non_witness_utxo.vout[txin.prevout.n].nValue
elif psbt_in.witness_utxo:
txinputtype.amount = psbt_in.witness_utxo.nValue
# Each input has one signature
for input_num, (psbt_in, sig) in py_enumerate(list(zip(tx.inputs, signed_tx[0]))):
if input_num in to_ignore:
continue
for pubkey in psbt_in.hd_keypaths.keys():
fp = psbt_in.hd_keypaths[pubkey][0]
if fp == master_fp and pubkey not in psbt_in.partial_sigs:
psbt_in.partial_sigs[pubkey] = sig + b'\x01'
break
# append to inputs
inputs.append(txinputtype)
# address version byte
if self.is_testnet:
p2pkh_version = b'\x6f'
p2sh_version = b'\xc4'
else:
p2pkh_version = b'\x00'
p2sh_version = b'\x05'
# prepare outputs
outputs = []
for out in tx.tx.vout:
txoutput = proto.TxOutputType()
txoutput.amount = out.nValue
if out.is_p2pkh():
txoutput.address = to_address(out.scriptPubKey[3:23], p2pkh_version)
txoutput.script_type = 0
elif out.is_p2sh():
txoutput.address = to_address(out.scriptPubKey[2:22], p2sh_version)
txoutput.script_type = 1
else:
# TODO: Figure out what to do here. for now, just break
break
# append to outputs
outputs.append(txoutput)
logging.debug(txoutput)
# Sign the transaction
self.client.set_tx_api(TxAPIPSBT(tx))
if self.is_testnet:
signed_tx = self.client.sign_tx("Testnet", inputs, outputs, tx.tx.nVersion, tx.tx.nLockTime)
else:
signed_tx = self.client.sign_tx("Bitcoin", inputs, outputs, tx.tx.nVersion, tx.tx.nLockTime)
signatures = signed_tx[0]
logging.debug(binascii.hexlify(signed_tx[1]))
for psbt_in in tx.inputs:
for pubkey, sig in zip(psbt_in.hd_keypaths.keys(), signatures):
fp = psbt_in.hd_keypaths[pubkey][0]
keypath = psbt_in.hd_keypaths[pubkey][1:]
if fp == master_fp:
psbt_in.partial_sigs[pubkey] = sig + b'\x01'
p += 1
return {'psbt':tx.serialize()}