SecureValueRecovery2/enclave/client/client.cc
2023-05-05 16:25:12 -06:00

226 lines
8.0 KiB
C++

// Copyright 2023 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
#include "client/client.h"
#include <atomic>
#include "env/env.h"
#include "util/log.h"
#include "metrics/metrics.h"
namespace svr2::client {
static std::atomic<uint64_t> id_gen{1};
const NoiseProtocolId client_protocol = {
.prefix_id = NOISE_PREFIX_STANDARD,
.pattern_id = NOISE_PATTERN_NK,
.dh_id = NOISE_DH_CURVE25519,
.cipher_id = NOISE_CIPHER_CHACHAPOLY,
.hash_id = NOISE_HASH_SHA256,
.hybrid_id = 0,
};
Client::Client(const std::string& authenticated_id)
: hs_(noise::WrapHandshakeState(nullptr)),
tx_(noise::WrapCipherState(nullptr)),
rx_(noise::WrapCipherState(nullptr)),
id_(id_gen.fetch_add(1)),
authenticated_id_(authenticated_id) {
}
Client::~Client() {
}
error::Error Client::Init(const noise::DHState& dhstate, const e2e::Attestation& attestation) {
util::unique_lock lock(mu_);
NoiseHandshakeState* hs;
if (NOISE_ERROR_NONE != noise_handshakestate_new_by_id(&hs, &client_protocol, NOISE_ROLE_RESPONDER)) {
return COUNTED_ERROR(Client_HandshakeState);
}
auto hs_wrap = noise::WrapHandshakeState(hs);
if (NOISE_ERROR_NONE != noise_dhstate_copy(
noise_handshakestate_get_local_keypair_dh(hs),
dhstate.get())) {
return COUNTED_ERROR(Client_CopyDHState);
}
if (NOISE_ERROR_NONE != noise_handshakestate_start(hs)) {
return COUNTED_ERROR(Client_HandshakeStart);
}
hs_start_.mutable_test_only_pubkey()->resize(32, '\0');
if (NOISE_ERROR_NONE != noise_dhstate_get_public_key(
dhstate.get(),
noise::StrU8Ptr(hs_start_.mutable_test_only_pubkey()),
hs_start_.mutable_test_only_pubkey()->size())) {
return COUNTED_ERROR(Client_ExtractPublicKey);
}
*hs_start_.mutable_evidence() = attestation.evidence();
*hs_start_.mutable_endorsement() = attestation.endorsements();
hs_.swap(hs_wrap);
return error::OK;
}
std::pair<std::string, error::Error> Client::FinishHandshake(context::Context* ctx, const std::string& data) {
ACQUIRE_LOCK(mu_, ctx, lock_client);
MEASURE_CPU(ctx, cpu_client_hs_finish);
if (!hs_.get() || tx_.get() || rx_.get()
|| noise_handshakestate_get_action(hs_.get()) != NOISE_ACTION_READ_MESSAGE) {
return std::make_pair("", COUNTED_ERROR(Client_HandshakeState));
}
std::string buffer = data;
NoiseBuffer read_buf = noise::BufferInputFromString(&buffer);
if (NOISE_ERROR_NONE != noise_handshakestate_read_message(hs_.get(), &read_buf, nullptr)) {
return std::make_pair("", COUNTED_ERROR(Client_FinishReadHandshake));
}
if (NOISE_ACTION_WRITE_MESSAGE != noise_handshakestate_get_action(hs_.get())) {
return std::make_pair("", COUNTED_ERROR(Client_HandshakeState));
}
buffer.resize(noise::HANDSHAKE_INIT_SIZE, '\0');
NoiseBuffer write_buf = noise::BufferOutputFromString(&buffer);
if (NOISE_ERROR_NONE != noise_handshakestate_write_message(hs_.get(), &write_buf, nullptr)) {
return std::make_pair("", COUNTED_ERROR(Client_FinishWriteHandshake));
}
buffer.resize(write_buf.size);
if (NOISE_ACTION_SPLIT != noise_handshakestate_get_action(hs_.get())) {
return std::make_pair("", COUNTED_ERROR(Client_HandshakeState));
}
NoiseCipherState* tx;
NoiseCipherState* rx;
if (NOISE_ERROR_NONE != noise_handshakestate_split(hs_.get(), &tx, &rx)) {
return std::make_pair("", COUNTED_ERROR(Client_FinishSplit));
}
tx_.reset(tx);
rx_.reset(rx);
hs_.reset(nullptr);
return std::make_pair(buffer, error::OK);
}
error::Error Client::DecryptRequest(context::Context* ctx, const std::string& data, google::protobuf::MessageLite* request) {
ACQUIRE_LOCK(mu_, ctx, lock_client);
MEASURE_CPU(ctx, cpu_client_decrypt);
if (hs_.get() || !tx_.get() || !rx_.get()) {
return COUNTED_ERROR(Client_DecryptState);
}
auto [plaintext, err] = noise::Decrypt(rx_.get(), data);
if (err != error::OK) {
return err;
}
if (!request->ParseFromString(plaintext)) {
return COUNTED_ERROR(Client_DecryptParse);
}
return error::OK;
}
std::pair<std::string, error::Error> Client::EncryptResponse(context::Context* ctx, const google::protobuf::MessageLite& response) {
ACQUIRE_LOCK(mu_, ctx, lock_client);
MEASURE_CPU(ctx, cpu_client_encrypt);
if (hs_.get() || !tx_.get() || !rx_.get()) {
return std::make_pair("", COUNTED_ERROR(Client_EncryptState));
}
std::string plaintext;
if (!response.SerializeToString(&plaintext)) {
return std::make_pair("", COUNTED_ERROR(Client_EncryptSerialize));
}
return noise::Encrypt(tx_.get(), plaintext);
}
std::pair<Client*, error::Error> ClientManager::NewClient(context::Context* ctx, const std::string& authenticated_id) {
ACQUIRE_LOCK(mu_, ctx, lock_clientmanager);
MEASURE_CPU(ctx, cpu_client_hs_start);
std::unique_ptr<Client> c(new Client(authenticated_id));
error::Error err = c->Init(dhstate_, attestation_);
if (err != error::OK) {
return std::make_pair(nullptr, err);
}
Client* ptr = c.get();
clients_[ptr->ID()] = std::move(c);
GAUGE(client, clients)->Set(clients_.size());
COUNTER(client, created)->Increment();
return std::make_pair(ptr, error::OK);
}
Client* ClientManager::GetClient(context::Context* ctx, ClientID id) const {
ACQUIRE_LOCK(mu_, ctx, lock_clientmanager);
auto find = clients_.find(id);
if (find == clients_.end()) { return nullptr; }
return find->second.get();
}
bool ClientManager::RemoveClient(context::Context* ctx, ClientID id) {
ACQUIRE_LOCK(mu_, ctx, lock_clientmanager);
auto find = clients_.find(id);
if (find == clients_.end()) { return false; }
clients_.erase(find);
GAUGE(client, clients)->Set(clients_.size());
COUNTER(client, closed)->Increment();
return true;
}
noise::DHState ClientManager::NewDHState() {
COUNTER(client, new_dh_state)->Increment();
noise::DHState out = noise::WrapDHState(nullptr);
NoiseDHState* dhstate;
if (NOISE_ERROR_NONE != noise_dhstate_new_by_id(&dhstate, client::client_protocol.dh_id)) {
return out;
}
noise::DHState client_dh = noise::WrapDHState(dhstate);
if (NOISE_ERROR_NONE != noise_dhstate_generate_keypair(dhstate)) {
return out;
}
client_dh.swap(out);
return out;
}
error::Error ClientManager::RotateKeyAndRefreshAttestation(context::Context* ctx, const enclaveconfig::RaftGroupConfig& config) {
auto dhstate = NewDHState();
auto [attestation, err] = GetAttestation(dhstate, config);
if (err != error::OK) {
COUNTER(client, key_rotate_failure)->Increment();
return err;
}
ACQUIRE_LOCK(mu_, ctx, lock_clientmanager);
dhstate_.swap(dhstate);
attestation_.CopyFrom(attestation);
COUNTER(client, key_rotate_success)->Increment();
return error::OK;
}
error::Error ClientManager::RefreshAttestation(context::Context* ctx, const enclaveconfig::RaftGroupConfig& config) {
auto dhstate = DHState(ctx);
auto [attestation, err] = GetAttestation(DHState(ctx), config);
if (err != error::OK) {
COUNTER(client, attestation_refresh_failure)->Increment();
return err;
}
ACQUIRE_LOCK(mu_, ctx, lock_clientmanager);
attestation_.CopyFrom(attestation);
// There's a chance that a RotateKeyAndRefreshAttestation call
// could have happened between when we got dhstate and when we're
// setting attestation here... reset to the one we received just
// in case.
dhstate_.swap(dhstate);
COUNTER(client, attestation_refresh_success)->Increment();
return error::OK;
}
std::pair<e2e::Attestation, error::Error> ClientManager::GetAttestation(const noise::DHState& dhstate, const enclaveconfig::RaftGroupConfig& config) {
e2e::Attestation attestation;
// get attestation for its public key
uint8_t public_key[32];
if (NOISE_ERROR_NONE != noise_dhstate_get_public_key(dhstate.get(), public_key, sizeof(public_key))) {
return std::make_pair(attestation, error::Peers_NewKeyPublic);
}
env::PublicKey public_key_array {};
std::copy(std::begin(public_key), std::end(public_key), std::begin(public_key_array));
return env::environment->Evidence(public_key_array, config);
}
noise::DHState ClientManager::DHState(context::Context* ctx) const {
ACQUIRE_LOCK(mu_, ctx, lock_clientmanager);
return noise::CloneDHState(dhstate_);
}
} // namespace svr2::client