diff --git a/shared/psbt.py b/shared/psbt.py index 0a3da635..31bc40dd 100644 --- a/shared/psbt.py +++ b/shared/psbt.py @@ -5,6 +5,7 @@ import stash, gc, history, sys, ngu, ckcc, chains from ustruct import unpack_from, unpack, pack from ubinascii import hexlify as b2a_hex +from ucollections import OrderedDict from utils import xfp2str, B2A, keypath_to_str from utils import seconds2human_readable, datetime_from_timestamp, datetime_to_str from chains import NLOCK_IS_TIME @@ -1761,7 +1762,21 @@ class psbtObject(psbtProxy): foreign = [] total_in = 0 + prevouts_max_len = 50 + prevouts = OrderedDict() + for i, txi in self.input_iter(): + # check for duplicate inputs + if len(prevouts) >= prevouts_max_len: + first = next(iter(prevouts)) # O(1) in mpy + del prevouts[first] + + k = txi.prevout.n, txi.prevout.hash + if k in prevouts: # O(1) + raise FatalPSBTIssue("Duplicate inputs!") + + prevouts[k] = True + inp = self.inputs[i] if inp.fully_signed: self.presigned_inputs.add(i) diff --git a/testing/test_sign.py b/testing/test_sign.py index 06f10e1e..d98540b1 100644 --- a/testing/test_sign.py +++ b/testing/test_sign.py @@ -3646,4 +3646,17 @@ def test_txn_nVersion_zero(segwit, fake_txn, start_sign, cap_story, goto_home): assert title == "Failure" assert "txn version" in story +@pytest.mark.parametrize("segwit_in", [True, False]) +@pytest.mark.parametrize("num_ins", [2, 50, 51]) +def test_duplicate_inputs(segwit_in, num_ins, fake_txn, start_sign, end_sign, cap_story): + psbt = fake_txn(num_ins, 2, segwit_in=segwit_in, dupe_ins=[num_ins-1]) + start_sign(psbt) + title, story = cap_story() + if num_ins <= 50: + # only works if duplicate is no longer than 50 inputs from original + assert "Duplicate inputs!" in story + else: + # does not work + assert title == "OK TO SEND?" + # EOF diff --git a/testing/txn.py b/testing/txn.py index 3f7937d4..17a75512 100644 --- a/testing/txn.py +++ b/testing/txn.py @@ -25,7 +25,11 @@ def fake_txn(dev, pytestconfig): outstyles=['p2pkh'], psbt_hacker=None, change_outputs=[], capture_scripts=None, add_xpub=None, op_return=None, taproot_in=False, psbt_v2=None, input_amount=1E8, unknown_out_script=None, lock_time=0, - sequences=None, sighashes=None): + sequences=None, sighashes=None, dupe_ins=[]): + + # dupe_ins cannot contain zero, as that will be the duplicated input + if dupe_ins: + assert 0 not in dupe_ins psbt = BasicPSBT() @@ -58,12 +62,17 @@ def fake_txn(dev, pytestconfig): # - each input is 1BTC # addr where the fake money will be stored. - subkey = mk.subkey_for_path(subpath % i) + if i in dupe_ins: + # always duplicate zeroth input + subkey = mk.subkey_for_path(subpath % 0) + else: + subkey = mk.subkey_for_path(subpath % i) sec = subkey.sec() assert len(sec) == 33, "expect compressed" assert subpath[0:2] == '0/' - psbt.inputs[i].bip32_paths[sec] = xfp + struct.pack('