diff --git a/hwilib/devices/digitalbitbox.py b/hwilib/devices/digitalbitbox.py index 1e459fd..e214474 100644 --- a/hwilib/devices/digitalbitbox.py +++ b/hwilib/devices/digitalbitbox.py @@ -6,6 +6,7 @@ import json import base64 import pyaes import hashlib +import hmac import os import binascii import logging @@ -40,27 +41,31 @@ def aes_decrypt_with_iv(key, iv, data): s = aes.feed(data) + aes.feed() # empty aes.feed() strips pkcs padding return s - -def EncodeAES(secret, s): +def encrypt_aes(secret, s): iv = bytes(os.urandom(16)) ct = aes_encrypt_with_iv(secret, iv, s) e = iv + ct - return base64.b64encode(e) + return e - -def DecodeAES(secret, e): - e = bytes(base64.b64decode(e)) +def decrypt_aes(secret, e): iv, e = e[:16], e[16:] s = aes_decrypt_with_iv(secret, iv, e) return s - def sha256(x): return hashlib.sha256(x).digest() +def sha512(x): + return hashlib.sha512(x).digest() -def Hash(x): - return sha256(sha256(x.encode('utf-8'))) +def double_hash(x): + if type(x) is not bytearray: x=x.encode('utf-8') + return sha256(sha256(x)) + +def derive_keys(x): + h = double_hash(x) + h = sha512(h) + return (h[:len(h)//2], h[len(h)//2:]) def to_string(x, enc): if isinstance(x, (bytes, bytearray)): @@ -106,11 +111,18 @@ def read_frame(device): assert cmd == HWW_CMD, '- USB command frame mismatch' return data +def get_firmware_version(device): + serial_number = device.get_serial_number_string() + split_serial = serial_number.split(':') + firm_ver = split_serial[1][1:] # Version is vX.Y.Z, we just need X.Y.Z + split_ver = firm_ver.split('.') + return (int(split_ver[0]), int(split_ver[1]), int(split_ver[2])) # major, minor, revision + def send_plain(msg, device): reply = "" try: - serial_number = device.get_serial_number_string() - if "v2.0." in serial_number or "v1." in serial_number: + firm_ver = get_firmware_version(device) + if (firm_ver[0] == 2 and firm_ver[1] == 0) or (firm_ver[0] == 1): hidBufSize = 4096 device.write('\0' + msg + '\0' * (hidBufSize - len(msg))) r = bytearray() @@ -130,11 +142,27 @@ def send_plain(msg, device): def send_encrypt(msg, password, device): reply = "" try: - secret = Hash(password) - msg = EncodeAES(secret, msg) - reply = send_plain(msg, device) + firm_ver = get_firmware_version(device) + if firm_ver[0] >= 5: + encryption_key, authentication_key = derive_keys(password) + msg = encrypt_aes(encryption_key, msg) + hmac_digest = hmac.new(authentication_key, msg, digestmod=hashlib.sha256).digest() + authenticated_msg = base64.b64encode(msg + hmac_digest) + else: + encryption_key = double_hash(password) + authenticated_msg = base64.b64encode(encrypt_aes(encryption_key, msg)) + reply = send_plain(authenticated_msg, device) if 'ciphertext' in reply: - reply = DecodeAES(secret, ''.join(reply["ciphertext"])) + b64_unencoded = bytes(base64.b64decode(''.join(reply["ciphertext"]))) + if firm_ver[0] >= 5: + msg = b64_unencoded[:-32] + reply_hmac = b64_unencoded[-32:] + hmac_calculated = hmac.new(authentication_key, msg, digestmod=hashlib.sha256).digest() + if not hmac.compare_digest(reply_hmac, hmac_calculated): + raise Exception("Failed to validate HMAC") + else: + msg = b64_unencoded + reply = decrypt_aes(encryption_key, msg) reply = json.loads(reply.decode("utf-8")) if 'error' in reply: password = None