Use ADs to protect multipart Noise messages from truncation.

This commit is contained in:
gram-signal 2026-04-15 10:58:24 -07:00 committed by GitHub
parent 2e735941e1
commit b6b8b459ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 150 additions and 28 deletions

View File

@ -12,6 +12,17 @@
namespace svr2::client {
// If this boolean is set to true, clients will utilize ADs during
// multipart noise message encryption to verify that messages are not
// truncated. Changing this, though, does require a client-side
// change, so it's currently off. Within normal SVR operation, though,
// clients should never send messages large enough to allow for truncation,
// so this default is safe (at least for current DB implementations).
//
// Note that when talking enclave-to-enclave, larger messages are the norm,
// hence for peer communication, we do always utilize length-verifying ADs.
const bool NOISE_VERIFY_LENGTH_WITH_AD = false;
static std::atomic<uint64_t> id_gen{1};
const NoiseProtocolId client_protocol = {
@ -121,7 +132,7 @@ error::Error Client::DecryptRequest(context::Context* ctx, const std::string& da
if (hs_.get() || !tx_.get() || !rx_.get()) {
return COUNTED_ERROR(Client_DecryptState);
}
auto [plaintext, err] = noise::Decrypt(rx_.get(), data);
auto [plaintext, err] = noise::Decrypt(rx_.get(), data, NOISE_VERIFY_LENGTH_WITH_AD);
if (err != error::OK) {
return err;
}
@ -141,7 +152,7 @@ std::pair<std::string, error::Error> Client::EncryptResponse(context::Context* c
if (!response.SerializeToString(&plaintext)) {
return std::make_pair("", COUNTED_ERROR(Client_EncryptSerialize));
}
return noise::Encrypt(tx_.get(), plaintext);
return noise::Encrypt(tx_.get(), plaintext, NOISE_VERIFY_LENGTH_WITH_AD);
}
std::pair<Client*, error::Error> ClientManager::NewClient(

View File

@ -17,6 +17,8 @@
namespace svr2::client {
extern const bool NOISE_VERIFY_LENGTH_WITH_AD;
class ClientManager;
typedef uint64_t ClientID;
extern const NoiseProtocolId client_protocol;

View File

@ -8,6 +8,7 @@
#include "testingcore.h"
#include "util/bytes.h"
#include "client/client.h"
#define NOISE_OK(x) \
do { \
@ -51,7 +52,7 @@ void TestingClient::RequestBackup(SecretData data, PIN pin, uint32_t tries) {
// serialize and encrypt
std::string req_str;
ASSERT_TRUE(req.SerializeToString(&req_str));
auto [ciphertext, encrypt_err] = noise::Encrypt(tx_.get(), req_str);
auto [ciphertext, encrypt_err] = noise::Encrypt(tx_.get(), req_str, client::NOISE_VERIFY_LENGTH_WITH_AD);
ASSERT_EQ(error::OK, encrypt_err);
ASSERT_EQ(error::OK,
core_.ExistingClientRequest(this, client_id_, ciphertext));
@ -68,7 +69,7 @@ void TestingClient::RequestExpose(SecretData data) {
// serialize and encrypt
std::string req_str;
ASSERT_TRUE(req.SerializeToString(&req_str));
auto [ciphertext, encrypt_err] = noise::Encrypt(tx_.get(), req_str);
auto [ciphertext, encrypt_err] = noise::Encrypt(tx_.get(), req_str, client::NOISE_VERIFY_LENGTH_WITH_AD);
ASSERT_EQ(error::OK, encrypt_err);
ASSERT_EQ(error::OK,
core_.ExistingClientRequest(this, client_id_, ciphertext));
@ -85,7 +86,7 @@ void TestingClient::RequestRestore(PIN pin) {
// serialize and encrypt
std::string req_str;
ASSERT_TRUE(req.SerializeToString(&req_str));
auto [ciphertext, encrypt_err] = noise::Encrypt(tx_.get(), req_str);
auto [ciphertext, encrypt_err] = noise::Encrypt(tx_.get(), req_str, client::NOISE_VERIFY_LENGTH_WITH_AD);
ASSERT_EQ(error::OK, encrypt_err);
ASSERT_EQ(error::OK,
core_.ExistingClientRequest(this, client_id_, ciphertext));
@ -101,7 +102,7 @@ void TestingClient::RequestTries() {
// serialize and encrypt
std::string req_str;
ASSERT_TRUE(req.SerializeToString(&req_str));
auto [ciphertext, encrypt_err] = noise::Encrypt(tx_.get(), req_str);
auto [ciphertext, encrypt_err] = noise::Encrypt(tx_.get(), req_str, client::NOISE_VERIFY_LENGTH_WITH_AD);
ASSERT_EQ(error::OK, encrypt_err);
ASSERT_EQ(error::OK,
core_.ExistingClientRequest(this, client_id_, ciphertext));
@ -153,7 +154,7 @@ void TestingClient::FinishHandshake(ExistingClientReply ecr) {
void TestingClient::DecryptClientReply(ExistingClientReply ecr,
client::Response* rsp) {
auto [plaintext, decrypt_err] = noise::Decrypt(rx_.get(), ecr.data());
auto [plaintext, decrypt_err] = noise::Decrypt(rx_.get(), ecr.data(), client::NOISE_VERIFY_LENGTH_WITH_AD);
ASSERT_EQ(error::OK, decrypt_err);
ASSERT_TRUE(rsp->ParseFromString(plaintext));

View File

@ -57,6 +57,7 @@
#include "core/coretest/replicagroup.h"
#include "core/coretest/testingclient.h"
#include "ristretto/ristretto.h"
#include "client/client.h"
// This test is pretty large and contains a lot of code which should maybe be
// moved into some coretest library at a later date. There's a few very
@ -281,7 +282,7 @@ class CoreTest : public ::testing::Test {
{ // send the request, parse response.
std::string req_str;
ASSERT_TRUE(req.SerializeToString(&req_str));
auto [ciphertext, encrypt_err] = noise::Encrypt(txp, req_str);
auto [ciphertext, encrypt_err] = noise::Encrypt(txp, req_str, client::NOISE_VERIFY_LENGTH_WITH_AD);
ASSERT_EQ(error::OK, encrypt_err);
UntrustedMessage msg;
auto host = msg.mutable_h2e_request();
@ -298,7 +299,7 @@ class CoreTest : public ::testing::Test {
ASSERT_EQ(resp.status(), error::OK);
ASSERT_EQ(resp.inner_case(), HostToEnclaveResponse::kExistingClientReply);
auto ec2 = resp.existing_client_reply();
auto [plaintext, decrypt_err] = noise::Decrypt(rxp, ec2.data());
auto [plaintext, decrypt_err] = noise::Decrypt(rxp, ec2.data(), client::NOISE_VERIFY_LENGTH_WITH_AD);
ASSERT_EQ(error::OK, decrypt_err);
ASSERT_TRUE(cli_resp->ParseFromString(plaintext));
}

View File

@ -10,7 +10,7 @@ namespace svr2::noise {
static size_t max_message_size = 65535;
std::pair<std::string, error::Error> Encrypt(NoiseCipherState* cs, const std::string& plaintext) {
std::pair<std::string, error::Error> Encrypt(NoiseCipherState* cs, const std::string& plaintext, bool with_length_verifying_ad) {
std::string ciphertext;
size_t mac_size = noise_cipherstate_get_mac_length(cs);
size_t max_encrypt_size = max_message_size - mac_size;
@ -37,14 +37,23 @@ std::pair<std::string, error::Error> Encrypt(NoiseCipherState* cs, const std::st
plaintext_start += plaintext_size;
NoiseBuffer buf;
noise_buffer_set_inout(buf, StrU8Ptr(&ciphertext) + start, plaintext_size, plaintext_size + mac_size);
if (NOISE_ERROR_NONE != noise_cipherstate_encrypt(cs, &buf)) {
// The approach that doesn't utilize an AD has an issue where a middleman can
// truncate an incoming/outgoing message at a message boundary, and the received
// message will still decrypt and verify successfully. To avoid this, we
// use a very simple AD which is the array [0] if this message is part of a longer
// message, and [1] if this is the final piece in a longer multi-part message.
// Reordering of messages/etc should all be caught by Noise's internal MAC-ing.
const uint8_t ad[1] = {static_cast<uint8_t>(
start + max_message_size >= final_size ? 1 : 0)};
if (NOISE_ERROR_NONE != noise_cipherstate_encrypt_with_ad(
cs, ad, with_length_verifying_ad ? 1 : 0, &buf)) {
return std::make_pair("", COUNTED_ERROR(Peers_Encrypt));
}
}
return std::make_pair(ciphertext, error::OK);
}
std::pair<std::string, error::Error> Decrypt(NoiseCipherState* cs, const std::string& ciphertext) {
std::pair<std::string, error::Error> Decrypt(NoiseCipherState* cs, const std::string& ciphertext, bool with_length_verifying_ad) {
std::string plaintext(ciphertext.size(), 0);
size_t plaintext_start = 0;
// Data comes in as [ciphertext][mac][ciphertext][mac].
@ -53,7 +62,16 @@ std::pair<std::string, error::Error> Decrypt(NoiseCipherState* cs, const std::st
memcpy(StrU8Ptr(&plaintext) + plaintext_start, StrU8Ptr(ciphertext) + start, size);
NoiseBuffer buf;
noise_buffer_set_inout(buf, StrU8Ptr(&plaintext) + plaintext_start, size, size);
if (NOISE_ERROR_NONE != noise_cipherstate_decrypt(cs, &buf)) {
// The approach that doesn't utilize an AD has an issue where a middleman can
// truncate an incoming/outgoing message at a message boundary, and the received
// message will still decrypt and verify successfully. To avoid this, we
// use a very simple AD which is the array [0] if this message is part of a longer
// message, and [1] if this is the final piece in a longer multi-part message.
// Reordering of messages/etc should all be caught by Noise's internal MAC-ing.
const uint8_t ad[1] = {static_cast<uint8_t>(
start + max_message_size >= ciphertext.size() ? 1 : 0)};
if (NOISE_ERROR_NONE != noise_cipherstate_decrypt_with_ad(
cs, ad, with_length_verifying_ad ? 1 : 0, &buf)) {
return std::make_pair("", COUNTED_ERROR(Peers_Decrypt));
}
plaintext_start += buf.size;

View File

@ -68,9 +68,9 @@ inline CipherState WrapCipherState(NoiseCipherState* s) {
}
// Encrypt the given string.
std::pair<std::string, error::Error> Encrypt(NoiseCipherState* cs, const std::string& plaintext);
std::pair<std::string, error::Error> Encrypt(NoiseCipherState* cs, const std::string& plaintext, bool with_length_verifying_ad);
// Decrypt the given string.
std::pair<std::string, error::Error> Decrypt(NoiseCipherState* cs, const std::string& ciphertext);
std::pair<std::string, error::Error> Decrypt(NoiseCipherState* cs, const std::string& ciphertext, bool with_length_verifying_ad);
} // namespace svr2::noise

View File

@ -31,7 +31,7 @@ class CipherStateTest : public ::testing::Test {
env::Init(env::SIMULATED);
}
void EncryptDecrypt(const std::string& plaintext, std::string* ciphertext_out, int type) {
void EncryptDecrypt(const std::string& plaintext, std::string* ciphertext_out, int type, bool with_ad) {
std::array<uint8_t, 32> key = {1};
NoiseCipherState* s1n;
NoiseCipherState* s2n;
@ -41,9 +41,9 @@ class CipherStateTest : public ::testing::Test {
ASSERT_EQ(NOISE_ERROR_NONE, noise_cipherstate_init_key(s2n, key.data(), key.size()));
noise::CipherState s1 = noise::WrapCipherState(s1n);
noise::CipherState s2 = noise::WrapCipherState(s2n);
auto [ciphertext, enc_err] = noise::Encrypt(s1n, plaintext);
auto [ciphertext, enc_err] = noise::Encrypt(s1n, plaintext, with_ad);
ASSERT_EQ(error::OK, enc_err);
auto [computed_plaintext, dec_err] = noise::Decrypt(s2n, ciphertext);
auto [computed_plaintext, dec_err] = noise::Decrypt(s2n, ciphertext, with_ad);
ASSERT_EQ(error::OK, dec_err);
ASSERT_EQ(plaintext, computed_plaintext);
ciphertext_out->swap(ciphertext);
@ -52,28 +52,117 @@ class CipherStateTest : public ::testing::Test {
TEST_F(CipherStateTest, EncryptDecrypt) {
std::string ciphertext;
EncryptDecrypt("", &ciphertext, NOISE_CIPHER_CHACHAPOLY);
EncryptDecrypt("", &ciphertext, NOISE_CIPHER_CHACHAPOLY, false);
ASSERT_EQ(16, ciphertext.size());
EncryptDecrypt("a", &ciphertext, NOISE_CIPHER_CHACHAPOLY);
EncryptDecrypt("a", &ciphertext, NOISE_CIPHER_CHACHAPOLY, false);
ASSERT_EQ(17, ciphertext.size());
EncryptDecrypt("this is a test of the emergency broadcast system", &ciphertext, NOISE_CIPHER_CHACHAPOLY);
EncryptDecrypt("this is a test of the emergency broadcast system", &ciphertext, NOISE_CIPHER_CHACHAPOLY, false);
std::string s;
s.resize(65535-16, 'a');
EncryptDecrypt(s, &ciphertext, NOISE_CIPHER_CHACHAPOLY);
EncryptDecrypt(s, &ciphertext, NOISE_CIPHER_CHACHAPOLY, false);
ASSERT_EQ(ciphertext.size(), 65535);
s.resize(65535-15, 'a');
EncryptDecrypt(s, &ciphertext, NOISE_CIPHER_CHACHAPOLY);
EncryptDecrypt(s, &ciphertext, NOISE_CIPHER_CHACHAPOLY, false);
ASSERT_EQ(ciphertext.size(), 65535-15+32);
s.resize((65535-16)*10, 'a');
EncryptDecrypt(s, &ciphertext, NOISE_CIPHER_CHACHAPOLY);
EncryptDecrypt(s, &ciphertext, NOISE_CIPHER_CHACHAPOLY, false);
ASSERT_EQ(ciphertext.size(), 65535*10);
}
TEST_F(CipherStateTest, EncryptDecryptWithLengthAD) {
std::string ciphertext;
EncryptDecrypt("", &ciphertext, NOISE_CIPHER_CHACHAPOLY, true);
ASSERT_EQ(16, ciphertext.size());
EncryptDecrypt("a", &ciphertext, NOISE_CIPHER_CHACHAPOLY, true);
ASSERT_EQ(17, ciphertext.size());
EncryptDecrypt("this is a test of the emergency broadcast system", &ciphertext, NOISE_CIPHER_CHACHAPOLY, true);
std::string s;
s.resize(65535-16, 'a');
EncryptDecrypt(s, &ciphertext, NOISE_CIPHER_CHACHAPOLY, true);
ASSERT_EQ(ciphertext.size(), 65535);
s.resize(65535-15, 'a');
EncryptDecrypt(s, &ciphertext, NOISE_CIPHER_CHACHAPOLY, true);
ASSERT_EQ(ciphertext.size(), 65535-15+32);
s.resize((65535-16)*10, 'a');
EncryptDecrypt(s, &ciphertext, NOISE_CIPHER_CHACHAPOLY, true);
ASSERT_EQ(ciphertext.size(), 65535*10);
}
TEST_F(CipherStateTest, DecryptTruncatedWithLengthADFails) {
std::array<uint8_t, 32> key = {1};
NoiseCipherState* s1n;
NoiseCipherState* s2n;
ASSERT_EQ(NOISE_ERROR_NONE, noise_cipherstate_new_by_id(&s1n, NOISE_CIPHER_CHACHAPOLY));
ASSERT_EQ(NOISE_ERROR_NONE, noise_cipherstate_init_key(s1n, key.data(), key.size()));
ASSERT_EQ(NOISE_ERROR_NONE, noise_cipherstate_new_by_id(&s2n, NOISE_CIPHER_CHACHAPOLY));
ASSERT_EQ(NOISE_ERROR_NONE, noise_cipherstate_init_key(s2n, key.data(), key.size()));
size_t max_message_size = 65535;
size_t mac_size = noise_cipherstate_get_mac_length(s1n);
size_t max_encrypt_size = max_message_size - mac_size;
noise::CipherState s1 = noise::WrapCipherState(s1n);
noise::CipherState s2 = noise::WrapCipherState(s2n);
bool with_ad = true;
std::string plaintext(65535, 'a');
auto [ciphertext, enc_err] = noise::Encrypt(s1n, plaintext, with_ad);
ASSERT_EQ(error::OK, enc_err);
ASSERT_EQ(ciphertext.size(), max_message_size + mac_size * 2);
// Truncate ciphertext at a message size boundary.
ciphertext.resize(max_message_size);
auto [computed_plaintext, dec_err] = noise::Decrypt(s2n, ciphertext, with_ad);
// Note that this would be error::OK if with_ad were false.
ASSERT_EQ(error::Peers_Decrypt, dec_err);
}
TEST_F(CipherStateTest, DecryptTruncatedWithoutLengthADSucceedsButBreaksCiphertext) {
std::array<uint8_t, 32> key = {1};
NoiseCipherState* s1n;
NoiseCipherState* s2n;
ASSERT_EQ(NOISE_ERROR_NONE, noise_cipherstate_new_by_id(&s1n, NOISE_CIPHER_CHACHAPOLY));
ASSERT_EQ(NOISE_ERROR_NONE, noise_cipherstate_init_key(s1n, key.data(), key.size()));
ASSERT_EQ(NOISE_ERROR_NONE, noise_cipherstate_new_by_id(&s2n, NOISE_CIPHER_CHACHAPOLY));
ASSERT_EQ(NOISE_ERROR_NONE, noise_cipherstate_init_key(s2n, key.data(), key.size()));
size_t max_message_size = 65535;
size_t mac_size = noise_cipherstate_get_mac_length(s1n);
size_t max_encrypt_size = max_message_size - mac_size;
noise::CipherState s1 = noise::WrapCipherState(s1n);
noise::CipherState s2 = noise::WrapCipherState(s2n);
bool with_ad = false;
std::string plaintext(65535, 'a');
{
auto [ciphertext, enc_err] = noise::Encrypt(s1n, plaintext, with_ad);
ASSERT_EQ(error::OK, enc_err);
ASSERT_EQ(ciphertext.size(), max_message_size + mac_size * 2);
// Truncate ciphertext at a message size boundary.
ciphertext.resize(max_message_size);
auto [computed_plaintext, dec_err] = noise::Decrypt(s2n, ciphertext, with_ad);
ASSERT_EQ(error::OK, dec_err);
}
// While our truncation was successful, unless the rest of the truncation is then
// fed into the receiving ciphertext, it will be broken.
plaintext = "a";
{
auto [ciphertext, enc_err] = noise::Encrypt(s1n, plaintext, with_ad);
ASSERT_EQ(error::OK, enc_err);
auto [computed_plaintext, dec_err] = noise::Decrypt(s2n, ciphertext, with_ad);
ASSERT_EQ(error::Peers_Decrypt, dec_err);
}
}
TEST_F(CipherStateTest, BenchmarkChaChaPoly) {
std::string plaintext;
std::string ciphertext;
@ -81,7 +170,7 @@ TEST_F(CipherStateTest, BenchmarkChaChaPoly) {
auto start = util::asm_rdtsc();
int times = 100;
for (int i = 0; i < times; i++) {
EncryptDecrypt(plaintext, &ciphertext, NOISE_CIPHER_CHACHAPOLY);
EncryptDecrypt(plaintext, &ciphertext, NOISE_CIPHER_CHACHAPOLY, false);
}
LOG(INFO) << "took " << ((util::asm_rdtsc() - start) * 1.0 / (times * plaintext.size())) << " cycles/byte";
}
@ -93,7 +182,7 @@ TEST_F(CipherStateTest, BenchmarkAesGcm) {
auto start = util::asm_rdtsc();
int times = 100;
for (int i = 0; i < times; i++) {
EncryptDecrypt(plaintext, &ciphertext, NOISE_CIPHER_AESGCM);
EncryptDecrypt(plaintext, &ciphertext, NOISE_CIPHER_AESGCM, false);
}
LOG(INFO) << "took " << ((util::asm_rdtsc() - start) * 1.0 / (times * plaintext.size())) << " cycles/byte";
}

View File

@ -82,7 +82,7 @@ error::Error Peer::Send(
if (!msg.SerializeToString(&serialized)) {
return COUNTED_ERROR(Peers_EncryptSerialize);
}
auto [ciphertext, err] = noise::Encrypt(tx_.get(), serialized);
auto [ciphertext, err] = noise::Encrypt(tx_.get(), serialized, true);
if (err != error::OK) {
// An encryption error probably means bad noise state, which is unrecoverable.
InternalDisconnect();
@ -128,7 +128,7 @@ error::Error Peer::Recv(
SendRst(ctx, id_);
return COUNTED_ERROR(Peers_DataNotConnected);
}
auto [plaintext, err] = noise::Decrypt(rx_.get(), msg.data());
auto [plaintext, err] = noise::Decrypt(rx_.get(), msg.data(), true);
if (err != error::OK) {
// A decryption error probably means bad noise state, which is unrecoverable.
InternalDisconnect();