First implementation of Restore1/2.

This commit is contained in:
gram-signal 2024-08-02 12:22:39 -07:00 committed by GitHub
parent 5f21d8c661
commit 2b823c0bb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 457 additions and 28 deletions

View File

@ -5,6 +5,7 @@
#include <algorithm>
#include <sodium/crypto_auth_hmacsha256.h>
#include <sodium/crypto_hash_sha512.h>
#include <sodium/crypto_core_ristretto255.h>
#include "util/log.h"
@ -19,6 +20,39 @@
namespace svr2::db {
namespace {
template <class T>
const uint8_t* U8(const T& p) {
return reinterpret_cast<const uint8_t*>(p.data());
}
template <class T>
bool ValidRistrettoPoint(const T& p) {
if (p.size() != sizeof(DB4::RistrettoPoint)) {
return false;
}
return crypto_core_ristretto255_is_valid_point(U8(p));
}
template <class T>
bool ValidRistrettoScalar(const T& s) {
if (s.size() != sizeof(DB4::RistrettoScalar)) {
return false;
}
// libsodium doesn't seem to have a "is this scalar already reduced"
// call, so we instead resort to reducing and checking equality.
auto data = U8(s);
uint8_t nonreduced[crypto_core_ristretto255_NONREDUCEDSCALARBYTES] = {0};
uint8_t reduced[sizeof(DB4::RistrettoScalar)] = {0};
// Bytes are stored little-endian, so copy into the front of nonreduced.
memcpy(nonreduced, data, sizeof(DB4::RistrettoScalar));
crypto_core_ristretto255_scalar_reduce(reduced, nonreduced);
return util::ConstantTimeEqualsBytes(data, reduced, sizeof(DB4::RistrettoScalar));
}
} // namespace
const DB4::Protocol db4_protocol;
DB::Request* DB4::Protocol::RequestPB(context::Context* ctx) const {
@ -69,16 +103,31 @@ std::pair<const DB::Response*, error::Error> DB4::ClientState::ResponseFromReque
util::unique_lock lock(mu_);
restore2_client_state = std::move(restore2_);
}
if (r->inner_case() != client::Request4::kRestore2) {
// Only Restore2 is handled within this function; all other
// operations are handled by ResponseFromEffect. Returning
// (null, OK) here signals that we should continue on to that.
return std::make_pair(nullptr, error::OK);
switch (r->inner_case()) {
case client::Request4::kRestore2: {
if (authenticated_id().size() != sizeof(BackupID)) {
return std::make_pair(nullptr, COUNTED_ERROR(DB4_BackupIDSize));
}
auto resp = ctx->Protobuf<client::Response4>();
auto r2 = resp->mutable_restore2();
if (!restore2_client_state) {
r2->set_status(client::Response4::Restore2::RESTORE1_MISSING);
} else if (!ValidRistrettoScalar(r->restore2().auth_scalar()) ||
!ValidRistrettoPoint(r->restore2().auth_point())) {
r2->set_status(client::Response4::Restore2::INVALID_REQUEST);
} else {
BackupID id;
CHECK(error::OK == util::StringIntoByteArray(authenticated_id(), &id));
Restore2(ctx, id, r->restore2(), restore2_client_state.get(), r2);
}
return std::make_pair(resp, error::OK);
}
default:
// Only Restore2 is handled within this function; all other
// operations are handled by ResponseFromEffect. Returning
// (null, OK) here signals that we should continue on to that.
return std::make_pair(nullptr, error::OK);
}
if (!restore2_client_state) {
return std::make_pair(nullptr, COUNTED_ERROR(DB4_Restore2StateMissing));
}
return std::make_pair(nullptr, error::General_Unimplemented);
}
const DB::Response* DB4::ClientState::ResponseFromEffect(
@ -88,14 +137,15 @@ const DB::Response* DB4::ClientState::ResponseFromEffect(
if (e == nullptr) {
return nullptr;
}
if (e->resp().inner_case() == client::Response4::kRestore1 &&
e->resp().restore1().status() == client::Response4::Restore1::OK) {
auto resp = &e->resp();
if (resp->inner_case() == client::Response4::kRestore1 &&
resp->restore1().status() == client::Response4::Restore1::OK) {
auto restore2 = std::make_unique<client::Effect4::Restore2State>();
// TODO: fill in restore2 client state.
restore2->MergeFrom(dynamic_cast<const client::Effect4*>(&effect)->restore2_client_state());
util::unique_lock lock(mu_);
restore2_ = std::move(restore2);
}
return &e->resp();
return resp;
}
const std::string& DB4::Protocol::LogKey(const DB::Log& req) const {
@ -112,15 +162,19 @@ error::Error DB4::Protocol::ValidateClientLog(const DB::Log& log_pb) const {
case client::Request4::kCreate: {
const auto& req = log->req().create();
if (req.max_tries() > MAX_ALLOWED_MAX_TRIES ||
req.oprf_secretshare().size() != sizeof(RistrettoScalar) ||
req.auth_commitment().size() != sizeof(RistrettoPoint) ||
req.encryption_secretshare().size() != sizeof(AESKey) ||
req.zero_secretshare().size() != sizeof(RistrettoScalar)) {
!ValidRistrettoPoint(req.auth_commitment()) ||
!ValidRistrettoScalar(req.oprf_secretshare()) ||
!ValidRistrettoScalar(req.zero_secretshare()) ||
req.encryption_secretshare().size() != sizeof(AESKey)) {
return COUNTED_ERROR(DB4_RequestInvalid);
}
} break;
case client::Request4::kRestore1: {
const auto& req = log->req().restore1();
if (!ValidRistrettoPoint(req.blinded())) {
return COUNTED_ERROR(DB4_RequestInvalid);
}
} break;
case client::Request4::kRestore1:
return COUNTED_ERROR(General_Unimplemented);
case client::Request4::kRemove:
return COUNTED_ERROR(General_Unimplemented);
case client::Request4::kQuery:
@ -169,7 +223,10 @@ DB::Effect* DB4::Run(context::Context* ctx, const DB::Log& log_pb) {
} break;
case client::Request4::kRestore1: {
COUNTER(db4, ops_restore1)->Increment();
Restore1(ctx, id, log->req().restore1(), out->mutable_resp()->mutable_restore1());
Restore1(
ctx, id, log->req().restore1(),
out->mutable_resp()->mutable_restore1(),
out->mutable_restore2_client_state());
} break;
case client::Request4::kRemove: {
COUNTER(db4, ops_remove)->Increment();
@ -284,7 +341,8 @@ void DB4::Restore1(
context::Context* ctx,
const DB4::BackupID& id,
const client::Request4::Restore1& req,
client::Response4::Restore1* resp) {
client::Response4::Restore1* resp,
client::Effect4::Restore2State* state) {
auto find = rows_.find(id);
if (find == rows_.end()) {
resp->set_status(client::Response4::Restore1::MISSING);
@ -298,6 +356,41 @@ void DB4::Restore1(
}
row->tries--;
resp->set_tries_remaining(row->tries);
resp->set_status(client::Response4::Restore1::ERROR);
RistrettoPoint blinded_prime;
if (0 != crypto_scalarmult_ristretto255(
blinded_prime.data(),
row->oprf_secretshare.data(),
U8(req.blinded()))) {
goto restore1_error;
}
std::array<uint8_t, 64> sha512_hash;
crypto_hash_sha512_state sha512_state;
crypto_hash_sha512_init(&sha512_state);
crypto_hash_sha512_update(&sha512_state, id.data(), sizeof(id));
crypto_hash_sha512_update(&sha512_state, U8(req.blinded()), req.blinded().size());
crypto_hash_sha512_final(&sha512_state, sha512_hash.data());
RistrettoPoint ristretto_hash;
crypto_core_ristretto255_from_hash(ristretto_hash.data(), sha512_hash.data());
RistrettoPoint mask;
if (0 != crypto_scalarmult_ristretto255(mask.data(), row->zero_secretshare.data(), ristretto_hash.data())) {
goto restore1_error;
}
RistrettoPoint evaluated;
if (0 != crypto_core_ristretto255_add(evaluated.data(), blinded_prime.data(), mask.data())) {
goto restore1_error;
}
resp->set_element(util::ByteArrayToString(evaluated));
resp->set_status(client::Response4::Restore1::OK);
state->set_auth_commitment(util::ByteArrayToString(row->auth_commitment));
state->set_encryption_secretshare(util::ByteArrayToString(row->encryption_secretshare));
restore1_error:
if (row->tries == 0) {
rows_.erase(find);
row = nullptr; // The `row` ptr is no longer valid due to the `erase` call.
@ -305,7 +398,44 @@ void DB4::Restore1(
} else {
row->merkle_leaf_.Update(merkle::HashFrom(HashRow(id, *row)));
}
resp->set_status(client::Response4::Restore1::OK);
}
void DB4::ClientState::Restore2(
context::Context* ctx,
const BackupID& id,
const client::Request4::Restore2& req,
const client::Effect4::Restore2State* state,
client::Response4::Restore2* resp) const {
RistrettoPoint lhs1;
resp->set_status(client::Response4::Restore2::ERROR);
if (0 != crypto_scalarmult_ristretto255_base(lhs1.data(), U8(req.auth_scalar()))) {
return;
}
std::array<uint8_t, 64> sha512_hash;
crypto_hash_sha512_state sha512_state;
crypto_hash_sha512_init(&sha512_state);
crypto_hash_sha512_update(&sha512_state, U8(req.auth_point()), req.auth_point().size());
crypto_hash_sha512_final(&sha512_state, sha512_hash.data());
RistrettoScalar scalar_hash;
crypto_core_ristretto255_scalar_reduce(scalar_hash.data(), sha512_hash.data());
RistrettoPoint rhs1;
if (0 != crypto_scalarmult_ristretto255(rhs1.data(), scalar_hash.data(), U8(state->auth_commitment()))) {
return;
}
RistrettoPoint rhs2;
if (0 != crypto_core_ristretto255_add(rhs2.data(), rhs1.data(), U8(req.auth_point()))) {
return;
}
if (!util::ConstantTimeEquals(lhs1, rhs2)) {
return;
}
resp->set_status(client::Response4::Restore2::OK);
resp->set_encryption_secretshare(state->encryption_secretshare());
}
void DB4::Remove(

View File

@ -47,6 +47,13 @@ class DB4 : public DB {
// in future Restore2 requests.
virtual const Response* ResponseFromEffect(context::Context* ctx, const Effect& effect);
private:
void Restore2(
context::Context* ctx,
const BackupID& id,
const client::Request4::Restore2& req,
const client::Effect4::Restore2State* state,
client::Response4::Restore2* resp) const;
util::mutex mu_;
std::unique_ptr<client::Effect4::Restore2State> restore2_ GUARDED_BY(mu_);
};
@ -133,7 +140,8 @@ class DB4 : public DB {
context::Context* ctx,
const BackupID& id,
const client::Request4::Restore1& req,
client::Response4::Restore1* resp);
client::Response4::Restore1* resp,
client::Effect4::Restore2State* state);
void Remove(
context::Context* ctx,
const BackupID& id,

286
enclave/db/tests/db4.cc Normal file
View File

@ -0,0 +1,286 @@
// Copyright 2023 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//TESTDEP gtest
//TESTDEP merkle
//TESTDEP sip
//TESTDEP context
//TESTDEP env
//TESTDEP env/test
//TESTDEP env
//TESTDEP util
//TESTDEP metrics
//TESTDEP proto
//TESTDEP protobuf-lite
//TESTDEP libsodium
#include <gtest/gtest.h>
#include "db/db4.h"
#include "env/env.h"
#include "util/log.h"
#include "util/constant.h"
#include "util/macros.h"
#include "util/endian.h"
#include "util/bytes.h"
#include "util/hex.h"
#include "proto/client3.pb.h"
#include "proto/clientlog.pb.h"
#include <memory>
#include <sodium/crypto_scalarmult_ristretto255.h>
#include <sodium/crypto_auth_hmacsha512.h>
#define ASSERT_AND_ASSIGN(var, val) ASSERT_AND_ASSIGN_CTR1(var, val, __COUNTER__)
#define ASSERT_AND_ASSIGN_CTR1(var, val, ctr) ASSERT_AND_ASSIGN_CTR2(var, val, ctr)
#define ASSERT_AND_ASSIGN_CTR2(var, val, ctr) \
auto [var, __err##ctr] = (val); \
ASSERT_EQ(__err##ctr, error::OK);
namespace svr2::db {
class DB4Test : public ::testing::Test {
public:
DB4Test() {}
protected:
static void SetUpTestCase() {
env::Init(env::SIMULATED);
}
context::Context ctx;
merkle::Tree merk;
};
DB4::RistrettoScalar RandomScalar() {
DB4::RistrettoScalar s;
crypto_core_ristretto255_scalar_random(s.data());
return s;
}
template <class T1>
std::array<uint8_t, 64> SHA512(const T1& t1) {
std::array<uint8_t, 64> sha512_hash;
crypto_hash_sha512_state sha512_state;
crypto_hash_sha512_init(&sha512_state);
crypto_hash_sha512_update(&sha512_state, t1.data(), t1.size());
crypto_hash_sha512_final(&sha512_state, sha512_hash.data());
return sha512_hash;
}
template <class T1, class T2>
std::array<uint8_t, 64> SHA512(const T1& t1, const T2& t2) {
std::array<uint8_t, 64> sha512_hash;
crypto_hash_sha512_state sha512_state;
crypto_hash_sha512_init(&sha512_state);
crypto_hash_sha512_update(&sha512_state, t1.data(), t1.size());
crypto_hash_sha512_update(&sha512_state, t2.data(), t2.size());
crypto_hash_sha512_final(&sha512_state, sha512_hash.data());
return sha512_hash;
}
DB4::RistrettoScalar Reduce(const std::array<uint8_t, 64>& v) {
DB4::RistrettoScalar s;
crypto_core_ristretto255_scalar_reduce(s.data(), v.data());
return s;
}
template <class T1, class T2>
DB4::RistrettoScalar TestKDF(const T1& t1, const T2& t2) {
return Reduce(SHA512(t1, t2));
}
template <size_t N>
class DB4Client {
public:
DB4Client() {
CHECK(error::OK == env::environment->RandomBytes(id_.data(), id_.size()));
CHECK(error::OK == env::environment->RandomBytes(input_.data(), input_.size()));
k_oprf_ = RandomScalar();
DB4::RistrettoPoint hash_pt;
crypto_core_ristretto255_from_hash(hash_pt.data(), input_.data());
CHECK(0 == crypto_scalarmult_ristretto255(k_auth_pt_.data(), k_oprf_.data(), hash_pt.data()));
k_auth_ = TestKDF(input_, k_auth_pt_);
// We compute k_1..k_n such that SUM(k_*) == k_oprf by using the fact that
// k1 + k2 + ... + kn = k_oprf
// -k_oprf + k1 + k2 + ... + kn = 0
// -k_oprf + k2 + ... + kn = -k1
crypto_core_ristretto255_scalar_negate(k_[0].data(), k_oprf_.data());
for (int i = 1; i < N; i++) {
k_[i] = RandomScalar();
crypto_core_ristretto255_scalar_add(k_[0].data(), k_[0].data(), k_[i].data());
}
crypto_core_ristretto255_scalar_negate(k_[0].data(), k_[0].data());
// Choose random aes_enc_, and choose aes_[i] such that they xor to aes_enc_.
CHECK(error::OK == env::environment->RandomBytes(aes_enc_.data(), aes_enc_.size()));
memcpy(aes_[0].data(), aes_enc_.data(), aes_enc_.size());
for (int i = 1; i < N; i++) {
CHECK(error::OK == env::environment->RandomBytes(aes_[i].data(), aes_[i].size()));
for (int j = 0; j < aes_[0].size(); j++) {
aes_[0][j] ^= aes_[i][j];
}
}
k_enc_ = TestKDF(aes_enc_, k_auth_);
// Choose random z_i_ such that they sum to 0.
memset(z_[0].data(), 0, z_[0].size());
for (int i = 1; i < N; i++) {
z_[i] = RandomScalar();
crypto_core_ristretto255_scalar_add(z_[0].data(), z_[0].data(), z_[i].data());
}
crypto_core_ristretto255_scalar_negate(z_[0].data(), z_[0].data());
// Choose N secret and public keys.
for (int i = 0; i < N; i++) {
std::array<uint8_t, 1> index = {(uint8_t)('0' + i)};
sk_[i] = TestKDF(k_auth_, index);
crypto_scalarmult_ristretto255_base(pk_[i].data(), sk_[i].data());
}
}
client::Request4 Create(int i) {
client::Request4 r;
r.mutable_create()->set_max_tries(10);
r.mutable_create()->set_oprf_secretshare(util::ByteArrayToString(k_[i]));
r.mutable_create()->set_zero_secretshare(util::ByteArrayToString(z_[i]));
r.mutable_create()->set_auth_commitment(util::ByteArrayToString(pk_[i]));
r.mutable_create()->set_encryption_secretshare(util::ByteArrayToString(aes_[i]));
return r;
}
client::Request4 Restore1(int i, const DB4::RistrettoScalar& b) {
DB4::RistrettoPoint e;
crypto_core_ristretto255_from_hash(e.data(), input_.data());
DB4::RistrettoPoint blinded;
CHECK(0 == crypto_scalarmult_ristretto255(blinded.data(), b.data(), e.data()));
client::Request4 r;
r.mutable_restore1()->set_blinded(util::ByteArrayToString(blinded));
return r;
}
client::Request4 Restore2(
int i,
const DB4::RistrettoScalar& b,
const std::array<client::Response4::Restore1, N>& resps) {
DB4::RistrettoScalar evaluated_sum = {0};
for (int i = 0; i < N; i++) {
DB4::RistrettoScalar s;
CHECK(error::OK == util::StringIntoByteArray(resps[i].element(), &s));
crypto_core_ristretto255_scalar_add(evaluated_sum.data(), evaluated_sum.data(), s.data());
}
DB4::RistrettoScalar b_inverse;
crypto_core_ristretto255_scalar_invert(b_inverse.data(), b.data());
DB4::RistrettoScalar unblinded;
crypto_core_ristretto255_scalar_mul(unblinded.data(), b_inverse.data(), evaluated_sum.data());
// With all of this unblinding, we should have been able to recreate
// k_auth_pt.
DB4::RistrettoPoint k_auth_pt;
crypto_core_ristretto255_from_hash(k_auth_pt.data(), input_.data());
CHECK(0 == crypto_scalarmult_ristretto255(k_auth_pt.data(), k_oprf_.data(), k_auth_pt.data()));
CHECK(util::ConstantTimeEquals(k_auth_pt_, k_auth_pt));
// From k_auth_pt, we could now derive sk_ and pk_. But we already have them,
// so let's not bother.
auto rand = RandomScalar();
DB4::RistrettoPoint proof_point;
CHECK(0 == crypto_scalarmult_ristretto255_base(proof_point.data(), rand.data()));
auto c = Reduce(SHA512(proof_point));
DB4::RistrettoScalar proof_scalar;
crypto_core_ristretto255_scalar_mul(proof_scalar.data(), c.data(), sk_[i].data());
crypto_core_ristretto255_scalar_add(proof_scalar.data(), proof_scalar.data(), rand.data());
client::Request4 r;
r.mutable_restore2()->set_auth_point(util::ByteArrayToString(proof_point));
r.mutable_restore2()->set_auth_scalar(util::ByteArrayToString(proof_scalar));
return r;
}
bool EncryptionKeyMatches(int i, const client::Response4::Restore2& r) {
return util::ConstantTimeEquals(r.encryption_secretshare(), aes_[i]);
}
const DB4::BackupID& id() const { return id_; }
std::string authenticated_id() const { return util::ByteArrayToString(id_); }
private:
std::array<uint8_t, 64> input_;
DB4::BackupID id_;
DB4::RistrettoScalar k_oprf_;
DB4::RistrettoPoint k_auth_pt_;
DB4::RistrettoScalar k_auth_;
std::array<DB4::RistrettoScalar, N> k_;
std::array<DB4::RistrettoScalar, N> z_;
DB4::AESKey aes_enc_;
std::array<DB4::AESKey, N> aes_;
std::array<DB4::RistrettoScalar, N> sk_;
std::array<DB4::RistrettoPoint, N> pk_;
DB4::RistrettoScalar k_enc_;
};
TEST_F(DB4Test, SingleBackupLifecycle) {
ASSERT_EQ(1, 1);
const size_t N = 3;
DB4Client<N> client;
std::array<std::unique_ptr<DB4>, N> dbs;
std::array<std::unique_ptr<DB::ClientState>, N> states;
for (int i = 0; i < N; i++) {
dbs[i] = std::make_unique<DB4>(&merk);
states[i] = dbs[i]->P()->NewClientState(client.authenticated_id());
}
for (int i = 0; i < N; i++) {
LOG(INFO) << "Create." << i;
auto req = client.Create(i);
auto [resp, resp_err] = states[i]->ResponseFromRequest(&ctx, req);
ASSERT_EQ(nullptr, resp);
ASSERT_EQ(error::OK, resp_err);
auto [log, log_err] = states[i]->LogFromRequest(&ctx, req);
ASSERT_EQ(error::OK, log_err);
ASSERT_NE(nullptr, log);
auto effect = dbs[i]->Run(&ctx, *log);
ASSERT_NE(nullptr, effect);
auto resp2 = dynamic_cast<const client::Response4*>(states[i]->ResponseFromEffect(&ctx, *effect));
ASSERT_NE(nullptr, resp2);
ASSERT_EQ(resp2->create().status(), client::Response4::Create::OK);
ASSERT_EQ(client::Response4::kCreate, resp2->inner_case());
}
DB4::RistrettoScalar b = RandomScalar();
std::array<client::Response4::Restore1, N> restore1resp;
for (int i = 0; i < N; i++) {
LOG(INFO) << "Restore1." << i;
auto req = client.Restore1(i, b);
auto [resp, resp_err] = states[i]->ResponseFromRequest(&ctx, req);
ASSERT_EQ(nullptr, resp);
ASSERT_EQ(error::OK, resp_err);
auto [log, log_err] = states[i]->LogFromRequest(&ctx, req);
ASSERT_EQ(error::OK, log_err);
ASSERT_NE(nullptr, log);
auto effect = dbs[i]->Run(&ctx, *log);
ASSERT_NE(nullptr, effect);
auto resp2 = dynamic_cast<const client::Response4*>(states[i]->ResponseFromEffect(&ctx, *effect));
ASSERT_NE(nullptr, resp2);
ASSERT_EQ(client::Response4::kRestore1, resp2->inner_case());
ASSERT_EQ(resp2->restore1().status(), client::Response4::Restore1::OK);
ASSERT_EQ(resp2->restore1().tries_remaining(), 9);
restore1resp[i].MergeFrom(resp2->restore1());
}
for (int i = 0; i < N; i++) {
LOG(INFO) << "Restore2." << i;
auto req = client.Restore2(i, b, restore1resp);
auto [resp, resp_err] = states[i]->ResponseFromRequest(&ctx, req);
auto r = dynamic_cast<const client::Response4*>(resp);
ASSERT_NE(nullptr, r);
ASSERT_EQ(error::OK, resp_err);
ASSERT_EQ(client::Response4::kRestore2, r->inner_case());
ASSERT_EQ(r->restore2().status(), client::Response4::Restore2::OK);
ASSERT_TRUE(client.EncryptionKeyMatches(i, r->restore2()));
}
}
} // namespace svr2::db

View File

@ -34,6 +34,9 @@ message Log4 {
// Effect4 is the effect of applying a Log4 on a SVR4 (db4) database.
message Effect4 {
client.Response4 resp = 1;
message Restore2State {}
message Restore2State {
bytes auth_commitment = 1;
bytes encryption_secretshare = 2;
}
Restore2State restore2_client_state = 2;
}

View File

@ -26,10 +26,11 @@ message Request4 {
bytes zero_secretshare = 5; // z_i
}
message Restore1 {
bytes element = 1;
bytes blinded = 1;
}
message Restore2 {
bytes auth_proof = 1;
bytes auth_point = 1; // R
bytes auth_scalar = 2; // z
}
message Remove {
}
@ -75,10 +76,11 @@ message Response4 {
OK = 1;
MISSING = 2;
INVALID_REQUEST = 3;
ERROR = 4;
RESTORE1_MISSING = 4;
ERROR = 5;
}
Status status = 1;
bytes oprf_keyshare = 2;
bytes encryption_secretshare = 2;
}
message Remove {