Use HWWErrors instead of built-ins

This commit is contained in:
Andrew Chow 2019-01-30 21:29:36 -05:00
parent b4a95a6610
commit 4c3a6df1b7
8 changed files with 73 additions and 52 deletions

View File

@ -4,6 +4,7 @@ from .commands import backup_device, displayaddress, enumerate, find_device, \
get_client, getmasterxpub, getxpub, getkeypool, prompt_pin, restore_device, send_pin, setup_device, \
signmessage, signtx, wipe_device
from .errors import (
HWWError,
NO_DEVICE_PATH,
DEVICE_CONN_ERROR,
NO_PASSWORD,

View File

@ -1,7 +1,7 @@
# Coldcard interaction script
from ..hwwclient import HardwareWalletClient
from ..errors import UnavailableActionError
from ..errors import UnavailableActionError, DeviceFailureError
from ckcc.client import ColdcardDevice, COINKITE_VID, CKCC_PID
from ckcc.protocol import CCProtocolPacker, CCProtoError
from ckcc.constants import MAX_BLK_LEN, AF_P2WPKH, AF_CLASSIC, AF_P2WPKH_P2SH
@ -71,7 +71,7 @@ class ColdcardClient(HardwareWalletClient):
result = self.device.send_recv(CCProtocolPacker.sha256())
assert len(result) == 32
if result != expect:
raise ValueError("Wrong checksum:\nexpect: %s\n got: %s" % (b2a_hex(expect).decode('ascii'), b2a_hex(result).decode('ascii')))
raise DeviceFailureError("Wrong checksum:\nexpect: %s\n got: %s" % (b2a_hex(expect).decode('ascii'), b2a_hex(result).decode('ascii')))
# start the signing process
ok = self.device.send_recv(CCProtocolPacker.sign_transaction(sz, expect), timeout=None)
@ -89,7 +89,7 @@ class ColdcardClient(HardwareWalletClient):
break
if len(done) != 2:
raise ValueError('Failed: %r' % done)
raise DeviceFailureError('Failed: %r' % done)
result_len, result_sha = done
@ -109,7 +109,7 @@ class ColdcardClient(HardwareWalletClient):
if self.device.is_simulator:
self.device.send_recv(CCProtocolPacker.sim_keypress(b'y'))
except CCProtoError as e:
raise ValueError(str(e))
raise DeviceFailureError(str(e))
while 1:
time.sleep(0.250)
@ -120,7 +120,7 @@ class ColdcardClient(HardwareWalletClient):
break
if len(done) != 2:
raise ValueError('Failed: %r' % done)
raise DeviceFailureError('Failed: %r' % done)
addr, raw = done
@ -171,7 +171,7 @@ class ColdcardClient(HardwareWalletClient):
break
if len(done) != 2:
raise ValueError('Failed: %r' % done)
raise DeviceFailureError('Failed: %r' % done)
result_len, result_sha = done

View File

@ -15,7 +15,7 @@ import sys
import time
from ..hwwclient import HardwareWalletClient
from ..errors import NoPasswordError, UnavailableActionError
from ..errors import BadArgumentError, NoPasswordError, UnavailableActionError, DeviceFailureError
from ..serializations import CTransaction, PSBT, hash256, hash160, ser_sig_der, ser_sig_compact, ser_compact_size
from ..base58 import get_xpub_fingerprint, decode, to_address, xpub_main_2_test, get_xpub_fingerprint_hex
@ -76,7 +76,7 @@ def to_string(x, enc):
if isinstance(x, str):
return x
else:
raise TypeError("Not a string or bytes like object")
raise DeviceFailureError("Not a string or bytes like object")
class BitboxSimulator():
def __init__(self, ip, port):
@ -223,7 +223,7 @@ class DigitalbitboxClient(HardwareWalletClient):
# Retrieves the public key at the specified BIP 32 derivation path
def get_pubkey_at_path(self, path):
if '\'' not in path and 'h' not in path and 'H' not in path:
raise ValueError('The digital bitbox requires one part of the derivation path to be derived using hardened keys')
raise BadArgumentError('The digital bitbox requires one part of the derivation path to be derived using hardened keys')
reply = send_encrypt('{"xpub":"' + path + '"}', self.password, self.device)
if 'error' in reply:
return reply
@ -430,7 +430,7 @@ class DigitalbitboxClient(HardwareWalletClient):
# Need a wallet name and backup passphrase
if not label or not passphrase:
raise ValueError('THe label and backup passphrase for a new Digital Bitbox wallet must be specified and cannot be empty')
raise BadArgumentError('THe label and backup passphrase for a new Digital Bitbox wallet must be specified and cannot be empty')
# Set password
to_send = {'password': self.password}

View File

@ -1,7 +1,7 @@
# KeepKey interaction script
from ..hwwclient import HardwareWalletClient
from ..errors import DeviceAlreadyUnlockedError, UnavailableActionError, DeviceNotReadyError
from ..errors import BadArgumentError, DeviceAlreadyUnlockedError, UnavailableActionError, DeviceNotReadyError
from keepkeylib.transport_hid import HidTransport
from keepkeylib.transport_udp import UDPTransport
from keepkeylib.client import BaseClient, DebugWireMixin, DebugLinkMixin, ProtocolMixin, TextUIMixin
@ -70,7 +70,7 @@ class TxAPIPSBT(TxApi):
o.amount = psbt_in.witness_utxo.nValue
o.script_pubkey = psbt_in.witness_utxo.scriptPubKey
else:
raise ValueError('{} is not an input in this transaction'.format(txhash))
raise BadArgumentError('{} is not an input in this transaction'.format(txhash))
return t
@ -293,7 +293,7 @@ class KeepkeyClient(HardwareWalletClient):
if wit:
txoutput.address = bech32.encode(bech32_hrp, ver, prog)
else:
raise TypeError("Output is not an address")
raise BadArgumentError("Output is not an address")
# append to outputs
outputs.append(txoutput)
@ -384,7 +384,7 @@ class KeepkeyClient(HardwareWalletClient):
# Send the pin
def send_pin(self, pin):
if not pin.isdigit():
raise ValueError("Non-numeric PIN provided")
raise BadArgumentError("Non-numeric PIN provided")
resp = self.client.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure):
self.client.features = self.client.call_raw(messages.GetFeatures())

View File

@ -1,7 +1,7 @@
# Trezor interaction script
from ..hwwclient import HardwareWalletClient
from ..errors import DeviceAlreadyInitError, DeviceAlreadyUnlockedError, UnavailableActionError, DeviceNotReadyError
from ..errors import BadArgumentError, DeviceAlreadyInitError, DeviceAlreadyUnlockedError, UnavailableActionError, DeviceNotReadyError
from trezorlib.client import TrezorClient as Trezor
from trezorlib.debuglink import TrezorClientDebugLink
from trezorlib.transport import enumerate_devices, get_transport
@ -226,7 +226,7 @@ class TrezorClient(HardwareWalletClient):
if wit:
txoutput.address = bech32.encode(bech32_hrp, ver, prog)
else:
raise TypeError("Output is not an address")
raise BadArgumentError("Output is not an address")
# append to outputs
outputs.append(txoutput)
@ -352,7 +352,7 @@ class TrezorClient(HardwareWalletClient):
if self.client.features.pin_cached:
raise DeviceAlreadyUnlockedError('The PIN has already been sent to this device')
if not pin.isdigit():
raise ValueError("Non-numeric PIN provided")
raise BadArgumentError("Non-numeric PIN provided")
resp = self.client.call_raw(proto.PinMatrixAck(pin=pin))
if isinstance(resp, proto.Failure):
return {'success': False}

View File

@ -12,6 +12,8 @@ NOT_IMPLEMENTED = -8
UNAVAILABLE_ACTION = -9
DEVICE_ALREADY_INIT = -10
DEVICE_ALREADY_UNLOCKED = -11
DEVICE_NOT_READY = -12
UNKNOWN_ERROR = -13
# Exceptions
class HWWError(Exception):
@ -52,3 +54,19 @@ class DeviceAlreadyUnlockedError(HWWError):
class UnknownDeviceError(HWWError):
def __init__(self, msg):
HWWError.__init__(self, msg, UNKNWON_DEVICE_TYPE)
class NotImplementedError(HWWError):
def __init__(self, msg):
HWWError.__init__(self, msg, NOT_IMPLEMENTED)
class PSBTSerializationError(HWWError):
def __init__(self, msg):
HWWError.__init__(self, msg, INVALID_TX)
class BadArgumentError(HWWError):
def __init__(self, msg):
HWWError.__init__(self, msg, BAD_ARGUMENT)
class DeviceFailureError(HWWError):
def __init__(self, msg):
HWWError.__init__(self, msg, UNKNOWN_ERROR)

View File

@ -17,6 +17,7 @@ ser_*, deser_*: functions that handle serialization/deserialization
from io import BytesIO, BufferedReader
from codecs import encode
from .errors import PSBTSerializationError
import struct
import binascii
import hashlib
@ -493,10 +494,10 @@ class CTransaction(object):
def DeserializeHDKeypath(f, key, hd_keypaths):
if len(key) != 34 and len(key) != 66:
raise IOError("Size of key was not the expected size for the type partial signature pubkey")
raise PSBTSerializationError("Size of key was not the expected size for the type partial signature pubkey")
pubkey = key[1:]
if pubkey in hd_keypaths:
raise IOError("Duplicate key, input partial signature for pubkey already provided")
raise PSBTSerializationError("Duplicate key, input partial signature for pubkey already provided")
value = deser_string(f)
hd_keypaths[pubkey] = struct.unpack("<" + "I" * (len(value) // 4), value)
@ -551,9 +552,9 @@ class PartiallySignedInput:
if key_type == 0:
if self.non_witness_utxo:
raise IOError("Duplicate Key, input non witness utxo already provided")
raise PSBTSerializationError("Duplicate Key, input non witness utxo already provided")
elif len(key) != 1:
raise IOError("non witness utxo key is more than one byte type")
raise PSBTSerializationError("non witness utxo key is more than one byte type")
self.non_witness_utxo = CTransaction()
value = BufferedReader(BytesIO(deser_string(f)))
self.non_witness_utxo.deserialize(value)
@ -561,43 +562,43 @@ class PartiallySignedInput:
elif key_type == 1:
if self.witness_utxo:
raise IOError("Duplicate Key, input witness utxo already provided")
raise PSBTSerializationError("Duplicate Key, input witness utxo already provided")
elif len(key) != 1:
raise IOError("witness utxo key is more than one byte type")
raise PSBTSerializationError("witness utxo key is more than one byte type")
self.witness_utxo = CTxOut()
value = BufferedReader(BytesIO(deser_string(f)))
self.witness_utxo.deserialize(value)
elif key_type == 2:
if len(key) != 34 and len(key) != 66:
raise IOError("Size of key was not the expected size for the type partial signature pubkey")
raise PSBTSerializationError("Size of key was not the expected size for the type partial signature pubkey")
pubkey = key[1:]
if pubkey in self.partial_sigs:
raise IOError("Duplicate key, input partial signature for pubkey already provided")
raise PSBTSerializationError("Duplicate key, input partial signature for pubkey already provided")
sig = deser_string(f)
self.partial_sigs[pubkey] = sig;
elif key_type == 3:
if self.sighash > 0:
raise IOError("Duplicate key, input sighash type already provided")
raise PSBTSerializationError("Duplicate key, input sighash type already provided")
elif len(key) != 1:
raise IOError("sighash key is more than one byte type")
raise PSBTSerializationError("sighash key is more than one byte type")
value = deser_string(f)
self.sighash = struct.unpack("<I", value)[0]
elif key_type == 4:
if len(self.redeem_script) != 0:
raise IOError("Duplicate key, input redeemScript already provided")
raise PSBTSerializationError("Duplicate key, input redeemScript already provided")
elif len(key) != 1:
raise IOError("redeemScript key is more than one byte type")
raise PSBTSerializationError("redeemScript key is more than one byte type")
self.redeem_script = deser_string(f)
elif key_type == 5:
if len(self.witness_script) != 0:
raise IOError("Duplicate key, input witnessScript already provided")
raise PSBTSerializationError("Duplicate key, input witnessScript already provided")
elif len(key) != 1:
raise IOError("witnessScript key is more than one byte type")
raise PSBTSerializationError("witnessScript key is more than one byte type")
self.witness_script = deser_string(f)
elif key_type == 6:
@ -605,22 +606,22 @@ class PartiallySignedInput:
elif key_type == 7:
if len(self.final_script_sig) != 0:
raise IOError("Duplicate key, input final scriptSig already provided")
raise PSBTSerializationError("Duplicate key, input final scriptSig already provided")
elif len(key) != 1:
raise IOError("final scriptSig key is more than one byte type")
raise PSBTSerializationError("final scriptSig key is more than one byte type")
self.final_script_sig = deser_string(f)
elif key_type == 8:
if not self.final_script_witness.is_null():
raise IOError("Duplicate key, input final scriptWitness already provided")
raise PSBTSerializationError("Duplicate key, input final scriptWitness already provided")
elif len(key) != 1:
raise IOError("final scriptWitness key is more than one byte type")
raise PSBTSerializationError("final scriptWitness key is more than one byte type")
value = BufferedReader(BytesIO(deser_string(f)))
self.final_script_witness.deserialize(value)
else:
if key in self.unknown:
raise IOError("Duplicate key, key for unknown value already provided")
raise PSBTSerializationError("Duplicate key, key for unknown value already provided")
value = deser_string(f)
self.unknown[key] = value
@ -712,16 +713,16 @@ class PartiallySignedOutput:
if key_type == 0:
if len(self.redeem_script) != 0:
raise IOError("Duplicate key, output redeemScript already provided")
raise PSBTSerializationError("Duplicate key, output redeemScript already provided")
elif len(key) != 1:
raise IOError("Output redeemScript key is more than one byte type")
raise PSBTSerializationError("Output redeemScript key is more than one byte type")
self.redeem_script = deser_string(f)
elif key_type == 1:
if len(self.witness_script) != 0:
raise IOError("Duplicate key, output witnessScript already provided")
raise PSBTSerializationError("Duplicate key, output witnessScript already provided")
elif len(key) != 1:
raise IOError("Output witnessScript key is more than one byte type")
raise PSBTSerializationError("Output witnessScript key is more than one byte type")
self.witness_script = deser_string(f)
elif key_type == 2:
@ -729,7 +730,7 @@ class PartiallySignedOutput:
else:
if key in self.unknown:
raise IOError("Duplicate key, key for unknown value already provided")
raise PSBTSerializationError("Duplicate key, key for unknown value already provided")
value = deser_string(f)
self.unknown[key] = value
@ -772,7 +773,7 @@ class PSBT(object):
# Read the magic bytes
magic = f.read(5)
if magic != b"psbt\xff":
raise IOError("invalid magic")
raise PSBTSerializationError("invalid magic")
# Read loop
separators = 0
@ -796,9 +797,9 @@ class PSBT(object):
if key_type == 0x00:
# Checks for correctness
if not self.tx.is_null:
raise IOError("Duplicate key, unsigned tx already provided")
raise PSBTSerializationError("Duplicate key, unsigned tx already provided")
elif len(key) > 1:
raise IOError("Global unsigned tx key is more than one byte type")
raise PSBTSerializationError("Global unsigned tx key is more than one byte type")
# read in value
value = BufferedReader(BytesIO(deser_string(f)))
@ -807,17 +808,17 @@ class PSBT(object):
# Make sure that all scriptSigs and scriptWitnesses are empty
for txin in self.tx.vin:
if len(txin.scriptSig) != 0 or not self.tx.wit.is_null():
raise IOError("Unsigned tx does not have empty scriptSigs and scriptWitnesses")
raise PSBTSerializationError("Unsigned tx does not have empty scriptSigs and scriptWitnesses")
else:
if key in self.unknown:
raise IOError("Duplicate key, key for unknown value already provided")
raise PSBTSerializationError("Duplicate key, key for unknown value already provided")
value = deser_string(f)
self.unknown[key] = value
# make sure that we got an unsigned tx
if self.tx.is_null():
raise IOError("No unsigned trasaction was provided")
raise PSBTSerializationError("No unsigned trasaction was provided")
# Read input data
for txin in self.tx.vin:
@ -828,10 +829,10 @@ class PSBT(object):
self.inputs.append(input)
if input.non_witness_utxo and input.non_witness_utxo.rehash() and input.non_witness_utxo.sha256 != txin.prevout.sha256:
raise IOError("Non-witness UTXO does not match outpoint hash")
raise PSBTSerializationError("Non-witness UTXO does not match outpoint hash")
if (len(self.inputs) != len(self.tx.vin)):
raise IOError("Inputs provided does not match the number of inputs in transaction")
raise PSBTSerializationError("Inputs provided does not match the number of inputs in transaction")
# Read output data
for txout in self.tx.vout:
@ -842,10 +843,10 @@ class PSBT(object):
self.outputs.append(output)
if len(self.outputs) != len(self.tx.vout):
raise IOError("Outputs provided does not match the number of outputs in transaction")
raise PSBTSerializationError("Outputs provided does not match the number of outputs in transaction")
if not self.is_sane():
raise IOError("PSBT is not sane")
raise PSBTSerializationError("PSBT is not sane")
def serialize(self):
r = b""

View File

@ -1,6 +1,7 @@
#! /usr/bin/env python3
from hwilib.serializations import PSBT
from hwilib.errors import PSBTSerializationError
import json
import os
import unittest
@ -15,7 +16,7 @@ class TestPSBT(unittest.TestCase):
def test_invalid_psbt(self):
for invalid in self.data['invalid']:
with self.subTest(invalid=invalid):
with self.assertRaises(IOError) as cm:
with self.assertRaises(PSBTSerializationError) as cm:
psbt = PSBT()
psbt.deserialize(invalid)