Replace trezorlib with our own stripped down version

This commit is contained in:
Andrew Chow 2019-02-07 22:31:00 -05:00
parent 754adc93dc
commit aa3665c971
98 changed files with 5103 additions and 14 deletions

View File

@ -2,13 +2,13 @@
from ..hwwclient import HardwareWalletClient
from ..errors import ActionCanceledError, BadArgumentError, DeviceAlreadyInitError, DeviceAlreadyUnlockedError, DeviceConnectionError, UnavailableActionError, DeviceNotReadyError
from trezorlib.client import TrezorClient as Trezor
from trezorlib.debuglink import DebugLink, DebugUI, TrezorClientDebugLink
from trezorlib.exceptions import Cancelled
from trezorlib.transport import enumerate_devices, get_transport
from trezorlib.ui import ClickUI, mnemonic_words, PIN_MATRIX_DESCRIPTION
from trezorlib import protobuf, tools, btc, device
from trezorlib import messages as proto
from .trezorlib.client import TrezorClient as Trezor
from .trezorlib.debuglink import DebugLink, DebugUI, TrezorClientDebugLink
from .trezorlib.exceptions import Cancelled
from .trezorlib.transport import enumerate_devices, get_transport
from .trezorlib.ui import PassphraseUI, mnemonic_words, PIN_MATRIX_DESCRIPTION
from .trezorlib import protobuf, tools, btc, device
from .trezorlib import messages as proto
from ..base58 import get_xpub_fingerprint, decode, to_address, xpub_main_2_test, get_xpub_fingerprint_hex
from ..serializations import ser_uint256, uint256_from_str
from .. import bech32
@ -99,14 +99,13 @@ class TrezorClient(HardwareWalletClient):
transport = get_transport(path)
self.client = TrezorDebugNoInit(transport=transport)
else:
self.client = TrezorNoInit(transport=get_transport(path), ui=ClickUI())
self.client = TrezorNoInit(transport=get_transport(path), ui=PassphraseUI(password))
# if it wasn't able to find a client, throw an error
if not self.client:
raise IOError("no Device")
self.password = password
os.environ['PASSPHRASE'] = password
self.client.open()
def _check_unlocked(self):

View File

@ -0,0 +1,11 @@
# Python Trezor Library
This is a stripped down version of the official [python-trezor](https://github.com/trezor/python-trezor) library.
This stripped down version was made at commit [d5c2636f0d1b7da3cb94a4eff6169d77f58cefc1](https://github.com/trezor/python-trezor/tree/d5c2636f0d1b7da3cb94a4eff6169d77f58cefc1).
## Changes
- Removed altcoin support
- Include the compiled protobuf definitions instead of making them on install
- Removed functions that HWI does not use or plan to use

View File

@ -0,0 +1,8 @@
__version__ = "0.11.1"
# fmt: off
MINIMUM_FIRMWARE_VERSION = {
"1": (1, 6, 1),
"T": (2, 0, 10),
}
# fmt: on

View File

@ -0,0 +1,172 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from . import messages
from .tools import CallException, expect, normalize_nfc, session
@expect(messages.PublicKey)
def get_public_node(
client,
n,
ecdsa_curve_name=None,
show_display=False,
coin_name=None,
script_type=messages.InputScriptType.SPENDADDRESS,
):
return client.call(
messages.GetPublicKey(
address_n=n,
ecdsa_curve_name=ecdsa_curve_name,
show_display=show_display,
coin_name=coin_name,
script_type=script_type,
)
)
@expect(messages.Address, field="address")
def get_address(
client,
coin_name,
n,
show_display=False,
multisig=None,
script_type=messages.InputScriptType.SPENDADDRESS,
):
return client.call(
messages.GetAddress(
address_n=n,
coin_name=coin_name,
show_display=show_display,
multisig=multisig,
script_type=script_type,
)
)
@expect(messages.MessageSignature)
def sign_message(
client, coin_name, n, message, script_type=messages.InputScriptType.SPENDADDRESS
):
message = normalize_nfc(message)
return client.call(
messages.SignMessage(
coin_name=coin_name, address_n=n, message=message, script_type=script_type
)
)
@session
def sign_tx(client, coin_name, inputs, outputs, details=None, prev_txes=None):
# set up a transactions dict
txes = {None: messages.TransactionType(inputs=inputs, outputs=outputs)}
# preload all relevant transactions ahead of time
for inp in inputs:
if inp.script_type not in (
messages.InputScriptType.SPENDP2SHWITNESS,
messages.InputScriptType.SPENDWITNESS,
messages.InputScriptType.EXTERNAL,
):
try:
prev_tx = prev_txes[inp.prev_hash]
except Exception as e:
raise ValueError("Could not retrieve prev_tx") from e
if not isinstance(prev_tx, messages.TransactionType):
raise ValueError("Invalid value for prev_tx") from None
txes[inp.prev_hash] = prev_tx
if details is None:
signtx = messages.SignTx()
else:
signtx = details
signtx.coin_name = coin_name
signtx.inputs_count = len(inputs)
signtx.outputs_count = len(outputs)
res = client.call(signtx)
# Prepare structure for signatures
signatures = [None] * len(inputs)
serialized_tx = b""
def copy_tx_meta(tx):
tx_copy = messages.TransactionType()
tx_copy.CopyFrom(tx)
# clear fields
tx_copy.inputs_cnt = len(tx.inputs)
tx_copy.inputs = []
tx_copy.outputs_cnt = len(tx.bin_outputs or tx.outputs)
tx_copy.outputs = []
tx_copy.bin_outputs = []
tx_copy.extra_data_len = len(tx.extra_data or b"")
tx_copy.extra_data = None
return tx_copy
R = messages.RequestType
while isinstance(res, messages.TxRequest):
# If there's some part of signed transaction, let's add it
if res.serialized:
if res.serialized.serialized_tx:
serialized_tx += res.serialized.serialized_tx
if res.serialized.signature_index is not None:
idx = res.serialized.signature_index
sig = res.serialized.signature
if signatures[idx] is not None:
raise ValueError("Signature for index %d already filled" % idx)
signatures[idx] = sig
if res.request_type == R.TXFINISHED:
break
# Device asked for one more information, let's process it.
current_tx = txes[res.details.tx_hash]
if res.request_type == R.TXMETA:
msg = copy_tx_meta(current_tx)
res = client.call(messages.TxAck(tx=msg))
elif res.request_type == R.TXINPUT:
msg = messages.TransactionType()
msg.inputs = [current_tx.inputs[res.details.request_index]]
res = client.call(messages.TxAck(tx=msg))
elif res.request_type == R.TXOUTPUT:
msg = messages.TransactionType()
if res.details.tx_hash:
msg.bin_outputs = [current_tx.bin_outputs[res.details.request_index]]
else:
msg.outputs = [current_tx.outputs[res.details.request_index]]
res = client.call(messages.TxAck(tx=msg))
elif res.request_type == R.TXEXTRADATA:
o, l = res.details.extra_data_offset, res.details.extra_data_len
msg = messages.TransactionType()
msg.extra_data = current_tx.extra_data[o : o + l]
res = client.call(messages.TxAck(tx=msg))
if isinstance(res, messages.Failure):
raise CallException("Signing failed")
if not isinstance(res, messages.TxRequest):
raise CallException("Unexpected message")
if None in signatures:
raise RuntimeError("Some signatures are missing!")
return signatures, serialized_tx

View File

@ -0,0 +1,247 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import sys
import warnings
from mnemonic import Mnemonic
from . import MINIMUM_FIRMWARE_VERSION, exceptions, messages, tools
if sys.version_info.major < 3:
raise Exception("Trezorlib does not support Python 2 anymore.")
LOG = logging.getLogger(__name__)
VENDORS = ("bitcointrezor.com", "trezor.io")
MAX_PASSPHRASE_LENGTH = 50
DEPRECATION_ERROR = """
Incompatible Trezor library detected.
(Original error: {})
""".strip()
OUTDATED_FIRMWARE_ERROR = """
Your Trezor firmware is out of date. Update it with the following command:
trezorctl firmware-update
Or visit https://wallet.trezor.io/
""".strip()
def get_buttonrequest_value(code):
# Converts integer code to its string representation of ButtonRequestType
return [
k
for k in dir(messages.ButtonRequestType)
if getattr(messages.ButtonRequestType, k) == code
][0]
class TrezorClient:
"""Trezor client, a connection to a Trezor device.
This class allows you to manage connection state, send and receive protobuf
messages, handle user interactions, and perform some generic tasks
(send a cancel message, initialize or clear a session, ping the device).
You have to provide a transport, i.e., a raw connection to the device. You can use
`trezorlib.transport.get_transport` to find one.
You have to provide an UI implementation for the three kinds of interaction:
- button request (notify the user that their interaction is needed)
- PIN request (on T1, ask the user to input numbers for a PIN matrix)
- passphrase request (ask the user to enter a passphrase)
See `trezorlib.ui` for details.
You can supply a `state` you saved in the previous session. If you do,
the user might not need to enter their passphrase again.
"""
def __init__(self, transport, ui=None, state=None):
LOG.info("creating client instance for device: {}".format(transport.get_path()))
self.transport = transport
self.ui = ui
self.state = state
if ui is None:
warnings.warn("UI class not supplied. This will probably crash soon.")
self.session_counter = 0
self.init_device()
def open(self):
if self.session_counter == 0:
self.transport.begin_session()
self.session_counter += 1
def close(self):
if self.session_counter == 1:
self.transport.end_session()
self.session_counter -= 1
def cancel(self):
self._raw_write(messages.Cancel())
def call_raw(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
self._raw_write(msg)
return self._raw_read()
def _raw_write(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
self.transport.write(msg)
def _raw_read(self):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
return self.transport.read()
def _callback_pin(self, msg):
try:
pin = self.ui.get_pin(msg.type)
except exceptions.Cancelled:
self.call_raw(messages.Cancel())
raise
if not pin.isdigit():
self.call_raw(messages.Cancel())
raise ValueError("Non-numeric PIN provided")
resp = self.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise exceptions.PinException(resp.code, resp.message)
else:
return resp
def _callback_passphrase(self, msg):
if msg.on_device:
passphrase = None
else:
try:
passphrase = self.ui.get_passphrase()
except exceptions.Cancelled:
self.call_raw(messages.Cancel())
raise
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
self.call_raw(messages.Cancel())
raise ValueError("Passphrase too long")
resp = self.call_raw(messages.PassphraseAck(passphrase=passphrase))
if isinstance(resp, messages.PassphraseStateRequest):
self.state = resp.state
return self.call_raw(messages.PassphraseStateAck())
else:
return resp
def _callback_button(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
self._raw_write(messages.ButtonAck())
self.ui.button_request(msg.code)
return self._raw_read()
@tools.session
def call(self, msg):
self.check_firmware_version()
resp = self.call_raw(msg)
while True:
if isinstance(resp, messages.PinMatrixRequest):
resp = self._callback_pin(resp)
elif isinstance(resp, messages.PassphraseRequest):
resp = self._callback_passphrase(resp)
elif isinstance(resp, messages.ButtonRequest):
resp = self._callback_button(resp)
elif isinstance(resp, messages.Failure):
if resp.code == messages.FailureType.ActionCancelled:
raise exceptions.Cancelled
raise exceptions.TrezorFailure(resp)
else:
return resp
@tools.session
def init_device(self):
resp = self.call_raw(messages.Initialize(state=self.state))
if not isinstance(resp, messages.Features):
raise exceptions.TrezorException("Unexpected initial response")
else:
self.features = resp
if self.features.vendor not in VENDORS:
raise RuntimeError("Unsupported device")
# A side-effect of this is a sanity check for broken protobuf definitions.
# If the `vendor` field doesn't exist, you probably have a mismatched
# checkout of trezor-common.
self.version = (
self.features.major_version,
self.features.minor_version,
self.features.patch_version,
)
self.check_firmware_version(warn_only=True)
def is_outdated(self):
if self.features.bootloader_mode:
return False
model = self.features.model or "1"
required_version = MINIMUM_FIRMWARE_VERSION[model]
return self.version < required_version
def check_firmware_version(self, warn_only=False):
if self.is_outdated():
if warn_only:
warnings.warn(OUTDATED_FIRMWARE_ERROR, stacklevel=2)
else:
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
@tools.expect(messages.Success, field="message")
def ping(
self,
msg,
button_protection=False,
pin_protection=False,
passphrase_protection=False,
):
# We would like ping to work on any valid TrezorClient instance, but
# due to the protection modes, we need to go through self.call, and that will
# raise an exception if the firmware is too old.
# So we short-circuit the simplest variant of ping with call_raw.
if not button_protection and not pin_protection and not passphrase_protection:
# XXX this should be: `with self:`
try:
self.open()
return self.call_raw(messages.Ping(message=msg))
finally:
self.close()
msg = messages.Ping(
message=msg,
button_protection=button_protection,
pin_protection=pin_protection,
passphrase_protection=passphrase_protection,
)
return self.call(msg)
def get_device_id(self):
return self.features.device_id
@tools.expect(messages.Success, field="message")
@tools.session
def clear_session(self):
return self.call_raw(messages.ClearSession())

View File

@ -0,0 +1,505 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from copy import deepcopy
from mnemonic import Mnemonic
from . import messages as proto, protobuf, tools
from .client import TrezorClient
from .tools import expect
EXPECTED_RESPONSES_CONTEXT_LINES = 3
class DebugLink:
def __init__(self, transport, auto_interact=True):
self.transport = transport
self.allow_interactions = auto_interact
def open(self):
self.transport.begin_session()
def close(self):
self.transport.end_session()
def _call(self, msg, nowait=False):
self.transport.write(msg)
if nowait:
return None
ret = self.transport.read()
return ret
def state(self):
return self._call(proto.DebugLinkGetState())
def read_pin(self):
state = self.state()
return state.pin, state.matrix
def read_pin_encoded(self):
return self.encode_pin(*self.read_pin())
def encode_pin(self, pin, matrix=None):
"""Transform correct PIN according to the displayed matrix."""
if matrix is None:
_, matrix = self.read_pin()
return "".join([str(matrix.index(p) + 1) for p in pin])
def read_layout(self):
obj = self._call(proto.DebugLinkGetState())
return obj.layout
def read_mnemonic(self):
obj = self._call(proto.DebugLinkGetState())
return obj.mnemonic
def read_recovery_word(self):
obj = self._call(proto.DebugLinkGetState())
return (obj.recovery_fake_word, obj.recovery_word_pos)
def read_reset_word(self):
obj = self._call(proto.DebugLinkGetState())
return obj.reset_word
def read_reset_word_pos(self):
obj = self._call(proto.DebugLinkGetState())
return obj.reset_word_pos
def read_reset_entropy(self):
obj = self._call(proto.DebugLinkGetState())
return obj.reset_entropy
def read_passphrase_protection(self):
obj = self._call(proto.DebugLinkGetState())
return obj.passphrase_protection
def input(self, word=None, button=None, swipe=None):
if not self.allow_interactions:
return
decision = proto.DebugLinkDecision()
if button is not None:
decision.yes_no = button
elif word is not None:
decision.input = word
elif swipe is not None:
decision.up_down = swipe
else:
raise ValueError("You need to provide input data.")
self._call(decision, nowait=True)
def press_button(self, yes_no):
self._call(proto.DebugLinkDecision(yes_no=yes_no), nowait=True)
def press_yes(self):
self.input(button=True)
def press_no(self):
self.input(button=False)
def swipe_up(self):
self.input(swipe=True)
def swipe_down(self):
self.input(swipe=False)
def stop(self):
self._call(proto.DebugLinkStop(), nowait=True)
@expect(proto.DebugLinkMemory, field="memory")
def memory_read(self, address, length):
return self._call(proto.DebugLinkMemoryRead(address=address, length=length))
def memory_write(self, address, memory, flash=False):
self._call(
proto.DebugLinkMemoryWrite(address=address, memory=memory, flash=flash),
nowait=True,
)
def flash_erase(self, sector):
self._call(proto.DebugLinkFlashErase(sector=sector), nowait=True)
class NullDebugLink(DebugLink):
def __init__(self):
super().__init__(None)
def open(self):
pass
def close(self):
pass
def _call(self, msg, nowait=False):
if not nowait:
if isinstance(msg, proto.DebugLinkGetState):
return proto.DebugLinkState()
else:
raise RuntimeError("unexpected call to a fake debuglink")
class DebugUI:
INPUT_FLOW_DONE = object()
def __init__(self, debuglink: DebugLink):
self.debuglink = debuglink
self.pin = None
self.passphrase = "sphinx of black quartz, judge my wov"
self.input_flow = None
def button_request(self, code):
if self.input_flow is None:
self.debuglink.press_yes()
elif self.input_flow is self.INPUT_FLOW_DONE:
raise AssertionError("input flow ended prematurely")
else:
try:
self.input_flow.send(code)
except StopIteration:
self.input_flow = self.INPUT_FLOW_DONE
def get_pin(self, code=None):
if self.pin:
return self.pin
else:
return self.debuglink.read_pin_encoded()
def get_passphrase(self):
return self.passphrase
class TrezorClientDebugLink(TrezorClient):
# This class implements automatic responses
# and other functionality for unit tests
# for various callbacks, created in order
# to automatically pass unit tests.
#
# This mixing should be used only for purposes
# of unit testing, because it will fail to work
# without special DebugLink interface provided
# by the device.
def __init__(self, transport, auto_interact=True):
try:
debug_transport = transport.find_debug()
self.debug = DebugLink(debug_transport, auto_interact)
except Exception:
if not auto_interact:
self.debug = NullDebugLink()
else:
raise
self.ui = DebugUI(self.debug)
self.in_with_statement = 0
self.screenshot_id = 0
self.filters = {}
# Always press Yes and provide correct pin
self.setup_debuglink(True, True)
# Do not expect any specific response from device
self.expected_responses = None
self.current_response = None
# Use blank passphrase
self.set_passphrase("")
super().__init__(transport, ui=self.ui)
def open(self):
super().open()
self.debug.open()
def close(self):
self.debug.close()
super().close()
def set_filter(self, message_type, callback):
self.filters[message_type] = callback
def _filter_message(self, msg):
message_type = msg.__class__
callback = self.filters.get(message_type)
if callable(callback):
return callback(deepcopy(msg))
else:
return msg
def set_input_flow(self, input_flow):
if input_flow is None:
self.ui.input_flow = None
return
if callable(input_flow):
input_flow = input_flow()
if not hasattr(input_flow, "send"):
raise RuntimeError("input_flow should be a generator function")
self.ui.input_flow = input_flow
next(input_flow) # can't send before first yield
def __enter__(self):
# For usage in with/expected_responses
self.in_with_statement += 1
return self
def __exit__(self, _type, value, traceback):
self.in_with_statement -= 1
if _type is not None:
# Another exception raised
return False
if self.expected_responses is None:
# no need to check anything else
return False
# return isinstance(value, TypeError)
# Evaluate missed responses in 'with' statement
if self.current_response < len(self.expected_responses):
self._raise_unexpected_response(None)
# Cleanup
self.expected_responses = None
self.current_response = None
return False
def set_expected_responses(self, expected):
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
self.expected_responses = expected
self.current_response = 0
def setup_debuglink(self, button, pin_correct):
# self.button = button # True -> YES button, False -> NO button
if pin_correct:
self.ui.pin = None
else:
self.ui.pin = "444222"
def set_passphrase(self, passphrase):
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
def set_mnemonic(self, mnemonic):
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
def _raw_read(self):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# if SCREENSHOT and self.debug:
# from PIL import Image
# layout = self.debug.state().layout
# im = Image.new("RGB", (128, 64))
# pix = im.load()
# for x in range(128):
# for y in range(64):
# rx, ry = 127 - x, 63 - y
# if (ord(layout[rx + (ry / 8) * 128]) & (1 << (ry % 8))) > 0:
# pix[x, y] = (255, 255, 255)
# im.save("scr%05d.png" % self.screenshot_id)
# self.screenshot_id += 1
resp = super()._raw_read()
resp = self._filter_message(resp)
self._check_request(resp)
return resp
def _raw_write(self, msg):
return super()._raw_write(self._filter_message(msg))
def _raise_unexpected_response(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
start_at = max(self.current_response - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
stop_at = min(
self.current_response + EXPECTED_RESPONSES_CONTEXT_LINES + 1,
len(self.expected_responses),
)
output = []
output.append("Expected responses:")
if start_at > 0:
output.append(" (...{} previous responses omitted)".format(start_at))
for i in range(start_at, stop_at):
exp = self.expected_responses[i]
prefix = " " if i != self.current_response else ">>> "
set_fields = {
key: value
for key, value in exp.__dict__.items()
if value is not None and value != []
}
oneline_str = ", ".join("{}={!r}".format(*i) for i in set_fields.items())
if len(oneline_str) < 60:
output.append(
"{}{}({})".format(prefix, exp.__class__.__name__, oneline_str)
)
else:
item = []
item.append("{}{}(".format(prefix, exp.__class__.__name__))
for key, value in set_fields.items():
item.append("{} {}={!r}".format(prefix, key, value))
item.append("{})".format(prefix))
output.append("\n".join(item))
if stop_at < len(self.expected_responses):
omitted = len(self.expected_responses) - stop_at
output.append(" (...{} following responses omitted)".format(omitted))
output.append("")
if msg is not None:
output.append("Actually received:")
output.append(protobuf.format_message(msg))
else:
output.append("This message was never received.")
raise AssertionError("\n".join(output))
def _check_request(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
if self.expected_responses is None:
return
if self.current_response >= len(self.expected_responses):
raise AssertionError(
"No more messages were expected, but we got:\n"
+ protobuf.format_message(msg)
)
expected = self.expected_responses[self.current_response]
if msg.__class__ != expected.__class__:
self._raise_unexpected_response(msg)
for field, value in expected.__dict__.items():
if value is None or value == []:
continue
if getattr(msg, field) != value:
self._raise_unexpected_response(msg)
self.current_response += 1
def mnemonic_callback(self, _):
word, pos = self.debug.read_recovery_word()
if word != "":
return word
if pos != 0:
return self.mnemonic[pos - 1]
raise RuntimeError("Unexpected call")
@expect(proto.Success, field="message")
def load_device_by_mnemonic(
client,
mnemonic,
pin,
passphrase_protection,
label,
language="english",
skip_checksum=False,
expand=False,
):
# Convert mnemonic to UTF8 NKFD
mnemonic = Mnemonic.normalize_string(mnemonic)
# Convert mnemonic to ASCII stream
mnemonic = mnemonic.encode()
m = Mnemonic("english")
if expand:
mnemonic = m.expand(mnemonic)
if not skip_checksum and not m.check(mnemonic):
raise ValueError("Invalid mnemonic checksum")
if client.features.initialized:
raise RuntimeError(
"Device is initialized already. Call device.wipe() and try again."
)
resp = client.call(
proto.LoadDevice(
mnemonic=mnemonic,
pin=pin,
passphrase_protection=passphrase_protection,
language=language,
label=label,
skip_checksum=skip_checksum,
)
)
client.init_device()
return resp
@expect(proto.Success, field="message")
def load_device_by_xprv(client, xprv, pin, passphrase_protection, label, language):
if client.features.initialized:
raise RuntimeError(
"Device is initialized already. Call wipe_device() and try again."
)
if xprv[0:4] not in ("xprv", "tprv"):
raise ValueError("Unknown type of xprv")
if not 100 < len(xprv) < 112: # yes this is correct in Python
raise ValueError("Invalid length of xprv")
node = proto.HDNodeType()
data = tools.b58decode(xprv, None).hex()
if data[90:92] != "00":
raise ValueError("Contain invalid private key")
checksum = (tools.btc_hash(bytes.fromhex(data[:156]))[:4]).hex()
if checksum != data[156:]:
raise ValueError("Checksum doesn't match")
# version 0488ade4
# depth 00
# fingerprint 00000000
# child_num 00000000
# chaincode 873dff81c02f525623fd1fe5167eac3a55a049de3d314bb42ee227ffed37d508
# privkey 00e8f32e723decf4051aefac8e2c93c9c5b214313817cdb01a1494b917c8436b35
# checksum e77e9d71
node.depth = int(data[8:10], 16)
node.fingerprint = int(data[10:18], 16)
node.child_num = int(data[18:26], 16)
node.chain_code = bytes.fromhex(data[26:90])
node.private_key = bytes.fromhex(data[92:156]) # skip 0x00 indicating privkey
resp = client.call(
proto.LoadDevice(
node=node,
pin=pin,
passphrase_protection=passphrase_protection,
language=language,
label=label,
)
)
client.init_device()
return resp
@expect(proto.Success, field="message")
def self_test(client):
if client.features.bootloader_mode is not True:
raise RuntimeError("Device must be in bootloader mode")
return client.call(
proto.SelfTest(
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
)
)

View File

@ -0,0 +1,201 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import os
import time
import warnings
from . import messages as proto
from .exceptions import Cancelled
from .tools import expect, session
from .transport import enumerate_devices, get_transport
RECOVERY_BACK = "\x08" # backspace character, sent literally
class TrezorDevice:
"""
This class is deprecated. (There is no reason for it to exist in the first
place, it is nothing but a collection of two functions.)
Instead, please use functions from the ``trezorlib.transport`` module.
"""
@classmethod
def enumerate(cls):
warnings.warn("TrezorDevice is deprecated.", DeprecationWarning)
return enumerate_devices()
@classmethod
def find_by_path(cls, path):
warnings.warn("TrezorDevice is deprecated.", DeprecationWarning)
return get_transport(path, prefix_search=False)
@expect(proto.Success, field="message")
def apply_settings(
client,
label=None,
language=None,
use_passphrase=None,
homescreen=None,
passphrase_source=None,
auto_lock_delay_ms=None,
):
settings = proto.ApplySettings()
if label is not None:
settings.label = label
if language:
settings.language = language
if use_passphrase is not None:
settings.use_passphrase = use_passphrase
if homescreen is not None:
settings.homescreen = homescreen
if passphrase_source is not None:
settings.passphrase_source = passphrase_source
if auto_lock_delay_ms is not None:
settings.auto_lock_delay_ms = auto_lock_delay_ms
out = client.call(settings)
client.init_device() # Reload Features
return out
@expect(proto.Success, field="message")
def apply_flags(client, flags):
out = client.call(proto.ApplyFlags(flags=flags))
client.init_device() # Reload Features
return out
@expect(proto.Success, field="message")
def change_pin(client, remove=False):
ret = client.call(proto.ChangePin(remove=remove))
client.init_device() # Re-read features
return ret
@expect(proto.Success, field="message")
def wipe(client):
ret = client.call(proto.WipeDevice())
client.init_device()
return ret
@expect(proto.Success, field="message")
def recover(
client,
word_count=24,
passphrase_protection=False,
pin_protection=True,
label=None,
language="english",
input_callback=None,
type=proto.RecoveryDeviceType.ScrambledWords,
dry_run=False,
u2f_counter=None,
):
if client.features.model == "1" and input_callback is None:
raise RuntimeError("Input callback required for Trezor One")
if word_count not in (12, 18, 24):
raise ValueError("Invalid word count. Use 12/18/24")
if client.features.initialized and not dry_run:
raise RuntimeError(
"Device already initialized. Call device.wipe() and try again."
)
if u2f_counter is None:
u2f_counter = int(time.time())
res = client.call(
proto.RecoveryDevice(
word_count=word_count,
passphrase_protection=bool(passphrase_protection),
pin_protection=bool(pin_protection),
label=label,
language=language,
enforce_wordlist=True,
type=type,
dry_run=dry_run,
u2f_counter=u2f_counter,
)
)
while isinstance(res, proto.WordRequest):
try:
inp = input_callback(res.type)
res = client.call(proto.WordAck(word=inp))
except Cancelled:
res = client.call(proto.Cancel())
client.init_device()
return res
@expect(proto.Success, field="message")
@session
def reset(
client,
display_random=False,
strength=None,
passphrase_protection=False,
pin_protection=True,
label=None,
language="english",
u2f_counter=0,
skip_backup=False,
no_backup=False,
):
if client.features.initialized:
raise RuntimeError(
"Device is initialized already. Call wipe_device() and try again."
)
if strength is None:
if client.features.model == "1":
strength = 256
else:
strength = 128
# Begin with device reset workflow
msg = proto.ResetDevice(
display_random=bool(display_random),
strength=strength,
passphrase_protection=bool(passphrase_protection),
pin_protection=bool(pin_protection),
language=language,
label=label,
u2f_counter=u2f_counter,
skip_backup=bool(skip_backup),
no_backup=bool(no_backup),
)
resp = client.call(msg)
if not isinstance(resp, proto.EntropyRequest):
raise RuntimeError("Invalid response, expected EntropyRequest")
external_entropy = os.urandom(32)
# LOG.debug("Computer generated entropy: " + external_entropy.hex())
ret = client.call(proto.EntropyAck(entropy=external_entropy))
client.init_device()
return ret
@expect(proto.Success, field="message")
def backup(client):
ret = client.call(proto.BackupDevice())
return ret

View File

@ -0,0 +1,51 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
class TrezorException(Exception):
pass
class TrezorFailure(TrezorException):
def __init__(self, failure):
self.failure = failure
# TODO: this is backwards compatibility with tests. it should be changed
super().__init__(self.failure.code, self.failure.message)
def __str__(self):
from .messages import FailureType
types = {
getattr(FailureType, name): name
for name in dir(FailureType)
if not name.startswith("_")
}
if self.failure.message is not None:
return "{}: {}".format(types[self.failure.code], self.failure.message)
else:
return types[self.failure.code]
class PinException(TrezorException):
pass
class Cancelled(TrezorException):
pass
class OutdatedFirmwareError(TrezorException):
pass

View File

@ -0,0 +1,321 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import hashlib
from enum import Enum
from typing import NewType, Tuple
import construct as c
import ecdsa
import pyblake2
from . import cosi, messages, tools
V1_SIGNATURE_SLOTS = 3
V1_BOOTLOADER_KEYS = {
1: "04d571b7f148c5e4232c3814f777d8faeaf1a84216c78d569b71041ffc768a5b2d810fc3bb134dd026b57e65005275aedef43e155f48fc11a32ec790a93312bd58",
2: "0463279c0c0866e50c05c799d32bd6bab0188b6de06536d1109d2ed9ce76cb335c490e55aee10cc901215132e853097d5432eda06b792073bd7740c94ce4516cb1",
3: "0443aedbb6f7e71c563f8ed2ef64ec9981482519e7ef4f4aa98b27854e8c49126d4956d300ab45fdc34cd26bc8710de0a31dbdf6de7435fd0b492be70ac75fde58",
4: "04877c39fd7c62237e038235e9c075dab261630f78eeb8edb92487159fffedfdf6046c6f8b881fa407c4a4ce6c28de0b19c1f4e29f1fcbc5a58ffd1432a3e0938a",
5: "047384c51ae81add0a523adbb186c91b906ffb64c2c765802bf26dbd13bdf12c319e80c2213a136c8ee03d7874fd22b70d68e7dee469decfbbb510ee9a460cda45",
}
V2_BOOTLOADER_KEYS = [
bytes.fromhex("c2c87a49c5a3460977fbb2ec9dfe60f06bd694db8244bd4981fe3b7a26307f3f"),
bytes.fromhex("80d036b08739b846f4cb77593078deb25dc9487aedcf52e30b4fb7cd7024178a"),
bytes.fromhex("b8307a71f552c60a4cbb317ff48b82cdbf6b6bb5f04c920fec7badf017883751"),
]
V2_BOOTLOADER_M = 2
V2_BOOTLOADER_N = 3
V2_CHUNK_SIZE = 1024 * 128
def _transform_vendor_trust(data: bytes) -> bytes:
"""Byte-swap and bit-invert the VendorTrust field.
Vendor trust is interpreted as a bitmask in a 16-bit little-endian integer,
with the added twist that 0 means set and 1 means unset.
We feed it to a `BitStruct` that expects a big-endian sequence where bits have
the traditional meaning. We must therefore do a bitwise negation of each byte,
and return them in reverse order. This is the same transformation both ways,
fortunately, so we don't need two separate functions.
"""
return bytes(~b & 0xFF for b in data)[::-1]
# fmt: off
Toif = c.Struct(
"magic" / c.Const(b"TOI"),
"format" / c.Enum(c.Byte, full_color=b"f", grayscale=b"g"),
"width" / c.Int16ul,
"height" / c.Int16ul,
"data" / c.Prefixed(c.Int32ul, c.GreedyBytes),
)
VendorTrust = c.Transformed(c.BitStruct(
"reserved" / c.Default(c.BitsInteger(9), 0),
"show_vendor_string" / c.Flag,
"require_user_click" / c.Flag,
"red_background" / c.Flag,
"delay" / c.BitsInteger(4),
), _transform_vendor_trust, 2, _transform_vendor_trust, 2)
VendorHeader = c.Struct(
"_start_offset" / c.Tell,
"magic" / c.Const(b"TRZV"),
"_header_len" / c.Padding(4),
"expiry" / c.Int32ul,
"version" / c.Struct(
"major" / c.Int8ul,
"minor" / c.Int8ul,
),
"vendor_sigs_required" / c.Int8ul,
"vendor_sigs_n" / c.Rebuild(c.Int8ul, c.len_(c.this.pubkeys)),
"vendor_trust" / VendorTrust,
"reserved" / c.Padding(14),
"pubkeys" / c.Bytes(32)[c.this.vendor_sigs_n],
"vendor_string" / c.Aligned(4, c.PascalString(c.Int8ul, "utf-8")),
"vendor_image" / Toif,
"_data_end_offset" / c.Tell,
c.Padding(-(c.this._data_end_offset + 65) % 512),
"sigmask" / c.Byte,
"signature" / c.Bytes(64),
"_end_offset" / c.Tell,
"header_len" / c.Pointer(
c.this._start_offset + 4,
c.Rebuild(c.Int32ul, c.this._end_offset - c.this._start_offset)
),
)
VersionLong = c.Struct(
"major" / c.Int8ul,
"minor" / c.Int8ul,
"patch" / c.Int8ul,
"build" / c.Int8ul,
)
FirmwareHeader = c.Struct(
"_start_offset" / c.Tell,
"magic" / c.Const(b"TRZF"),
"_header_len" / c.Padding(4),
"expiry" / c.Int32ul,
"code_length" / c.Rebuild(
c.Int32ul,
lambda this:
len(this._.code) if "code" in this._
else (this.code_length or 0)
),
"version" / VersionLong,
"fix_version" / VersionLong,
"reserved" / c.Padding(8),
"hashes" / c.Bytes(32)[16],
"reserved" / c.Padding(415),
"sigmask" / c.Byte,
"signature" / c.Bytes(64),
"_end_offset" / c.Tell,
"header_len" / c.Pointer(
c.this._start_offset + 4,
c.Rebuild(c.Int32ul, c.this._end_offset - c.this._start_offset)
),
)
Firmware = c.Struct(
"vendor_header" / VendorHeader,
"firmware_header" / FirmwareHeader,
"_code_offset" / c.Tell,
"code" / c.Bytes(c.this.firmware_header.code_length),
c.Terminated,
)
FirmwareV1 = c.Struct(
"magic" / c.Const(b"TRZR"),
"code_length" / c.Rebuild(c.Int32ul, c.len_(c.this.code)),
"key_indexes" / c.Int8ul[V1_SIGNATURE_SLOTS], # pylint: disable=E1136
"flags" / c.BitStruct(
c.Padding(7),
"restore_storage" / c.Flag,
),
"reserved" / c.Padding(52),
"signatures" / c.Bytes(64)[V1_SIGNATURE_SLOTS],
"code" / c.Bytes(c.this.code_length),
c.Terminated,
)
# fmt: on
class FirmwareFormat(Enum):
TREZOR_ONE = 1
TREZOR_T = 2
FirmwareType = NewType("FirmwareType", c.Container)
ParsedFirmware = Tuple[FirmwareFormat, FirmwareType]
def parse(data: bytes) -> ParsedFirmware:
if data[:4] == b"TRZR":
version = FirmwareFormat.TREZOR_ONE
cls = FirmwareV1
elif data[:4] == b"TRZV":
version = FirmwareFormat.TREZOR_T
cls = Firmware
else:
raise ValueError("Unrecognized firmware image type")
try:
fw = cls.parse(data)
except Exception as e:
raise ValueError("Invalid firmware image") from e
return version, FirmwareType(fw)
def digest_v1(fw: FirmwareType) -> bytes:
return hashlib.sha256(fw.code).digest()
def check_sig_v1(fw: FirmwareType, idx: int) -> bool:
key_idx = fw.key_indexes[idx]
signature = fw.signatures[idx]
if key_idx == 0:
# no signature = invalid signature
return False
if key_idx not in V1_BOOTLOADER_KEYS:
# unknown pubkey
return False
pubkey = bytes.fromhex(V1_BOOTLOADER_KEYS[key_idx])[1:]
verify = ecdsa.VerifyingKey.from_string(
pubkey, curve=ecdsa.curves.SECP256k1, hashfunc=hashlib.sha256
)
try:
verify.verify(signature, fw.code)
return True
except ecdsa.BadSignatureError:
return False
def _header_digest(header: c.Container, header_type: c.Construct) -> bytes:
stripped_header = header.copy()
stripped_header.sigmask = 0
stripped_header.signature = b"\0" * 64
header_bytes = header_type.build(stripped_header)
return pyblake2.blake2s(header_bytes).digest()
def digest(fw: FirmwareType) -> bytes:
return _header_digest(fw.firmware_header, FirmwareHeader)
def validate(fw: FirmwareType, skip_vendor_header=False) -> bool:
vendor_fingerprint = _header_digest(fw.vendor_header, VendorHeader)
fingerprint = digest(fw)
if not skip_vendor_header:
try:
# if you want to validate a custom vendor header, you can modify
# the global variables to match your keys and m-of-n scheme
cosi.verify_m_of_n(
fw.vendor_header.signature,
vendor_fingerprint,
V2_BOOTLOADER_M,
V2_BOOTLOADER_N,
fw.vendor_header.sigmask,
V2_BOOTLOADER_KEYS,
)
except Exception:
raise ValueError("Invalid vendor header signature.")
# XXX expiry is not used now
# now = time.gmtime()
# if time.gmtime(fw.vendor_header.expiry) < now:
# raise ValueError("Vendor header expired.")
try:
cosi.verify_m_of_n(
fw.firmware_header.signature,
fingerprint,
fw.vendor_header.vendor_sigs_required,
fw.vendor_header.vendor_sigs_n,
fw.firmware_header.sigmask,
fw.vendor_header.pubkeys,
)
except Exception:
raise ValueError("Invalid firmware signature.")
# XXX expiry is not used now
# if time.gmtime(fw.firmware_header.expiry) < now:
# raise ValueError("Firmware header expired.")
for i, expected_hash in enumerate(fw.firmware_header.hashes):
if i == 0:
# Because first chunk is sent along with headers, there is less code in it.
chunk = fw.code[: V2_CHUNK_SIZE - fw._code_offset]
else:
# Subsequent chunks are shifted by the "missing header" size.
ptr = i * V2_CHUNK_SIZE - fw._code_offset
chunk = fw.code[ptr : ptr + V2_CHUNK_SIZE]
if not chunk and expected_hash == b"\0" * 32:
continue
chunk_hash = pyblake2.blake2s(chunk).digest()
if chunk_hash != expected_hash:
raise ValueError("Invalid firmware data.")
return True
# ====== Client functions ====== #
@tools.session
def update(client, data):
if client.features.bootloader_mode is False:
raise RuntimeError("Device must be in bootloader mode")
resp = client.call(messages.FirmwareErase(length=len(data)))
# TREZORv1 method
if isinstance(resp, messages.Success):
resp = client.call(messages.FirmwareUpload(payload=data))
if isinstance(resp, messages.Success):
return
else:
raise RuntimeError("Unexpected result %s" % resp)
# TREZORv2 method
while isinstance(resp, messages.FirmwareRequest):
payload = data[resp.offset : resp.offset + resp.length]
digest = pyblake2.blake2s(payload).digest()
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))
if isinstance(resp, messages.Success):
return
else:
raise RuntimeError("Unexpected message %s" % resp)

View File

@ -0,0 +1,51 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
from typing import Optional, Set, Type
from . import protobuf
OMITTED_MESSAGES = set() # type: Set[Type[protobuf.MessageType]]
class PrettyProtobufFormatter(logging.Formatter):
def format(self, record: logging.LogRecord) -> str:
time = self.formatTime(record)
message = "[{time}] {source} {level}: {msg}".format(
time=time,
level=record.levelname.upper(),
source=record.name,
msg=super().format(record),
)
if hasattr(record, "protobuf"):
if type(record.protobuf) in OMITTED_MESSAGES:
message += " ({} bytes)".format(record.protobuf.ByteSize())
else:
message += "\n" + protobuf.format_message(record.protobuf)
return message
def enable_debug_output(handler: Optional[logging.Handler] = None):
if handler is None:
handler = logging.StreamHandler()
formatter = PrettyProtobufFormatter()
handler.setFormatter(formatter)
logger = logging.getLogger("trezorlib")
logger.setLevel(logging.DEBUG)
logger.addHandler(handler)

View File

@ -0,0 +1,62 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from . import messages
map_type_to_class = {}
map_class_to_type = {}
def build_map():
for msg_name in dir(messages.MessageType):
if msg_name.startswith("__"):
continue
try:
msg_class = getattr(messages, msg_name)
except AttributeError:
raise ValueError(
"Implementation of protobuf message '%s' is missing" % msg_name
)
if msg_class.MESSAGE_WIRE_TYPE != getattr(messages.MessageType, msg_name):
raise ValueError(
"Inconsistent wire type and MessageType record for '%s'" % msg_class
)
register_message(msg_class)
def register_message(msg_class):
if msg_class.MESSAGE_WIRE_TYPE in map_type_to_class:
raise Exception(
"Message for wire type %s is already registered by %s"
% (msg_class.MESSAGE_WIRE_TYPE, get_class(msg_class.MESSAGE_WIRE_TYPE))
)
map_class_to_type[msg_class] = msg_class.MESSAGE_WIRE_TYPE
map_type_to_class[msg_class.MESSAGE_WIRE_TYPE] = msg_class
def get_type(msg):
return map_class_to_type[msg.__class__]
def get_class(t):
return map_type_to_class[t]
build_map()

View File

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class Address(p.MessageType):
MESSAGE_WIRE_TYPE = 30
def __init__(
self,
address: str = None,
) -> None:
self.address = address
@classmethod
def get_fields(cls):
return {
1: ('address', p.UnicodeType, 0), # required
}

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class ApplyFlags(p.MessageType):
MESSAGE_WIRE_TYPE = 28
def __init__(
self,
flags: int = None,
) -> None:
self.flags = flags
@classmethod
def get_fields(cls):
return {
1: ('flags', p.UVarintType, 0),
}

View File

@ -0,0 +1,34 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class ApplySettings(p.MessageType):
MESSAGE_WIRE_TYPE = 25
def __init__(
self,
language: str = None,
label: str = None,
use_passphrase: bool = None,
homescreen: bytes = None,
passphrase_source: int = None,
auto_lock_delay_ms: int = None,
) -> None:
self.language = language
self.label = label
self.use_passphrase = use_passphrase
self.homescreen = homescreen
self.passphrase_source = passphrase_source
self.auto_lock_delay_ms = auto_lock_delay_ms
@classmethod
def get_fields(cls):
return {
1: ('language', p.UnicodeType, 0),
2: ('label', p.UnicodeType, 0),
3: ('use_passphrase', p.BoolType, 0),
4: ('homescreen', p.BytesType, 0),
5: ('passphrase_source', p.UVarintType, 0),
6: ('auto_lock_delay_ms', p.UVarintType, 0),
}

View File

@ -0,0 +1,7 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class BackupDevice(p.MessageType):
MESSAGE_WIRE_TYPE = 34

View File

@ -0,0 +1,7 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class ButtonAck(p.MessageType):
MESSAGE_WIRE_TYPE = 27

View File

@ -0,0 +1,22 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class ButtonRequest(p.MessageType):
MESSAGE_WIRE_TYPE = 26
def __init__(
self,
code: int = None,
data: str = None,
) -> None:
self.code = code
self.data = data
@classmethod
def get_fields(cls):
return {
1: ('code', p.UVarintType, 0),
2: ('data', p.UnicodeType, 0),
}

View File

@ -0,0 +1,17 @@
# Automatically generated by pb2py
# fmt: off
Other = 1
FeeOverThreshold = 2
ConfirmOutput = 3
ResetDevice = 4
ConfirmWord = 5
WipeDevice = 6
ProtectCall = 7
SignTx = 8
FirmwareCheck = 9
Address = 10
PublicKey = 11
MnemonicWordCount = 12
MnemonicInput = 13
PassphraseType = 14
UnknownDerivationPath = 15

View File

@ -0,0 +1,7 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class Cancel(p.MessageType):
MESSAGE_WIRE_TYPE = 20

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class ChangePin(p.MessageType):
MESSAGE_WIRE_TYPE = 4
def __init__(
self,
remove: bool = None,
) -> None:
self.remove = remove
@classmethod
def get_fields(cls):
return {
1: ('remove', p.BoolType, 0),
}

View File

@ -0,0 +1,7 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class ClearSession(p.MessageType):
MESSAGE_WIRE_TYPE = 24

View File

@ -0,0 +1,25 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class DebugLinkDecision(p.MessageType):
MESSAGE_WIRE_TYPE = 100
def __init__(
self,
yes_no: bool = None,
up_down: bool = None,
input: str = None,
) -> None:
self.yes_no = yes_no
self.up_down = up_down
self.input = input
@classmethod
def get_fields(cls):
return {
1: ('yes_no', p.BoolType, 0),
2: ('up_down', p.BoolType, 0),
3: ('input', p.UnicodeType, 0),
}

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class DebugLinkFlashErase(p.MessageType):
MESSAGE_WIRE_TYPE = 113
def __init__(
self,
sector: int = None,
) -> None:
self.sector = sector
@classmethod
def get_fields(cls):
return {
1: ('sector', p.UVarintType, 0),
}

View File

@ -0,0 +1,7 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class DebugLinkGetState(p.MessageType):
MESSAGE_WIRE_TYPE = 101

View File

@ -0,0 +1,25 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class DebugLinkLog(p.MessageType):
MESSAGE_WIRE_TYPE = 104
def __init__(
self,
level: int = None,
bucket: str = None,
text: str = None,
) -> None:
self.level = level
self.bucket = bucket
self.text = text
@classmethod
def get_fields(cls):
return {
1: ('level', p.UVarintType, 0),
2: ('bucket', p.UnicodeType, 0),
3: ('text', p.UnicodeType, 0),
}

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class DebugLinkMemory(p.MessageType):
MESSAGE_WIRE_TYPE = 111
def __init__(
self,
memory: bytes = None,
) -> None:
self.memory = memory
@classmethod
def get_fields(cls):
return {
1: ('memory', p.BytesType, 0),
}

View File

@ -0,0 +1,22 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class DebugLinkMemoryRead(p.MessageType):
MESSAGE_WIRE_TYPE = 110
def __init__(
self,
address: int = None,
length: int = None,
) -> None:
self.address = address
self.length = length
@classmethod
def get_fields(cls):
return {
1: ('address', p.UVarintType, 0),
2: ('length', p.UVarintType, 0),
}

View File

@ -0,0 +1,25 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class DebugLinkMemoryWrite(p.MessageType):
MESSAGE_WIRE_TYPE = 112
def __init__(
self,
address: int = None,
memory: bytes = None,
flash: bool = None,
) -> None:
self.address = address
self.memory = memory
self.flash = flash
@classmethod
def get_fields(cls):
return {
1: ('address', p.UVarintType, 0),
2: ('memory', p.BytesType, 0),
3: ('flash', p.BoolType, 0),
}

View File

@ -0,0 +1,51 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
from .HDNodeType import HDNodeType
class DebugLinkState(p.MessageType):
MESSAGE_WIRE_TYPE = 102
def __init__(
self,
layout: bytes = None,
pin: str = None,
matrix: str = None,
mnemonic: str = None,
node: HDNodeType = None,
passphrase_protection: bool = None,
reset_word: str = None,
reset_entropy: bytes = None,
recovery_fake_word: str = None,
recovery_word_pos: int = None,
reset_word_pos: int = None,
) -> None:
self.layout = layout
self.pin = pin
self.matrix = matrix
self.mnemonic = mnemonic
self.node = node
self.passphrase_protection = passphrase_protection
self.reset_word = reset_word
self.reset_entropy = reset_entropy
self.recovery_fake_word = recovery_fake_word
self.recovery_word_pos = recovery_word_pos
self.reset_word_pos = reset_word_pos
@classmethod
def get_fields(cls):
return {
1: ('layout', p.BytesType, 0),
2: ('pin', p.UnicodeType, 0),
3: ('matrix', p.UnicodeType, 0),
4: ('mnemonic', p.UnicodeType, 0),
5: ('node', HDNodeType, 0),
6: ('passphrase_protection', p.BoolType, 0),
7: ('reset_word', p.UnicodeType, 0),
8: ('reset_entropy', p.BytesType, 0),
9: ('recovery_fake_word', p.UnicodeType, 0),
10: ('recovery_word_pos', p.UVarintType, 0),
11: ('reset_word_pos', p.UVarintType, 0),
}

View File

@ -0,0 +1,7 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class DebugLinkStop(p.MessageType):
MESSAGE_WIRE_TYPE = 103

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class Entropy(p.MessageType):
MESSAGE_WIRE_TYPE = 10
def __init__(
self,
entropy: bytes = None,
) -> None:
self.entropy = entropy
@classmethod
def get_fields(cls):
return {
1: ('entropy', p.BytesType, 0), # required
}

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class EntropyAck(p.MessageType):
MESSAGE_WIRE_TYPE = 36
def __init__(
self,
entropy: bytes = None,
) -> None:
self.entropy = entropy
@classmethod
def get_fields(cls):
return {
1: ('entropy', p.BytesType, 0),
}

View File

@ -0,0 +1,7 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class EntropyRequest(p.MessageType):
MESSAGE_WIRE_TYPE = 35

View File

@ -0,0 +1,22 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class Failure(p.MessageType):
MESSAGE_WIRE_TYPE = 3
def __init__(
self,
code: int = None,
message: str = None,
) -> None:
self.code = code
self.message = message
@classmethod
def get_fields(cls):
return {
1: ('code', p.UVarintType, 0),
2: ('message', p.UnicodeType, 0),
}

View File

@ -0,0 +1,15 @@
# Automatically generated by pb2py
# fmt: off
UnexpectedMessage = 1
ButtonExpected = 2
DataError = 3
ActionCancelled = 4
PinExpected = 5
PinCancelled = 6
PinInvalid = 7
InvalidSignature = 8
ProcessError = 9
NotEnoughFunds = 10
NotInitialized = 11
PinMismatch = 12
FirmwareError = 99

View File

@ -0,0 +1,97 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class Features(p.MessageType):
MESSAGE_WIRE_TYPE = 17
def __init__(
self,
vendor: str = None,
major_version: int = None,
minor_version: int = None,
patch_version: int = None,
bootloader_mode: bool = None,
device_id: str = None,
pin_protection: bool = None,
passphrase_protection: bool = None,
language: str = None,
label: str = None,
initialized: bool = None,
revision: bytes = None,
bootloader_hash: bytes = None,
imported: bool = None,
pin_cached: bool = None,
passphrase_cached: bool = None,
firmware_present: bool = None,
needs_backup: bool = None,
flags: int = None,
model: str = None,
fw_major: int = None,
fw_minor: int = None,
fw_patch: int = None,
fw_vendor: str = None,
fw_vendor_keys: bytes = None,
unfinished_backup: bool = None,
no_backup: bool = None,
) -> None:
self.vendor = vendor
self.major_version = major_version
self.minor_version = minor_version
self.patch_version = patch_version
self.bootloader_mode = bootloader_mode
self.device_id = device_id
self.pin_protection = pin_protection
self.passphrase_protection = passphrase_protection
self.language = language
self.label = label
self.initialized = initialized
self.revision = revision
self.bootloader_hash = bootloader_hash
self.imported = imported
self.pin_cached = pin_cached
self.passphrase_cached = passphrase_cached
self.firmware_present = firmware_present
self.needs_backup = needs_backup
self.flags = flags
self.model = model
self.fw_major = fw_major
self.fw_minor = fw_minor
self.fw_patch = fw_patch
self.fw_vendor = fw_vendor
self.fw_vendor_keys = fw_vendor_keys
self.unfinished_backup = unfinished_backup
self.no_backup = no_backup
@classmethod
def get_fields(cls):
return {
1: ('vendor', p.UnicodeType, 0),
2: ('major_version', p.UVarintType, 0),
3: ('minor_version', p.UVarintType, 0),
4: ('patch_version', p.UVarintType, 0),
5: ('bootloader_mode', p.BoolType, 0),
6: ('device_id', p.UnicodeType, 0),
7: ('pin_protection', p.BoolType, 0),
8: ('passphrase_protection', p.BoolType, 0),
9: ('language', p.UnicodeType, 0),
10: ('label', p.UnicodeType, 0),
12: ('initialized', p.BoolType, 0),
13: ('revision', p.BytesType, 0),
14: ('bootloader_hash', p.BytesType, 0),
15: ('imported', p.BoolType, 0),
16: ('pin_cached', p.BoolType, 0),
17: ('passphrase_cached', p.BoolType, 0),
18: ('firmware_present', p.BoolType, 0),
19: ('needs_backup', p.BoolType, 0),
20: ('flags', p.UVarintType, 0),
21: ('model', p.UnicodeType, 0),
22: ('fw_major', p.UVarintType, 0),
23: ('fw_minor', p.UVarintType, 0),
24: ('fw_patch', p.UVarintType, 0),
25: ('fw_vendor', p.UnicodeType, 0),
26: ('fw_vendor_keys', p.BytesType, 0),
27: ('unfinished_backup', p.BoolType, 0),
28: ('no_backup', p.BoolType, 0),
}

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class FirmwareErase(p.MessageType):
MESSAGE_WIRE_TYPE = 6
def __init__(
self,
length: int = None,
) -> None:
self.length = length
@classmethod
def get_fields(cls):
return {
1: ('length', p.UVarintType, 0),
}

View File

@ -0,0 +1,22 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class FirmwareRequest(p.MessageType):
MESSAGE_WIRE_TYPE = 8
def __init__(
self,
offset: int = None,
length: int = None,
) -> None:
self.offset = offset
self.length = length
@classmethod
def get_fields(cls):
return {
1: ('offset', p.UVarintType, 0),
2: ('length', p.UVarintType, 0),
}

View File

@ -0,0 +1,22 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class FirmwareUpload(p.MessageType):
MESSAGE_WIRE_TYPE = 7
def __init__(
self,
payload: bytes = None,
hash: bytes = None,
) -> None:
self.payload = payload
self.hash = hash
@classmethod
def get_fields(cls):
return {
1: ('payload', p.BytesType, 0), # required
2: ('hash', p.BytesType, 0),
}

View File

@ -0,0 +1,39 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
from .MultisigRedeemScriptType import MultisigRedeemScriptType
if __debug__:
try:
from typing import List
except ImportError:
List = None # type: ignore
class GetAddress(p.MessageType):
MESSAGE_WIRE_TYPE = 29
def __init__(
self,
address_n: List[int] = None,
coin_name: str = None,
show_display: bool = None,
multisig: MultisigRedeemScriptType = None,
script_type: int = None,
) -> None:
self.address_n = address_n if address_n is not None else []
self.coin_name = coin_name
self.show_display = show_display
self.multisig = multisig
self.script_type = script_type
@classmethod
def get_fields(cls):
return {
1: ('address_n', p.UVarintType, p.FLAG_REPEATED),
2: ('coin_name', p.UnicodeType, 0), # default=Bitcoin
3: ('show_display', p.BoolType, 0),
4: ('multisig', MultisigRedeemScriptType, 0),
5: ('script_type', p.UVarintType, 0), # default=SPENDADDRESS
}

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class GetEntropy(p.MessageType):
MESSAGE_WIRE_TYPE = 9
def __init__(
self,
size: int = None,
) -> None:
self.size = size
@classmethod
def get_fields(cls):
return {
1: ('size', p.UVarintType, 0), # required
}

View File

@ -0,0 +1,7 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class GetFeatures(p.MessageType):
MESSAGE_WIRE_TYPE = 55

View File

@ -0,0 +1,37 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
if __debug__:
try:
from typing import List
except ImportError:
List = None # type: ignore
class GetPublicKey(p.MessageType):
MESSAGE_WIRE_TYPE = 11
def __init__(
self,
address_n: List[int] = None,
ecdsa_curve_name: str = None,
show_display: bool = None,
coin_name: str = None,
script_type: int = None,
) -> None:
self.address_n = address_n if address_n is not None else []
self.ecdsa_curve_name = ecdsa_curve_name
self.show_display = show_display
self.coin_name = coin_name
self.script_type = script_type
@classmethod
def get_fields(cls):
return {
1: ('address_n', p.UVarintType, p.FLAG_REPEATED),
2: ('ecdsa_curve_name', p.UnicodeType, 0),
3: ('show_display', p.BoolType, 0),
4: ('coin_name', p.UnicodeType, 0), # default=Bitcoin
5: ('script_type', p.UVarintType, 0), # default=SPENDADDRESS
}

View File

@ -0,0 +1,29 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
from .HDNodeType import HDNodeType
if __debug__:
try:
from typing import List
except ImportError:
List = None # type: ignore
class HDNodePathType(p.MessageType):
def __init__(
self,
node: HDNodeType = None,
address_n: List[int] = None,
) -> None:
self.node = node
self.address_n = address_n if address_n is not None else []
@classmethod
def get_fields(cls):
return {
1: ('node', HDNodeType, 0), # required
2: ('address_n', p.UVarintType, p.FLAG_REPEATED),
}

View File

@ -0,0 +1,33 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class HDNodeType(p.MessageType):
def __init__(
self,
depth: int = None,
fingerprint: int = None,
child_num: int = None,
chain_code: bytes = None,
private_key: bytes = None,
public_key: bytes = None,
) -> None:
self.depth = depth
self.fingerprint = fingerprint
self.child_num = child_num
self.chain_code = chain_code
self.private_key = private_key
self.public_key = public_key
@classmethod
def get_fields(cls):
return {
1: ('depth', p.UVarintType, 0), # required
2: ('fingerprint', p.UVarintType, 0), # required
3: ('child_num', p.UVarintType, 0), # required
4: ('chain_code', p.BytesType, 0), # required
5: ('private_key', p.BytesType, 0),
6: ('public_key', p.BytesType, 0),
}

View File

@ -0,0 +1,33 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class IdentityType(p.MessageType):
def __init__(
self,
proto: str = None,
user: str = None,
host: str = None,
port: str = None,
path: str = None,
index: int = None,
) -> None:
self.proto = proto
self.user = user
self.host = host
self.port = port
self.path = path
self.index = index
@classmethod
def get_fields(cls):
return {
1: ('proto', p.UnicodeType, 0),
2: ('user', p.UnicodeType, 0),
3: ('host', p.UnicodeType, 0),
4: ('port', p.UnicodeType, 0),
5: ('path', p.UnicodeType, 0),
6: ('index', p.UVarintType, 0), # default=0
}

View File

@ -0,0 +1,22 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class Initialize(p.MessageType):
MESSAGE_WIRE_TYPE = 0
def __init__(
self,
state: bytes = None,
skip_passphrase: bool = None,
) -> None:
self.state = state
self.skip_passphrase = skip_passphrase
@classmethod
def get_fields(cls):
return {
1: ('state', p.BytesType, 0),
2: ('skip_passphrase', p.BoolType, 0),
}

View File

@ -0,0 +1,7 @@
# Automatically generated by pb2py
# fmt: off
SPENDADDRESS = 0
SPENDMULTISIG = 1
EXTERNAL = 2
SPENDWITNESS = 3
SPENDP2SHWITNESS = 4

View File

@ -0,0 +1,42 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
from .HDNodeType import HDNodeType
class LoadDevice(p.MessageType):
MESSAGE_WIRE_TYPE = 13
def __init__(
self,
mnemonic: str = None,
node: HDNodeType = None,
pin: str = None,
passphrase_protection: bool = None,
language: str = None,
label: str = None,
skip_checksum: bool = None,
u2f_counter: int = None,
) -> None:
self.mnemonic = mnemonic
self.node = node
self.pin = pin
self.passphrase_protection = passphrase_protection
self.language = language
self.label = label
self.skip_checksum = skip_checksum
self.u2f_counter = u2f_counter
@classmethod
def get_fields(cls):
return {
1: ('mnemonic', p.UnicodeType, 0),
2: ('node', HDNodeType, 0),
3: ('pin', p.UnicodeType, 0),
4: ('passphrase_protection', p.BoolType, 0),
5: ('language', p.UnicodeType, 0), # default=english
6: ('label', p.UnicodeType, 0),
7: ('skip_checksum', p.BoolType, 0),
8: ('u2f_counter', p.UVarintType, 0),
}

View File

@ -0,0 +1,22 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class MessageSignature(p.MessageType):
MESSAGE_WIRE_TYPE = 40
def __init__(
self,
address: str = None,
signature: bytes = None,
) -> None:
self.address = address
self.signature = signature
@classmethod
def get_fields(cls):
return {
1: ('address', p.UnicodeType, 0),
2: ('signature', p.BytesType, 0),
}

View File

@ -0,0 +1,57 @@
# Automatically generated by pb2py
# fmt: off
Initialize = 0
Ping = 1
Success = 2
Failure = 3
ChangePin = 4
WipeDevice = 5
GetEntropy = 9
Entropy = 10
LoadDevice = 13
ResetDevice = 14
Features = 17
PinMatrixRequest = 18
PinMatrixAck = 19
Cancel = 20
ClearSession = 24
ApplySettings = 25
ButtonRequest = 26
ButtonAck = 27
ApplyFlags = 28
BackupDevice = 34
EntropyRequest = 35
EntropyAck = 36
PassphraseRequest = 41
PassphraseAck = 42
PassphraseStateRequest = 77
PassphraseStateAck = 78
RecoveryDevice = 45
WordRequest = 46
WordAck = 47
GetFeatures = 55
FirmwareErase = 6
FirmwareUpload = 7
FirmwareRequest = 8
SelfTest = 32
GetPublicKey = 11
PublicKey = 12
SignTx = 15
TxRequest = 21
TxAck = 22
GetAddress = 29
Address = 30
SignMessage = 38
VerifyMessage = 39
MessageSignature = 40
SignIdentity = 53
SignedIdentity = 54
DebugLinkDecision = 100
DebugLinkGetState = 101
DebugLinkState = 102
DebugLinkStop = 103
DebugLinkLog = 104
DebugLinkMemoryRead = 110
DebugLinkMemory = 111
DebugLinkMemoryWrite = 112
DebugLinkFlashErase = 113

View File

@ -0,0 +1,32 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
from .HDNodePathType import HDNodePathType
if __debug__:
try:
from typing import List
except ImportError:
List = None # type: ignore
class MultisigRedeemScriptType(p.MessageType):
def __init__(
self,
pubkeys: List[HDNodePathType] = None,
signatures: List[bytes] = None,
m: int = None,
) -> None:
self.pubkeys = pubkeys if pubkeys is not None else []
self.signatures = signatures if signatures is not None else []
self.m = m
@classmethod
def get_fields(cls):
return {
1: ('pubkeys', HDNodePathType, p.FLAG_REPEATED),
2: ('signatures', p.BytesType, p.FLAG_REPEATED),
3: ('m', p.UVarintType, 0),
}

View File

@ -0,0 +1,8 @@
# Automatically generated by pb2py
# fmt: off
PAYTOADDRESS = 0
PAYTOSCRIPTHASH = 1
PAYTOMULTISIG = 2
PAYTOOPRETURN = 3
PAYTOWITNESS = 4
PAYTOP2SHWITNESS = 5

View File

@ -0,0 +1,22 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class PassphraseAck(p.MessageType):
MESSAGE_WIRE_TYPE = 42
def __init__(
self,
passphrase: str = None,
state: bytes = None,
) -> None:
self.passphrase = passphrase
self.state = state
@classmethod
def get_fields(cls):
return {
1: ('passphrase', p.UnicodeType, 0),
2: ('state', p.BytesType, 0),
}

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class PassphraseRequest(p.MessageType):
MESSAGE_WIRE_TYPE = 41
def __init__(
self,
on_device: bool = None,
) -> None:
self.on_device = on_device
@classmethod
def get_fields(cls):
return {
1: ('on_device', p.BoolType, 0),
}

View File

@ -0,0 +1,5 @@
# Automatically generated by pb2py
# fmt: off
ASK = 0
DEVICE = 1
HOST = 2

View File

@ -0,0 +1,7 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class PassphraseStateAck(p.MessageType):
MESSAGE_WIRE_TYPE = 78

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class PassphraseStateRequest(p.MessageType):
MESSAGE_WIRE_TYPE = 77
def __init__(
self,
state: bytes = None,
) -> None:
self.state = state
@classmethod
def get_fields(cls):
return {
1: ('state', p.BytesType, 0),
}

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class PinMatrixAck(p.MessageType):
MESSAGE_WIRE_TYPE = 19
def __init__(
self,
pin: str = None,
) -> None:
self.pin = pin
@classmethod
def get_fields(cls):
return {
1: ('pin', p.UnicodeType, 0), # required
}

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class PinMatrixRequest(p.MessageType):
MESSAGE_WIRE_TYPE = 18
def __init__(
self,
type: int = None,
) -> None:
self.type = type
@classmethod
def get_fields(cls):
return {
1: ('type', p.UVarintType, 0),
}

View File

@ -0,0 +1,5 @@
# Automatically generated by pb2py
# fmt: off
Current = 1
NewFirst = 2
NewSecond = 3

View File

@ -0,0 +1,28 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class Ping(p.MessageType):
MESSAGE_WIRE_TYPE = 1
def __init__(
self,
message: str = None,
button_protection: bool = None,
pin_protection: bool = None,
passphrase_protection: bool = None,
) -> None:
self.message = message
self.button_protection = button_protection
self.pin_protection = pin_protection
self.passphrase_protection = passphrase_protection
@classmethod
def get_fields(cls):
return {
1: ('message', p.UnicodeType, 0),
2: ('button_protection', p.BoolType, 0),
3: ('pin_protection', p.BoolType, 0),
4: ('passphrase_protection', p.BoolType, 0),
}

View File

@ -0,0 +1,24 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
from .HDNodeType import HDNodeType
class PublicKey(p.MessageType):
MESSAGE_WIRE_TYPE = 12
def __init__(
self,
node: HDNodeType = None,
xpub: str = None,
) -> None:
self.node = node
self.xpub = xpub
@classmethod
def get_fields(cls):
return {
1: ('node', HDNodeType, 0), # required
2: ('xpub', p.UnicodeType, 0),
}

View File

@ -0,0 +1,43 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class RecoveryDevice(p.MessageType):
MESSAGE_WIRE_TYPE = 45
def __init__(
self,
word_count: int = None,
passphrase_protection: bool = None,
pin_protection: bool = None,
language: str = None,
label: str = None,
enforce_wordlist: bool = None,
type: int = None,
u2f_counter: int = None,
dry_run: bool = None,
) -> None:
self.word_count = word_count
self.passphrase_protection = passphrase_protection
self.pin_protection = pin_protection
self.language = language
self.label = label
self.enforce_wordlist = enforce_wordlist
self.type = type
self.u2f_counter = u2f_counter
self.dry_run = dry_run
@classmethod
def get_fields(cls):
return {
1: ('word_count', p.UVarintType, 0),
2: ('passphrase_protection', p.BoolType, 0),
3: ('pin_protection', p.BoolType, 0),
4: ('language', p.UnicodeType, 0), # default=english
5: ('label', p.UnicodeType, 0),
6: ('enforce_wordlist', p.BoolType, 0),
8: ('type', p.UVarintType, 0),
9: ('u2f_counter', p.UVarintType, 0),
10: ('dry_run', p.BoolType, 0),
}

View File

@ -0,0 +1,4 @@
# Automatically generated by pb2py
# fmt: off
ScrambledWords = 0
Matrix = 1

View File

@ -0,0 +1,7 @@
# Automatically generated by pb2py
# fmt: off
TXINPUT = 0
TXOUTPUT = 1
TXMETA = 2
TXFINISHED = 3
TXEXTRADATA = 4

View File

@ -0,0 +1,43 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class ResetDevice(p.MessageType):
MESSAGE_WIRE_TYPE = 14
def __init__(
self,
display_random: bool = None,
strength: int = None,
passphrase_protection: bool = None,
pin_protection: bool = None,
language: str = None,
label: str = None,
u2f_counter: int = None,
skip_backup: bool = None,
no_backup: bool = None,
) -> None:
self.display_random = display_random
self.strength = strength
self.passphrase_protection = passphrase_protection
self.pin_protection = pin_protection
self.language = language
self.label = label
self.u2f_counter = u2f_counter
self.skip_backup = skip_backup
self.no_backup = no_backup
@classmethod
def get_fields(cls):
return {
1: ('display_random', p.BoolType, 0),
2: ('strength', p.UVarintType, 0), # default=256
3: ('passphrase_protection', p.BoolType, 0),
4: ('pin_protection', p.BoolType, 0),
5: ('language', p.UnicodeType, 0), # default=english
6: ('label', p.UnicodeType, 0),
7: ('u2f_counter', p.UVarintType, 0),
8: ('skip_backup', p.BoolType, 0),
9: ('no_backup', p.BoolType, 0),
}

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class SelfTest(p.MessageType):
MESSAGE_WIRE_TYPE = 32
def __init__(
self,
payload: bytes = None,
) -> None:
self.payload = payload
@classmethod
def get_fields(cls):
return {
1: ('payload', p.BytesType, 0),
}

View File

@ -0,0 +1,30 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
from .IdentityType import IdentityType
class SignIdentity(p.MessageType):
MESSAGE_WIRE_TYPE = 53
def __init__(
self,
identity: IdentityType = None,
challenge_hidden: bytes = None,
challenge_visual: str = None,
ecdsa_curve_name: str = None,
) -> None:
self.identity = identity
self.challenge_hidden = challenge_hidden
self.challenge_visual = challenge_visual
self.ecdsa_curve_name = ecdsa_curve_name
@classmethod
def get_fields(cls):
return {
1: ('identity', IdentityType, 0),
2: ('challenge_hidden', p.BytesType, 0),
3: ('challenge_visual', p.UnicodeType, 0),
4: ('ecdsa_curve_name', p.UnicodeType, 0),
}

View File

@ -0,0 +1,34 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
if __debug__:
try:
from typing import List
except ImportError:
List = None # type: ignore
class SignMessage(p.MessageType):
MESSAGE_WIRE_TYPE = 38
def __init__(
self,
address_n: List[int] = None,
message: bytes = None,
coin_name: str = None,
script_type: int = None,
) -> None:
self.address_n = address_n if address_n is not None else []
self.message = message
self.coin_name = coin_name
self.script_type = script_type
@classmethod
def get_fields(cls):
return {
1: ('address_n', p.UVarintType, p.FLAG_REPEATED),
2: ('message', p.BytesType, 0), # required
3: ('coin_name', p.UnicodeType, 0), # default=Bitcoin
4: ('script_type', p.UVarintType, 0), # default=SPENDADDRESS
}

View File

@ -0,0 +1,43 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class SignTx(p.MessageType):
MESSAGE_WIRE_TYPE = 15
def __init__(
self,
outputs_count: int = None,
inputs_count: int = None,
coin_name: str = None,
version: int = None,
lock_time: int = None,
expiry: int = None,
overwintered: bool = None,
version_group_id: int = None,
timestamp: int = None,
) -> None:
self.outputs_count = outputs_count
self.inputs_count = inputs_count
self.coin_name = coin_name
self.version = version
self.lock_time = lock_time
self.expiry = expiry
self.overwintered = overwintered
self.version_group_id = version_group_id
self.timestamp = timestamp
@classmethod
def get_fields(cls):
return {
1: ('outputs_count', p.UVarintType, 0), # required
2: ('inputs_count', p.UVarintType, 0), # required
3: ('coin_name', p.UnicodeType, 0), # default=Bitcoin
4: ('version', p.UVarintType, 0), # default=1
5: ('lock_time', p.UVarintType, 0), # default=0
6: ('expiry', p.UVarintType, 0),
7: ('overwintered', p.BoolType, 0),
8: ('version_group_id', p.UVarintType, 0),
9: ('timestamp', p.UVarintType, 0),
}

View File

@ -0,0 +1,25 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class SignedIdentity(p.MessageType):
MESSAGE_WIRE_TYPE = 54
def __init__(
self,
address: str = None,
public_key: bytes = None,
signature: bytes = None,
) -> None:
self.address = address
self.public_key = public_key
self.signature = signature
@classmethod
def get_fields(cls):
return {
1: ('address', p.UnicodeType, 0),
2: ('public_key', p.BytesType, 0),
3: ('signature', p.BytesType, 0),
}

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class Success(p.MessageType):
MESSAGE_WIRE_TYPE = 2
def __init__(
self,
message: str = None,
) -> None:
self.message = message
@classmethod
def get_fields(cls):
return {
1: ('message', p.UnicodeType, 0),
}

View File

@ -0,0 +1,64 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
from .TxInputType import TxInputType
from .TxOutputBinType import TxOutputBinType
from .TxOutputType import TxOutputType
if __debug__:
try:
from typing import List
except ImportError:
List = None # type: ignore
class TransactionType(p.MessageType):
def __init__(
self,
version: int = None,
inputs: List[TxInputType] = None,
bin_outputs: List[TxOutputBinType] = None,
lock_time: int = None,
outputs: List[TxOutputType] = None,
inputs_cnt: int = None,
outputs_cnt: int = None,
extra_data: bytes = None,
extra_data_len: int = None,
expiry: int = None,
overwintered: bool = None,
version_group_id: int = None,
timestamp: int = None,
) -> None:
self.version = version
self.inputs = inputs if inputs is not None else []
self.bin_outputs = bin_outputs if bin_outputs is not None else []
self.lock_time = lock_time
self.outputs = outputs if outputs is not None else []
self.inputs_cnt = inputs_cnt
self.outputs_cnt = outputs_cnt
self.extra_data = extra_data
self.extra_data_len = extra_data_len
self.expiry = expiry
self.overwintered = overwintered
self.version_group_id = version_group_id
self.timestamp = timestamp
@classmethod
def get_fields(cls):
return {
1: ('version', p.UVarintType, 0),
2: ('inputs', TxInputType, p.FLAG_REPEATED),
3: ('bin_outputs', TxOutputBinType, p.FLAG_REPEATED),
4: ('lock_time', p.UVarintType, 0),
5: ('outputs', TxOutputType, p.FLAG_REPEATED),
6: ('inputs_cnt', p.UVarintType, 0),
7: ('outputs_cnt', p.UVarintType, 0),
8: ('extra_data', p.BytesType, 0),
9: ('extra_data_len', p.UVarintType, 0),
10: ('expiry', p.UVarintType, 0),
11: ('overwintered', p.BoolType, 0),
12: ('version_group_id', p.UVarintType, 0),
13: ('timestamp', p.UVarintType, 0),
}

View File

@ -0,0 +1,21 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
from .TransactionType import TransactionType
class TxAck(p.MessageType):
MESSAGE_WIRE_TYPE = 22
def __init__(
self,
tx: TransactionType = None,
) -> None:
self.tx = tx
@classmethod
def get_fields(cls):
return {
1: ('tx', TransactionType, 0),
}

View File

@ -0,0 +1,59 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
from .MultisigRedeemScriptType import MultisigRedeemScriptType
if __debug__:
try:
from typing import List
except ImportError:
List = None # type: ignore
class TxInputType(p.MessageType):
def __init__(
self,
address_n: List[int] = None,
prev_hash: bytes = None,
prev_index: int = None,
script_sig: bytes = None,
sequence: int = None,
script_type: int = None,
multisig: MultisigRedeemScriptType = None,
amount: int = None,
decred_tree: int = None,
decred_script_version: int = None,
prev_block_hash_bip115: bytes = None,
prev_block_height_bip115: int = None,
) -> None:
self.address_n = address_n if address_n is not None else []
self.prev_hash = prev_hash
self.prev_index = prev_index
self.script_sig = script_sig
self.sequence = sequence
self.script_type = script_type
self.multisig = multisig
self.amount = amount
self.decred_tree = decred_tree
self.decred_script_version = decred_script_version
self.prev_block_hash_bip115 = prev_block_hash_bip115
self.prev_block_height_bip115 = prev_block_height_bip115
@classmethod
def get_fields(cls):
return {
1: ('address_n', p.UVarintType, p.FLAG_REPEATED),
2: ('prev_hash', p.BytesType, 0), # required
3: ('prev_index', p.UVarintType, 0), # required
4: ('script_sig', p.BytesType, 0),
5: ('sequence', p.UVarintType, 0), # default=4294967295
6: ('script_type', p.UVarintType, 0), # default=SPENDADDRESS
7: ('multisig', MultisigRedeemScriptType, 0),
8: ('amount', p.UVarintType, 0),
9: ('decred_tree', p.UVarintType, 0),
10: ('decred_script_version', p.UVarintType, 0),
11: ('prev_block_hash_bip115', p.BytesType, 0),
12: ('prev_block_height_bip115', p.UVarintType, 0),
}

View File

@ -0,0 +1,24 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class TxOutputBinType(p.MessageType):
def __init__(
self,
amount: int = None,
script_pubkey: bytes = None,
decred_script_version: int = None,
) -> None:
self.amount = amount
self.script_pubkey = script_pubkey
self.decred_script_version = decred_script_version
@classmethod
def get_fields(cls):
return {
1: ('amount', p.UVarintType, 0), # required
2: ('script_pubkey', p.BytesType, 0), # required
3: ('decred_script_version', p.UVarintType, 0),
}

View File

@ -0,0 +1,50 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
from .MultisigRedeemScriptType import MultisigRedeemScriptType
if __debug__:
try:
from typing import List
except ImportError:
List = None # type: ignore
class TxOutputType(p.MessageType):
def __init__(
self,
address: str = None,
address_n: List[int] = None,
amount: int = None,
script_type: int = None,
multisig: MultisigRedeemScriptType = None,
op_return_data: bytes = None,
decred_script_version: int = None,
block_hash_bip115: bytes = None,
block_height_bip115: int = None,
) -> None:
self.address = address
self.address_n = address_n if address_n is not None else []
self.amount = amount
self.script_type = script_type
self.multisig = multisig
self.op_return_data = op_return_data
self.decred_script_version = decred_script_version
self.block_hash_bip115 = block_hash_bip115
self.block_height_bip115 = block_height_bip115
@classmethod
def get_fields(cls):
return {
1: ('address', p.UnicodeType, 0),
2: ('address_n', p.UVarintType, p.FLAG_REPEATED),
3: ('amount', p.UVarintType, 0), # required
4: ('script_type', p.UVarintType, 0), # required
5: ('multisig', MultisigRedeemScriptType, 0),
6: ('op_return_data', p.BytesType, 0),
7: ('decred_script_version', p.UVarintType, 0),
8: ('block_hash_bip115', p.BytesType, 0),
9: ('block_height_bip115', p.UVarintType, 0),
}

View File

@ -0,0 +1,28 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
from .TxRequestDetailsType import TxRequestDetailsType
from .TxRequestSerializedType import TxRequestSerializedType
class TxRequest(p.MessageType):
MESSAGE_WIRE_TYPE = 21
def __init__(
self,
request_type: int = None,
details: TxRequestDetailsType = None,
serialized: TxRequestSerializedType = None,
) -> None:
self.request_type = request_type
self.details = details
self.serialized = serialized
@classmethod
def get_fields(cls):
return {
1: ('request_type', p.UVarintType, 0),
2: ('details', TxRequestDetailsType, 0),
3: ('serialized', TxRequestSerializedType, 0),
}

View File

@ -0,0 +1,27 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class TxRequestDetailsType(p.MessageType):
def __init__(
self,
request_index: int = None,
tx_hash: bytes = None,
extra_data_len: int = None,
extra_data_offset: int = None,
) -> None:
self.request_index = request_index
self.tx_hash = tx_hash
self.extra_data_len = extra_data_len
self.extra_data_offset = extra_data_offset
@classmethod
def get_fields(cls):
return {
1: ('request_index', p.UVarintType, 0),
2: ('tx_hash', p.BytesType, 0),
3: ('extra_data_len', p.UVarintType, 0),
4: ('extra_data_offset', p.UVarintType, 0),
}

View File

@ -0,0 +1,24 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class TxRequestSerializedType(p.MessageType):
def __init__(
self,
signature_index: int = None,
signature: bytes = None,
serialized_tx: bytes = None,
) -> None:
self.signature_index = signature_index
self.signature = signature
self.serialized_tx = serialized_tx
@classmethod
def get_fields(cls):
return {
1: ('signature_index', p.UVarintType, 0),
2: ('signature', p.BytesType, 0),
3: ('serialized_tx', p.BytesType, 0),
}

View File

@ -0,0 +1,28 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class VerifyMessage(p.MessageType):
MESSAGE_WIRE_TYPE = 39
def __init__(
self,
address: str = None,
signature: bytes = None,
message: bytes = None,
coin_name: str = None,
) -> None:
self.address = address
self.signature = signature
self.message = message
self.coin_name = coin_name
@classmethod
def get_fields(cls):
return {
1: ('address', p.UnicodeType, 0),
2: ('signature', p.BytesType, 0),
3: ('message', p.BytesType, 0),
4: ('coin_name', p.UnicodeType, 0), # default=Bitcoin
}

View File

@ -0,0 +1,7 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class WipeDevice(p.MessageType):
MESSAGE_WIRE_TYPE = 5

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class WordAck(p.MessageType):
MESSAGE_WIRE_TYPE = 47
def __init__(
self,
word: str = None,
) -> None:
self.word = word
@classmethod
def get_fields(cls):
return {
1: ('word', p.UnicodeType, 0), # required
}

View File

@ -0,0 +1,19 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
class WordRequest(p.MessageType):
MESSAGE_WIRE_TYPE = 46
def __init__(
self,
type: int = None,
) -> None:
self.type = type
@classmethod
def get_fields(cls):
return {
1: ('type', p.UVarintType, 0),
}

View File

@ -0,0 +1,5 @@
# Automatically generated by pb2py
# fmt: off
Plain = 0
Matrix9 = 1
Matrix6 = 2

View File

@ -0,0 +1,78 @@
# Automatically generated by pb2py
# fmt: off
from .Address import Address
from .ApplyFlags import ApplyFlags
from .ApplySettings import ApplySettings
from .BackupDevice import BackupDevice
from .ButtonAck import ButtonAck
from .ButtonRequest import ButtonRequest
from .Cancel import Cancel
from .ChangePin import ChangePin
from .ClearSession import ClearSession
from .DebugLinkDecision import DebugLinkDecision
from .DebugLinkFlashErase import DebugLinkFlashErase
from .DebugLinkGetState import DebugLinkGetState
from .DebugLinkLog import DebugLinkLog
from .DebugLinkMemory import DebugLinkMemory
from .DebugLinkMemoryRead import DebugLinkMemoryRead
from .DebugLinkMemoryWrite import DebugLinkMemoryWrite
from .DebugLinkState import DebugLinkState
from .DebugLinkStop import DebugLinkStop
from .Entropy import Entropy
from .EntropyAck import EntropyAck
from .EntropyRequest import EntropyRequest
from .Failure import Failure
from .Features import Features
from .FirmwareErase import FirmwareErase
from .FirmwareRequest import FirmwareRequest
from .FirmwareUpload import FirmwareUpload
from .GetAddress import GetAddress
from .GetEntropy import GetEntropy
from .GetFeatures import GetFeatures
from .GetPublicKey import GetPublicKey
from .HDNodePathType import HDNodePathType
from .HDNodeType import HDNodeType
from .IdentityType import IdentityType
from .Initialize import Initialize
from .LoadDevice import LoadDevice
from .MessageSignature import MessageSignature
from .MultisigRedeemScriptType import MultisigRedeemScriptType
from .PassphraseAck import PassphraseAck
from .PassphraseRequest import PassphraseRequest
from .PassphraseStateAck import PassphraseStateAck
from .PassphraseStateRequest import PassphraseStateRequest
from .PinMatrixAck import PinMatrixAck
from .PinMatrixRequest import PinMatrixRequest
from .Ping import Ping
from .PublicKey import PublicKey
from .RecoveryDevice import RecoveryDevice
from .ResetDevice import ResetDevice
from .SelfTest import SelfTest
from .SignIdentity import SignIdentity
from .SignMessage import SignMessage
from .SignTx import SignTx
from .SignedIdentity import SignedIdentity
from .Success import Success
from .TransactionType import TransactionType
from .TxAck import TxAck
from .TxInputType import TxInputType
from .TxOutputBinType import TxOutputBinType
from .TxOutputType import TxOutputType
from .TxRequest import TxRequest
from .TxRequestDetailsType import TxRequestDetailsType
from .TxRequestSerializedType import TxRequestSerializedType
from .VerifyMessage import VerifyMessage
from .WipeDevice import WipeDevice
from .WordAck import WordAck
from .WordRequest import WordRequest
from . import ButtonRequestType
from . import FailureType
from . import InputScriptType
from . import MessageType
from . import OutputScriptType
from . import PassphraseSourceType
from . import PinMatrixRequestType
from . import RecoveryDeviceType
from . import RequestType
from . import WordRequestType

View File

@ -0,0 +1,425 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
'''
Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
bytes, string, embedded message and repeated fields.
For de-sererializing (loading) protobuf types, object with `Reader`
interface is required:
>>> class Reader:
>>> def readinto(self, buffer):
>>> """
>>> Reads `len(buffer)` bytes into `buffer`, or raises `EOFError`.
>>> """
For serializing (dumping) protobuf types, object with `Writer` interface is
required:
>>> class Writer:
>>> def write(self, buffer):
>>> """
>>> Writes all bytes from `buffer`, or raises `EOFError`.
>>> """
'''
from io import BytesIO
from typing import Any, Optional
_UVARINT_BUFFER = bytearray(1)
def load_uvarint(reader):
buffer = _UVARINT_BUFFER
result = 0
shift = 0
byte = 0x80
while byte & 0x80:
if reader.readinto(buffer) == 0:
raise EOFError
byte = buffer[0]
result += (byte & 0x7F) << shift
shift += 7
return result
def dump_uvarint(writer, n):
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
buffer = _UVARINT_BUFFER
shifted = True
while shifted:
shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
writer.write(buffer)
n = shifted
# protobuf interleaved signed encoding:
# https://developers.google.com/protocol-buffers/docs/encoding#structure
# the idea is to save the sign in LSbit instead of twos-complement.
# so counting up, you go: 0, -1, 1, -2, 2, ... (as the first bit changes, sign flips)
#
# To achieve this with a twos-complement number:
# 1. shift left by 1, leaving LSbit free
# 2. if the number is negative, do bitwise negation.
# This keeps positive number the same, and converts negative from twos-complement
# to the appropriate value, while setting the sign bit.
#
# The original algorithm makes use of the fact that arithmetic (signed) shift
# keeps the sign bits, so for a n-bit number, (x >> n) gets us "all sign bits".
# Then you can take "number XOR all-sign-bits", which is XOR 0 (identity) for positive
# and XOR 1 (bitwise negation) for negative. Cute and efficient.
#
# But this is harder in Python because we don't natively know the bit size of the number.
# So we have to branch on whether the number is negative.
def sint_to_uint(sint):
res = sint << 1
if sint < 0:
res = ~res
return res
def uint_to_sint(uint):
sign = uint & 1
res = uint >> 1
if sign:
res = ~res
return res
class UVarintType:
WIRE_TYPE = 0
class SVarintType:
WIRE_TYPE = 0
class BoolType:
WIRE_TYPE = 0
class BytesType:
WIRE_TYPE = 2
class UnicodeType:
WIRE_TYPE = 2
class MessageType:
WIRE_TYPE = 2
@classmethod
def get_fields(cls):
return {}
def __init__(self, **kwargs):
for kw in kwargs:
setattr(self, kw, kwargs[kw])
self._fill_missing()
def __eq__(self, rhs):
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
def __repr__(self):
d = {}
for key, value in self.__dict__.items():
if value is None or value == []:
continue
d[key] = value
return "<%s: %s>" % (self.__class__.__name__, d)
def __iter__(self):
return iter(self.keys())
def keys(self):
return (name for name, _, _ in self.get_fields().values())
def __getitem__(self, key):
return getattr(self, key)
def _fill_missing(self):
# fill missing fields
for fname, ftype, fflags in self.get_fields().values():
if not hasattr(self, fname):
if fflags & FLAG_REPEATED:
setattr(self, fname, [])
else:
setattr(self, fname, None)
def CopyFrom(self, obj):
self.__dict__ = obj.__dict__.copy()
def ByteSize(self):
data = BytesIO()
dump_message(data, self)
return len(data.getvalue())
class LimitedReader:
def __init__(self, reader, limit):
self.reader = reader
self.limit = limit
def readinto(self, buf):
if self.limit < len(buf):
raise EOFError
else:
nread = self.reader.readinto(buf)
self.limit -= nread
return nread
class CountingWriter:
def __init__(self):
self.size = 0
def write(self, buf):
nwritten = len(buf)
self.size += nwritten
return nwritten
FLAG_REPEATED = 1
def load_message(reader, msg_type):
fields = msg_type.get_fields()
msg = msg_type()
while True:
try:
fkey = load_uvarint(reader)
except EOFError:
break # no more fields to load
ftag = fkey >> 3
wtype = fkey & 7
field = fields.get(ftag, None)
if field is None: # unknown field, skip it
if wtype == 0:
load_uvarint(reader)
elif wtype == 2:
ivalue = load_uvarint(reader)
reader.readinto(bytearray(ivalue))
else:
raise ValueError
continue
fname, ftype, fflags = field
if wtype != ftype.WIRE_TYPE:
raise TypeError # parsed wire type differs from the schema
ivalue = load_uvarint(reader)
if ftype is UVarintType:
fvalue = ivalue
elif ftype is SVarintType:
fvalue = uint_to_sint(ivalue)
elif ftype is BoolType:
fvalue = bool(ivalue)
elif ftype is BytesType:
buf = bytearray(ivalue)
reader.readinto(buf)
fvalue = bytes(buf)
elif ftype is UnicodeType:
buf = bytearray(ivalue)
reader.readinto(buf)
fvalue = buf.decode()
elif issubclass(ftype, MessageType):
fvalue = load_message(LimitedReader(reader, ivalue), ftype)
else:
raise TypeError # field type is unknown
if fflags & FLAG_REPEATED:
pvalue = getattr(msg, fname)
pvalue.append(fvalue)
fvalue = pvalue
setattr(msg, fname, fvalue)
return msg
def dump_message(writer, msg):
repvalue = [0]
mtype = msg.__class__
fields = mtype.get_fields()
for ftag in fields:
fname, ftype, fflags = fields[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
fkey = (ftag << 3) | ftype.WIRE_TYPE
if not fflags & FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
for svalue in fvalue:
dump_uvarint(writer, fkey)
if ftype is UVarintType:
dump_uvarint(writer, svalue)
elif ftype is SVarintType:
dump_uvarint(writer, sint_to_uint(svalue))
elif ftype is BoolType:
dump_uvarint(writer, int(svalue))
elif ftype is BytesType:
dump_uvarint(writer, len(svalue))
writer.write(svalue)
elif ftype is UnicodeType:
if not isinstance(svalue, bytes):
svalue = svalue.encode()
dump_uvarint(writer, len(svalue))
writer.write(svalue)
elif issubclass(ftype, MessageType):
counter = CountingWriter()
dump_message(counter, svalue)
dump_uvarint(writer, counter.size)
dump_message(writer, svalue)
else:
raise TypeError
def format_message(
pb: MessageType,
indent: int = 0,
sep: str = " " * 4,
truncate_after: Optional[int] = 256,
truncate_to: Optional[int] = 64,
) -> str:
def mostly_printable(bytes):
if not bytes:
return True
printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E)
return printable / len(bytes) > 0.8
def pformat_value(value: Any, indent: int) -> str:
level = sep * indent
leadin = sep * (indent + 1)
if isinstance(value, MessageType):
return format_message(value, indent, sep)
if isinstance(value, list):
# short list of simple values
if not value or not isinstance(value[0], MessageType):
return repr(value)
# long list, one line per entry
lines = ["[", level + "]"]
lines[1:1] = [leadin + pformat_value(x, indent + 1) + "," for x in value]
return "\n".join(lines)
if isinstance(value, dict):
lines = ["{"]
for key, val in sorted(value.items()):
if val is None or val == []:
continue
lines.append(leadin + key + ": " + pformat_value(val, indent + 1) + ",")
lines.append(level + "}")
return "\n".join(lines)
if isinstance(value, (bytes, bytearray)):
length = len(value)
suffix = ""
if truncate_after and length > truncate_after:
suffix = "..."
value = value[: truncate_to or 0]
if mostly_printable(value):
output = repr(value)
else:
output = "0x" + value.hex()
return "{} bytes {}{}".format(length, output, suffix)
return repr(value)
return "{name} ({size} bytes) {content}".format(
name=pb.__class__.__name__,
size=pb.ByteSize(),
content=pformat_value(pb.__dict__, indent),
)
def value_to_proto(ftype, value):
if issubclass(ftype, MessageType):
raise TypeError("value_to_proto only converts simple values")
if ftype in (UVarintType, SVarintType):
return int(value)
if ftype is BoolType:
return bool(value)
if ftype is UnicodeType:
return str(value)
if ftype is BytesType:
if isinstance(value, str):
return bytes.fromhex(value)
elif isinstance(value, bytes):
return value
else:
raise TypeError("can't convert {} value to bytes".format(type(value)))
def dict_to_proto(message_type, d):
params = {}
for fname, ftype, fflags in message_type.get_fields().values():
repeated = fflags & FLAG_REPEATED
value = d.get(fname)
if value is None:
continue
if not repeated:
value = [value]
if issubclass(ftype, MessageType):
function = dict_to_proto
else:
function = value_to_proto
newvalue = [function(ftype, v) for v in value]
if not repeated:
newvalue = newvalue[0]
params[fname] = newvalue
return message_type(**params)
def to_dict(msg):
res = {}
for key, value in msg.__dict__.items():
if value is None or value == []:
continue
if isinstance(value, MessageType):
value = to_dict(value)
res[key] = value
return res

View File

@ -0,0 +1,265 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import functools
import hashlib
import re
import struct
import unicodedata
from typing import List, NewType
from .exceptions import TrezorFailure
CallException = TrezorFailure
HARDENED_FLAG = 1 << 31
Address = NewType("Address", List[int])
def H_(x: int) -> int:
"""
Shortcut function that "hardens" a number in a BIP44 path.
"""
return x | HARDENED_FLAG
def btc_hash(data):
"""
Double-SHA256 hash as used in BTC
"""
return hashlib.sha256(hashlib.sha256(data).digest()).digest()
def hash_160(public_key):
md = hashlib.new("ripemd160")
md.update(hashlib.sha256(public_key).digest())
return md.digest()
def hash_160_to_bc_address(h160, address_type):
vh160 = struct.pack("<B", address_type) + h160
h = btc_hash(vh160)
addr = vh160 + h[0:4]
return b58encode(addr)
def compress_pubkey(public_key):
if public_key[0] == 4:
return bytes((public_key[64] & 1) + 2) + public_key[1:33]
raise ValueError("Pubkey is already compressed")
def public_key_to_bc_address(public_key, address_type, compress=True):
if public_key[0] == "\x04" and compress:
public_key = compress_pubkey(public_key)
h160 = hash_160(public_key)
return hash_160_to_bc_address(h160, address_type)
__b58chars = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
__b58base = len(__b58chars)
def b58encode(v):
""" encode v, which is a string of bytes, to base58."""
long_value = 0
for c in v:
long_value = long_value * 256 + c
result = ""
while long_value >= __b58base:
div, mod = divmod(long_value, __b58base)
result = __b58chars[mod] + result
long_value = div
result = __b58chars[long_value] + result
# Bitcoin does a little leading-zero-compression:
# leading 0-bytes in the input become leading-1s
nPad = 0
for c in v:
if c == 0:
nPad += 1
else:
break
return (__b58chars[0] * nPad) + result
def b58decode(v, length=None):
""" decode v into a string of len bytes."""
if isinstance(v, bytes):
v = v.decode()
for c in v:
if c not in __b58chars:
raise ValueError("invalid Base58 string")
long_value = 0
for (i, c) in enumerate(v[::-1]):
long_value += __b58chars.find(c) * (__b58base ** i)
result = b""
while long_value >= 256:
div, mod = divmod(long_value, 256)
result = struct.pack("B", mod) + result
long_value = div
result = struct.pack("B", long_value) + result
nPad = 0
for c in v:
if c == __b58chars[0]:
nPad += 1
else:
break
result = b"\x00" * nPad + result
if length is not None and len(result) != length:
return None
return result
def b58check_encode(v):
checksum = btc_hash(v)[:4]
return b58encode(v + checksum)
def b58check_decode(v, length=None):
dec = b58decode(v, length)
data, checksum = dec[:-4], dec[-4:]
if btc_hash(data)[:4] != checksum:
raise ValueError("invalid checksum")
return data
def parse_path(nstr: str) -> Address:
"""
Convert BIP32 path string to list of uint32 integers with hardened flags.
Several conventions are supported to set the hardened flag: -1, 1', 1h
e.g.: "0/1h/1" -> [0, 0x80000001, 1]
:param nstr: path string
:return: list of integers
"""
if not nstr:
return []
n = nstr.split("/")
# m/a/b/c => a/b/c
if n[0] == "m":
n = n[1:]
def str_to_harden(x: str) -> int:
if x.startswith("-"):
return H_(abs(int(x)))
elif x.endswith(("h", "'")):
return H_(int(x[:-1]))
else:
return int(x)
try:
return [str_to_harden(x) for x in n]
except Exception:
raise ValueError("Invalid BIP32 path", nstr)
def normalize_nfc(txt):
"""
Normalize message to NFC and return bytes suitable for protobuf.
This seems to be bitcoin-qt standard of doing things.
"""
if isinstance(txt, bytes):
txt = txt.decode()
return unicodedata.normalize("NFC", txt).encode()
class expect:
# Decorator checks if the method
# returned one of expected protobuf messages
# or raises an exception
def __init__(self, expected, field=None):
self.expected = expected
self.field = field
def __call__(self, f):
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
ret = f(*args, **kwargs)
if not isinstance(ret, self.expected):
raise RuntimeError(
"Got %s, expected %s" % (ret.__class__, self.expected)
)
if self.field is not None:
return getattr(ret, self.field)
else:
return ret
return wrapped_f
def session(f):
# Decorator wraps a BaseClient method
# with session activation / deactivation
@functools.wraps(f)
def wrapped_f(client, *args, **kwargs):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
client.open()
try:
return f(client, *args, **kwargs)
finally:
client.close()
return wrapped_f
# de-camelcasifier
# https://stackoverflow.com/a/1176023/222189
FIRST_CAP_RE = re.compile("(.)([A-Z][a-z]+)")
ALL_CAP_RE = re.compile("([a-z0-9])([A-Z])")
def from_camelcase(s):
s = FIRST_CAP_RE.sub(r"\1_\2", s)
return ALL_CAP_RE.sub(r"\1_\2", s).lower()
def dict_from_camelcase(d, renames=None):
if not isinstance(d, dict):
return d
if renames is None:
renames = {}
res = {}
for key, value in d.items():
newkey = from_camelcase(key)
renamed_key = renames.get(newkey) or renames.get(key)
if renamed_key:
newkey = renamed_key
if isinstance(value, list):
res[newkey] = [dict_from_camelcase(v, renames) for v in value]
else:
res[newkey] = dict_from_camelcase(value, renames)
return res

View File

@ -0,0 +1,149 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
from typing import Iterable, List, Type
from ..exceptions import TrezorException
from ..protobuf import MessageType
LOG = logging.getLogger(__name__)
# USB vendor/product IDs for Trezors
DEV_TREZOR1 = (0x534C, 0x0001)
DEV_TREZOR2 = (0x1209, 0x53C1)
DEV_TREZOR2_BL = (0x1209, 0x53C0)
TREZORS = {DEV_TREZOR1, DEV_TREZOR2, DEV_TREZOR2_BL}
UDEV_RULES_STR = """
Do you have udev rules installed?
https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
""".strip()
class TransportException(TrezorException):
pass
class Transport:
"""Raw connection to a Trezor device.
Transport subclass represents a kind of communication link: WebUSB
or USB-HID connection, or UDP socket of listening emulator(s).
It can also enumerate devices available over this communication link, and return
them as instances.
Transport instance is a thing that:
- can be identified and requested by a string URI-like path
- can open and close sessions, which enclose related operations
- can read and write protobuf messages
You need to implement a new Transport subclass if you invent a new way to connect
a Trezor device to a computer.
"""
PATH_PREFIX = None # type: str
ENABLED = False
def __str__(self) -> str:
return self.get_path()
def get_path(self) -> str:
raise NotImplementedError
def begin_session(self) -> None:
raise NotImplementedError
def end_session(self) -> None:
raise NotImplementedError
def read(self) -> MessageType:
raise NotImplementedError
def write(self, message: MessageType) -> None:
raise NotImplementedError
@classmethod
def enumerate(cls) -> Iterable["Transport"]:
raise NotImplementedError
@classmethod
def find_by_path(cls, path: str, prefix_search: bool = False) -> "Transport":
for device in cls.enumerate():
if (
path is None
or device.get_path() == path
or (prefix_search and device.get_path().startswith(path))
):
return device
raise TransportException(
"{} device not found: {}".format(cls.PATH_PREFIX, path)
)
def all_transports() -> Iterable[Type[Transport]]:
from .hid import HidTransport
from .udp import UdpTransport
from .webusb import WebUsbTransport
return set(
cls
for cls in (HidTransport, UdpTransport, WebUsbTransport)
if cls.ENABLED
)
def enumerate_devices() -> Iterable[Transport]:
devices = [] # type: List[Transport]
for transport in all_transports():
name = transport.__name__
try:
found = list(transport.enumerate())
LOG.info("Enumerating {}: found {} devices".format(name, len(found)))
devices.extend(found)
except NotImplementedError:
LOG.error("{} does not implement device enumeration".format(name))
except Exception as e:
excname = e.__class__.__name__
LOG.error("Failed to enumerate {}. {}: {}".format(name, excname, e))
return devices
def get_transport(path: str = None, prefix_search: bool = False) -> Transport:
if path is None:
try:
return next(iter(enumerate_devices()))
except StopIteration:
raise TransportException("No TREZOR device found") from None
# Find whether B is prefix of A (transport name is part of the path)
# or A is prefix of B (path is a prefix, or a name, of transport).
# This naively expects that no two transports have a common prefix.
def match_prefix(a: str, b: str) -> bool:
return a.startswith(b) or b.startswith(a)
LOG.info(
"looking for device by {}: {}".format(
"prefix" if prefix_search else "full path", path
)
)
transports = [t for t in all_transports() if match_prefix(path, t.PATH_PREFIX)]
if transports:
return transports[0].find_by_path(path, prefix_search=prefix_search)
raise TransportException("Could not find device by path: {}".format(path))

View File

@ -0,0 +1,161 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import sys
import time
from typing import Any, Dict, Iterable
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__)
try:
import hid
except Exception as e:
LOG.info("HID transport is disabled: {}".format(e))
hid = None
HidDevice = Dict[str, Any]
HidDeviceHandle = Any
class HidHandle:
def __init__(
self, path: bytes, serial: str, probe_hid_version: bool = False
) -> None:
self.path = path
self.serial = serial
self.handle = None # type: HidDeviceHandle
self.hid_version = None if probe_hid_version else 2
def open(self) -> None:
self.handle = hid.device()
try:
self.handle.open_path(self.path)
except (IOError, OSError) as e:
if sys.platform.startswith("linux"):
e.args = e.args + (UDEV_RULES_STR,)
raise e
# On some platforms, HID path stays the same over device reconnects.
# That means that someone could unplug a Trezor, plug a different one
# and we wouldn't even know.
# So we check that the serial matches what we expect.
serial = self.handle.get_serial_number_string()
if serial != self.serial:
self.handle.close()
self.handle = None
raise TransportException(
"Unexpected device {} on path {}".format(serial, self.path.decode())
)
self.handle.set_nonblocking(True)
if self.hid_version is None:
self.hid_version = self.probe_hid_version()
def close(self) -> None:
if self.handle is not None:
# reload serial, because device.wipe() can reset it
self.serial = self.handle.get_serial_number_string()
self.handle.close()
self.handle = None
def write_chunk(self, chunk: bytes) -> None:
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
if self.hid_version == 2:
self.handle.write(b"\0" + bytearray(chunk))
else:
self.handle.write(chunk)
def read_chunk(self) -> bytes:
while True:
chunk = self.handle.read(64)
if chunk:
break
else:
time.sleep(0.001)
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
return bytes(chunk)
def probe_hid_version(self) -> int:
n = self.handle.write([0, 63] + [0xFF] * 63)
if n == 65:
return 2
n = self.handle.write([63] + [0xFF] * 63)
if n == 64:
return 1
raise TransportException("Unknown HID version")
class HidTransport(ProtocolBasedTransport):
"""
HidTransport implements transport over USB HID interface.
"""
PATH_PREFIX = "hid"
ENABLED = hid is not None
def __init__(self, device: HidDevice) -> None:
self.device = device
self.handle = HidHandle(device["path"], device["serial_number"])
protocol = ProtocolV1(self.handle)
super().__init__(protocol=protocol)
def get_path(self) -> str:
return "%s:%s" % (self.PATH_PREFIX, self.device["path"].decode())
@classmethod
def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]:
devices = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id != DEV_TREZOR1:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
if self.protocol.VERSION >= 2:
# use the same device
return self
else:
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def is_wirelink(dev: HidDevice) -> bool:
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0
def is_debuglink(dev: HidDevice) -> bool:
return dev["usage_page"] == 0xFF01 or dev["interface_number"] == 1

View File

@ -0,0 +1,206 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import os
import struct
from io import BytesIO
from typing import Tuple
from typing_extensions import Protocol as StructuralType
from . import Transport
from .. import mapping, protobuf
REPLEN = 64
V2_FIRST_CHUNK = 0x01
V2_NEXT_CHUNK = 0x02
V2_BEGIN_SESSION = 0x03
V2_END_SESSION = 0x04
LOG = logging.getLogger(__name__)
class Handle(StructuralType):
"""PEP 544 structural type for Handle functionality.
(called a "Protocol" in the proposed PEP, name which is impractical here)
Handle is a "physical" layer for a protocol.
It can open/close a connection and read/write bare data in 64-byte chunks.
Functionally we gain nothing from making this an (abstract) base class for handle
implementations, so this definition is for type hinting purposes only. You can,
but don't have to, inherit from it.
"""
def open(self) -> None:
...
def close(self) -> None:
...
def read_chunk(self) -> bytes:
...
def write_chunk(self, chunk: bytes) -> None:
...
class Protocol:
"""Wire protocol that can communicate with a Trezor device, given a Handle.
A Protocol implements the part of the Transport API that relates to communicating
logical messages over a physical layer. It is a thing that can:
- open and close sessions,
- send and receive protobuf messages,
given the ability to:
- open and close physical connections,
- and send and receive binary chunks.
We declare a protocol version (we have implementations of v1 and v2).
For now, the class also handles session counting and opening the underlying Handle.
This will probably be removed in the future.
We will need a new Protocol class if we change the way a Trezor device encapsulates
its messages.
"""
VERSION = None # type: int
def __init__(self, handle: Handle) -> None:
self.handle = handle
self.session_counter = 0
# XXX we might be able to remove this now that TrezorClient does session handling
def begin_session(self) -> None:
if self.session_counter == 0:
self.handle.open()
self.session_counter += 1
def end_session(self) -> None:
if self.session_counter == 1:
self.handle.close()
self.session_counter -= 1
def read(self) -> protobuf.MessageType:
raise NotImplementedError
def write(self, message: protobuf.MessageType) -> None:
raise NotImplementedError
class ProtocolBasedTransport(Transport):
"""Transport that implements its communications through a Protocol.
Intended as a base class for implementations that proxy their communication
operations to a Protocol.
"""
def __init__(self, protocol: Protocol) -> None:
self.protocol = protocol
def write(self, message: protobuf.MessageType) -> None:
self.protocol.write(message)
def read(self) -> protobuf.MessageType:
return self.protocol.read()
def begin_session(self) -> None:
self.protocol.begin_session()
def end_session(self) -> None:
self.protocol.end_session()
class ProtocolV1(Protocol):
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
Does not understand sessions.
"""
VERSION = 1
def write(self, msg: protobuf.MessageType) -> None:
LOG.debug(
"sending message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
data = BytesIO()
protobuf.dump_message(data, msg)
ser = data.getvalue()
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
buffer = bytearray(b"##" + header + ser)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: REPLEN - 1]
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
buffer = buffer[63:]
def read(self) -> protobuf.MessageType:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
# Strip padding
data = BytesIO(buffer[:datalen])
# Parse to protobuf
msg = protobuf.load_message(data, mapping.get_class(msg_type))
LOG.debug(
"received message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
return msg
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + headerlen :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]
def get_protocol(handle: Handle, want_v2: bool) -> Protocol:
"""Make a Protocol instance for the given handle.
Each transport can have a preference for using a particular protocol version.
This preference is overridable through `TREZOR_PROTOCOL_V1` environment variable,
which forces the library to use V1 anyways.
As of 11/2018, no devices support V2, so we enforce V1 here. It is still possible
to set `TREZOR_PROTOCOL_V1=0` and thus enable V2 protocol for transports that ask
for it (i.e., USB transports for Trezor T).
"""
return ProtocolV1(handle)

View File

@ -0,0 +1,122 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import socket
from typing import Iterable, Optional, cast
from . import TransportException
from .protocol import ProtocolBasedTransport, get_protocol
class UdpTransport(ProtocolBasedTransport):
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 21324
PATH_PREFIX = "udp"
ENABLED = True
def __init__(self, device: str = None) -> None:
if not device:
host = UdpTransport.DEFAULT_HOST
port = UdpTransport.DEFAULT_PORT
else:
devparts = device.split(":")
host = devparts[0]
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
self.device = (host, port)
self.socket = None # type: Optional[socket.socket]
protocol = get_protocol(self, want_v2=False)
super().__init__(protocol=protocol)
def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport("{}:{}".format(host, port + 1))
@classmethod
def _try_path(cls, path: str) -> "UdpTransport":
d = cls(path)
try:
d.open()
if d._ping():
return d
else:
raise TransportException(
"No TREZOR device found at address {}".format(path)
)
finally:
d.close()
@classmethod
def enumerate(cls) -> Iterable["UdpTransport"]:
default_path = "{}:{}".format(cls.DEFAULT_HOST, cls.DEFAULT_PORT)
try:
return [cls._try_path(default_path)]
except TransportException:
return []
@classmethod
def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport":
if prefix_search:
return cast(UdpTransport, super().find_by_path(path, prefix_search))
# This is *technically* type-able: mark `find_by_path` as returning
# the same type from which `cls` comes from.
# Mypy can't handle that though, so here we are.
else:
path = path.replace("{}:".format(cls.PATH_PREFIX), "")
return cls._try_path(path)
def open(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.connect(self.device)
self.socket.settimeout(10)
def close(self) -> None:
if self.socket is not None:
self.socket.close()
self.socket = None
def _ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
resp = self.socket.recv(8)
except Exception:
pass
return resp == b"PONGPONG"
def write_chunk(self, chunk: bytes) -> None:
assert self.socket is not None
if len(chunk) != 64:
raise TransportException("Unexpected data length")
self.socket.sendall(chunk)
def read_chunk(self) -> bytes:
assert self.socket is not None
while True:
try:
chunk = self.socket.recv(64)
break
except socket.timeout:
continue
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
return bytearray(chunk)

View File

@ -0,0 +1,156 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import atexit
import logging
import sys
import time
from typing import Iterable, Optional
from . import TREZORS, UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__)
try:
import usb1
except Exception as e:
LOG.warning("WebUSB transport is disabled: {}".format(e))
usb1 = None
INTERFACE = 0
ENDPOINT = 1
DEBUG_INTERFACE = 1
DEBUG_ENDPOINT = 2
class WebUsbHandle:
def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None:
self.device = device
self.interface = DEBUG_INTERFACE if debug else INTERFACE
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
self.count = 0
self.handle = None # type: Optional[usb1.USBDeviceHandle]
def open(self) -> None:
self.handle = self.device.open()
if self.handle is None:
if sys.platform.startswith("linux"):
args = (UDEV_RULES_STR,)
else:
args = ()
raise IOError("Cannot open device", *args)
self.handle.claimInterface(self.interface)
def close(self) -> None:
if self.handle is not None:
self.handle.releaseInterface(self.interface)
self.handle.close()
self.handle = None
def write_chunk(self, chunk: bytes) -> None:
assert self.handle is not None
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
self.handle.interruptWrite(self.endpoint, chunk)
def read_chunk(self) -> bytes:
assert self.handle is not None
endpoint = 0x80 | self.endpoint
while True:
chunk = self.handle.interruptRead(endpoint, 64)
if chunk:
break
else:
time.sleep(0.001)
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
return chunk
class WebUsbTransport(ProtocolBasedTransport):
"""
WebUsbTransport implements transport over WebUSB interface.
"""
PATH_PREFIX = "webusb"
ENABLED = usb1 is not None
context = None
def __init__(
self, device: str, handle: WebUsbHandle = None, debug: bool = False
) -> None:
if handle is None:
handle = WebUsbHandle(device, debug)
self.device = device
self.handle = handle
self.debug = debug
super().__init__(protocol=ProtocolV1(handle))
def get_path(self) -> str:
return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device))
@classmethod
def enumerate(cls) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
devices = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in TREZORS:
continue
if not is_vendor_class(dev):
continue
try:
# workaround for issue #223:
# on certain combinations of Windows USB drivers and libusb versions,
# Trezor is returned twice (possibly because Windows know it as both
# a HID and a WebUSB device), and one of the returned devices is
# non-functional.
dev.getProduct()
devices.append(WebUsbTransport(dev))
except usb1.USBErrorNotSupported:
pass
return devices
def find_debug(self) -> "WebUsbTransport":
if self.protocol.VERSION >= 2:
# TODO test this
# XXX this is broken right now because sessions don't really work
# For v2 protocol, use the same WebUSB interface with a different session
return WebUsbTransport(self.device, self.handle)
else:
# For v1 protocol, find debug USB interface for the same serial number
return WebUsbTransport(self.device, debug=True)
def is_vendor_class(dev: "usb1.USBDevice") -> bool:
configurationId = 0
altSettingId = 0
return (
dev[configurationId][INTERFACE][altSettingId].getClass()
== usb1.libusb1.LIBUSB_CLASS_VENDOR_SPEC
)
def dev_to_str(dev: "usb1.USBDevice") -> str:
return ":".join(
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
)

View File

@ -0,0 +1,101 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import os
import sys
from mnemonic import Mnemonic
from . import device
from .exceptions import Cancelled
from .messages import PinMatrixRequestType, WordRequestType
PIN_MATRIX_DESCRIPTION = """
Use the numeric keypad to describe number positions. The layout is:
7 8 9
4 5 6
1 2 3
""".strip()
RECOVERY_MATRIX_DESCRIPTION = """
Use the numeric keypad to describe positions.
For the word list use only left and right keys.
Use backspace to correct an entry.
The keypad layout is:
7 8 9 7 | 9
4 5 6 4 | 6
1 2 3 1 | 3
""".strip()
PIN_GENERIC = None
PIN_CURRENT = PinMatrixRequestType.Current
PIN_NEW = PinMatrixRequestType.NewFirst
PIN_CONFIRM = PinMatrixRequestType.NewSecond
def echo(msg):
print(msg, file=sys.stderr)
def prompt(msg):
return input(msg)
class PassphraseUI:
def __init__(self, passphrase):
self.passphrase = passphrase
self.pinmatrix_shown = False
self.prompt_shown = False
self.always_prompt = False
def button_request(self, code):
if not self.prompt_shown:
echo("Please confirm action on your Trezor device")
if not self.always_prompt:
self.prompt_shown = True
def get_pin(self, code=None):
raise NotImplementedError('get_pin is not needed')
def get_passphrase(self):
return self.passphrase
def mnemonic_words(expand=False, language="english"):
if expand:
wordlist = Mnemonic(language).wordlist
else:
wordlist = set()
def expand_word(word):
if not expand:
return word
if word in wordlist:
return word
matches = [w for w in wordlist if w.startswith(word)]
if len(matches) == 1:
return word
echo("Choose one of: " + ", ".join(matches))
raise KeyError(word)
def get_word(type):
assert type == WordRequestType.Plain
while True:
try:
word = prompt("Enter one word of mnemonic")
return expand_word(word)
except KeyError:
pass
return get_word

View File

@ -15,12 +15,14 @@ setuptools.setup(
packages=setuptools.find_packages(exclude=['docs', 'test']),
install_requires=[
'hidapi', # HID API needed in general
'trezor>=0.11.0', # Trezor One
'btchip-python', # Ledger Nano S
'keepkey>=6.0.1', # KeepKey
'ckcc-protocol[cli]', # Coldcard
'pyaes',
'ecdsa', # Needed for Ledger but their library does not install it
'typing_extensions>=3.7',
'mnemonic>=0.18.0',
'libusb1'
],
python_requires='>=3',
classifiers=[

View File

@ -11,10 +11,10 @@ import time
import unittest
from bitcoinrpc.authproxy import AuthServiceProxy, JSONRPCException
from trezorlib.transport import enumerate_devices
from trezorlib.transport.udp import UdpTransport
from trezorlib.debuglink import DebugUI, TrezorClientDebugLink, load_device_by_mnemonic, load_device_by_xprv
from trezorlib import device, messages
from hwilib.devices.trezorlib.transport import enumerate_devices
from hwilib.devices.trezorlib.transport.udp import UdpTransport
from hwilib.devices.trezorlib.debuglink import DebugUI, TrezorClientDebugLink, load_device_by_mnemonic, load_device_by_xprv
from hwilib.devices.trezorlib import device, messages
from test_device import DeviceEmulator, DeviceTestCase, start_bitcoind, TestDeviceConnect, TestDisplayAddress, TestGetKeypool, TestSignMessage, TestSignTx
from hwilib.cli import process_commands