improve USB validation

This commit is contained in:
scgbckbone 2026-04-19 10:41:19 +02:00 committed by doc-hex
parent 6fd256dbdc
commit d5aba396a6
4 changed files with 310 additions and 23 deletions

View File

@ -8,6 +8,7 @@ This lists the new changes that have not yet been published in a normal release.
(read more [BIP-322 Proof of Reserves documentation](../docs/proof-of-reserves-bip-322.md) )
- Enhancement: WIF Store export watch-only descriptor
- Enhancement: WIF Store address detection without the need for PSBT_IN_BIP32_DERIVATION (Electrum support)
- Enhancement: Improve USB length validation
- Bugfix: Disable Virtual Disk and NFC before activating HSM
- Bugfix: Custom address default menu position wrong
- Bugfix: Delta Mode Trick PIN was never restored from backup

View File

@ -416,10 +416,12 @@ class USBHandler:
if cmd == 'dwld':
offset, length, fileno = unpack_from('<III', args)
assert len(args) == 12, 'badlen'
return await self.handle_download(offset, length, fileno)
if cmd == 'ncry':
version, his_pubkey = unpack_from('<I64s', args)
assert len(args) == 68, 'badlen'
return self.handle_crypto_setup(version, his_pubkey)
@ -449,9 +451,9 @@ class USBHandler:
if cmd == 'smsg':
# sign message
addr_fmt, len_subpath, len_msg = unpack_from('<III', args)
assert len(args) == (12 + len_subpath + len_msg), 'badlen'
subpath = args[12:12+len_subpath]
msg = args[12+len_subpath:]
assert len(msg) == len_msg, "badlen"
from auth import sign_msg
sign_msg(msg, subpath, addr_fmt)
@ -480,6 +482,7 @@ class USBHandler:
xfp_paths = []
for i in range(N):
assert offset < len(args), 'badlen'
ln = args[offset]
assert 1 <= ln <= 16, 'badlen'
xfp_paths.append(unpack_from('<%dI' % ln, args, offset+1))
@ -495,6 +498,7 @@ class USBHandler:
from auth import usb_show_address
addr_fmt, = unpack_from('<I', args)
assert len(args) >= 4, 'badlen'
# regression patch of AFC_BECH32M flag
# fixed here https://github.com/Coldcard/ckcc-protocol/commit/a6d901f9fca50755835eca895586ca74d0ca81ed
if addr_fmt == 0x17: # old P2TR
@ -506,6 +510,7 @@ class USBHandler:
# - text config file must already be uploaded
file_len, file_sha = unpack_from('<I32s', args)
assert len(args) == 36, 'badlen'
if file_sha != self.file_checksum.digest():
return b'err_Checksum'
assert 100 < file_len <= (20*200), "badlen"
@ -520,12 +525,13 @@ class USBHandler:
# Quick check to test if we have a wallet already installed.
from multisig import MultisigWallet
M, N, xfp_xor = unpack_from('<3I', args)
assert len(args) == 12, 'badlen'
return int(MultisigWallet.quick_check(M, N, xfp_xor))
if cmd == 'stxn':
# sign transaction
txn_len, flags, txn_sha = unpack_from('<II32s', args)
assert len(args) == 40, 'badlen'
if txn_sha != self.file_checksum.digest():
return b'err_Checksum'
@ -595,6 +601,8 @@ class USBHandler:
if cmd == 'rest':
# restore backup from what is already uploaded in PSRAM
file_len, file_sha, bf = unpack_from('<I32sB', args)
assert len(args) == 37, 'badlen'
assert 0 < file_len <= MAX_TXN_LEN, "badlen"
if file_sha != self.file_checksum.digest():
return b'err_Checksum'
@ -617,6 +625,7 @@ class USBHandler:
# HSM mode "start" -- requires user approval
if args:
file_len, file_sha = unpack_from('<I32s', args)
assert len(args) == 36, 'badlen'
if file_sha != self.file_checksum.digest():
return b'err_Checksum'
assert 2 <= file_len <= (200*1000), "badlen"
@ -644,6 +653,8 @@ class USBHandler:
if cmd == 'nwur': # new user
from users import Users
auth_mode, ul, sl = unpack_from('<BBB', args)
assert len(args) == (3 + ul + sl), 'badlen'
assert ul and sl, "badlen"
username = bytes(args[3:3+ul]).decode('ascii')
secret = bytes(args[3+ul:3+ul+sl])
@ -652,6 +663,8 @@ class USBHandler:
if cmd == 'rmur': # delete user
from users import Users
ul, = unpack_from('<B', args)
assert len(args) == (1 + ul), 'badlen'
assert ul, "badlen"
username = bytes(args[1:1+ul]).decode('ascii')
return Users.delete(username)
@ -659,6 +672,8 @@ class USBHandler:
if cmd == 'user': # auth user (HSM mode)
from users import Users
totp_time, ul, tl = unpack_from('<IBB', args)
assert len(args) == (6 + ul + tl), 'badlen'
assert ul and tl, "badlen"
username = bytes(args[6:6+ul]).decode('ascii')
token = bytes(args[6+ul:6+ul+tl])
@ -747,7 +762,8 @@ class USBHandler:
length = min(length, MAX_BLK_LEN)
assert 0 <= file_number < 2, 'bad fnum'
assert 0 <= offset <= MAX_TXN_LEN, "bad offset"
assert 0 <= offset < MAX_TXN_LEN, "bad offset"
assert offset + length <= MAX_TXN_LEN, "bad offset"
assert 1 <= length, 'len'
# maintain a running SHA256 over what's sent
@ -782,7 +798,8 @@ class USBHandler:
dis.progress_sofar(offset, total_size)
assert offset % 256 == 0, 'alignment'
assert offset+len(data) <= total_size <= MAX_UPLOAD_LEN, 'long'
assert 1 <= total_size <= MAX_UPLOAD_LEN, 'long'
assert offset + len(data) <= total_size, 'long'
if hsm_active or pa.hobbled_mode:
# additional restriction in HSM mode or hobbled: must be PSBT

View File

@ -114,15 +114,13 @@ def compute_policy_hash(policy):
return b2a_hex(sha256(json_.encode()).digest()).decode()
@pytest.fixture(autouse=True)
def enable_hsm_commands(dev, sim_exec, is_q1):
def enable_hsm_commands(settings_remove, settings_set, is_q1):
if is_q1:
raise pytest.skip("Q does not have HSM support")
cmd = 'from glob import settings; settings.set("hsmcmd", 1)'
sim_exec(cmd)
settings_set("hsmcmd", 1)
yield
cmd = 'from glob import settings; settings.remove_key("hsmcmd")'
sim_exec(cmd)
settings_remove("hsmcmd")
@pytest.fixture
@ -980,17 +978,19 @@ def test_bip322_por_psbt_uses_msg_sign_policy(quick_start_hsm, change_hsm, attem
attempt_psbt(psbt, "Message signing not permitted")
@pytest.mark.parametrize("M_N", [(2,3),(1,1)]) # TODO verify https://github.com/coinkite/afirmware/pull/653 fixes 1of 1case
def test_bip322_ms_psbt_uses_msg_sign_policy(quick_start_hsm, change_hsm, attempt_psbt,
bip322_ms_txn, import_ms_wallet, clear_ms):
bip322_ms_txn, import_ms_wallet, clear_ms, M_N):
clear_ms()
deriv = "m/48h/1h/0h/2h"
M, N = M_N
def path_mapper(idx):
return [0x80000030, 0x80000001, 0x80000000, 0x80000002, 0, 0]
keys = import_ms_wallet(1, 1, name="hsm_bip322_msg", accept=True, addr_fmt=AF_P2WSH,
keys = import_ms_wallet(M, N, name="hsm_bip322_msg", accept=True, addr_fmt=AF_P2WSH,
common=deriv, do_import=True, descriptor=True)
psbt, _ = bip322_ms_txn(1, 1, keys, path_mapper=path_mapper, inp_af=AF_P2WSH,
psbt, _ = bip322_ms_txn(1, M, keys, path_mapper=path_mapper, inp_af=AF_P2WSH,
msg=b"HSM multisig BIP-322 message")
quick_start_hsm(DICT(msg_paths=[deriv + "/0/0"]))
@ -1689,4 +1689,95 @@ def test_backup_policy_worst(unit_test, start_hsm, load_hsm_users):
start_hsm(policy)
unit_test('devtest/backups.py')
# USB validation for HSM commands (hsmcmd=1 in this module)
def test_nwur_short_args(dev):
msg = b'nwur' + struct.pack('<B', 1)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_nwur_trailing_garbage(dev):
msg = b'nwur' + struct.pack('<BBB', 3, 4, 0) + b'test' + b'\xff'
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_rmur_short_args(dev):
msg = b'rmur'
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_rmur_trailing_garbage(dev):
msg = b'rmur' + struct.pack('<B', 4) + b'test' + b'\xff'
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_user_short_args(dev):
msg = b'user' + struct.pack('<I', 0)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_user_trailing_garbage(dev):
msg = b'user' + struct.pack('<IBB', 0, 4, 6) + b'test' + b'123456' + b'\xff'
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_hsms_short_args(dev):
msg = b'hsms' + struct.pack('<I', 100)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_hsms_trailing_garbage(dev):
msg = b'hsms' + struct.pack('<I', 100) + bytes(32) + b'\xff'
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_nwur_ul_exceeds_payload(dev):
msg = struct.pack('<4sBBB', b'nwur', 1, 10, 0) + b'ab'
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_nwur_invalid_sl(dev):
msg = struct.pack('<4sBBB', b'nwur', 1, 5, 7) + b'alice' + b'x' * 7
with pytest.raises(CCProtoError):
dev.send_recv(msg, encrypt=False)
def test_user_zero_ul(dev):
msg = struct.pack('<4sIBB', b'user', 0, 0, 6) + b'000000'
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_user_zero_tl(dev):
msg = struct.pack('<4sIBB', b'user', 0, 5, 0) + b'alice'
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_user_tl_exceeds_payload(dev):
msg = struct.pack('<4sIBB', b'user', 0, 5, 32) + b'alice' + b'000000'
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_rmur_zero_ul(dev):
msg = struct.pack('<4sB', b'rmur', 0)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_rmur_ul_exceeds_payload(dev):
msg = struct.pack('<4sB', b'rmur', 10) + b'ab'
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
# EOF

View File

@ -4,11 +4,12 @@
#
# - not working well on simulator right now, but that's not key
#
import pytest, struct
import pytest, struct, hashlib, os
from bip32 import BIP32Node
from binascii import b2a_hex
from constants import simulator_fixed_tprv
from ckcc_protocol.protocol import MAX_MSG_LEN, CCProtocolPacker, CCProtoError
from ckcc_protocol.constants import MSG_SIGNING_MAX_LENGTH
@pytest.mark.skip
def test_usb_fuzz(dev):
@ -98,7 +99,7 @@ def test_xpub_invalid(dev, path):
# some bad paths
with pytest.raises(CCProtoError):
xpub = dev.send_recv(CCProtocolPacker.get_xpub(path), timeout=None)
dev.send_recv(CCProtocolPacker.get_xpub(path), timeout=None)
def test_version(dev, is_q1):
@ -120,8 +121,6 @@ def test_version(dev, is_q1):
@pytest.mark.parametrize('data_len', [1, 24, 60, 61, 62, 63, 64, 1000])
def test_upload_short(dev, data_len):
# upload a few really short files
from hashlib import sha256
data = b'a'*data_len
@ -129,7 +128,7 @@ def test_upload_short(dev, data_len):
assert v == 0
chk = dev.send_recv(CCProtocolPacker.sha256())
assert chk == sha256(data).digest(), 'bad hash'
assert chk == hashlib.sha256(data).digest(), 'bad hash'
# clear screen / test a degerate case
dev.send_recv(CCProtocolPacker.upload(256, 256, b''))
@ -137,9 +136,6 @@ def test_upload_short(dev, data_len):
@pytest.mark.parametrize('pkt_len', [256, 1024, 2048])
def test_upload_long(dev, pkt_len, count=5, data=None):
# upload a larger "file"
from hashlib import sha256
import os
data = data or os.urandom(pkt_len * count)
@ -147,7 +143,7 @@ def test_upload_long(dev, pkt_len, count=5, data=None):
v = dev.send_recv(CCProtocolPacker.upload(pos, len(data), data[pos:pos+pkt_len]))
assert v == pos
chk = dev.send_recv(CCProtocolPacker.sha256())
assert chk == sha256(data[0:pos+pkt_len]).digest(), 'bad hash'
assert chk == hashlib.sha256(data[0:pos+pkt_len]).digest(), 'bad hash'
# clear screen / test a degerate case
dev.send_recv(CCProtocolPacker.upload(256, 256, b''))
@ -206,7 +202,6 @@ def test_mitm(dev):
assert dev.mitm_verify(sig2, dev.master_xpub) == False
def test_remote_upload(dev):
import os
dev.upload_file(b'testing')
dev.upload_file(os.urandom(3000))
@ -216,7 +211,6 @@ def test_remote_up_download(f_len, dev, mk_num):
if f_len > (384*1024) and mk_num <= 3:
raise pytest.skip('mk4+ only case')
import os
data = os.urandom(f_len)
ll, sha = dev.upload_file(data, verify=True)
assert ll == len(data) == f_len
@ -224,4 +218,188 @@ def test_remote_up_download(f_len, dev, mk_num):
rb = dev.download_file(ll, sha, file_number=0)
assert rb == data
def test_dwld_offset_at_max(dev, mk_num):
max_txn = 2*1024*1024
msg = struct.pack('<4sIII', b'dwld', max_txn, 1, 1)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'bad offset' in str(e.value)
def test_dwld_offset_one_past_max(dev, mk_num):
max_txn = 2*1024*1024
msg = struct.pack('<4sIII', b'dwld', max_txn + 1, 1, 1)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'bad offset' in str(e.value)
def test_smsg_zero_length_message(dev):
subpath = b'm'
msg = struct.pack('<4sIII', b'smsg', 0x01, len(subpath), 0) + subpath
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'msg too short (min. 2)' in str(e.value)
def test_smsg_oversized_message(dev):
subpath = b'm'
raw_msg = b'a' * (MSG_SIGNING_MAX_LENGTH + 1)
msg = struct.pack('<4sIII', b'smsg', 0x01, len(subpath), len(raw_msg)) + subpath + raw_msg
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'msg too long (max. 240)' in str(e.value)
def test_ncry_invalid_pubkey(dev):
msg = struct.pack('<4sI64s', b'ncry', 0x01, bytes(64))
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'secp256k1_ec_pubkey_parse' in str(e.value)
@pytest.mark.parametrize("file_no", [0, 1])
def test_dwld_oob_psram_read(file_no, dev, mk_num):
max_txn = 2*1024*1024
msg = struct.pack('<4sIII', b'dwld', max_txn - 1, 2, file_no)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'bad offset' in str(e.value)
def test_p2sh_truncated_xfp_paths(dev):
AF_P2SH = 0x08
header = struct.pack('<IBBH', AF_P2SH, 1, 2, 30)
script = bytes(30)
xfp0 = struct.pack('<BI', 1, 0xDEADBEEF) # one uint32
msg = b'p2sh' + header + script + xfp0
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_p2sh_xfp_path_data_too_short(dev):
AF_P2SH = 0x08
header = struct.pack('<IBBH', AF_P2SH, 1, 2, 30)
script = bytes(30)
xfp0 = struct.pack('<BI', 1, 0xDEADBEEF)
xfp1_ln = struct.pack('<B', 2)
msg = b'p2sh' + header + script + xfp0 + xfp1_ln
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_rest_zero_file_len(dev):
empty_sha = hashlib.sha256(b'').digest()
msg = b'rest' + struct.pack('<I32sB', 0, empty_sha, 0)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_rest_oversized_file_len(dev):
empty_sha = hashlib.sha256(b'').digest()
max_txn_len = 2 * 1024 * 1024
msg = b'rest' + struct.pack('<I32sB', max_txn_len + 1, empty_sha, 0)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_upld_zero_total_size(dev):
msg = struct.pack('<4sII', b'upld', 0, 0)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'long' in str(e.value)
def test_upld_short_args(dev):
msg = b'upld' + struct.pack('<I', 0)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_ncry_short_args(dev):
msg = b'ncry' + struct.pack('<I', 1)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_stxn_short_args(dev):
msg = b'stxn' + struct.pack('<II', 100, 0)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_smsg_short_args(dev):
msg = b'smsg' + struct.pack('<II', 0, 5)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_enrl_short_args(dev):
msg = b'enrl' + struct.pack('<I', 200)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_rest_short_args(dev):
msg = b'rest' + struct.pack('<I', 100)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_show_short_args(dev):
msg = b'show'
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_p2sh_short_args(dev):
msg = b'p2sh' + struct.pack('<I', 0x08)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_dwld_short_args(dev):
msg = b'dwld' + struct.pack('<II', 0, 256)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_msck_short_args(dev):
msg = b'msck' + struct.pack('<II', 1, 2)
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'buffer too small' in str(e.value)
def test_dwld_trailing_garbage(dev):
msg = b'dwld' + struct.pack('<III', 0, 256, 0) + b'\xff' # 13 bytes, need exactly 12
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_ncry_trailing_garbage(dev):
msg = b'ncry' + struct.pack('<I', 1) + bytes(64) + b'\xff' # 69 bytes, need exactly 68
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_enrl_trailing_garbage(dev):
msg = b'enrl' + struct.pack('<I', 200) + bytes(32) + b'\xff' # 37 bytes, need exactly 36
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_msck_trailing_garbage(dev):
msg = b'msck' + struct.pack('<III', 1, 2, 0xAB) + b'\xff' # 13 bytes, need exactly 12
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_stxn_trailing_garbage(dev):
msg = b'stxn' + struct.pack('<II', 100, 0) + bytes(32) + b'\xff' # 41 bytes, need exactly 40
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
def test_rest_trailing_garbage(dev):
empty_sha = hashlib.sha256(b'').digest()
msg = b'rest' + struct.pack('<I32sB', 100, empty_sha, 0) + b'\xff' # 38 bytes, need exactly 37
with pytest.raises(CCProtoError) as e:
dev.send_recv(msg, encrypt=False)
assert 'badlen' in str(e.value)
# EOF