enclave: add rustfmt config and run rustfmt on the code base

This commit is contained in:
Curt Brune 2020-08-11 16:39:00 -07:00
parent 6452d65fe7
commit 24ddbf697a
54 changed files with 3684 additions and 5267 deletions

View File

@ -66,6 +66,11 @@ unsafe fn realloc_fallback(alloc: &System, ptr: *mut u8, old_layout: Layout, new
#[alloc_error_handler]
pub fn handle_alloc_error(layout: Layout) -> ! {
let status = sgx_ffi::util::MemoryStatus::collect();
panic!("out of memory allocating {} bytes with {} used of {} bytes in {} chunks",
layout.size(), status.used_bytes, status.footprint_bytes, status.free_chunks);
panic!(
"out of memory allocating {} bytes with {} used of {} bytes in {} chunks",
layout.size(),
status.used_bytes,
status.footprint_bytes,
status.free_chunks
);
}

File diff suppressed because it is too large Load Diff

View File

@ -8,27 +8,19 @@
use crate::prelude::*;
use std::cell::*;
use std::ptr::{NonNull};
use std::ptr::NonNull;
use std::slice;
use prost::Message;
use sgx_ffi::untrusted_slice::{UntrustedSlice};
use sgx_ffi::untrusted_slice::UntrustedSlice;
use super::bindgen_wrapper::{kbupd_enclave_ocall_alloc, kbupd_enclave_ocall_recv_enclave_msg};
pub use super::bindgen_wrapper::{
sgxsd_server_init_args_t as StartArgs,
sgxsd_server_handle_call_args_t as CallArgs,
sgxsd_server_terminate_args_t as StopArgs,
KBUPD_REQUEST_TYPE_ANY,
KBUPD_REQUEST_TYPE_BACKUP,
KBUPD_REQUEST_TYPE_RESTORE,
KBUPD_REQUEST_TYPE_DELETE,
};
use super::bindgen_wrapper::{
kbupd_enclave_ocall_recv_enclave_msg,
kbupd_enclave_ocall_alloc,
sgxsd_server_handle_call_args_t as CallArgs, sgxsd_server_init_args_t as StartArgs, sgxsd_server_terminate_args_t as StopArgs,
KBUPD_REQUEST_TYPE_ANY, KBUPD_REQUEST_TYPE_BACKUP, KBUPD_REQUEST_TYPE_DELETE, KBUPD_REQUEST_TYPE_RESTORE,
};
use crate::protobufs::kbupd::{UntrustedMessageBatch, UntrustedMessage, EnclaveMessageBatch, EnclaveMessage};
use crate::protobufs::kbupd::{EnclaveMessage, EnclaveMessageBatch, UntrustedMessage, UntrustedMessageBatch};
pub trait KbupdService {
fn untrusted_message(&mut self, message: UntrustedMessage);
@ -38,8 +30,7 @@ const ENCLAVE_MESSAGE_BUFFER_SIZE: usize = 10240;
#[cfg(not(any(test, feature = "test")))]
pub fn with_buffer<F, R>(fun: F) -> R
where F: FnOnce(&RefCell<Option<Vec<u8>>>) -> R
{
where F: FnOnce(&RefCell<Option<Vec<u8>>>) -> R {
#[thread_local]
static ENCLAVE_MESSAGE_BUFFER: RefCell<Option<Vec<u8>>> = RefCell::new(None);
@ -48,8 +39,7 @@ where F: FnOnce(&RefCell<Option<Vec<u8>>>) -> R
#[cfg(any(test, feature = "test"))]
pub fn with_buffer<F, R>(fun: F) -> R
where F: FnOnce(&RefCell<Option<Vec<u8>>>) -> R
{
where F: FnOnce(&RefCell<Option<Vec<u8>>>) -> R {
thread_local! {
static ENCLAVE_MESSAGE_BUFFER: RefCell<Option<Vec<u8>>> = RefCell::new(None);
}
@ -59,7 +49,7 @@ where F: FnOnce(&RefCell<Option<Vec<u8>>>) -> R
pub fn kbupd_enclave_alloc_untrusted(mut size: usize) -> Result<UntrustedSlice<'static>, ()> {
let mut p_data: *mut libc::c_void = std::ptr::null_mut();
match unsafe { kbupd_enclave_ocall_alloc(&mut p_data, &mut size) } {
0 => UntrustedSlice::new(p_data as *mut u8, size),
0 => UntrustedSlice::new(p_data as *mut u8, size),
error => {
error!("ocall error allocating {} bytes from untrusted: {}", size, error);
Err(())
@ -68,8 +58,7 @@ pub fn kbupd_enclave_alloc_untrusted(mut size: usize) -> Result<UntrustedSlice<'
}
pub fn kbupd_enclave_recv_untrusted_msg<S>(service: &mut S, p_data: *const u8, data_size: usize)
where S: KbupdService,
{
where S: KbupdService {
let data = ECallSlice(NonNull::new(p_data as *mut _), data_size);
match UntrustedMessageBatch::decode(data.as_ref()) {
@ -86,13 +75,13 @@ where S: KbupdService,
}
pub fn kbupd_send(message: EnclaveMessage) {
let batch = EnclaveMessageBatch { messages: vec![message] };
let batch = EnclaveMessageBatch { messages: vec![message] };
let buffer_len = with_buffer(|buffer| buffer.borrow().as_ref().map(Vec::len).unwrap_or(0));
if buffer_len.saturating_add(batch.encoded_len()) > ENCLAVE_MESSAGE_BUFFER_SIZE {
kbupd_send_flush();
}
with_buffer(|buffer| {
let mut buffer_ref_mut = RefMut::map(buffer.borrow_mut(), |maybe_buffer: &mut Option<Vec<u8>>| {
let mut buffer_ref_mut = RefMut::map(buffer.borrow_mut(), |maybe_buffer: &mut Option<Vec<u8>>| {
maybe_buffer.get_or_insert_with(|| Vec::with_capacity(ENCLAVE_MESSAGE_BUFFER_SIZE))
});
let buffer_mut: &mut Vec<u8> = buffer_ref_mut.as_mut();
@ -101,12 +90,10 @@ pub fn kbupd_send(message: EnclaveMessage) {
}
pub fn kbupd_send_flush() {
let maybe_buffer = with_buffer(|buffer_tls| {
std::mem::replace(&mut *buffer_tls.borrow_mut(), Default::default())
});
let maybe_buffer = with_buffer(|buffer_tls| std::mem::replace(&mut *buffer_tls.borrow_mut(), Default::default()));
let mut buffer = match maybe_buffer {
Some(buffer) => buffer,
None => return,
None => return,
};
if !buffer.is_empty() {
let ocall_res = unsafe { kbupd_enclave_ocall_recv_enclave_msg(buffer.as_ptr(), buffer.len()) };
@ -140,14 +127,13 @@ impl AsRef<[u8]> for ECallSlice {
#[cfg(test)]
mod tests {
use super::*;
use super::super::mocks;
use super::*;
use mockers::*;
struct MockKbupdService {}
impl KbupdService for MockKbupdService {
fn untrusted_message(&mut self, _message: UntrustedMessage) {
}
fn untrusted_message(&mut self, _message: UntrustedMessage) {}
}
#[test]
@ -160,16 +146,34 @@ mod tests {
fn kbupd_enclave_recv_untrusted_msg_bad() {
let bad_requests: &[&[u8]] = &[
// bad tag 0, types 0..=7, truncated
&[0x00], &[0x01], &[0x02], &[0x03], &[0x04], &[0x05], &[0x06], &[0x07],
&[0x00],
&[0x01],
&[0x02],
&[0x03],
&[0x04],
&[0x05],
&[0x06],
&[0x07],
// tag 1, bad types 0..=1, truncated
&[0x08], &[0x09],
&[0x08],
&[0x09],
// tag 1, type 2, truncated
&[0x0A],
// tag 1, bad types 3..=7, truncated
&[0x0B], &[0x0C], &[0x0D], &[0x0E], &[0x0F],
&[0x0B],
&[0x0C],
&[0x0D],
&[0x0E],
&[0x0F],
// tag 2, types 0..=7, truncated
&[0x10], &[0x11], &[0x12], &[0x13], &[0x14], &[0x15], &[0x16], &[0x17],
&[0x10],
&[0x11],
&[0x12],
&[0x13],
&[0x14],
&[0x15],
&[0x16],
&[0x17],
// tag 1, bad type 0
&[0x08, 0x00],
// tag 1, type 2, length 1, truncated
@ -179,16 +183,13 @@ mod tests {
// tag 1, type 2, length 1 (overlong varint), truncated
&[0x0A, 0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x00],
&[0x0A, 0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02],
// tag 2, type 0, bad varints
&[0x10, 0x80],
&[0x10, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80],
&[0x10, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x00],
// bad tag 0 (overlong varint), type 0
&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x00, 0x00],
&[0x80, 0x80, 0x80, 0x00, 0x00],
// bad tag 2^32, type 0
&[0x80, 0x80, 0x80, 0x80, 0x10, 0x00],
// bad tag 2^64-1, type 0
@ -200,14 +201,12 @@ mod tests {
let null_requests: &[&[u8]] = &[
// empty
&[],
// tag 1, type 2, length 0 (overlong varint)
&[0x0A, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x00],
// tag 1, type 2, length 0 (overlong varint, extra bits ignored)
&[0x0A, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02],
// tag 1 (overlong varint), type 2, length 0
&[0x8A, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x00, 0x00],
// tag 2 (overlong varint), type 0
&[0x90, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x00, 0x00],
// tag 2 (overlong varint, extra bits ignored), type 0

View File

@ -5,19 +5,17 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//
use std::cell::{RefCell};
use std;
use std::cell::RefCell;
use mockers::*;
use mockers::matchers::*;
use mockers::*;
use mockers_derive::mocked;
use prost::{Message};
use prost::Message;
use crate::protobufs;
use super::bindgen_wrapper::{
sgx_status_t,
};
use super::bindgen_wrapper::sgx_status_t;
//
// mock extern "C" functions
@ -40,10 +38,19 @@ pub trait KbupdEnclaveOcallAlloc {
}
impl MatchArg<protobufs::kbupd::enclave_message::Inner> for Box<dyn MatchArg<protobufs::kbupd::enclave_message::Inner>> {
fn matches(&self, arg: &protobufs::kbupd::enclave_message::Inner) -> Result<(), String> { (**self).matches(arg) }
fn describe(&self) -> String { (**self).describe() }
fn matches(&self, arg: &protobufs::kbupd::enclave_message::Inner) -> Result<(), String> {
(**self).matches(arg)
}
fn describe(&self) -> String {
(**self).describe()
}
}
pub fn expect_enclave_messages(scenario: &Scenario, matchers: impl IntoIterator<Item = Box<dyn MatchArg<protobufs::kbupd::enclave_message::Inner>>>) {
pub fn expect_enclave_messages(
scenario: &Scenario,
matchers: impl IntoIterator<Item = Box<dyn MatchArg<protobufs::kbupd::enclave_message::Inner>>>,
)
{
let mock = test_ffi::mock_for(&KBUPD_ENCLAVE_OCALL_RECV_ENCLAVE_MSG, &scenario);
for matcher in matchers {
scenario.expect(mock.enclave_message(matcher).and_return(()));
@ -51,15 +58,13 @@ pub fn expect_enclave_messages(scenario: &Scenario, matchers: impl IntoIterator<
scenario.expect(mock.kbupd_enclave_ocall_recv_enclave_msg().and_return_clone(0).times(..));
}
pub fn expect_kbupd_enclave_ocall_alloc(scenario: &Scenario,
request_size: usize,
returned_ptr: *mut libc::c_void,
returned_size: usize) {
pub fn expect_kbupd_enclave_ocall_alloc(scenario: &Scenario, request_size: usize, returned_ptr: *mut libc::c_void, returned_size: usize) {
assert_ne!(request_size, 0);
let mock = test_ffi::mock_for(&KBUPD_ENCLAVE_OCALL_ALLOC, &scenario);
scenario.expect(mock.kbupd_enclave_ocall_alloc(
eq(request_size)
).and_return(Ok((returned_ptr, returned_size))));
scenario.expect(
mock.kbupd_enclave_ocall_alloc(eq(request_size))
.and_return(Ok((returned_ptr, returned_size))),
);
sgx_ffi::mocks::expect_sgx_is_outside_enclave(scenario, returned_ptr as *const libc::c_void, returned_size, true);
}
@ -110,14 +115,12 @@ pub mod impls {
assert!(!p_size_in_out.is_null());
let size = unsafe { *p_size_in_out };
assert_ne!(size, 0);
let res = KBUPD_ENCLAVE_OCALL_ALLOC.with(|mock| {
(mock.borrow().as_ref().expect("no mock for kbupd_enclave_ocall_alloc"))
.kbupd_enclave_ocall_alloc(size)
});
let res = KBUPD_ENCLAVE_OCALL_ALLOC
.with(|mock| (mock.borrow().as_ref().expect("no mock for kbupd_enclave_ocall_alloc")).kbupd_enclave_ocall_alloc(size));
match res {
Ok((ptr, size)) => {
unsafe {
*p_ptr_out = ptr;
*p_ptr_out = ptr;
*p_size_in_out = size;
}
0
@ -131,13 +134,18 @@ pub mod impls {
arg1: *mut ::std::os::raw::c_uchar,
arg2: *const ::std::os::raw::c_uchar,
arg3: *const ::std::os::raw::c_uchar,
) -> ::std::os::raw::c_int {
) -> ::std::os::raw::c_int
{
let arg1 = unsafe { std::slice::from_raw_parts_mut(arg1, 32) };
let arg2 = unsafe { std::slice::from_raw_parts(arg2, 32) };
let arg3 = unsafe { std::slice::from_raw_parts(arg3, 32) };
test_ffi::read_rand(arg1);
arg2.iter().for_each(|p| unsafe { std::ptr::read_volatile(p); });
arg3.iter().for_each(|p| unsafe { std::ptr::read_volatile(p); });
arg2.iter().for_each(|p| unsafe {
std::ptr::read_volatile(p);
});
arg3.iter().for_each(|p| unsafe {
std::ptr::read_volatile(p);
});
0
}
}

View File

@ -6,10 +6,12 @@
//
#[allow(dead_code, non_camel_case_types, non_upper_case_globals, non_snake_case, improper_ctypes, clippy::all, clippy::pedantic, clippy::integer_arithmetic)]
#[rustfmt::skip]
mod bindgen_wrapper;
pub mod ecalls;
#[cfg(not(any(test, feature = "test")))]
mod panic;
pub mod snow_resolver;
#[cfg(test)] pub mod mocks;
#[cfg(test)]
pub mod mocks;

View File

@ -5,10 +5,10 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//
use alloc::string::{ToString};
use core::panic::{PanicInfo};
use alloc::string::ToString;
use core::panic::PanicInfo;
use super::bindgen_wrapper::{kbupd_enclave_ocall_panic};
use super::bindgen_wrapper::kbupd_enclave_ocall_panic;
#[panic_handler]
fn panic(info: &PanicInfo<'_>) -> ! {

View File

@ -5,15 +5,15 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//
use core::convert::{TryInto};
use core::convert::TryInto;
use crate::prelude::*;
use rand_core::{RngCore, CryptoRng};
use rand_core::{CryptoRng, RngCore};
use sgxsd_ffi::*;
use snow::params::*;
use snow::resolvers::*;
use snow::types::*;
use snow::params::*;
#[derive(Default)]
pub struct SnowResolver;
@ -48,21 +48,21 @@ impl CryptoResolver for SnowResolver {
fn resolve_dh(&self, choice: &DHChoice) -> Option<Box<dyn Dh>> {
match *choice {
DHChoice::Curve25519 => Some(Box::new(SnowDh25519::default())),
_ => None,
_ => None,
}
}
fn resolve_hash(&self, choice: &HashChoice) -> Option<Box<dyn Hash>> {
match *choice {
HashChoice::SHA256 => Some(Box::new(SnowHashSHA256::default())),
_ => None,
HashChoice::SHA256 => Some(Box::new(SnowHashSHA256::default())),
_ => None,
}
}
fn resolve_cipher(&self, choice: &CipherChoice) -> Option<Box<dyn Cipher>> {
match *choice {
CipherChoice::AESGCM => Some(Box::new(SnowCipherAESGCM::default())),
_ => None,
CipherChoice::AESGCM => Some(Box::new(SnowCipherAESGCM::default())),
_ => None,
}
}
}
@ -74,10 +74,21 @@ impl CryptoResolver for SnowResolver {
impl Random for SnowRdRand {}
impl CryptoRng for SnowRdRand {}
impl RngCore for SnowRdRand {
fn next_u32(&mut self) -> u32 { RdRand.next_u32() }
fn next_u64(&mut self) -> u64 { RdRand.next_u64() }
fn fill_bytes(&mut self, dest: &mut [u8]) { RdRand.fill_bytes(dest) }
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { RdRand.try_fill_bytes(dest) }
fn next_u32(&mut self) -> u32 {
RdRand.next_u32()
}
fn next_u64(&mut self) -> u64 {
RdRand.next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
RdRand.fill_bytes(dest)
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> {
RdRand.try_fill_bytes(dest)
}
}
//
@ -85,27 +96,40 @@ impl RngCore for SnowRdRand {
//
impl Dh for SnowDh25519 {
fn name(&self) -> &'static str { "25519" }
fn pub_len(&self) -> usize { 32 }
fn priv_len(&self) -> usize { 32 }
fn name(&self) -> &'static str {
"25519"
}
fn pub_len(&self) -> usize {
32
}
fn priv_len(&self) -> usize {
32
}
fn set(&mut self, privkey: &[u8]) {
let privkey: &[u8; 32] = privkey.try_into().unwrap_or_else(|_| panic!("overflow"));
self.key.set_key(privkey);
}
fn generate(&mut self, rng: &mut dyn Random) {
self.key.generate(rng);
}
fn pubkey(&self) -> &[u8] {
self.key.pubkey()
}
fn privkey(&self) -> &[u8] {
self.key.privkey()
}
fn dh(&self, pubkey: &[u8], out: &mut [u8]) -> Result<(), ()> {
let pubkey: &[u8] = pubkey.get(..32).unwrap_or_else(|| panic!("overflow"));
let pubkey: &[u8; 32] = pubkey.try_into().unwrap_or_else(|_| static_unreachable!());
let out: &mut [u8] = out.get_mut(..32).unwrap_or_else(|| panic!("overflow"));
let out: &mut [u8; 32] = out.try_into().unwrap_or_else(|_| static_unreachable!());
let pubkey: &[u8] = pubkey.get(..32).unwrap_or_else(|| panic!("overflow"));
let pubkey: &[u8; 32] = pubkey.try_into().unwrap_or_else(|_| static_unreachable!());
let out: &mut [u8] = out.get_mut(..32).unwrap_or_else(|| panic!("overflow"));
let out: &mut [u8; 32] = out.try_into().unwrap_or_else(|_| static_unreachable!());
self.key.dh(pubkey, out);
Ok(())
}
@ -116,22 +140,26 @@ impl Dh for SnowDh25519 {
//
impl Cipher for SnowCipherAESGCM {
fn name(&self) -> &'static str { "AESGCM" }
fn name(&self) -> &'static str {
"AESGCM"
}
fn set(&mut self, key: &[u8]) {
let key: &[u8; 32] = key.try_into().unwrap_or_else(|_| panic!("overflow"));
self.key.set_key(key);
}
fn encrypt(&self, nonce: u64, authtext: &[u8], plaintext: &[u8], out: &mut [u8]) -> usize {
let (text_in_out, out) = out.split_at_mut(plaintext.len());
let out: &mut [u8] = out.get_mut(..16).unwrap_or_else(|| panic!("overflow"));
let out: &mut [u8] = out.get_mut(..16).unwrap_or_else(|| panic!("overflow"));
let out: &mut [u8; 16] = out.try_into().unwrap_or_else(|_| static_unreachable!());
text_in_out.copy_from_slice(plaintext);
let mut mac = AesGcmMac::default();
let mut iv = AesGcmIv::default();
let iv_data: &mut [u8] = iv.data.get_mut(4..).unwrap_or_else(|| static_unreachable!());
let mut mac = AesGcmMac::default();
let mut iv = AesGcmIv::default();
let iv_data: &mut [u8] = iv.data.get_mut(4..).unwrap_or_else(|| static_unreachable!());
let iv_data: &mut [u8; 8] = iv_data.try_into().unwrap_or_else(|_| static_unreachable!());
*iv_data = nonce.to_be_bytes();
*iv_data = nonce.to_be_bytes();
match self.key.encrypt(text_in_out, authtext, &iv, &mut mac) {
Ok(()) => {
@ -146,22 +174,24 @@ impl Cipher for SnowCipherAESGCM {
}
fn decrypt(&self, nonce: u64, authtext: &[u8], ciphertext: &[u8], out: &mut [u8]) -> Result<usize, ()> {
let ciphertext_len = ciphertext.len().checked_sub(16)
.unwrap_or_else(|| panic!("overflow"));
let ciphertext_len = ciphertext.len().checked_sub(16).unwrap_or_else(|| panic!("overflow"));
let (ciphertext, ciphertext_mac_data) = ciphertext.split_at(ciphertext_len);
let ciphertext_mac_data: &[u8; 16] = ciphertext_mac_data.try_into().unwrap_or_else(|_| unreachable!());
let ciphertext_mac_data: &[u8; 16] = ciphertext_mac_data.try_into().unwrap_or_else(|_| unreachable!());
let (in_out_text, _) = out.split_at_mut(ciphertext.len());
in_out_text.copy_from_slice(ciphertext);
let mac = AesGcmMac { data: *ciphertext_mac_data };
let mut iv = AesGcmIv::default();
let iv_data: &mut [u8] = iv.data.get_mut(4..).unwrap_or_else(|| static_unreachable!());
let mac = AesGcmMac {
data: *ciphertext_mac_data,
};
let mut iv = AesGcmIv::default();
let iv_data: &mut [u8] = iv.data.get_mut(4..).unwrap_or_else(|| static_unreachable!());
let iv_data: &mut [u8; 8] = iv_data.try_into().unwrap_or_else(|_| static_unreachable!());
*iv_data = nonce.to_be_bytes();
*iv_data = nonce.to_be_bytes();
self.key.decrypt(in_out_text, authtext, &iv, &mac)
.map(|()| in_out_text.len())
.map_err(drop)
self.key
.decrypt(in_out_text, authtext, &iv, &mac)
.map(|()| in_out_text.len())
.map_err(drop)
}
}
@ -170,15 +200,26 @@ impl Cipher for SnowCipherAESGCM {
//
impl Hash for SnowHashSHA256 {
fn name(&self) -> &'static str { "SHA256" }
fn block_len(&self) -> usize { 64 }
fn hash_len(&self) -> usize { SHA256Context::hash_len() }
fn name(&self) -> &'static str {
"SHA256"
}
fn block_len(&self) -> usize {
64
}
fn hash_len(&self) -> usize {
SHA256Context::hash_len()
}
fn reset(&mut self) {
self.context.reset();
}
fn input(&mut self, data: &[u8]) {
self.context.update(data);
}
fn result(&mut self, out: &mut [u8]) {
let out: &mut [u8] = out.get_mut(..SHA256Context::hash_len()).unwrap_or_else(|| panic!("overflow"));
let out: &mut [u8; SHA256Context::hash_len()] = out.try_into().unwrap_or_else(|_| static_unreachable!());
@ -190,8 +231,8 @@ impl Hash for SnowHashSHA256 {
pub mod tests {
use super::*;
use mockers::*;
use mockers::matchers::*;
use mockers::*;
#[test]
fn resolve_rng_test() {
@ -208,22 +249,22 @@ pub mod tests {
#[test]
fn resolve_dh_test() {
let dh = SnowResolver.resolve_dh(&DHChoice::Curve25519);
let dh = SnowResolver.resolve_dh(&DHChoice::Curve25519);
let rand = SnowResolver.resolve_rng();
assert!(dh.is_some());
assert!(rand.is_some());
if let (Some(mut dh), Some(mut rand)) = (dh, rand) {
dh.generate(&mut *rand);
let mut privkey = [0; 32];
let mut pubkey = [0; 32];
let mut pubkey = [0; 32];
privkey.copy_from_slice(dh.privkey());
pubkey.copy_from_slice(dh.pubkey());
assert_ne!(&privkey, &[0; 32]);
assert_ne!(&pubkey, &[0; 32]);
assert_ne!(&pubkey, &[0; 32]);
dh.set(&test_ffi::rand_bytes([0; 32]));
assert_ne!(dh.privkey(), privkey);
assert_ne!(dh.pubkey(), pubkey);
assert_ne!(dh.pubkey(), pubkey);
let mut res = vec![0; 32];
dh.dh(&test_ffi::rand_bytes(vec![0; 32]), &mut res).unwrap();
@ -274,10 +315,10 @@ pub mod tests {
}
macro_rules! eq_vec {
($vec:expr) => ({
($vec:expr) => {{
let vec: Vec<u8> = $vec.clone();
check(move |slice: &&[u8]| slice == &&vec[..])
})
}};
}
#[test]
@ -290,9 +331,9 @@ pub mod tests {
test_ffi::clear(&sgxsd_ffi::mocks::SGXSD_AES_GCM_ENCRYPT);
let authtext = vec![0; 100];
let plaintext = vec![0; 100];
let mut out = vec![0; plaintext.len() + 16];
let authtext = vec![0; 100];
let plaintext = vec![0; 100];
let mut out = vec![0; plaintext.len() + 16];
assert_eq!(cipher.encrypt(0, &authtext, &plaintext, &mut out), out.len());
assert_ne!(&vec![0; 8][..], &out[..8]);
assert_ne!(&vec![0; 8][..], &out[(out.len() - 8)..]);
@ -300,7 +341,7 @@ pub mod tests {
test_ffi::clear(&sgxsd_ffi::mocks::SGXSD_AES_GCM_DECRYPT);
let ciphertext = vec![0; plaintext.len() + 16];
let mut out = vec![0; plaintext.len()];
let mut out = vec![0; plaintext.len()];
assert_eq!(cipher.decrypt(0, &authtext, &ciphertext, &mut out), Ok(out.len()));
assert_ne!(&out[..], &vec![0; out.len()][..]);
}
@ -308,70 +349,68 @@ pub mod tests {
#[test]
fn resolve_cipher_success() {
let scenario = Scenario::new();
let scenario = Scenario::new();
let mut cipher = SnowResolver.resolve_cipher(&CipherChoice::AESGCM).unwrap();
let privkey = test_ffi::rand_bytes(vec![0; 32]);
let authtext = test_ffi::rand_bytes(vec![0; 100]);
let plaintext = test_ffi::rand_bytes(vec![0; 100]);
let privkey = test_ffi::rand_bytes(vec![0; 32]);
let authtext = test_ffi::rand_bytes(vec![0; 100]);
let plaintext = test_ffi::rand_bytes(vec![0; 100]);
let ciphertext = test_ffi::rand_bytes(vec![0; 100]);
cipher.set(&privkey);
let mock = test_ffi::mock_for(&sgxsd_ffi::mocks::SGXSD_AES_GCM_ENCRYPT, &scenario);
scenario.expect(mock.sgxsd_aes_gcm_encrypt(
eq_vec!(privkey),
eq_vec!(plaintext),
any(),
eq_vec!(authtext)
).and_return(Ok(ciphertext.clone())));
scenario.expect(
mock.sgxsd_aes_gcm_encrypt(eq_vec!(privkey), eq_vec!(plaintext), any(), eq_vec!(authtext))
.and_return(Ok(ciphertext.clone())),
);
let mut ciphertext_and_tag_out = vec![0; plaintext.len() + 16];
assert_eq!(cipher.encrypt(0, &authtext, &plaintext, &mut ciphertext_and_tag_out), ciphertext_and_tag_out.len());
assert_eq!(
cipher.encrypt(0, &authtext, &plaintext, &mut ciphertext_and_tag_out),
ciphertext_and_tag_out.len()
);
assert_eq!(&ciphertext_and_tag_out[..ciphertext.len()], &ciphertext[..]);
let mock = test_ffi::mock_for(&sgxsd_ffi::mocks::SGXSD_AES_GCM_DECRYPT, &scenario);
scenario.expect(mock.sgxsd_aes_gcm_decrypt(
eq_vec!(privkey),
eq_vec!(ciphertext),
any(),
eq_vec!(authtext)
).and_return(Ok(plaintext.clone())));
scenario.expect(
mock.sgxsd_aes_gcm_decrypt(eq_vec!(privkey), eq_vec!(ciphertext), any(), eq_vec!(authtext))
.and_return(Ok(plaintext.clone())),
);
let mut plaintext_out = vec![0; plaintext.len()];
assert_eq!(cipher.decrypt(0, &authtext, &ciphertext_and_tag_out, &mut plaintext_out), Ok(plaintext_out.len()));
assert_eq!(
cipher.decrypt(0, &authtext, &ciphertext_and_tag_out, &mut plaintext_out),
Ok(plaintext_out.len())
);
assert_eq!(&plaintext_out, &plaintext);
}
#[test]
fn resolve_cipher_fail() {
let scenario = Scenario::new();
let scenario = Scenario::new();
let mut cipher = SnowResolver.resolve_cipher(&CipherChoice::AESGCM).unwrap();
let privkey = test_ffi::rand_bytes(vec![0; 32]);
let authtext = test_ffi::rand_bytes(vec![0; 100]);
let plaintext = test_ffi::rand_bytes(vec![0; 100]);
let privkey = test_ffi::rand_bytes(vec![0; 32]);
let authtext = test_ffi::rand_bytes(vec![0; 100]);
let plaintext = test_ffi::rand_bytes(vec![0; 100]);
let ciphertext = test_ffi::rand_bytes(vec![0; 100]);
cipher.set(&privkey);
let mock = test_ffi::mock_for(&sgxsd_ffi::mocks::SGXSD_AES_GCM_ENCRYPT, &scenario);
scenario.expect(mock.sgxsd_aes_gcm_encrypt(
eq_vec!(privkey),
eq_vec!(plaintext),
any(),
eq_vec!(authtext)
).and_return(Err(())));
scenario.expect(
mock.sgxsd_aes_gcm_encrypt(eq_vec!(privkey), eq_vec!(plaintext), any(), eq_vec!(authtext))
.and_return(Err(())),
);
let mut ciphertext_and_tag_out = vec![0; plaintext.len() + 16];
assert_eq!(cipher.encrypt(0, &authtext, &plaintext, &mut ciphertext_and_tag_out), 0);
assert_eq!(&ciphertext_and_tag_out, &vec![0; ciphertext_and_tag_out.len()]);
let mock = test_ffi::mock_for(&sgxsd_ffi::mocks::SGXSD_AES_GCM_DECRYPT, &scenario);
scenario.expect(mock.sgxsd_aes_gcm_decrypt(
eq_vec!(privkey),
eq_vec!(ciphertext),
any(),
eq_vec!(authtext)
).and_return(Err(())));
scenario.expect(
mock.sgxsd_aes_gcm_decrypt(eq_vec!(privkey), eq_vec!(ciphertext), any(), eq_vec!(authtext))
.and_return(Err(())),
);
ciphertext_and_tag_out[..ciphertext.len()].copy_from_slice(&ciphertext);
let mut plaintext_out = ciphertext.clone();

View File

@ -9,8 +9,8 @@
use core::hash::{BuildHasher, SipHasher};
use rand_core::{RngCore};
use sgxsd_ffi::{RdRand};
use rand_core::RngCore;
use sgxsd_ffi::RdRand;
#[derive(Clone)]
pub struct DefaultHasher(u64, u64);
@ -23,6 +23,7 @@ impl Default for DefaultHasher {
impl BuildHasher for DefaultHasher {
type Hasher = SipHasher;
fn build_hasher(&self) -> Self::Hasher {
SipHasher::new_with_keys(self.0, self.1)
}

View File

@ -9,18 +9,14 @@
#![cfg_attr(not(any(test, feature = "test")), no_std)]
#![cfg_attr(not(any(test, feature = "test")), feature(alloc_error_handler))]
#![cfg_attr(not(any(test, feature = "test")), feature(thread_local))]
#![allow(
unused_parens,
clippy::style,
clippy::large_enum_variant,
)]
#![allow(unused_parens, clippy::style, clippy::large_enum_variant)]
#![warn(
bare_trait_objects,
elided_lifetimes_in_paths,
trivial_numeric_casts,
variant_size_differences,
clippy::integer_arithmetic,
clippy::wildcard_enum_match_arm,
clippy::wildcard_enum_match_arm
)]
#![deny(
clippy::cast_possible_truncation,
@ -48,7 +44,7 @@
clippy::unimplemented,
clippy::use_debug,
clippy::use_self,
clippy::use_underscore_binding,
clippy::use_underscore_binding
)]
extern crate alloc;
@ -73,46 +69,46 @@ mod prelude;
#[allow(clippy::all, clippy::pedantic, clippy::integer_arithmetic)]
mod protobufs;
mod protobufs_impl;
mod service;
mod storage;
mod raft;
mod remote;
mod remote_group;
mod service;
mod storage;
mod util;
pub use crate::ffi::ecalls::{
kbupd_send,
kbupd_send_flush,
};
pub use crate::ffi::ecalls::{kbupd_send, kbupd_send_flush};
pub mod external {
use sgx_ffi::sgx::{SgxStatus};
use sgxsd_ffi::ecalls::{SgxsdServer};
use sgx_ffi::sgx::SgxStatus;
use sgxsd_ffi::ecalls::SgxsdServer;
use crate::service::main;
#[no_mangle]
pub extern "C" fn sgxsd_enclave_server_init(p_args: *const <main::SgxsdState as SgxsdServer>::InitArgs,
pp_state: *mut *mut main::SgxsdState)
-> SgxStatus
pub extern "C" fn sgxsd_enclave_server_init(
p_args: *const <main::SgxsdState as SgxsdServer>::InitArgs,
pp_state: *mut *mut main::SgxsdState,
) -> SgxStatus
{
sgxsd_ffi::ecalls::sgxsd_enclave_server_init(p_args, pp_state)
}
#[no_mangle]
pub extern "C" fn sgxsd_enclave_server_handle_call(p_args: *const <main::SgxsdState as SgxsdServer>::HandleCallArgs,
msg_buf: sgxsd_ffi::ecalls::sgxsd_msg_buf_t,
mut from: sgxsd_ffi::ecalls::sgxsd_msg_from_t,
pp_state: *mut *mut main::SgxsdState)
-> SgxStatus
pub extern "C" fn sgxsd_enclave_server_handle_call(
p_args: *const <main::SgxsdState as SgxsdServer>::HandleCallArgs,
msg_buf: sgxsd_ffi::ecalls::sgxsd_msg_buf_t,
mut from: sgxsd_ffi::ecalls::sgxsd_msg_from_t,
pp_state: *mut *mut main::SgxsdState,
) -> SgxStatus
{
sgxsd_ffi::ecalls::sgxsd_enclave_server_handle_call(p_args, msg_buf, &mut from, pp_state)
}
#[no_mangle]
pub extern "C" fn sgxsd_enclave_server_terminate(p_args: *const <main::SgxsdState as SgxsdServer>::TerminateArgs,
p_state: *mut main::SgxsdState)
-> SgxStatus
pub extern "C" fn sgxsd_enclave_server_terminate(
p_args: *const <main::SgxsdState as SgxsdServer>::TerminateArgs,
p_state: *mut main::SgxsdState,
) -> SgxStatus
{
sgxsd_ffi::ecalls::sgxsd_enclave_server_terminate(p_args, p_state)
}

View File

@ -5,8 +5,8 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//
use intrusive_collections::LinkedList;
use intrusive_collections::*;
use intrusive_collections::{LinkedList};
use std::rc::*;
@ -34,14 +34,16 @@ impl<T> Lru<T> {
length: Default::default(),
}
}
pub fn len(&self) -> usize {
self.length
}
pub fn push_back(&mut self, item: T) -> Weak<LruEntry<T>> {
let lru_entry = Rc::new(LruEntry {
item,
token: Rc::downgrade(&self.token),
link: Default::default()
link: Default::default(),
});
let lru_entry_weak = Rc::downgrade(&lru_entry);
@ -49,6 +51,7 @@ impl<T> Lru<T> {
self.length = self.length.saturating_add(1);
lru_entry_weak
}
pub fn bump(&mut self, lru_entry_weak: &Weak<LruEntry<T>>) -> bool {
if let Some(lru_entry) = lru_entry_weak.upgrade() {
// check that lru_entry is a member of self.list
@ -64,9 +67,7 @@ impl<T> Lru<T> {
}
// safety: lru_entry must be a member of self.list
let mut lru_cursor = unsafe {
self.list.cursor_mut_from_ptr(lru_entry.as_ref())
};
let mut lru_cursor = unsafe { self.list.cursor_mut_from_ptr(lru_entry.as_ref()) };
if let Some(lru_entry_rc) = lru_cursor.remove() {
self.list.push_back(lru_entry_rc);
} else {
@ -77,6 +78,7 @@ impl<T> Lru<T> {
false
}
}
pub fn pop_front(&mut self) -> Option<Rc<LruEntry<T>>> {
if let Some(lru_entry) = self.list.pop_front() {
self.length = self.length.saturating_sub(1);
@ -99,10 +101,9 @@ impl<T> LruEntry<T> {
}
}
impl<'a, T> IntoIterator for &'a Lru<T> {
type Item = &'a LruEntry<T>;
type IntoIter = linked_list::Iter<'a, LruAdapter<T>>;
type Item = &'a LruEntry<T>;
fn into_iter(self) -> linked_list::Iter<'a, LruAdapter<T>> {
self.list.iter()

View File

@ -77,7 +77,7 @@ macro_rules! assert_match {
}
macro_rules! static_unreachable {
() => ({
() => {{
#[cfg(not(debug_assertions))]
{
extern "C" {
@ -87,5 +87,5 @@ macro_rules! static_unreachable {
}
#[cfg(debug_assertions)]
unreachable!()
})
}};
}

View File

@ -5,8 +5,8 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//
pub use alloc::{format, vec};
pub use alloc::borrow::{ToOwned};
pub use alloc::boxed::{Box};
pub use alloc::borrow::ToOwned;
pub use alloc::boxed::Box;
pub use alloc::string::{String, ToString};
pub use alloc::vec::{Vec};
pub use alloc::vec::Vec;
pub use alloc::{format, vec};

File diff suppressed because it is too large Load Diff

View File

@ -1,20 +1,27 @@
//
// Copyright (C) 2019, 2020 Signal Messenger, LLC.
// All rights reserved.
//
// SPDX-License-Identifier: AGPL-3.0-or-later
//
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Request {
#[prost(message, optional, tag="1")]
pub backup: ::std::option::Option<BackupRequest>,
#[prost(message, optional, tag="2")]
#[prost(message, optional, tag = "1")]
pub backup: ::std::option::Option<BackupRequest>,
#[prost(message, optional, tag = "2")]
pub restore: ::std::option::Option<RestoreRequest>,
#[prost(message, optional, tag="3")]
pub delete: ::std::option::Option<DeleteRequest>,
#[prost(message, optional, tag = "3")]
pub delete: ::std::option::Option<DeleteRequest>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Response {
#[prost(message, optional, tag="1")]
pub backup: ::std::option::Option<BackupResponse>,
#[prost(message, optional, tag="2")]
#[prost(message, optional, tag = "1")]
pub backup: ::std::option::Option<BackupResponse>,
#[prost(message, optional, tag = "2")]
pub restore: ::std::option::Option<RestoreResponse>,
#[prost(message, optional, tag="3")]
pub delete: ::std::option::Option<DeleteResponse>,
#[prost(message, optional, tag = "3")]
pub delete: ::std::option::Option<DeleteResponse>,
}
//
// backup
@ -22,35 +29,35 @@ pub struct Response {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BackupRequest {
#[prost(bytes, optional, tag="1")]
#[prost(bytes, optional, tag = "1")]
pub service_id: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(bytes, optional, tag="2")]
pub backup_id: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(bytes, optional, tag="3")]
pub nonce: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(uint64, optional, tag="4")]
#[prost(bytes, optional, tag = "2")]
pub backup_id: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(bytes, optional, tag = "3")]
pub nonce: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(uint64, optional, tag = "4")]
pub valid_from: ::std::option::Option<u64>,
#[prost(bytes, optional, tag="5")]
pub data: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(bytes, optional, tag="6")]
pub pin: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(uint32, optional, tag="7")]
pub tries: ::std::option::Option<u32>,
#[prost(bytes, optional, tag = "5")]
pub data: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(bytes, optional, tag = "6")]
pub pin: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(uint32, optional, tag = "7")]
pub tries: ::std::option::Option<u32>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BackupResponse {
#[prost(enumeration="backup_response::Status", optional, tag="1")]
#[prost(enumeration = "backup_response::Status", optional, tag = "1")]
pub status: ::std::option::Option<i32>,
#[prost(bytes, optional, tag="2")]
pub nonce: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(bytes, optional, tag = "2")]
pub nonce: ::std::option::Option<std::vec::Vec<u8>>,
}
pub mod backup_response {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
#[repr(i32)]
pub enum Status {
Ok = 1,
Ok = 1,
AlreadyExists = 2,
NotYetValid = 3,
NotYetValid = 3,
}
}
//
@ -59,37 +66,37 @@ pub mod backup_response {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct RestoreRequest {
#[prost(bytes, optional, tag="1")]
#[prost(bytes, optional, tag = "1")]
pub service_id: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(bytes, optional, tag="2")]
pub backup_id: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(bytes, optional, tag="3")]
pub nonce: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(uint64, optional, tag="4")]
#[prost(bytes, optional, tag = "2")]
pub backup_id: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(bytes, optional, tag = "3")]
pub nonce: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(uint64, optional, tag = "4")]
pub valid_from: ::std::option::Option<u64>,
#[prost(bytes, optional, tag="5")]
pub pin: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(bytes, optional, tag = "5")]
pub pin: ::std::option::Option<std::vec::Vec<u8>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct RestoreResponse {
#[prost(enumeration="restore_response::Status", optional, tag="1")]
#[prost(enumeration = "restore_response::Status", optional, tag = "1")]
pub status: ::std::option::Option<i32>,
#[prost(bytes, optional, tag="2")]
pub nonce: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(bytes, optional, tag="3")]
pub data: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(uint32, optional, tag="4")]
pub tries: ::std::option::Option<u32>,
#[prost(bytes, optional, tag = "2")]
pub nonce: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(bytes, optional, tag = "3")]
pub data: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(uint32, optional, tag = "4")]
pub tries: ::std::option::Option<u32>,
}
pub mod restore_response {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
#[repr(i32)]
pub enum Status {
Ok = 1,
Ok = 1,
NonceMismatch = 2,
NotYetValid = 3,
Missing = 4,
PinMismatch = 5,
NotYetValid = 3,
Missing = 4,
PinMismatch = 5,
}
}
//
@ -98,11 +105,10 @@ pub mod restore_response {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct DeleteRequest {
#[prost(bytes, optional, tag="1")]
#[prost(bytes, optional, tag = "1")]
pub service_id: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(bytes, optional, tag="2")]
pub backup_id: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(bytes, optional, tag = "2")]
pub backup_id: ::std::option::Option<std::vec::Vec<u8>>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct DeleteResponse {
}
pub struct DeleteResponse {}

View File

@ -1,6 +1,13 @@
//
// Copyright (C) 2019, 2020 Signal Messenger, LLC.
// All rights reserved.
//
// SPDX-License-Identifier: AGPL-3.0-or-later
//
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct SecretBytes {
#[prost(bytes, required, tag="1")]
#[prost(bytes, required, tag = "1")]
pub data: std::vec::Vec<u8>,
}
//
@ -9,168 +16,168 @@ pub struct SecretBytes {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TransactionData {
#[prost(oneof="transaction_data::Inner", tags="1, 2, 3, 4, 5, 6, 7, 8, 9")]
#[prost(oneof = "transaction_data::Inner", tags = "1, 2, 3, 4, 5, 6, 7, 8, 9")]
pub inner: ::std::option::Option<transaction_data::Inner>,
}
pub mod transaction_data {
#[derive(Clone, PartialEq, ::prost::Oneof)]
pub enum Inner {
#[prost(message, tag="1")]
#[prost(message, tag = "1")]
FrontendRequest(super::FrontendRequestTransaction),
#[prost(message, tag="2")]
#[prost(message, tag = "2")]
StartXfer(super::StartXferTransaction),
#[prost(message, tag="3")]
#[prost(message, tag = "3")]
SetSid(super::SetSidTransaction),
#[prost(message, tag="4")]
#[prost(message, tag = "4")]
RemoveChunk(super::RemoveChunkTransaction),
#[prost(message, tag="5")]
#[prost(message, tag = "5")]
ApplyChunk(super::ApplyChunkTransaction),
#[prost(message, tag="6")]
#[prost(message, tag = "6")]
PauseXfer(super::PauseXferTransaction),
#[prost(message, tag="7")]
#[prost(message, tag = "7")]
ResumeXfer(super::ResumeXferTransaction),
#[prost(message, tag="8")]
#[prost(message, tag = "8")]
FinishXfer(super::FinishXferTransaction),
#[prost(message, tag="9")]
#[prost(message, tag = "9")]
SetTime(super::SetTimeTransaction),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct FrontendRequestTransaction {
#[prost(bytes, required, tag="1")]
#[prost(bytes, required, tag = "1")]
pub from_node_id: std::vec::Vec<u8>,
#[prost(uint64, required, tag="2")]
pub request_id: u64,
#[prost(oneof="frontend_request_transaction::Transaction", tags="3, 4, 5, 6")]
pub transaction: ::std::option::Option<frontend_request_transaction::Transaction>,
#[prost(uint64, required, tag = "2")]
pub request_id: u64,
#[prost(oneof = "frontend_request_transaction::Transaction", tags = "3, 4, 5, 6")]
pub transaction: ::std::option::Option<frontend_request_transaction::Transaction>,
}
pub mod frontend_request_transaction {
#[derive(Clone, PartialEq, ::prost::Oneof)]
pub enum Transaction {
#[prost(message, tag="3")]
#[prost(message, tag = "3")]
Create(super::CreateBackupTransaction),
#[prost(message, tag="4")]
#[prost(message, tag = "4")]
Backup(super::BackupTransaction),
#[prost(message, tag="5")]
#[prost(message, tag = "5")]
Restore(super::RestoreTransaction),
#[prost(message, tag="6")]
#[prost(message, tag = "6")]
Delete(super::DeleteBackupTransaction),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct CreateBackupTransaction {
#[prost(message, required, tag="1")]
pub backup_id: super::kbupd::BackupId,
#[prost(bytes, required, tag="2")]
#[prost(message, required, tag = "1")]
pub backup_id: super::kbupd::BackupId,
#[prost(bytes, required, tag = "2")]
pub new_creation_nonce: std::vec::Vec<u8>,
#[prost(bytes, required, tag="3")]
pub new_nonce: std::vec::Vec<u8>,
#[prost(bytes, required, tag = "3")]
pub new_nonce: std::vec::Vec<u8>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BackupTransaction {
#[prost(message, required, tag="1")]
pub backup_id: super::kbupd::BackupId,
#[prost(bytes, required, tag="2")]
pub old_nonce: std::vec::Vec<u8>,
#[prost(bytes, required, tag="3")]
#[prost(message, required, tag = "1")]
pub backup_id: super::kbupd::BackupId,
#[prost(bytes, required, tag = "2")]
pub old_nonce: std::vec::Vec<u8>,
#[prost(bytes, required, tag = "3")]
pub new_creation_nonce: std::vec::Vec<u8>,
#[prost(bytes, required, tag="4")]
pub new_nonce: std::vec::Vec<u8>,
#[prost(message, required, tag="5")]
pub data: SecretBytes,
#[prost(message, required, tag="6")]
pub pin: SecretBytes,
#[prost(uint32, required, tag="7")]
pub tries: u32,
#[prost(bytes, required, tag = "4")]
pub new_nonce: std::vec::Vec<u8>,
#[prost(message, required, tag = "5")]
pub data: SecretBytes,
#[prost(message, required, tag = "6")]
pub pin: SecretBytes,
#[prost(uint32, required, tag = "7")]
pub tries: u32,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct RestoreTransaction {
#[prost(message, required, tag="1")]
pub backup_id: super::kbupd::BackupId,
#[prost(bytes, required, tag="2")]
#[prost(message, required, tag = "1")]
pub backup_id: super::kbupd::BackupId,
#[prost(bytes, required, tag = "2")]
pub creation_nonce: std::vec::Vec<u8>,
#[prost(bytes, required, tag="3")]
pub old_nonce: std::vec::Vec<u8>,
#[prost(bytes, required, tag="4")]
pub new_nonce: std::vec::Vec<u8>,
#[prost(message, required, tag="5")]
pub pin: SecretBytes,
#[prost(bytes, required, tag = "3")]
pub old_nonce: std::vec::Vec<u8>,
#[prost(bytes, required, tag = "4")]
pub new_nonce: std::vec::Vec<u8>,
#[prost(message, required, tag = "5")]
pub pin: SecretBytes,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct DeleteBackupTransaction {
#[prost(message, required, tag="1")]
#[prost(message, required, tag = "1")]
pub backup_id: super::kbupd::BackupId,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct StartXferTransaction {
#[prost(bytes, required, tag="1")]
#[prost(bytes, required, tag = "1")]
pub from_node_id: std::vec::Vec<u8>,
#[prost(message, required, tag="2")]
#[prost(message, required, tag = "2")]
pub xfer_request: XferRequest,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct SetSidTransaction {
#[prost(bytes, required, tag="1")]
#[prost(bytes, required, tag = "1")]
pub from_node_id: std::vec::Vec<u8>,
#[prost(message, required, tag="2")]
pub service_id: super::kbupd::ServiceId,
#[prost(message, required, tag = "2")]
pub service_id: super::kbupd::ServiceId,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct RemoveChunkTransaction {
#[prost(bytes, required, tag="1")]
pub from_node_id: std::vec::Vec<u8>,
#[prost(message, required, tag="2")]
#[prost(bytes, required, tag = "1")]
pub from_node_id: std::vec::Vec<u8>,
#[prost(message, required, tag = "2")]
pub xfer_chunk_reply: XferChunkReply,
#[prost(message, required, tag="3")]
pub chunk_last: super::kbupd::BackupId,
#[prost(message, required, tag = "3")]
pub chunk_last: super::kbupd::BackupId,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ApplyChunkTransaction {
#[prost(bytes, required, tag="1")]
pub from_node_id: std::vec::Vec<u8>,
#[prost(message, required, tag="2")]
#[prost(bytes, required, tag = "1")]
pub from_node_id: std::vec::Vec<u8>,
#[prost(message, required, tag = "2")]
pub xfer_chunk_request: XferChunkRequest,
#[prost(message, required, tag="3")]
pub xfer_chunk_reply: XferChunkReply,
#[prost(message, required, tag = "3")]
pub xfer_chunk_reply: XferChunkReply,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct PauseXferTransaction {
#[prost(uint64, required, tag="1")]
#[prost(uint64, required, tag = "1")]
pub request_id: u64,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ResumeXferTransaction {
#[prost(uint64, required, tag="1")]
#[prost(uint64, required, tag = "1")]
pub request_id: u64,
#[prost(message, required, tag="2")]
#[prost(message, required, tag = "2")]
pub chunk_last: super::kbupd::BackupId,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct FinishXferTransaction {
#[prost(uint64, required, tag="1")]
#[prost(uint64, required, tag = "1")]
pub request_id: u64,
#[prost(bool, required, tag="2")]
pub force: bool,
#[prost(bool, required, tag = "2")]
pub force: bool,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct SetTimeTransaction {
#[prost(uint64, required, tag="1")]
#[prost(uint64, required, tag = "1")]
pub now_secs: u64,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct PeerConnectRequest {
#[prost(enumeration="NodeType", required, tag="1")]
pub node_type: i32,
#[prost(message, optional, tag="2")]
#[prost(enumeration = "NodeType", required, tag = "1")]
pub node_type: i32,
#[prost(message, optional, tag = "2")]
pub ias_report: ::std::option::Option<super::kbupd::IasReport>,
#[prost(bytes, required, tag="3")]
#[prost(bytes, required, tag = "3")]
pub noise_data: std::vec::Vec<u8>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct PeerConnectReply {
#[prost(bytes, required, tag="1")]
pub sgx_quote: std::vec::Vec<u8>,
#[prost(bytes, required, tag="2")]
#[prost(bytes, required, tag = "1")]
pub sgx_quote: std::vec::Vec<u8>,
#[prost(bytes, required, tag = "2")]
pub noise_data: std::vec::Vec<u8>,
}
//
@ -178,11 +185,10 @@ pub struct PeerConnectReply {
//
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct EnclaveGetQuoteRequest {
}
pub struct EnclaveGetQuoteRequest {}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct EnclaveGetQuoteReply {
#[prost(bytes, required, tag="1")]
#[prost(bytes, required, tag = "1")]
pub sgx_quote: std::vec::Vec<u8>,
}
//
@ -191,74 +197,74 @@ pub struct EnclaveGetQuoteReply {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct FrontendToReplicaMessage {
#[prost(oneof="frontend_to_replica_message::Inner", tags="1, 2")]
#[prost(oneof = "frontend_to_replica_message::Inner", tags = "1, 2")]
pub inner: ::std::option::Option<frontend_to_replica_message::Inner>,
}
pub mod frontend_to_replica_message {
#[derive(Clone, PartialEq, ::prost::Oneof)]
pub enum Inner {
#[prost(message, tag="1")]
#[prost(message, tag = "1")]
TransactionRequest(super::TransactionRequest),
#[prost(message, tag="2")]
#[prost(message, tag = "2")]
EnclaveGetQuoteRequest(super::EnclaveGetQuoteRequest),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TransactionRequest {
#[prost(uint64, required, tag="1")]
#[prost(uint64, required, tag = "1")]
pub request_id: u64,
#[prost(oneof="transaction_request::Data", tags="2, 3, 4, 5")]
pub data: ::std::option::Option<transaction_request::Data>,
#[prost(oneof = "transaction_request::Data", tags = "2, 3, 4, 5")]
pub data: ::std::option::Option<transaction_request::Data>,
}
pub mod transaction_request {
#[derive(Clone, PartialEq, ::prost::Oneof)]
pub enum Data {
#[prost(message, tag="2")]
#[prost(message, tag = "2")]
Create(super::super::kbupd::CreateBackupRequest),
#[prost(message, tag="3")]
#[prost(message, tag = "3")]
Backup(super::BackupTransactionRequest),
#[prost(message, tag="4")]
#[prost(message, tag = "4")]
Restore(super::RestoreTransactionRequest),
#[prost(message, tag="5")]
#[prost(message, tag = "5")]
Delete(super::DeleteTransactionRequest),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BackupTransactionRequest {
#[prost(bytes, optional, tag="1")]
#[prost(bytes, optional, tag = "1")]
pub service_id: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(message, required, tag="2")]
pub backup_id: super::kbupd::BackupId,
#[prost(bytes, required, tag="3")]
pub nonce: std::vec::Vec<u8>,
#[prost(uint64, required, tag="4")]
#[prost(message, required, tag = "2")]
pub backup_id: super::kbupd::BackupId,
#[prost(bytes, required, tag = "3")]
pub nonce: std::vec::Vec<u8>,
#[prost(uint64, required, tag = "4")]
pub valid_from: u64,
#[prost(message, required, tag="5")]
pub data: SecretBytes,
#[prost(message, required, tag="6")]
pub pin: SecretBytes,
#[prost(uint32, required, tag="7")]
pub tries: u32,
#[prost(message, required, tag = "5")]
pub data: SecretBytes,
#[prost(message, required, tag = "6")]
pub pin: SecretBytes,
#[prost(uint32, required, tag = "7")]
pub tries: u32,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct RestoreTransactionRequest {
#[prost(bytes, optional, tag="1")]
#[prost(bytes, optional, tag = "1")]
pub service_id: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(message, required, tag="2")]
pub backup_id: super::kbupd::BackupId,
#[prost(bytes, required, tag="3")]
pub nonce: std::vec::Vec<u8>,
#[prost(uint64, required, tag="4")]
#[prost(message, required, tag = "2")]
pub backup_id: super::kbupd::BackupId,
#[prost(bytes, required, tag = "3")]
pub nonce: std::vec::Vec<u8>,
#[prost(uint64, required, tag = "4")]
pub valid_from: u64,
#[prost(message, required, tag="5")]
pub pin: SecretBytes,
#[prost(message, required, tag = "5")]
pub pin: SecretBytes,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct DeleteTransactionRequest {
#[prost(bytes, optional, tag="1")]
#[prost(bytes, optional, tag = "1")]
pub service_id: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(message, required, tag="2")]
pub backup_id: super::kbupd::BackupId,
#[prost(message, required, tag = "2")]
pub backup_id: super::kbupd::BackupId,
}
//
// replica to frontend
@ -266,157 +272,153 @@ pub struct DeleteTransactionRequest {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ReplicaToFrontendMessage {
#[prost(oneof="replica_to_frontend_message::Inner", tags="1, 2")]
#[prost(oneof = "replica_to_frontend_message::Inner", tags = "1, 2")]
pub inner: ::std::option::Option<replica_to_frontend_message::Inner>,
}
pub mod replica_to_frontend_message {
#[derive(Clone, PartialEq, ::prost::Oneof)]
pub enum Inner {
#[prost(message, tag="1")]
#[prost(message, tag = "1")]
TransactionReply(super::TransactionReply),
#[prost(message, tag="2")]
#[prost(message, tag = "2")]
EnclaveGetQuoteReply(super::EnclaveGetQuoteReply),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TransactionReply {
#[prost(uint64, required, tag="1")]
#[prost(uint64, required, tag = "1")]
pub request_id: u64,
#[prost(oneof="transaction_reply::Data", tags="2, 3, 4, 5, 6, 7, 8, 9, 10")]
pub data: ::std::option::Option<transaction_reply::Data>,
#[prost(oneof = "transaction_reply::Data", tags = "2, 3, 4, 5, 6, 7, 8, 9, 10")]
pub data: ::std::option::Option<transaction_reply::Data>,
}
pub mod transaction_reply {
#[derive(Clone, PartialEq, ::prost::Oneof)]
pub enum Data {
#[prost(message, tag="2")]
#[prost(message, tag = "2")]
ClientResponse(super::super::kbupd_client::Response),
#[prost(message, tag="3")]
#[prost(message, tag = "3")]
CreateBackupReply(super::super::kbupd::CreateBackupReply),
#[prost(message, tag="4")]
#[prost(message, tag = "4")]
DeleteBackupReply(super::super::kbupd::DeleteBackupReply),
#[prost(message, tag="5")]
#[prost(message, tag = "5")]
NotLeader(super::TransactionErrorNotLeader),
#[prost(message, tag="6")]
#[prost(message, tag = "6")]
WrongPartition(super::TransactionErrorWrongPartition),
#[prost(message, tag="7")]
#[prost(message, tag = "7")]
ServiceIdMismatch(super::TransactionErrorServiceIdMismatch),
#[prost(message, tag="8")]
#[prost(message, tag = "8")]
XferInProgress(super::TransactionErrorXferInProgress),
#[prost(message, tag="9")]
#[prost(message, tag = "9")]
InvalidRequest(super::TransactionErrorInvalidRequest),
#[prost(message, tag="10")]
#[prost(message, tag = "10")]
InternalError(super::TransactionErrorInternalError),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TransactionErrorNotLeader {
#[prost(bytes, optional, tag="1")]
#[prost(bytes, optional, tag = "1")]
pub leader_node_id: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(message, required, tag="2")]
pub term: super::raft::TermId,
#[prost(message, required, tag = "2")]
pub term: super::raft::TermId,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TransactionErrorWrongPartition {
#[prost(message, optional, tag="1")]
pub range: ::std::option::Option<super::kbupd::PartitionKeyRangePb>,
#[prost(message, optional, tag="2")]
#[prost(message, optional, tag = "1")]
pub range: ::std::option::Option<super::kbupd::PartitionKeyRangePb>,
#[prost(message, optional, tag = "2")]
pub new_partition: ::std::option::Option<super::kbupd::PartitionConfig>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TransactionErrorServiceIdMismatch {
}
pub struct TransactionErrorServiceIdMismatch {}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TransactionErrorXferInProgress {
}
pub struct TransactionErrorXferInProgress {}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TransactionErrorInvalidRequest {
}
pub struct TransactionErrorInvalidRequest {}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TransactionErrorInternalError {
}
pub struct TransactionErrorInternalError {}
//
// replica to replica
//
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ReplicaToReplicaMessage {
#[prost(oneof="replica_to_replica_message::Inner", tags="1, 2, 8, 9, 3, 4, 5, 6, 7")]
#[prost(oneof = "replica_to_replica_message::Inner", tags = "1, 2, 8, 9, 3, 4, 5, 6, 7")]
pub inner: ::std::option::Option<replica_to_replica_message::Inner>,
}
pub mod replica_to_replica_message {
#[derive(Clone, PartialEq, ::prost::Oneof)]
pub enum Inner {
#[prost(message, tag="1")]
#[prost(message, tag = "1")]
RaftMessage(super::super::raft::RaftMessage),
#[prost(message, tag="2")]
#[prost(message, tag = "2")]
CreateRaftGroupRequest(super::CreateRaftGroupRequest),
#[prost(message, tag="8")]
#[prost(message, tag = "8")]
EnclaveGetQuoteRequest(super::EnclaveGetQuoteRequest),
#[prost(message, tag="9")]
#[prost(message, tag = "9")]
EnclaveGetQuoteReply(super::EnclaveGetQuoteReply),
#[prost(message, tag="3")]
#[prost(message, tag = "3")]
XferRequest(super::XferRequest),
#[prost(message, tag="4")]
#[prost(message, tag = "4")]
XferReply(super::XferReply),
#[prost(message, tag="5")]
#[prost(message, tag = "5")]
XferChunkRequest(super::XferChunkRequest),
#[prost(message, tag="6")]
#[prost(message, tag = "6")]
XferChunkReply(super::XferChunkReply),
#[prost(message, tag="7")]
#[prost(message, tag = "7")]
XferErrorNotLeader(super::XferErrorNotLeader),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct CreateRaftGroupRequest {
#[prost(message, optional, tag="1")]
pub service_id: ::std::option::Option<super::kbupd::ServiceId>,
#[prost(message, required, tag="2")]
pub group_id: super::raft::RaftGroupId,
#[prost(bytes, repeated, tag="3")]
pub node_ids: ::std::vec::Vec<std::vec::Vec<u8>>,
#[prost(message, required, tag="4")]
pub config: super::kbupd::EnclaveReplicaGroupConfig,
#[prost(message, optional, tag="5")]
#[prost(message, optional, tag = "1")]
pub service_id: ::std::option::Option<super::kbupd::ServiceId>,
#[prost(message, required, tag = "2")]
pub group_id: super::raft::RaftGroupId,
#[prost(bytes, repeated, tag = "3")]
pub node_ids: ::std::vec::Vec<std::vec::Vec<u8>>,
#[prost(message, required, tag = "4")]
pub config: super::kbupd::EnclaveReplicaGroupConfig,
#[prost(message, optional, tag = "5")]
pub source_partition: ::std::option::Option<super::kbupd::SourcePartitionConfig>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct XferRequest {
#[prost(uint32, required, tag="1")]
#[prost(uint32, required, tag = "1")]
pub chunk_size: u32,
#[prost(message, required, tag="2")]
#[prost(message, required, tag = "2")]
pub full_range: super::kbupd::PartitionKeyRangePb,
#[prost(bytes, repeated, tag="3")]
pub node_ids: ::std::vec::Vec<std::vec::Vec<u8>>,
#[prost(message, required, tag="4")]
pub group_id: super::raft::RaftGroupId,
#[prost(bytes, repeated, tag = "3")]
pub node_ids: ::std::vec::Vec<std::vec::Vec<u8>>,
#[prost(message, required, tag = "4")]
pub group_id: super::raft::RaftGroupId,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct XferReply {
#[prost(message, required, tag="1")]
#[prost(message, required, tag = "1")]
pub service: super::kbupd::ServiceId,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct XferChunkRequest {
#[prost(message, required, tag="1")]
pub data: SecretBytes,
#[prost(message, required, tag="2")]
pub chunk_range: super::kbupd::PartitionKeyRangePb,
#[prost(message, required, tag="3")]
#[prost(message, required, tag = "1")]
pub data: SecretBytes,
#[prost(message, required, tag = "2")]
pub chunk_range: super::kbupd::PartitionKeyRangePb,
#[prost(message, required, tag = "3")]
pub min_attestation: super::kbupd::AttestationParameters,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct XferChunkReply {
#[prost(message, required, tag="1")]
pub new_last: super::kbupd::BackupId,
#[prost(uint32, required, tag="2")]
#[prost(message, required, tag = "1")]
pub new_last: super::kbupd::BackupId,
#[prost(uint32, required, tag = "2")]
pub chunk_size: u32,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct XferErrorNotLeader {
#[prost(bytes, optional, tag="1")]
#[prost(bytes, optional, tag = "1")]
pub leader_node_id: ::std::option::Option<std::vec::Vec<u8>>,
#[prost(message, required, tag="2")]
pub term: super::raft::TermId,
#[prost(message, required, tag = "2")]
pub term: super::raft::TermId,
}
//
// remote enclave handshake
@ -425,7 +427,7 @@ pub struct XferErrorNotLeader {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
#[repr(i32)]
pub enum NodeType {
None = 0,
None = 0,
Frontend = 1,
Replica = 2,
Replica = 2,
}

View File

@ -1,3 +1,10 @@
//
// Copyright (C) 2019, 2020 Signal Messenger, LLC.
// All rights reserved.
//
// SPDX-License-Identifier: AGPL-3.0-or-later
//
pub mod kbupd;
pub mod kbupd_client;
pub mod kbupd_enclave;

View File

@ -1,76 +1,83 @@
//
// Copyright (C) 2019, 2020 Signal Messenger, LLC.
// All rights reserved.
//
// SPDX-License-Identifier: AGPL-3.0-or-later
//
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct RaftMessage {
#[prost(message, required, tag="1")]
#[prost(message, required, tag = "1")]
pub group: RaftGroupId,
#[prost(message, required, tag="2")]
pub term: TermId,
#[prost(oneof="raft_message::Inner", tags="3, 4, 5, 6")]
#[prost(message, required, tag = "2")]
pub term: TermId,
#[prost(oneof = "raft_message::Inner", tags = "3, 4, 5, 6")]
pub inner: ::std::option::Option<raft_message::Inner>,
}
pub mod raft_message {
#[derive(Clone, PartialEq, ::prost::Oneof)]
pub enum Inner {
#[prost(message, tag="3")]
#[prost(message, tag = "3")]
VoteRequest(super::VoteRequest),
#[prost(message, tag="4")]
#[prost(message, tag = "4")]
VoteResponse(super::VoteResponse),
#[prost(message, tag="5")]
#[prost(message, tag = "5")]
AppendRequest(super::AppendRequest),
#[prost(message, tag="6")]
#[prost(message, tag = "6")]
AppendResponse(super::AppendResponse),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct VoteRequest {
#[prost(message, required, tag="2")]
pub last_log_idx: LogIdx,
#[prost(message, required, tag="3")]
#[prost(message, required, tag = "2")]
pub last_log_idx: LogIdx,
#[prost(message, required, tag = "3")]
pub last_log_term: TermId,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct VoteResponse {
#[prost(bool, required, tag="2")]
#[prost(bool, required, tag = "2")]
pub vote_granted: bool,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct AppendRequest {
#[prost(message, required, tag="1")]
pub prev_log_idx: LogIdx,
#[prost(message, required, tag="2")]
#[prost(message, required, tag = "1")]
pub prev_log_idx: LogIdx,
#[prost(message, required, tag = "2")]
pub prev_log_term: TermId,
#[prost(message, required, tag="3")]
#[prost(message, required, tag = "3")]
pub leader_commit: LogIdx,
#[prost(message, repeated, tag="4")]
pub entries: ::std::vec::Vec<LogEntry>,
#[prost(message, repeated, tag = "4")]
pub entries: ::std::vec::Vec<LogEntry>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct AppendResponse {
#[prost(bool, required, tag="1")]
pub success: bool,
#[prost(message, required, tag="2")]
pub match_idx: LogIdx,
#[prost(message, required, tag="3")]
#[prost(bool, required, tag = "1")]
pub success: bool,
#[prost(message, required, tag = "2")]
pub match_idx: LogIdx,
#[prost(message, required, tag = "3")]
pub last_log_idx: LogIdx,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct LogEntry {
#[prost(message, required, tag="1")]
#[prost(message, required, tag = "1")]
pub term: TermId,
#[prost(bytes, required, tag="2")]
#[prost(bytes, required, tag = "2")]
pub data: std::vec::Vec<u8>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct RaftGroupId {
#[prost(bytes, required, tag="1")]
#[prost(bytes, required, tag = "1")]
pub id: std::vec::Vec<u8>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TermId {
#[prost(uint64, required, tag="1")]
#[prost(uint64, required, tag = "1")]
pub id: u64,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct LogIdx {
#[prost(uint64, required, tag="1")]
#[prost(uint64, required, tag = "1")]
pub id: u64,
}

View File

@ -5,16 +5,15 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//
use crate::prelude::*;
use std::cmp::*;
use std::fmt;
use std::ops::*;
use crate::prelude::*;
use crate::util::*;
use crate::protobufs::kbupd::*;
use crate::protobufs::kbupd_enclave;
use crate::protobufs::kbupd_client;
use crate::protobufs::kbupd_enclave;
use crate::util::*;
//
// ServiceId impls
@ -33,13 +32,13 @@ impl fmt::Display for ServiceId {
impl BackupId {
pub const LENGTH: usize = 32;
pub fn valid_len() -> u32 {
32
}
pub fn try_from_slice<T>(slice: T) -> Result<Self, ()>
where T: AsRef<[u8]>,
{
where T: AsRef<[u8]> {
let slice: &[u8] = slice.as_ref();
if slice.len() == Self::LENGTH {
let id = slice.to_vec();
@ -48,6 +47,7 @@ impl BackupId {
Err(())
}
}
pub fn try_to_array(&self) -> Result<[u8; Self::LENGTH], ()> {
let mut array = [0; Self::LENGTH];
if self.id.len() == array.len() {
@ -73,6 +73,7 @@ impl AsRef<[u8]> for BackupId {
impl Deref for BackupId {
type Target = Vec<u8>;
fn deref(&self) -> &Self::Target {
&self.id
}
@ -80,10 +81,14 @@ impl Deref for BackupId {
impl Eq for BackupId {}
impl PartialOrd for BackupId {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for BackupId {
fn cmp(&self, other: &Self) -> Ordering { self.id.cmp(&other.id) }
fn cmp(&self, other: &Self) -> Ordering {
self.id.cmp(&other.id)
}
}
impl fmt::Display for BackupId {
@ -120,51 +125,60 @@ impl fmt::Display for UntrustedTransactionRequest {
impl enclave_frontend_request_transaction::Transaction {
pub fn from_reply(backup_id: BackupId, reply_data: &kbupd_enclave::transaction_reply::Data) -> Self {
use enclave_frontend_request_transaction::{Transaction};
use kbupd_enclave::transaction_reply::{Data as ReplyData};
use enclave_frontend_request_transaction::Transaction;
use kbupd_enclave::transaction_reply::Data as ReplyData;
use kbupd_enclave::*;
match reply_data {
ReplyData::CreateBackupReply(CreateBackupReply { .. }) =>
Transaction::Create(EnclaveCreateBackupTransaction { backup_id }),
ReplyData::CreateBackupReply(CreateBackupReply { .. }) => Transaction::Create(EnclaveCreateBackupTransaction { backup_id }),
ReplyData::ClientResponse(kbupd_client::Response { backup: Some(backup_response), .. }) =>
Transaction::Backup(EnclaveBackupTransaction {
backup_id,
status: backup_response.status.unwrap_or_default(),
}),
ReplyData::ClientResponse(kbupd_client::Response {
backup: Some(backup_response),
..
}) => Transaction::Backup(EnclaveBackupTransaction {
backup_id,
status: backup_response.status.unwrap_or_default(),
}),
ReplyData::ClientResponse(kbupd_client::Response { restore: Some(restore_response), .. }) =>
Transaction::Restore(EnclaveRestoreTransaction {
backup_id,
status: restore_response.status.unwrap_or_default(),
}),
ReplyData::ClientResponse(kbupd_client::Response {
restore: Some(restore_response),
..
}) => Transaction::Restore(EnclaveRestoreTransaction {
backup_id,
status: restore_response.status.unwrap_or_default(),
}),
ReplyData::ClientResponse(kbupd_client::Response { delete: Some(_), .. }) |
ReplyData::DeleteBackupReply(DeleteBackupReply { .. }) =>
Transaction::Delete(EnclaveDeleteBackupTransaction { backup_id }),
ReplyData::DeleteBackupReply(DeleteBackupReply { .. }) => Transaction::Delete(EnclaveDeleteBackupTransaction { backup_id }),
ReplyData::WrongPartition(TransactionErrorWrongPartition { new_partition, .. }) =>
ReplyData::WrongPartition(TransactionErrorWrongPartition { new_partition, .. }) => {
Transaction::WrongPartition(EnclaveTransactionErrorWrongPartition {
new_partition_unknown: new_partition.is_none(),
}),
})
}
ReplyData::XferInProgress(TransactionErrorXferInProgress {}) =>
Transaction::XferInProgress(EnclaveTransactionErrorXferInProgress {}),
ReplyData::XferInProgress(TransactionErrorXferInProgress {}) => {
Transaction::XferInProgress(EnclaveTransactionErrorXferInProgress {})
}
ReplyData::ClientResponse(kbupd_client::Response { backup: None, restore: None, delete: None }) |
ReplyData::InvalidRequest(TransactionErrorInvalidRequest {}) =>
Transaction::InvalidRequest(EnclaveTransactionErrorInvalidRequest {}),
ReplyData::ClientResponse(kbupd_client::Response {
backup: None,
restore: None,
delete: None,
}) |
ReplyData::InvalidRequest(TransactionErrorInvalidRequest {}) => {
Transaction::InvalidRequest(EnclaveTransactionErrorInvalidRequest {})
}
ReplyData::NotLeader(TransactionErrorNotLeader { .. }) |
ReplyData::ServiceIdMismatch(TransactionErrorServiceIdMismatch {}) |
ReplyData::InternalError(TransactionErrorInternalError {}) =>
Transaction::InternalError(EnclaveTransactionErrorInternalError {}),
ReplyData::InternalError(TransactionErrorInternalError {}) => {
Transaction::InternalError(EnclaveTransactionErrorInternalError {})
}
}
}
}
//
// EnclaveFrontendConfig
//
@ -193,8 +207,8 @@ impl fmt::Display for SourcePartitionConfig {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { range, node_ids } = self;
fmt.debug_struct("SourcePartitionConfig")
.field("range", &DisplayAsDebug(range))
.field("node_ids", &ListDisplay(node_ids.iter().map(|node_id| ToHex(node_id))))
.finish()
.field("range", &DisplayAsDebug(range))
.field("node_ids", &ListDisplay(node_ids.iter().map(|node_id| ToHex(node_id))))
.finish()
}
}

View File

@ -5,7 +5,7 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//
use sgx_ffi::util::{clear};
use sgx_ffi::util::clear;
use crate::protobufs::kbupd_client::*;

View File

@ -7,7 +7,7 @@
use std::fmt;
use sgx_ffi::util::{clear};
use sgx_ffi::util::clear;
use crate::protobufs::kbupd::*;
use crate::protobufs::kbupd_enclave::*;
@ -30,11 +30,11 @@ impl Drop for SecretBytes {
impl FrontendRequestTransaction {
pub fn backup_id(&self) -> Option<&BackupId> {
match &self.transaction {
Some(frontend_request_transaction::Transaction::Backup(backup)) => Some(&backup.backup_id),
Some(frontend_request_transaction::Transaction::Backup(backup)) => Some(&backup.backup_id),
Some(frontend_request_transaction::Transaction::Restore(restore)) => Some(&restore.backup_id),
Some(frontend_request_transaction::Transaction::Create(create)) => Some(&create.backup_id),
Some(frontend_request_transaction::Transaction::Delete(delete)) => Some(&delete.backup_id),
None => None
Some(frontend_request_transaction::Transaction::Create(create)) => Some(&create.backup_id),
Some(frontend_request_transaction::Transaction::Delete(delete)) => Some(&delete.backup_id),
None => None,
}
}
}
@ -55,12 +55,17 @@ impl fmt::Display for PeerConnectRequest {
impl fmt::Display for XferRequest {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
let Self { chunk_size, full_range, node_ids, group_id } = self;
let Self {
chunk_size,
full_range,
node_ids,
group_id,
} = self;
fmt.debug_struct("XferRequest")
.field("chunk_size", chunk_size)
.field("full_range", &DisplayAsDebug(full_range))
.field("node_ids", &ListDisplay(node_ids.iter().map(|node_id| ToHex(node_id))))
.field("group_id", &DisplayAsDebug(group_id))
.finish()
.field("chunk_size", chunk_size)
.field("full_range", &DisplayAsDebug(full_range))
.field("node_ids", &ListDisplay(node_ids.iter().map(|node_id| ToHex(node_id))))
.field("group_id", &DisplayAsDebug(group_id))
.finish()
}
}

File diff suppressed because it is too large Load Diff

View File

@ -18,20 +18,20 @@ use std::ops::*;
use std::rc::*;
use std::time::*;
use bytes::{BufMut};
use bytes::BufMut;
use chrono::{DateTime, NaiveDateTime, Utc};
use num_traits::{ToPrimitive};
use num_traits::ToPrimitive;
use prost::{self, Message};
use serde::{Deserialize};
use serde::Deserialize;
use sgx_ffi::sgx;
use sgx_ffi::util::{SecretValue};
use sgxsd_ffi::{SHA256Context};
use sgx_ffi::util::SecretValue;
use sgxsd_ffi::SHA256Context;
use snow;
use crate::{kbupd_send};
use crate::ffi::snow_resolver::*;
use crate::kbupd_send;
use crate::protobufs::kbupd::enclave_message::Inner as EnclaveMessageInner;
use crate::protobufs::kbupd::*;
use crate::protobufs::kbupd::enclave_message::{Inner as EnclaveMessageInner};
use crate::protobufs::kbupd_enclave::*;
use crate::util::{self, deserialize_base64};
@ -59,7 +59,7 @@ pub struct NodeParams {
}
pub struct RemoteSender<M>
where M: prost::Message + 'static,
where M: prost::Message + 'static
{
id: NodeId,
shared: Rc<RefCell<Shared<M>>>,
@ -96,9 +96,10 @@ pub trait Remote: RemoteCommon {
fn attestation_reply(&mut self, ias_report: IasReport) -> Result<Option<AttestationParameters>, ()>;
}
pub struct RemoteState<M,R>
where M: prost::Message + 'static,
R: prost::Message + Default + 'static
pub struct RemoteState<M, R>
where
M: prost::Message + 'static,
R: prost::Message + Default + 'static,
{
node_params: Rc<NodeParams>,
remote_node_id: NodeId,
@ -172,21 +173,24 @@ enum SessionState {
},
}
impl<M,R> RemoteState<M,R>
where M: prost::Message + 'static,
R: prost::Message + Default + 'static
impl<M, R> RemoteState<M, R>
where
M: prost::Message + 'static,
R: prost::Message + Default + 'static,
{
pub fn new(node_params: Rc<NodeParams>,
remote_node_id: NodeId,
remote_type: NodeType,
auth_type: RemoteAuthorizationType,
noise_buffer: SharedNoiseBuffers)
-> Self {
pub fn new(
node_params: Rc<NodeParams>,
remote_node_id: NodeId,
remote_type: NodeType,
auth_type: RemoteAuthorizationType,
noise_buffer: SharedNoiseBuffers,
) -> Self
{
let shared = Rc::new(RefCell::new(Shared {
session: SessionState::Disconnected,
session: SessionState::Disconnected,
remote_node_id: remote_node_id.clone(),
noise_buffer,
_message: Default::default(),
_message: Default::default(),
}));
Self {
node_params,
@ -215,7 +219,7 @@ where M: prost::Message + 'static,
prologue_buf.put_i32_le(self.remote_type.into());
prologue_buf.put_i32_le(self.node_params.node_type.into());
}
let params = NOISE_PARAMS.parse().unwrap_or_else(|_| unreachable!());
let params = NOISE_PARAMS.parse().unwrap_or_else(|_| unreachable!());
let builder = snow::Builder::with_resolver(params, Box::new(SnowResolver))
.prologue(&prologue_buf)
.local_private_key(&self.node_params.node_key)
@ -246,7 +250,7 @@ where M: prost::Message + 'static,
fn connection_response(mut noise: snow::HandshakeState) -> Result<(snow::TransportState, Vec<u8>, HandshakeHash), snow::Error> {
let mut msg_buf = vec![0; NOISE_CHUNK_MAX_LENGTH];
let msg_len = noise.write_message(&[0;0], &mut msg_buf)?;
let msg_len = noise.write_message(&[0; 0], &mut msg_buf)?;
msg_buf.truncate(msg_len);
let handshake_hash = get_handshake_hash(&noise)?;
@ -257,12 +261,15 @@ where M: prost::Message + 'static,
}
#[allow(clippy::type_complexity)]
fn establish_connection(mut noise: snow::HandshakeState, encrypted_msg_data: &[u8])
-> Result<(snow::TransportState, Vec<u8>, HandshakeHash, HandshakeHash), snow::Error> {
fn establish_connection(
mut noise: snow::HandshakeState,
encrypted_msg_data: &[u8],
) -> Result<(snow::TransportState, Vec<u8>, HandshakeHash, HandshakeHash), snow::Error>
{
let their_handshake_hash = get_handshake_hash(&noise)?;
let mut payload_buf = vec![0; encrypted_msg_data.len()];
let payload_len = noise.read_message(encrypted_msg_data, &mut payload_buf)?;
let payload_len = noise.read_message(encrypted_msg_data, &mut payload_buf)?;
payload_buf.truncate(payload_len);
let final_handshake_hash = get_handshake_hash(&noise)?;
@ -273,7 +280,7 @@ where M: prost::Message + 'static,
pub fn recv(&mut self, msg_data: &[u8]) -> Result<R, RemoteRecvError> {
let mut shared_ref = self.shared.as_ref().borrow_mut();
let shared = &mut *shared_ref;
let shared = &mut *shared_ref;
match &mut shared.session {
session @ SessionState::Disconnected |
session @ SessionState::WaitingForAttestation { .. } |
@ -282,36 +289,37 @@ where M: prost::Message + 'static,
warn!("dropping message from {} in {} state", self.remote_node_id, session);
Err(RemoteRecvError::InvalidState)
}
session @ SessionState::Initiated { .. } => {
match PeerConnectReply::decode(msg_data) {
Ok(connect_reply) => {
let noise = match std::mem::replace(session, SessionState::Disconnected) {
SessionState::Initiated { noise } => noise,
_ => unreachable!(),
};
match Self::establish_connection(noise, &connect_reply.noise_data) {
Ok((noise, _payload, their_handshake_hash, final_handshake_hash)) => {
*session = SessionState::Connected { noise, their_handshake_hash, final_handshake_hash };
let sgx_quote = connect_reply.sgx_quote;
Err(RemoteRecvError::NeedsAttestation(GetAttestationRequest {
request_id: self.remote_node_id.to_vec(),
sgx_quote,
}))
}
Err(err) => {
warn!("error decrypting connect reply from {}: {}", self.remote_node_id, err);
Err(RemoteRecvError::DecodeError)
}
session @ SessionState::Initiated { .. } => match PeerConnectReply::decode(msg_data) {
Ok(connect_reply) => {
let noise = match std::mem::replace(session, SessionState::Disconnected) {
SessionState::Initiated { noise } => noise,
_ => unreachable!(),
};
match Self::establish_connection(noise, &connect_reply.noise_data) {
Ok((noise, _payload, their_handshake_hash, final_handshake_hash)) => {
*session = SessionState::Connected {
noise,
their_handshake_hash,
final_handshake_hash,
};
let sgx_quote = connect_reply.sgx_quote;
Err(RemoteRecvError::NeedsAttestation(GetAttestationRequest {
request_id: self.remote_node_id.to_vec(),
sgx_quote,
}))
}
Err(err) => {
warn!("error decrypting connect reply from {}: {}", self.remote_node_id, err);
Err(RemoteRecvError::DecodeError)
}
}
Err(err) => {
warn!("error decoding connect reply from {}: {}", self.remote_node_id, err);
Err(RemoteRecvError::DecodeError)
}
}
}
mut session @ SessionState::Responded { .. } |
mut session @ SessionState::Authorized { .. } => {
Err(err) => {
warn!("error decoding connect reply from {}: {}", self.remote_node_id, err);
Err(RemoteRecvError::DecodeError)
}
},
mut session @ SessionState::Responded { .. } | mut session @ SessionState::Authorized { .. } => {
let noise = match &mut session {
SessionState::Responded { noise, .. } => noise,
SessionState::Authorized { noise, .. } => noise,
@ -321,16 +329,21 @@ where M: prost::Message + 'static,
Ok(msg_data) => {
if let SessionState::Responded { .. } = &session {
*session = match std::mem::replace(session, SessionState::Disconnected) {
SessionState::Responded { noise, attestation, handshake_hash } =>
SessionState::Authorized { noise, attestation, handshake_hash },
SessionState::Responded {
noise,
attestation,
handshake_hash,
} => SessionState::Authorized {
noise,
attestation,
handshake_hash,
},
_ => unreachable!(),
};
}
match R::decode(&msg_data.get()[..]) {
Ok(reply) => {
Ok(reply)
}
Ok(reply) => Ok(reply),
Err(decode_error) => {
error!("error decoding message from {}: {}", &self.remote_node_id, decode_error);
Err(RemoteRecvError::DecodeError)
@ -347,21 +360,24 @@ where M: prost::Message + 'static,
}
}
impl<M,R> RemoteCommon for RemoteState<M,R>
where M: prost::Message + 'static,
R: prost::Message + Default + 'static
impl<M, R> RemoteCommon for RemoteState<M, R>
where
M: prost::Message + 'static,
R: prost::Message + Default + 'static,
{
fn id(&self) -> &NodeId {
&self.remote_node_id
}
fn attestation(&self) -> Option<AttestationParameters> {
self.shared.as_ref().borrow_mut().attestation()
}
}
impl<M,R> Remote for RemoteState<M,R>
where M: prost::Message + 'static,
R: prost::Message + Default + 'static
impl<M, R> Remote for RemoteState<M, R>
where
M: prost::Message + 'static,
R: prost::Message + Default + 'static,
{
fn connect(&mut self) -> Result<(), ()> {
if self.node_params.node_id == self.remote_node_id {
@ -369,55 +385,47 @@ where M: prost::Message + 'static,
}
let mut shared = self.shared.as_ref().borrow_mut();
let session = match &mut shared.session {
let session = match &mut shared.session {
session @ SessionState::Disconnected |
session @ SessionState::WaitingForAttestation { .. } |
session @ SessionState::Initiated { .. } |
session @ SessionState::Connected { .. } => {
session
}
session @ SessionState::Connected { .. } => session,
SessionState::Accepted { .. } |
SessionState::Responded { .. } |
SessionState::Authorized { .. } => {
SessionState::Accepted { .. } | SessionState::Responded { .. } | SessionState::Authorized { .. } => {
return Err(());
}
};
match self.initiate_connection() {
Ok(mut noise) => {
match self.auth_type {
RemoteAuthorizationType::Mutual | RemoteAuthorizationType::SelfOnly => {
*session = SessionState::WaitingForAttestation { noise };
}
RemoteAuthorizationType::RemoteOnly => {
match Self::connection_request(&mut noise) {
Ok(noise_data) => {
let connect_req = PeerConnectRequest {
node_type: self.node_params.node_type.into(),
ias_report: None,
noise_data,
};
let mut connect_req_data = Vec::with_capacity(connect_req.encoded_len());
assert!(connect_req.encode(&mut connect_req_data).is_ok());
kbupd_send(EnclaveMessage {
inner: Some(EnclaveMessageInner::SendMessageRequest(SendMessageRequest {
node_id: self.remote_node_id.to_vec(),
data: connect_req_data,
syn: true,
debug_msg: None,
})),
});
*session = SessionState::Initiated { noise };
}
Err(noise_error) => {
error!("noise error connecting to {}: {}", &self.remote_node_id, noise_error);
}
}
}
Ok(mut noise) => match self.auth_type {
RemoteAuthorizationType::Mutual | RemoteAuthorizationType::SelfOnly => {
*session = SessionState::WaitingForAttestation { noise };
}
}
RemoteAuthorizationType::RemoteOnly => match Self::connection_request(&mut noise) {
Ok(noise_data) => {
let connect_req = PeerConnectRequest {
node_type: self.node_params.node_type.into(),
ias_report: None,
noise_data,
};
let mut connect_req_data = Vec::with_capacity(connect_req.encoded_len());
assert!(connect_req.encode(&mut connect_req_data).is_ok());
kbupd_send(EnclaveMessage {
inner: Some(EnclaveMessageInner::SendMessageRequest(SendMessageRequest {
node_id: self.remote_node_id.to_vec(),
data: connect_req_data,
syn: true,
debug_msg: None,
})),
});
*session = SessionState::Initiated { noise };
}
Err(noise_error) => {
error!("noise error connecting to {}: {}", &self.remote_node_id, noise_error);
}
},
},
Err(noise_error) => {
error!("error initiating connection with {}: {}", self.remote_node_id, noise_error);
}
@ -431,13 +439,11 @@ where M: prost::Message + 'static,
}
let mut shared = self.shared.as_ref().borrow_mut();
let session = match &mut shared.session {
let session = match &mut shared.session {
session @ SessionState::Disconnected |
session @ SessionState::WaitingForAttestation { .. } |
session @ SessionState::Accepted { .. } |
session @ SessionState::Responded { .. } => {
session
}
session @ SessionState::Responded { .. } => session,
session @ SessionState::Initiated { .. } => {
if self.node_params.node_id < self.remote_node_id {
@ -447,34 +453,34 @@ where M: prost::Message + 'static,
return Err(());
}
}
session @ SessionState::Connected { .. } |
session @ SessionState::Authorized { .. } => {
session @ SessionState::Connected { .. } | session @ SessionState::Authorized { .. } => {
warn!("dropping connect request from {} in {} state", self.remote_node_id, session);
return Err(());
}
};
match self.accept_connection(&connect_request.noise_data) {
Ok((noise, their_handshake_hash)) => {
match self.auth_type {
RemoteAuthorizationType::Mutual | RemoteAuthorizationType::RemoteOnly => {
match validate_ias_report(connect_request.ias_report.as_ref(), &their_handshake_hash.hash) {
Ok(attestation) => {
*session = SessionState::Accepted { noise, attestation: Some(attestation) };
Ok(())
}
Err(attestation_error) => {
warn!("attestation error accepting peer {}: {}", self.remote_node_id, attestation_error);
Err(())
}
Ok((noise, their_handshake_hash)) => match self.auth_type {
RemoteAuthorizationType::Mutual | RemoteAuthorizationType::RemoteOnly => {
match validate_ias_report(connect_request.ias_report.as_ref(), &their_handshake_hash.hash) {
Ok(attestation) => {
*session = SessionState::Accepted {
noise,
attestation: Some(attestation),
};
Ok(())
}
Err(attestation_error) => {
warn!("attestation error accepting peer {}: {}", self.remote_node_id, attestation_error);
Err(())
}
}
RemoteAuthorizationType::SelfOnly => {
*session = SessionState::Accepted { noise, attestation: None };
Ok(())
}
}
}
RemoteAuthorizationType::SelfOnly => {
*session = SessionState::Accepted { noise, attestation: None };
Ok(())
}
},
Err(noise_error) => {
error!("decrypt error accepting peer {}: {}", self.remote_node_id, noise_error);
Err(())
@ -486,16 +492,11 @@ where M: prost::Message + 'static,
let shared = self.shared.as_ref().borrow();
let report_data: [u8; 32] = match &shared.session {
SessionState::WaitingForAttestation { noise, .. } |
SessionState::Accepted { noise, .. } => {
match get_handshake_hash(noise) {
Ok(our_handshake_hash) => our_handshake_hash.hash,
Err(_) => return Err(()),
}
}
SessionState::Authorized { handshake_hash, .. } => {
handshake_hash.get_hash_for_node(&self.node_params.node_id)
}
SessionState::WaitingForAttestation { noise, .. } | SessionState::Accepted { noise, .. } => match get_handshake_hash(noise) {
Ok(our_handshake_hash) => our_handshake_hash.hash,
Err(_) => return Err(()),
},
SessionState::Authorized { handshake_hash, .. } => handshake_hash.get_hash_for_node(&self.node_params.node_id),
_ => {
return Err(());
}
@ -511,12 +512,10 @@ where M: prost::Message + 'static,
config_id: &reply.config_id,
};
match sgx::create_report(&qe_target_info, &report_data) {
Ok(sgx_report) => {
Ok(GetQuoteRequest {
request_id: self.remote_node_id.to_vec(),
sgx_report,
})
}
Ok(sgx_report) => Ok(GetQuoteRequest {
request_id: self.remote_node_id.to_vec(),
sgx_report,
}),
Err(sgx_error) => {
warn!("error generating sgx report: {}", sgx_error);
Err(())
@ -527,30 +526,29 @@ where M: prost::Message + 'static,
fn get_quote_reply(&mut self, reply: GetQuoteReply) -> Result<Option<GetAttestationRequest>, Option<EnclaveGetQuoteReply>> {
let sgx_quote = reply.sgx_quote;
match &mut self.shared.as_ref().borrow_mut().session {
SessionState::WaitingForAttestation { .. } => {
Ok(Some(GetAttestationRequest {
request_id: self.remote_node_id.to_vec(),
sgx_quote,
}))
}
SessionState::WaitingForAttestation { .. } => Ok(Some(GetAttestationRequest {
request_id: self.remote_node_id.to_vec(),
sgx_quote,
})),
session @ SessionState::Accepted { .. } => {
let (noise, attestation) = match std::mem::replace(session, SessionState::Disconnected) {
SessionState::Accepted { noise, attestation } => (noise, attestation),
_ => unreachable!(),
};
let (noise, noise_data, handshake_hash) = match Self::connection_response(noise) {
Ok(result) => result,
Ok(result) => result,
Err(noise_error) => {
error!("error accepting connection request from {}: {}", self.remote_node_id, noise_error);
return Err(None);
}
};
*session = SessionState::Responded { noise, attestation, handshake_hash };
let msg = PeerConnectReply {
sgx_quote,
noise_data
*session = SessionState::Responded {
noise,
attestation,
handshake_hash,
};
let msg = PeerConnectReply { sgx_quote, noise_data };
let mut encoded_msg_data = Vec::with_capacity(msg.encoded_len());
assert!(msg.encode(&mut encoded_msg_data).is_ok());
@ -565,9 +563,7 @@ where M: prost::Message + 'static,
Ok(None)
}
SessionState::Authorized { .. } => {
Err(Some(EnclaveGetQuoteReply { sgx_quote }))
}
SessionState::Authorized { .. } => Err(Some(EnclaveGetQuoteReply { sgx_quote })),
_ => Ok(None),
}
}
@ -582,7 +578,7 @@ where M: prost::Message + 'static,
match Self::connection_request(&mut noise) {
Ok(noise_data) => {
let connect_req = PeerConnectRequest {
node_type: self.node_params.node_type.into(),
node_type: self.node_params.node_type.into(),
ias_report: Some(ias_report),
noise_data,
};
@ -608,61 +604,77 @@ where M: prost::Message + 'static,
}
session @ SessionState::Connected { .. } => {
let (noise, their_handshake_hash, final_handshake_hash) = match std::mem::replace(session, SessionState::Disconnected) {
SessionState::Connected { noise, their_handshake_hash, final_handshake_hash } =>
(noise, their_handshake_hash, final_handshake_hash),
SessionState::Connected {
noise,
their_handshake_hash,
final_handshake_hash,
} => (noise, their_handshake_hash, final_handshake_hash),
_ => unreachable!(),
};
match validate_ias_report(Some(&ias_report), &their_handshake_hash.hash) {
Ok(attestation) => {
let handshake_hash = final_handshake_hash;
*session = SessionState::Authorized { noise, attestation: Some(attestation), handshake_hash };
*session = SessionState::Authorized {
noise,
attestation: Some(attestation),
handshake_hash,
};
Ok(Some(attestation))
}
Err(attestation_error) => {
error!("error validating attestation report for {}: {}", &self.remote_node_id, attestation_error);
error!(
"error validating attestation report for {}: {}",
&self.remote_node_id, attestation_error
);
Err(())
}
}
}
SessionState::Authorized { attestation, handshake_hash, .. } => {
match validate_ias_report(Some(&ias_report), &handshake_hash.get_hash_for_node(&self.remote_node_id)) {
Ok(new_attestation) => {
verbose!("validated attestation report for {}: {}", &self.remote_node_id, &new_attestation);
*attestation = Some(new_attestation);
Ok(None)
}
Err(attestation_error) => {
error!("error validating attestation report for {}: {}", &self.remote_node_id, attestation_error);
Err(())
}
SessionState::Authorized {
attestation,
handshake_hash,
..
} => match validate_ias_report(Some(&ias_report), &handshake_hash.get_hash_for_node(&self.remote_node_id)) {
Ok(new_attestation) => {
verbose!("validated attestation report for {}: {}", &self.remote_node_id, &new_attestation);
*attestation = Some(new_attestation);
Ok(None)
}
}
_ => {
Err(())
}
Err(attestation_error) => {
error!(
"error validating attestation report for {}: {}",
&self.remote_node_id, attestation_error
);
Err(())
}
},
_ => Err(()),
}
}
}
impl<M,R> RemoteMessageSender for RemoteState<M,R>
where M: prost::Message + 'static,
R: prost::Message + Default + 'static,
impl<M, R> RemoteMessageSender for RemoteState<M, R>
where
M: prost::Message + 'static,
R: prost::Message + Default + 'static,
{
type Message = M;
fn send(&self, message: Rc<Self::Message>) -> Result<(), ()> {
self.shared.as_ref().borrow_mut().send(message)
}
}
impl<M,R> fmt::Display for RemoteState<M,R>
where M: prost::Message + 'static,
R: prost::Message + Default + 'static,
impl<M, R> fmt::Display for RemoteState<M, R>
where
M: prost::Message + 'static,
R: prost::Message + Default + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("RemoteState")
.field(&self.remote_node_id)
.field(&self.remote_type)
.finish()
.field(&self.remote_node_id)
.field(&self.remote_type)
.finish()
}
}
@ -673,9 +685,14 @@ fn get_handshake_hash(noise: &snow::HandshakeState) -> Result<HandshakeHash, sno
Ok(HandshakeHash { hash })
}
fn write_noise_message(noise: &mut snow::TransportState, noise_buffers: &SharedNoiseBuffers, payload: &[u8]) -> Result<Vec<u8>, snow::Error> {
fn write_noise_message(
noise: &mut snow::TransportState,
noise_buffers: &SharedNoiseBuffers,
payload: &[u8],
) -> Result<Vec<u8>, snow::Error>
{
let mut noise_buffer_ref = RefCell::borrow_mut(&noise_buffers.inner.write_buffer);
let chunk_buffer = &mut noise_buffer_ref.0;
let chunk_buffer = &mut noise_buffer_ref.0;
let payload_chunks = payload.chunks(65519);
let encrypted_msg_buf_len = payload_chunks.len().saturating_mul(NOISE_CHUNK_MAX_LENGTH);
@ -689,7 +706,12 @@ fn write_noise_message(noise: &mut snow::TransportState, noise_buffers: &SharedN
Ok(encrypted_msg_buf)
}
fn read_noise_message(noise: &mut snow::TransportState, shared_noise_buffers: &SharedNoiseBuffers, encrypted: &[u8]) -> Result<SecretValue<Vec<u8>>, snow::Error> {
fn read_noise_message(
noise: &mut snow::TransportState,
shared_noise_buffers: &SharedNoiseBuffers,
encrypted: &[u8],
) -> Result<SecretValue<Vec<u8>>, snow::Error>
{
let mut noise_buffer = shared_noise_buffers.inner.read_buffer.take().unwrap_or_default();
match read_noise_message_with_buffer(noise, &mut noise_buffer.get_mut().0, encrypted) {
Ok(msg_data) => {
@ -705,7 +727,12 @@ fn read_noise_message(noise: &mut snow::TransportState, shared_noise_buffers: &S
}
}
fn read_noise_message_with_buffer(noise: &mut snow::TransportState, chunk_buffer: &mut [u8; NOISE_CHUNK_MAX_LENGTH], encrypted: &[u8]) -> Result<SecretValue<Vec<u8>>, snow::Error> {
fn read_noise_message_with_buffer(
noise: &mut snow::TransportState,
chunk_buffer: &mut [u8; NOISE_CHUNK_MAX_LENGTH],
encrypted: &[u8],
) -> Result<SecretValue<Vec<u8>>, snow::Error>
{
let encrypted_chunks = encrypted.chunks(NOISE_CHUNK_MAX_LENGTH);
let msg_buf_len = encrypted_chunks.len().saturating_mul(65519);
let mut msg_buf = SecretValue::new(Vec::with_capacity(msg_buf_len));
@ -721,24 +748,43 @@ fn read_noise_message_with_buffer(noise: &mut snow::TransportState, chunk_buffer
impl fmt::Display for SessionState {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SessionState::Disconnected => write!(fmt, "Disconnected"),
SessionState::Disconnected => write!(fmt, "Disconnected"),
SessionState::WaitingForAttestation { .. } => write!(fmt, "WaitingForAttestation"),
SessionState::Initiated { .. } => write!(fmt, "Initiated"),
SessionState::Connected { .. } => write!(fmt, "Connected"),
SessionState::Accepted { .. } => write!(fmt, "Accepted"),
SessionState::Responded { .. } => write!(fmt, "Responded"),
SessionState::Authorized { .. } => write!(fmt, "Authorized"),
SessionState::Initiated { .. } => write!(fmt, "Initiated"),
SessionState::Connected { .. } => write!(fmt, "Connected"),
SessionState::Accepted { .. } => write!(fmt, "Accepted"),
SessionState::Responded { .. } => write!(fmt, "Responded"),
SessionState::Authorized { .. } => write!(fmt, "Authorized"),
}
}
}
static IAS_TRUST_ANCHORS: &webpki::TLSServerTrustAnchors<'_> = &webpki::TLSServerTrustAnchors(&[
webpki::TrustAnchor {
subject: &[49, 11, 48, 9, 6, 3, 85, 4, 6, 19, 2, 85, 83, 49, 11, 48, 9, 6, 3, 85, 4, 8, 12, 2, 67, 65, 49, 20, 48, 18, 6, 3, 85, 4, 7, 12, 11, 83, 97, 110, 116, 97, 32, 67, 108, 97, 114, 97, 49, 26, 48, 24, 6, 3, 85, 4, 10, 12, 17, 73, 110, 116, 101, 108, 32, 67, 111, 114, 112, 111, 114, 97, 116, 105, 111, 110, 49, 48, 48, 46, 6, 3, 85, 4, 3, 12, 39, 73, 110, 116, 101, 108, 32, 83, 71, 88, 32, 65, 116, 116, 101, 115, 116, 97, 116, 105, 111, 110, 32, 82, 101, 112, 111, 114, 116, 32, 83, 105, 103, 110, 105, 110, 103, 32, 67, 65],
spki: &[48, 13, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 1, 5, 0, 3, 130, 1, 143, 0, 48, 130, 1, 138, 2, 130, 1, 129, 0, 159, 60, 100, 126, 181, 119, 60, 187, 81, 45, 39, 50, 192, 215, 65, 94, 187, 85, 160, 250, 158, 222, 46, 100, 145, 153, 230, 130, 29, 185, 16, 213, 49, 119, 55, 9, 119, 70, 106, 106, 94, 71, 134, 204, 210, 221, 235, 212, 20, 157, 106, 47, 99, 37, 82, 157, 209, 12, 201, 135, 55, 176, 119, 156, 26, 7, 226, 156, 71, 161, 174, 0, 73, 72, 71, 108, 72, 159, 69, 165, 161, 93, 122, 200, 236, 198, 172, 198, 69, 173, 180, 61, 135, 103, 157, 245, 156, 9, 59, 197, 162, 233, 105, 108, 84, 120, 84, 27, 151, 158, 117, 75, 87, 57, 20, 190, 85, 211, 47, 244, 192, 157, 223, 39, 33, 153, 52, 205, 153, 5, 39, 179, 249, 46, 215, 143, 191, 41, 36, 106, 190, 203, 113, 36, 14, 243, 156, 45, 113, 7, 180, 71, 84, 90, 127, 251, 16, 235, 6, 10, 104, 169, 133, 128, 33, 158, 54, 145, 9, 82, 104, 56, 146, 214, 165, 226, 168, 8, 3, 25, 62, 64, 117, 49, 64, 78, 54, 179, 21, 98, 55, 153, 170, 130, 80, 116, 64, 151, 84, 162, 223, 232, 245, 175, 213, 254, 99, 30, 31, 194, 175, 56, 8, 144, 111, 40, 167, 144, 217, 221, 159, 224, 96, 147, 155, 18, 87, 144, 197, 128, 93, 3, 125, 245, 106, 153, 83, 27, 150, 222, 105, 222, 51, 237, 34, 108, 193, 32, 125, 16, 66, 181, 201, 171, 127, 64, 79, 199, 17, 192, 254, 71, 105, 251, 149, 120, 177, 220, 14, 196, 105, 234, 26, 37, 224, 255, 153, 20, 136, 110, 242, 105, 155, 35, 91, 180, 132, 125, 214, 255, 64, 182, 6, 230, 23, 7, 147, 194, 251, 152, 179, 20, 88, 127, 156, 253, 37, 115, 98, 223, 234, 177, 11, 59, 210, 217, 118, 115, 161, 164, 189, 68, 196, 83, 170, 244, 127, 193, 242, 211, 208, 243, 132, 247, 74, 6, 248, 156, 8, 159, 13, 166, 205, 183, 252, 238, 232, 201, 130, 26, 142, 84, 242, 92, 4, 22, 209, 140, 70, 131, 154, 95, 128, 18, 251, 221, 61, 199, 77, 37, 98, 121, 173, 194, 192, 213, 90, 255, 111, 6, 34, 66, 93, 27, 2, 3, 1, 0, 1],
name_constraints: None,
},
]);
static IAS_TRUST_ANCHORS: &webpki::TLSServerTrustAnchors<'_> = &webpki::TLSServerTrustAnchors(&[webpki::TrustAnchor {
subject: &[
49, 11, 48, 9, 6, 3, 85, 4, 6, 19, 2, 85, 83, 49, 11, 48, 9, 6, 3, 85, 4, 8, 12, 2, 67, 65, 49, 20, 48, 18, 6, 3, 85, 4, 7, 12, 11,
83, 97, 110, 116, 97, 32, 67, 108, 97, 114, 97, 49, 26, 48, 24, 6, 3, 85, 4, 10, 12, 17, 73, 110, 116, 101, 108, 32, 67, 111, 114,
112, 111, 114, 97, 116, 105, 111, 110, 49, 48, 48, 46, 6, 3, 85, 4, 3, 12, 39, 73, 110, 116, 101, 108, 32, 83, 71, 88, 32, 65, 116,
116, 101, 115, 116, 97, 116, 105, 111, 110, 32, 82, 101, 112, 111, 114, 116, 32, 83, 105, 103, 110, 105, 110, 103, 32, 67, 65,
],
spki: &[
48, 13, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 1, 5, 0, 3, 130, 1, 143, 0, 48, 130, 1, 138, 2, 130, 1, 129, 0, 159, 60, 100, 126,
181, 119, 60, 187, 81, 45, 39, 50, 192, 215, 65, 94, 187, 85, 160, 250, 158, 222, 46, 100, 145, 153, 230, 130, 29, 185, 16, 213,
49, 119, 55, 9, 119, 70, 106, 106, 94, 71, 134, 204, 210, 221, 235, 212, 20, 157, 106, 47, 99, 37, 82, 157, 209, 12, 201, 135, 55,
176, 119, 156, 26, 7, 226, 156, 71, 161, 174, 0, 73, 72, 71, 108, 72, 159, 69, 165, 161, 93, 122, 200, 236, 198, 172, 198, 69, 173,
180, 61, 135, 103, 157, 245, 156, 9, 59, 197, 162, 233, 105, 108, 84, 120, 84, 27, 151, 158, 117, 75, 87, 57, 20, 190, 85, 211, 47,
244, 192, 157, 223, 39, 33, 153, 52, 205, 153, 5, 39, 179, 249, 46, 215, 143, 191, 41, 36, 106, 190, 203, 113, 36, 14, 243, 156,
45, 113, 7, 180, 71, 84, 90, 127, 251, 16, 235, 6, 10, 104, 169, 133, 128, 33, 158, 54, 145, 9, 82, 104, 56, 146, 214, 165, 226,
168, 8, 3, 25, 62, 64, 117, 49, 64, 78, 54, 179, 21, 98, 55, 153, 170, 130, 80, 116, 64, 151, 84, 162, 223, 232, 245, 175, 213,
254, 99, 30, 31, 194, 175, 56, 8, 144, 111, 40, 167, 144, 217, 221, 159, 224, 96, 147, 155, 18, 87, 144, 197, 128, 93, 3, 125, 245,
106, 153, 83, 27, 150, 222, 105, 222, 51, 237, 34, 108, 193, 32, 125, 16, 66, 181, 201, 171, 127, 64, 79, 199, 17, 192, 254, 71,
105, 251, 149, 120, 177, 220, 14, 196, 105, 234, 26, 37, 224, 255, 153, 20, 136, 110, 242, 105, 155, 35, 91, 180, 132, 125, 214,
255, 64, 182, 6, 230, 23, 7, 147, 194, 251, 152, 179, 20, 88, 127, 156, 253, 37, 115, 98, 223, 234, 177, 11, 59, 210, 217, 118,
115, 161, 164, 189, 68, 196, 83, 170, 244, 127, 193, 242, 211, 208, 243, 132, 247, 74, 6, 248, 156, 8, 159, 13, 166, 205, 183, 252,
238, 232, 201, 130, 26, 142, 84, 242, 92, 4, 22, 209, 140, 70, 131, 154, 95, 128, 18, 251, 221, 61, 199, 77, 37, 98, 121, 173, 194,
192, 213, 90, 255, 111, 6, 34, 66, 93, 27, 2, 3, 1, 0, 1,
],
name_constraints: None,
}]);
static IAS_CHAIN_ALGOS: &'static [&webpki::SignatureAlgorithm] = &[
&webpki::RSA_PKCS1_2048_8192_SHA256,
&webpki::RSA_PKCS1_2048_8192_SHA384,
@ -775,10 +821,13 @@ fn parse_ias_timestamp(timestamp: &str) -> Result<u64, AttestationVerificationEr
.ok_or_else(|| AttestationVerificationError::InvalidTimestamp(timestamp.to_owned()))
}
fn validate_ias_report(maybe_ias_report: Option<&IasReport>,
expected_report_data: &[u8])
-> Result<AttestationParameters, AttestationVerificationError> {
#[cfg(feature = "insecure")] {
fn validate_ias_report(
maybe_ias_report: Option<&IasReport>,
expected_report_data: &[u8],
) -> Result<AttestationParameters, AttestationVerificationError>
{
#[cfg(feature = "insecure")]
{
match maybe_ias_report.as_ref() {
Some(ias_report) if ias_report.body.is_empty() => {
return Ok(AttestationParameters { unix_timestamp_seconds: 0 });
@ -794,19 +843,16 @@ fn validate_ias_report(maybe_ias_report: Option<&IasReport>,
}
};
let body: IasReportBody = serde_json::from_slice(&ias_report.body[..])
.map_err(AttestationVerificationError::InvalidJson)?;
let body: IasReportBody = serde_json::from_slice(&ias_report.body[..]).map_err(AttestationVerificationError::InvalidJson)?;
if body.version != 3 {
return Err(AttestationVerificationError::WrongVersion(body.version));
}
match body.isvEnclaveQuoteStatus.as_str() {
"OK" => {
}
"OK" => {}
#[cfg(feature = "insecure")]
"GROUP_OUT_OF_DATE" | "CONFIGURATION_NEEDED" => {
}
"GROUP_OUT_OF_DATE" | "CONFIGURATION_NEEDED" => {}
"SIGRL_VERSION_MISMATCH" => {
return Err(AttestationVerificationError::StaleRevocationList);
}
@ -815,21 +861,20 @@ fn validate_ias_report(maybe_ias_report: Option<&IasReport>,
}
}
let quote = SgxQuote::decode(&mut &body.isvEnclaveQuoteBody[..])
.map_err(AttestationVerificationError::InvalidQuote)?;
let quote = SgxQuote::decode(&mut &body.isvEnclaveQuoteBody[..]).map_err(AttestationVerificationError::InvalidQuote)?;
if &quote.report_data.0[0..32] != expected_report_data {
return Err(AttestationVerificationError::InvalidQuoteReportData);
}
let our_report = sgx::create_report_raw(None, &[0; 64])
.map_err(AttestationVerificationError::CreateReportError)?;
let our_report = sgx::create_report_raw(None, &[0; 64]).map_err(AttestationVerificationError::CreateReportError)?;
if quote.mrenclave != our_report.body.mr_enclave.m {
return Err(AttestationVerificationError::InvalidMrenclave(quote.mrenclave));
}
if quote.is_debug_quote() {
#[cfg(not(feature = "insecure"))] {
#[cfg(not(feature = "insecure"))]
{
return Err(AttestationVerificationError::IsDebugQuote);
}
}
@ -838,13 +883,20 @@ fn validate_ias_report(maybe_ias_report: Option<&IasReport>,
let certificate = (ias_report.certificates.get(0).ok_or(webpki::Error::BadDER))
.and_then(|certificate: &Vec<u8>| webpki::EndEntityCert::from(certificate))
.map_err(AttestationVerificationError::InvalidCertificate)?;
let chain = (ias_report.certificates.get(1..).unwrap_or_default().iter())
let chain = (ias_report.certificates.get(1..).unwrap_or_default().iter())
.map(|cert: &Vec<u8>| &cert[..])
.collect::<Vec<_>>();
certificate.verify_is_valid_tls_server_cert(IAS_CHAIN_ALGOS, IAS_TRUST_ANCHORS, &chain, webpki::Time::from_seconds_since_unix_epoch(unix_timestamp_seconds))
.map_err(AttestationVerificationError::InvalidCertificate)?;
certificate.verify_signature(&webpki::RSA_PKCS1_2048_8192_SHA256, &ias_report.body, &ias_report.signature)
.map_err(AttestationVerificationError::InvalidSignature)?;
certificate
.verify_is_valid_tls_server_cert(
IAS_CHAIN_ALGOS,
IAS_TRUST_ANCHORS,
&chain,
webpki::Time::from_seconds_since_unix_epoch(unix_timestamp_seconds),
)
.map_err(AttestationVerificationError::InvalidCertificate)?;
certificate
.verify_signature(&webpki::RSA_PKCS1_2048_8192_SHA256, &ias_report.body, &ias_report.signature)
.map_err(AttestationVerificationError::InvalidSignature)?;
Ok(AttestationParameters { unix_timestamp_seconds })
}
@ -897,9 +949,10 @@ impl fmt::Debug for NodeId {
impl Deref for NodeId {
type Target = [u8];
fn deref(&self) -> &[u8] {
match self {
NodeId::Valid(id) => id,
NodeId::Valid(id) => id,
NodeId::Invalid(id) => id,
}
}
@ -911,16 +964,17 @@ impl Deref for NodeId {
impl NodeParams {
pub fn generate(node_type: NodeType) -> Self {
let params = NOISE_PARAMS.parse().unwrap_or_else(|_| unreachable!());
let params = NOISE_PARAMS.parse().unwrap_or_else(|_| unreachable!());
let builder = snow::Builder::with_resolver(params, Box::new(SnowResolver));
let keypair = builder.generate_keypair().unwrap_or_else(|_| unreachable!());
assert_eq!(keypair.public.len(), 32);
Self {
node_key: keypair.private.into(),
node_id: keypair.public.into(),
node_key: keypair.private.into(),
node_id: keypair.public.into(),
node_type,
}
}
pub fn node_id(&self) -> &NodeId {
&self.node_id
}
@ -933,8 +987,8 @@ impl NodeParams {
impl fmt::Display for NodeType {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
NodeType::None => write!(fmt, "none"),
NodeType::Replica => write!(fmt, "replica"),
NodeType::None => write!(fmt, "none"),
NodeType::Replica => write!(fmt, "replica"),
NodeType::Frontend => write!(fmt, "frontend"),
}
}
@ -972,7 +1026,7 @@ impl AttestationParameters {
//
impl<M> Shared<M>
where M: prost::Message + 'static,
where M: prost::Message + 'static
{
fn attestation(&self) -> Option<AttestationParameters> {
match &self.session {
@ -1000,8 +1054,8 @@ where M: prost::Message + 'static,
node_id: self.remote_node_id.to_vec(),
debug_msg,
data: encrypted_msg_data,
syn: false,
}))
syn: false,
})),
});
Ok(())
}
@ -1040,27 +1094,29 @@ impl AsMut<[u8]> for Box<NoiseBuffer> {
//
impl<M> RemoteCommon for RemoteSender<M>
where M: prost::Message + 'static,
where M: prost::Message + 'static
{
fn id(&self) -> &NodeId {
&self.id
}
fn attestation(&self) -> Option<AttestationParameters> {
self.shared.as_ref().borrow_mut().attestation()
}
}
impl<M> RemoteMessageSender for RemoteSender<M>
where M: prost::Message + 'static,
where M: prost::Message + 'static
{
type Message = M;
fn send(&self, message: Rc<Self::Message>) -> Result<(), ()> {
self.shared.as_ref().borrow_mut().send(message)
}
}
impl<M> fmt::Display for RemoteSender<M>
where M: prost::Message + 'static,
where M: prost::Message + 'static
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self.id(), fmt)
@ -1068,7 +1124,7 @@ where M: prost::Message + 'static,
}
impl<M> Clone for RemoteSender<M>
where M: prost::Message + 'static,
where M: prost::Message + 'static
{
fn clone(&self) -> Self {
Self {

View File

@ -8,13 +8,13 @@
use std::collections::*;
use std::rc::*;
use hashbrown::{HashMap, hash_map};
use prost::{Message};
use rand_core::{RngCore};
use sgxsd_ffi::{RdRand};
use hashbrown::{hash_map, HashMap};
use prost::Message;
use rand_core::RngCore;
use sgxsd_ffi::RdRand;
use crate::{kbupd_send};
use crate::hasher::{DefaultHasher};
use crate::hasher::DefaultHasher;
use crate::kbupd_send;
use crate::protobufs::kbupd::*;
use crate::protobufs::kbupd_enclave::*;
use crate::remote::*;
@ -28,14 +28,14 @@ pub struct PeerManager<T> {
total_ticks: u32,
}
pub struct PeerStarter<'a,T,U> {
pub struct PeerStarter<'a, T, U> {
peer_entry: hash_map::VacantEntry<'a, NodeId, Option<T>, DefaultHasher>,
connecting_peers: &'a mut BTreeSet<ConnectingPeerState>,
connecting_peer: ConnectingPeerState,
remote: U,
}
pub struct PeerAcceptor<'a,T> {
pub struct PeerAcceptor<'a, T> {
peer_entry: hash_map::VacantEntry<'a, NodeId, Option<T>, DefaultHasher>,
node_params: Rc<NodeParams>,
noise_buffers: SharedNoiseBuffers,
@ -60,10 +60,7 @@ struct ConnectingPeerState {
enum QeInfoRequestState {
None,
Sent {
needs_qe_info: Vec<NodeId>,
ticks_elapsed: u32,
},
Sent { needs_qe_info: Vec<NodeId>, ticks_elapsed: u32 },
}
impl<T> PeerManager<T>
@ -121,18 +118,20 @@ where T: Peer
while let Some(mut connecting_peer) = self.take_connecting_peer() {
if let Some(peer) = self.peers.get_mut(&connecting_peer.node_id).and_then(Option::as_mut) {
let last_interval_ticks = connecting_peer.last_interval_ticks;
let half_interval_ticks = last_interval_ticks.min(max_timeout_ticks / 2)
.max(min_timeout_ticks);
let rand_interval_ticks = RdRand.next_u32().checked_rem(half_interval_ticks)
.unwrap_or(0);
let next_timeout_ticks = half_interval_ticks.saturating_add(rand_interval_ticks);
let half_interval_ticks = last_interval_ticks.min(max_timeout_ticks / 2).max(min_timeout_ticks);
let rand_interval_ticks = RdRand.next_u32().checked_rem(half_interval_ticks).unwrap_or(0);
let next_timeout_ticks = half_interval_ticks.saturating_add(rand_interval_ticks);
connecting_peer.last_interval_ticks = half_interval_ticks.saturating_add(half_interval_ticks);
connecting_peer.next_timeout_tick = next_timeout_ticks.saturating_add(self.total_ticks.wrapping_add(1));
connecting_peer.next_timeout_tick = next_timeout_ticks.saturating_add(self.total_ticks.wrapping_add(1));
match peer.remote_mut().connect() {
Ok(()) => {
info!("connecting to peer {} with retry in {} ticks, next interval {} ticks",
peer.remote_mut().id(), next_timeout_ticks, connecting_peer.last_interval_ticks);
info!(
"connecting to peer {} with retry in {} ticks, next interval {} ticks",
peer.remote_mut().id(),
next_timeout_ticks,
connecting_peer.last_interval_ticks
);
Self::get_qe_info(&mut self.qe_info_req, peer.remote_mut().id().clone());
new_connecting_peers.insert(connecting_peer);
}
@ -157,20 +156,26 @@ where T: Peer
}
}
pub fn start_peer<'a,M,R>(&'a mut self,
peer_node_id: NodeId,
peer_node_type: NodeType,
auth_type: RemoteAuthorizationType)
-> Result<PeerStarter<'a, T, RemoteState<M,R>>, Option<&'a mut T>>
where M: prost::Message + 'static,
R: prost::Message + Default + 'static,
pub fn start_peer<'a, M, R>(
&'a mut self,
peer_node_id: NodeId,
peer_node_type: NodeType,
auth_type: RemoteAuthorizationType,
) -> Result<PeerStarter<'a, T, RemoteState<M, R>>, Option<&'a mut T>>
where
M: prost::Message + 'static,
R: prost::Message + Default + 'static,
{
match self.peers.entry(peer_node_id) {
hash_map::Entry::Occupied(peer_entry) => {
Err(peer_entry.into_mut().as_mut())
}
hash_map::Entry::Occupied(peer_entry) => Err(peer_entry.into_mut().as_mut()),
hash_map::Entry::Vacant(peer_entry) => {
let remote = RemoteState::new(Rc::clone(&self.node_params), peer_entry.key().clone(), peer_node_type, auth_type, self.noise_buffers.clone());
let remote = RemoteState::new(
Rc::clone(&self.node_params),
peer_entry.key().clone(),
peer_node_type,
auth_type,
self.noise_buffers.clone(),
);
let connecting_peer = ConnectingPeerState {
next_timeout_tick: self.total_ticks.wrapping_add(1),
last_interval_ticks: 0,
@ -213,9 +218,7 @@ where T: Peer
}
pub fn get_qe_info_reply(&mut self, get_qe_info_reply: GetQeInfoReply) {
if let QeInfoRequestState::Sent { needs_qe_info, .. } =
std::mem::replace(&mut self.qe_info_req, QeInfoRequestState::None)
{
if let QeInfoRequestState::Sent { needs_qe_info, .. } = std::mem::replace(&mut self.qe_info_req, QeInfoRequestState::None) {
info!("generating quotes for {} peers", needs_qe_info.len());
for peer_node_id in needs_qe_info {
if let Some(peer) = self.peers.get_mut(&peer_node_id).and_then(Option::as_mut) {
@ -244,7 +247,7 @@ where T: Peer
Err(Some(enclave_get_quote_reply)) => {
let _ignore = peer.send_quote_reply(enclave_get_quote_reply);
}
Ok(None) => (),
Ok(None) => (),
Err(None) => (),
}
}
@ -266,30 +269,30 @@ where T: Peer
let peer = self.peers.get_mut(&peer_node_id)?.as_mut()?;
match peer.remote_mut().attestation_reply(get_attestation_reply.ias_report) {
Ok(Some(attestation)) => Some((peer, attestation)),
Ok(None) => None,
Err(()) => None,
Ok(None) => None,
Err(()) => None,
}
}
pub fn new_message_signal(&mut self, message: NewMessageSignal) -> Result<Option<(&mut T, <T as Peer>::Message)>, PeerAcceptor<'_,T>> {
pub fn new_message_signal(&mut self, message: NewMessageSignal) -> Result<Option<(&mut T, <T as Peer>::Message)>, PeerAcceptor<'_, T>> {
let peer_node_id: NodeId = message.node_id.into();
if message.syn {
let connect_request = match PeerConnectRequest::decode(&message.data[..]) {
Ok(connect_request) => connect_request,
Err(decode_error) => {
Err(decode_error) => {
warn!("dropping connect request from {}: {}", &peer_node_id, decode_error);
return Ok(None);
}
};
match self.peer_connect_request(connect_request, peer_node_id) {
Some(peer_acceptor) => Err(peer_acceptor),
None => Ok(None),
None => Ok(None),
}
} else {
match self.peer_message(message.data, peer_node_id) {
Ok(result) => Ok(result),
Err(()) => Ok(None),
Err(()) => Ok(None),
}
}
}
@ -298,9 +301,7 @@ where T: Peer
let peer_entry = self.peers.get_mut(&peer_node_id);
if let Some(Some(peer)) = peer_entry {
match peer.recv(&message_data) {
Ok(message) => {
Ok(Some((peer, message)))
}
Ok(message) => Ok(Some((peer, message))),
Err(RemoteRecvError::NeedsAttestation(get_attestation_request)) => {
info!("fetching attestation for peer {}", &peer_node_id);
kbupd_send(EnclaveMessage {
@ -308,10 +309,7 @@ where T: Peer
});
Ok(None)
}
Err(RemoteRecvError::DecodeError) |
Err(RemoteRecvError::InvalidState) => {
Err(())
}
Err(RemoteRecvError::DecodeError) | Err(RemoteRecvError::InvalidState) => Err(()),
}
} else if let Some(None) = peer_entry {
warn!("dropping message from evicted peer {}", &peer_node_id);
@ -322,7 +320,7 @@ where T: Peer
}
}
fn peer_connect_request(&mut self, connect_request: PeerConnectRequest, peer_node_id: NodeId) -> Option<PeerAcceptor<'_,T>> {
fn peer_connect_request(&mut self, connect_request: PeerConnectRequest, peer_node_id: NodeId) -> Option<PeerAcceptor<'_, T>> {
match self.peers.entry(peer_node_id) {
hash_map::Entry::Occupied(mut peer_entry) => {
if let Some(peer) = peer_entry.get_mut().as_mut() {
@ -341,15 +339,18 @@ where T: Peer
if let Some(remote_node_type) = NodeType::from_i32(connect_request.node_type) {
Some(PeerAcceptor {
peer_entry,
node_params: Rc::clone(&self.node_params),
noise_buffers: self.noise_buffers.clone(),
node_params: Rc::clone(&self.node_params),
noise_buffers: self.noise_buffers.clone(),
remote_node_type,
qe_info_req: &mut self.qe_info_req,
qe_info_req: &mut self.qe_info_req,
connect_request,
})
} else {
warn!("dropping connect request from {}: invalid node type {}",
peer_entry.key(), connect_request.node_type);
warn!(
"dropping connect request from {}: invalid node type {}",
peer_entry.key(),
connect_request.node_type
);
None
}
}
@ -361,31 +362,32 @@ where T: Peer
// PeerStarter impls
//
impl<'a,T,U> PeerStarter<'a,T,U>
where T: Peer,
U: Remote
impl<'a, T, U> PeerStarter<'a, T, U>
where
T: Peer,
U: Remote,
{
pub fn remote(&self) -> &U {
&self.remote
}
pub fn connect<F>(mut self, mapper: F) -> Result<&'a mut T, (Self, F)>
where F: FnOnce(U) -> T,
{
where F: FnOnce(U) -> T {
match self.remote.connect() {
Ok(()) => {
self.connecting_peers.insert(self.connecting_peer);
let peer = self.peer_entry.insert(Some(mapper(self.remote)));
Ok(peer.as_mut().unwrap_or_else(|| unreachable!()))
}
Err(()) => {
Err((self, mapper))
}
Err(()) => Err((self, mapper)),
}
}
pub fn insert(self, mapper: impl FnOnce(U) -> T) -> &'a mut T {
self.peer_entry.insert(Some(mapper(self.remote)))
.as_mut()
.unwrap_or_else(|| unreachable!())
self.peer_entry
.insert(Some(mapper(self.remote)))
.as_mut()
.unwrap_or_else(|| unreachable!())
}
}
@ -393,20 +395,29 @@ where T: Peer,
// PeerAcceptor impls
//
impl<'a,T> PeerAcceptor<'a,T>
impl<'a, T> PeerAcceptor<'a, T>
where T: Peer
{
pub fn node_id(&self) -> &NodeId {
self.peer_entry.key()
}
pub fn connect_request(&self) -> &PeerConnectRequest {
&self.connect_request
}
pub fn accept<M,R>(self, mapper: impl FnOnce(RemoteState<M,R>) -> T, auth_type: RemoteAuthorizationType) -> Result<&'a mut T, ()>
where M: prost::Message + 'static,
R: prost::Message + Default + 'static,
pub fn accept<M, R>(self, mapper: impl FnOnce(RemoteState<M, R>) -> T, auth_type: RemoteAuthorizationType) -> Result<&'a mut T, ()>
where
M: prost::Message + 'static,
R: prost::Message + Default + 'static,
{
let mut remote = RemoteState::new(Rc::clone(&self.node_params), self.peer_entry.key().clone(), self.remote_node_type, auth_type, self.noise_buffers);
let mut remote = RemoteState::new(
Rc::clone(&self.node_params),
self.peer_entry.key().clone(),
self.remote_node_type,
auth_type,
self.noise_buffers,
);
remote.accept(self.connect_request)?;
PeerManager::<T>::get_qe_info(self.qe_info_req, self.peer_entry.key().clone());
let peer = self.peer_entry.insert(Some(mapper(remote)));

View File

@ -5,14 +5,14 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//
use bytes::{Buf};
use num_traits::{ToPrimitive};
use bytes::Buf;
use num_traits::ToPrimitive;
const SGX_FLAGS_INITTED: u64 = 0x0000_0000_0000_0001;
const SGX_FLAGS_DEBUG: u64 = 0x0000_0000_0000_0002;
const SGX_FLAGS_MODE64BIT: u64 = 0x0000_0000_0000_0004;
const SGX_FLAGS_RESERVED: u64 = 0xFFFF_FFFF_FFFF_FFC8;
const SGX_XFRM_RESERVED: u64 = 0xFFFF_FFFF_FFFF_FFF8;
const SGX_FLAGS_INITTED: u64 = 0x0000_0000_0000_0001;
const SGX_FLAGS_DEBUG: u64 = 0x0000_0000_0000_0002;
const SGX_FLAGS_MODE64BIT: u64 = 0x0000_0000_0000_0004;
const SGX_FLAGS_RESERVED: u64 = 0xFFFF_FFFF_FFFF_FFC8;
const SGX_XFRM_RESERVED: u64 = 0xFFFF_FFFF_FFFF_FFF8;
#[derive(Default)]
pub struct SgxQuote {
@ -60,7 +60,7 @@ impl SgxQuote {
let mut quote: Self = Default::default();
quote.version = quote_buf.get_u16_le();
quote.version = quote_buf.get_u16_le();
if !(quote.version >= 1 && quote.version <= 2) {
return Err(SgxQuoteDecodeError::UnknownVersion(quote.version));
}
@ -71,8 +71,8 @@ impl SgxQuote {
}
quote.is_sig_linkable = sign_type == 1;
quote.gid = quote_buf.get_u32_le();
quote.qe_svn = quote_buf.get_u16_le();
quote.gid = quote_buf.get_u32_le();
quote.qe_svn = quote_buf.get_u16_le();
if quote.version > 1 {
quote.pce_svn = quote_buf.get_u16_le();
@ -93,10 +93,7 @@ impl SgxQuote {
Self::read_zero(quote_buf, 68, 28)?; // reserved1
quote.flags = quote_buf.get_u64_le();
if ((quote.flags & SGX_FLAGS_RESERVED ) != 0 ||
(quote.flags & SGX_FLAGS_INITTED ) == 0 ||
(quote.flags & SGX_FLAGS_MODE64BIT) == 0)
{
if ((quote.flags & SGX_FLAGS_RESERVED) != 0 || (quote.flags & SGX_FLAGS_INITTED) == 0 || (quote.flags & SGX_FLAGS_MODE64BIT) == 0) {
return Err(SgxQuoteDecodeError::InvalidFlags(quote.flags));
}
@ -110,7 +107,7 @@ impl SgxQuote {
quote_buf.copy_to_slice(&mut quote.mrsigner);
Self::read_zero(quote_buf, 208, 96)?; // reserved3
quote.isv_prod_id = quote_buf.get_u16_le();
quote.isv_svn = quote_buf.get_u16_le();
quote.isv_svn = quote_buf.get_u16_le();
Self::read_zero(quote_buf, 308, 60)?; // reserved4
quote_buf.copy_to_slice(&mut quote.report_data.0);

View File

@ -11,8 +11,8 @@ use std::collections::*;
use std::fmt;
use std::rc::*;
use rand_core::{RngCore};
use sgxsd_ffi::{RdRand};
use rand_core::RngCore;
use sgxsd_ffi::RdRand;
use crate::protobufs::kbupd::*;
use crate::protobufs::kbupd_enclave::*;
@ -22,7 +22,7 @@ use crate::util::*;
pub trait RemoteGroupPendingRequest {
type RequestId: Clone + Ord + Eq;
type Message: prost::Message;
type Message: prost::Message;
fn request_id(&self) -> &Self::RequestId;
fn message(&self) -> Rc<Self::Message>;
fn min_attestation(&self) -> Option<AttestationParameters>;
@ -32,8 +32,8 @@ pub trait RemoteGroupNode {
fn request_quote(&mut self, request: EnclaveGetQuoteRequest) -> Result<(), ()>;
}
pub struct RemoteGroupState<T,R>
where R: RemoteGroupPendingRequest,
pub struct RemoteGroupState<T, R>
where R: RemoteGroupPendingRequest
{
name: String,
nodes: Box<[RemoteGroupNodeState<T, R::RequestId>]>,
@ -65,10 +65,11 @@ struct RemoteGroupNodeState<T, RequestId> {
last_sent: Option<RequestId>,
}
impl<T,R> RemoteGroupState<T,R>
where T: RemoteMessageSender<Message = R::Message> + 'static,
T: RemoteGroupNode,
R: RemoteGroupPendingRequest + 'static,
impl<T, R> RemoteGroupState<T, R>
where
T: RemoteMessageSender<Message = R::Message> + 'static,
T: RemoteGroupNode,
R: RemoteGroupPendingRequest + 'static,
{
pub fn new(name: String, remotes: Vec<T>) -> Self {
let nodes = remotes.into_iter().map(|remote: T| RemoteGroupNodeState {
@ -77,14 +78,14 @@ where T: RemoteMessageSender<Message = R::Message> + 'static,
});
Self {
name,
nodes: nodes.collect::<Vec<_>>().into(),
leader: Default::default(),
term: Default::default(),
nodes: nodes.collect::<Vec<_>>().into(),
leader: Default::default(),
term: Default::default(),
pending: Default::default(),
timeout_ticks: Default::default(),
timeout_ticks: Default::default(),
request_quote_ticks: Default::default(),
total_ticks: Default::default(),
total_ticks: Default::default(),
}
}
@ -118,10 +119,9 @@ where T: RemoteMessageSender<Message = R::Message> + 'static,
BTreeMap::new()
};
trimmed.into_iter()
.map(|(_, pending_request_state): (_, PendingRequestState<R>)| {
pending_request_state.request
})
trimmed
.into_iter()
.map(|(_, pending_request_state): (_, PendingRequestState<R>)| pending_request_state.request)
}
pub fn reset_peer(&mut self, node_id: &NodeId) {
@ -132,9 +132,10 @@ where T: RemoteMessageSender<Message = R::Message> + 'static,
None => return,
}
warn!("resetting group {} peer {}", &self.name, node_id);
let maybe_old_leader =
self.leader.and_then(|leader: usize| self.nodes.get(leader))
.map(|leader: &RemoteGroupNodeState<T, _>| leader.remote.id());
let maybe_old_leader = self
.leader
.and_then(|leader: usize| self.nodes.get(leader))
.map(|leader: &RemoteGroupNodeState<T, _>| leader.remote.id());
if maybe_old_leader == Some(node_id) {
self.leader = None;
@ -173,23 +174,19 @@ where T: RemoteMessageSender<Message = R::Message> + 'static,
}
fn get_node_mut(&mut self, node_id: &NodeId) -> Option<&mut RemoteGroupNodeState<T, R::RequestId>> {
self.nodes.iter_mut().find_map(|node: &mut RemoteGroupNodeState<T, _>| {
if node.remote.id() == node_id {
Some(node)
} else {
None
}
})
self.nodes.iter_mut().find_map(
|node: &mut RemoteGroupNodeState<T, _>| {
if node.remote.id() == node_id { Some(node) } else { None }
},
)
}
fn get_node(&self, node_id: &NodeId) -> Option<&RemoteGroupNodeState<T, R::RequestId>> {
self.nodes.iter().find_map(|node: &RemoteGroupNodeState<T, _>| {
if node.remote.id() == node_id {
Some(node)
} else {
None
}
})
self.nodes.iter().find_map(
|node: &RemoteGroupNodeState<T, _>| {
if node.remote.id() == node_id { Some(node) } else { None }
},
)
}
fn get_leader_node(&self) -> Option<&RemoteGroupNodeState<T, R::RequestId>> {
@ -201,9 +198,9 @@ where T: RemoteMessageSender<Message = R::Message> + 'static,
}
pub fn timer_tick(&mut self, max_timeout_ticks: u32, max_request_quote_ticks: u32) {
self.timeout_ticks = self.timeout_ticks.saturating_add(1);
self.timeout_ticks = self.timeout_ticks.saturating_add(1);
self.request_quote_ticks = self.request_quote_ticks.saturating_add(1);
self.total_ticks = self.total_ticks.wrapping_add(1);
self.total_ticks = self.total_ticks.wrapping_add(1);
if self.timeout_ticks >= max_timeout_ticks {
self.timeout_ticks = Default::default();
@ -239,16 +236,19 @@ where T: RemoteMessageSender<Message = R::Message> + 'static,
}
let nodes = &self.nodes[..];
let maybe_old_leader =
self.leader.and_then(|leader: usize| nodes.get(leader))
.map(|leader: &RemoteGroupNodeState<T, _>| leader.remote.id());
let maybe_old_leader = self
.leader
.and_then(|leader: usize| nodes.get(leader))
.map(|leader: &RemoteGroupNodeState<T, _>| leader.remote.id());
if term >= self.term {
self.term = term;
// prevent re-send storm from a node responding NotLeader while contradictorily asserting itself as leader
if let Some(new_leader) = maybe_new_leader.filter(|new_leader: &&NodeId| new_leader != &from_node_id) {
if Some(new_leader) != maybe_old_leader {
info!("group {} changed leader to {} at term {}", &self.name, new_leader, &self.term.id);
self.leader = nodes.iter().position(|node: &RemoteGroupNodeState<T, _>| node.remote.id() == new_leader);
self.leader = nodes
.iter()
.position(|node: &RemoteGroupNodeState<T, _>| node.remote.id() == new_leader);
}
} else if let Some(old_leader) = maybe_old_leader {
info!("group {} lost leader {} at term {}", &self.name, old_leader, &self.term.id);
@ -265,16 +265,17 @@ where T: RemoteMessageSender<Message = R::Message> + 'static,
pub fn send(&mut self, request: R) -> Result<(), RemoteGroupSendError<R>> {
let request_id = request.request_id().clone();
let message = request.message();
let message = request.message();
if Some(&request_id) < self.pending.keys().last() {
return Err(RemoteGroupSendError::AlreadySent(request));
}
let nodes = &mut self.nodes[..];
let maybe_authorized_leader =
self.leader.and_then(|leader: usize| nodes.get_mut(leader))
.filter(|leader: &&mut RemoteGroupNodeState<T, _>| leader.remote.attestation().is_some());
let maybe_authorized_leader = self
.leader
.and_then(|leader: usize| nodes.get_mut(leader))
.filter(|leader: &&mut RemoteGroupNodeState<T, _>| leader.remote.attestation().is_some());
if let btree_map::Entry::Vacant(pending_request_entry) = self.pending.entry(request_id) {
let sent_at_tick = self.total_ticks;
if let Some(authorized_leader) = maybe_authorized_leader {
@ -297,12 +298,16 @@ where T: RemoteMessageSender<Message = R::Message> + 'static,
pub fn handle_reply(&mut self, request_id: &R::RequestId) -> Option<R> {
self.timeout_ticks = Default::default();
self.pending.remove(request_id)
.map(|request_state: PendingRequestState<R>| request_state.request)
self.pending
.remove(request_id)
.map(|request_state: PendingRequestState<R>| request_state.request)
}
pub fn get_remotes(&self) -> Vec<NodeId> {
self.nodes[..].iter().map(|node: &RemoteGroupNodeState<T, _>| node.remote.id().clone()).collect()
self.nodes[..]
.iter()
.map(|node: &RemoteGroupNodeState<T, _>| node.remote.id().clone())
.collect()
}
#[allow(clippy::indexing_slicing, clippy::integer_arithmetic)]
@ -314,11 +319,9 @@ where T: RemoteMessageSender<Message = R::Message> + 'static,
nodes.swap(nodes_idx, nodes_idx.wrapping_add(rand));
let node_idx = nodes[nodes_idx];
let node = &self.nodes[node_idx];
if (node.remote.attestation().is_some() &&
self.has_unsent_to(node))
{
self.leader = Some(node_idx);
let node = &self.nodes[node_idx];
if (node.remote.attestation().is_some() && self.has_unsent_to(node)) {
self.leader = Some(node_idx);
self.timeout_ticks = Default::default();
info!("group {} chose random leader {}", &self.name, node.remote.id());
break;
@ -330,9 +333,10 @@ where T: RemoteMessageSender<Message = R::Message> + 'static,
fn flush_requests(&mut self) {
let nodes = &mut self.nodes[..];
let maybe_authorized_leader =
self.leader.and_then(|leader: usize| nodes.get_mut(leader))
.filter(|leader: &&mut RemoteGroupNodeState<T, _>| leader.remote.attestation().is_some());
let maybe_authorized_leader = self
.leader
.and_then(|leader: usize| nodes.get_mut(leader))
.filter(|leader: &&mut RemoteGroupNodeState<T, _>| leader.remote.attestation().is_some());
if let Some(authorized_leader) = maybe_authorized_leader {
let mut queue = Vec::new();
let mut not_yet_valid_count: u64 = 0;
@ -348,13 +352,21 @@ where T: RemoteMessageSender<Message = R::Message> + 'static,
}
}
if not_yet_valid_count > 0 {
info!("group {} not sending {} messages to new leader {} due to attestation timestamp {}",
&self.name, not_yet_valid_count, authorized_leader.remote.id(),
OptionDisplay(authorized_leader.remote.attestation().as_ref()));
info!(
"group {} not sending {} messages to new leader {} due to attestation timestamp {}",
&self.name,
not_yet_valid_count,
authorized_leader.remote.id(),
OptionDisplay(authorized_leader.remote.attestation().as_ref())
);
}
if !queue.is_empty() {
info!("group {} resending {} messages to new leader {}",
&self.name, queue.len(), authorized_leader.remote.id());
info!(
"group {} resending {} messages to new leader {}",
&self.name,
queue.len(),
authorized_leader.remote.id()
);
for message in queue {
authorized_leader.send(message);
}
@ -373,35 +385,38 @@ where T: RemoteMessageSender<Message = R::Message> + 'static,
}
impl<T, RequestId> RemoteGroupNodeState<T, RequestId>
where T: RemoteMessageSender + 'static,
RequestId: Clone + Ord + Eq,
where
T: RemoteMessageSender + 'static,
RequestId: Clone + Ord + Eq,
{
fn send(&self, request: Rc<T::Message>) {
let _ignore = self.remote.send(request);
}
fn mark_sent(&mut self, sent_request_id: Option<&RequestId>) {
if sent_request_id > self.last_sent.as_ref() {
self.last_sent = sent_request_id.cloned();
}
}
fn has_sent<R>(&self, request: &R) -> bool
where R: RemoteGroupPendingRequest<RequestId = RequestId> + 'static,
{
where R: RemoteGroupPendingRequest<RequestId = RequestId> + 'static {
Some(request.request_id()) <= self.last_sent.as_ref()
}
}
impl<T,R> fmt::Display for RemoteGroupState<T,R>
where T: RemoteMessageSender<Message = R::Message> + RemoteGroupNode + 'static,
R: RemoteGroupPendingRequest + 'static,
impl<T, R> fmt::Display for RemoteGroupState<T, R>
where
T: RemoteMessageSender<Message = R::Message> + RemoteGroupNode + 'static,
R: RemoteGroupPendingRequest + 'static,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("RemoteGroupState")
.field("name", &self.name)
.field("nodes", &ListDisplay(self.nodes.iter().map(|node| node.remote.id())))
.field("leader", &OptionDisplay(self.get_leader_node().map(|node| node.remote.id())))
.field("term", &DisplayAsDebug(self.term))
.finish()
.field("name", &self.name)
.field("nodes", &ListDisplay(self.nodes.iter().map(|node| node.remote.id())))
.field("leader", &OptionDisplay(self.get_leader_node().map(|node| node.remote.id())))
.field("term", &DisplayAsDebug(self.term))
.finish()
}
}

View File

@ -7,27 +7,27 @@
use crate::prelude::*;
use std::cmp::{Ordering};
use std::ops::{Add};
use std::cmp::Ordering;
use std::ops::Add;
use std::rc::*;
use std::time::*;
use hashbrown::{HashMap};
use prost::{Message};
use sgx_ffi::util::{SecretValue};
use sgxsd_ffi::ecalls::{SgxsdMsgFrom};
use hashbrown::HashMap;
use prost::Message;
use sgx_ffi::util::SecretValue;
use sgxsd_ffi::ecalls::SgxsdMsgFrom;
use crate::{kbupd_send};
use crate::ffi::ecalls::*;
use crate::hasher::{DefaultHasher};
use crate::hasher::DefaultHasher;
use crate::kbupd_send;
use crate::protobufs::kbupd::*;
use crate::protobufs::kbupd_client;
use crate::protobufs::kbupd_enclave::*;
use crate::protobufs::raft::*;
use crate::util::*;
use crate::remote::*;
use crate::remote_group::*;
use crate::service::replica::{PartitionKeyRange};
use crate::service::replica::PartitionKeyRange;
use crate::util::*;
const NODE_TYPE: NodeType = NodeType::Frontend;
@ -67,9 +67,7 @@ struct PendingRequestId {
#[allow(variant_size_differences)]
enum PendingRequestFrom {
Client(PendingClientRequest),
Untrusted {
untrusted_request_id: u64,
}
Untrusted { untrusted_request_id: u64 },
}
struct PendingRequest {
@ -90,11 +88,11 @@ pub struct PendingClientRequest {
impl FrontendState {
pub fn init(request: StartFrontendRequest) -> Self {
let mut state = Self {
config: request.config,
replicas: PeerManager::new(NODE_TYPE),
partitions: Default::default(),
key_ranges: Default::default(),
last_request_id: Default::default(),
config: request.config,
replicas: PeerManager::new(NODE_TYPE),
partitions: Default::default(),
key_ranges: Default::default(),
last_request_id: Default::default(),
};
for partition_config in request.partitions {
@ -143,19 +141,22 @@ impl FrontendState {
}
fn start_replica_remote(&mut self, node_id: NodeId, group_id: RaftGroupId) -> Option<&Replica> {
match self.replicas.start_peer(node_id, NodeType::Replica, RemoteAuthorizationType::RemoteOnly) {
Ok(replica_entry) => {
match replica_entry.connect(|remote| Replica { remote, group_id }) {
Ok(replica) => Some(replica),
Err((replica_entry, mapper)) => {
warn!("inserting disconnected replica entry for {} due to connect error",
replica_entry.remote().id());
Some(replica_entry.insert(mapper))
}
match self
.replicas
.start_peer(node_id, NodeType::Replica, RemoteAuthorizationType::RemoteOnly)
{
Ok(replica_entry) => match replica_entry.connect(|remote| Replica { remote, group_id }) {
Ok(replica) => Some(replica),
Err((replica_entry, mapper)) => {
warn!(
"inserting disconnected replica entry for {} due to connect error",
replica_entry.remote().id()
);
Some(replica_entry.insert(mapper))
}
}
},
Err(Some(replica)) => Some(replica),
Err(None) => None,
Err(None) => None,
}
}
@ -165,38 +166,23 @@ impl FrontendState {
pub fn untrusted_message(&mut self, untrusted_message: UntrustedMessage) {
match untrusted_message.inner {
Some(untrusted_message::Inner::StartFrontendRequest(_)) |
Some(untrusted_message::Inner::StartReplicaRequest(_)) =>
(),
Some(untrusted_message::Inner::StartFrontendRequest(_)) | Some(untrusted_message::Inner::StartReplicaRequest(_)) => (),
Some(untrusted_message::Inner::StartReplicaGroupRequest(_)) =>
(),
Some(untrusted_message::Inner::UntrustedTransactionRequest(request)) =>
self.handle_untrusted_transaction_request(request),
Some(untrusted_message::Inner::UntrustedXferRequest(_)) =>
(),
Some(untrusted_message::Inner::GetEnclaveStatusRequest(request)) =>
self.handle_get_enclave_status_request(request),
Some(untrusted_message::Inner::StartReplicaGroupRequest(_)) => (),
Some(untrusted_message::Inner::UntrustedTransactionRequest(request)) => self.handle_untrusted_transaction_request(request),
Some(untrusted_message::Inner::UntrustedXferRequest(_)) => (),
Some(untrusted_message::Inner::GetEnclaveStatusRequest(request)) => self.handle_get_enclave_status_request(request),
Some(untrusted_message::Inner::GetQeInfoReply(reply)) =>
self.handle_get_qe_info_reply(reply),
Some(untrusted_message::Inner::GetQuoteReply(reply)) =>
self.handle_get_quote_reply(reply),
Some(untrusted_message::Inner::GetAttestationReply(reply)) =>
self.handle_get_attestation_reply(reply),
Some(untrusted_message::Inner::GetQeInfoReply(reply)) => self.handle_get_qe_info_reply(reply),
Some(untrusted_message::Inner::GetQuoteReply(reply)) => self.handle_get_quote_reply(reply),
Some(untrusted_message::Inner::GetAttestationReply(reply)) => self.handle_get_attestation_reply(reply),
Some(untrusted_message::Inner::NewMessageSignal(signal)) =>
self.handle_new_message_signal(signal),
Some(untrusted_message::Inner::TimerTickSignal(signal)) =>
self.handle_timer_tick_signal(signal),
Some(untrusted_message::Inner::SetFrontendConfigSignal(signal)) =>
self.handle_set_frontend_config_signal(signal),
Some(untrusted_message::Inner::SetReplicaConfigSignal(_)) =>
(),
Some(untrusted_message::Inner::ResetPeerSignal(signal)) =>
self.handle_reset_peer_signal(signal),
Some(untrusted_message::Inner::SetVerboseLoggingSignal(signal)) =>
self.handle_set_verbose_logging_signal(signal),
Some(untrusted_message::Inner::NewMessageSignal(signal)) => self.handle_new_message_signal(signal),
Some(untrusted_message::Inner::TimerTickSignal(signal)) => self.handle_timer_tick_signal(signal),
Some(untrusted_message::Inner::SetFrontendConfigSignal(signal)) => self.handle_set_frontend_config_signal(signal),
Some(untrusted_message::Inner::SetReplicaConfigSignal(_)) => (),
Some(untrusted_message::Inner::ResetPeerSignal(signal)) => self.handle_reset_peer_signal(signal),
Some(untrusted_message::Inner::SetVerboseLoggingSignal(signal)) => self.handle_set_verbose_logging_signal(signal),
None => (),
}
@ -218,11 +204,7 @@ impl FrontendState {
}
fn handle_get_enclave_status_request(&mut self, request: GetEnclaveStatusRequest) {
let memory_status = if request.memory_status {
Some(memory_status())
} else {
None
};
let memory_status = if request.memory_status { Some(memory_status()) } else { None };
let mut partitions = Vec::with_capacity(self.partitions.len());
for (group_id, partition) in &self.partitions {
partitions.push(EnclaveFrontendPartitionStatus {
@ -276,17 +258,22 @@ impl FrontendState {
self.replica_message(message, from_node_id);
}
Ok(None) => (),
Err(peer_entry) => {
warn!("unsolicited connect request from {}: {}", peer_entry.node_id(), peer_entry.connect_request())
}
Err(peer_entry) => warn!(
"unsolicited connect request from {}: {}",
peer_entry.node_id(),
peer_entry.connect_request()
),
}
}
fn handle_timer_tick_signal(&mut self, _signal: TimerTickSignal) {
self.replicas.timer_tick(self.config.min_connect_timeout_ticks, self.config.max_connect_timeout_ticks);
self.replicas
.timer_tick(self.config.min_connect_timeout_ticks, self.config.max_connect_timeout_ticks);
for partition in self.partitions.values_mut() {
partition.remote_group.timer_tick(self.config.replica_timeout_ticks, self.config.request_quote_ticks);
partition
.remote_group
.timer_tick(self.config.replica_timeout_ticks, self.config.request_quote_ticks);
}
}
@ -312,10 +299,12 @@ impl FrontendState {
fn replica_message(&mut self, replica_message: ReplicaToFrontendMessage, from_node_id: NodeId) {
match replica_message.inner {
Some(replica_to_frontend_message::Inner::TransactionReply(transaction_reply)) =>
self.handle_transaction_reply(transaction_reply, from_node_id),
Some(replica_to_frontend_message::Inner::EnclaveGetQuoteReply(reply)) =>
self.handle_enclave_get_quote_reply(reply, from_node_id),
Some(replica_to_frontend_message::Inner::TransactionReply(transaction_reply)) => {
self.handle_transaction_reply(transaction_reply, from_node_id)
}
Some(replica_to_frontend_message::Inner::EnclaveGetQuoteReply(reply)) => {
self.handle_enclave_get_quote_reply(reply, from_node_id)
}
None => (),
}
}
@ -324,8 +313,9 @@ impl FrontendState {
if let Some((replica, partition)) = Self::get_partition_replica_mut(&mut self.replicas, &mut self.partitions, &from_node_id) {
match transaction_reply.data {
Some(transaction_reply::Data::ClientResponse(client_reply)) => {
let maybe_pending_request: Option<PendingRequest> =
partition.remote_group.handle_reply(&PendingRequestId { id: transaction_reply.request_id });
let maybe_pending_request: Option<PendingRequest> = partition.remote_group.handle_reply(&PendingRequestId {
id: transaction_reply.request_id,
});
if let Some(PendingRequestFrom::Client(pending_client_request)) =
maybe_pending_request.map(|pending_request| pending_request.from)
{
@ -335,25 +325,33 @@ impl FrontendState {
}
}
Some(transaction_reply::Data::InvalidRequest(_invalid_request_error)) => {
error!("replica {} reported InvalidRequest {}", &from_node_id, &transaction_reply.request_id);
let maybe_pending_request: Option<PendingRequest> =
partition.remote_group.handle_reply(&PendingRequestId { id: transaction_reply.request_id });
error!(
"replica {} reported InvalidRequest {}",
&from_node_id, &transaction_reply.request_id
);
let maybe_pending_request: Option<PendingRequest> = partition.remote_group.handle_reply(&PendingRequestId {
id: transaction_reply.request_id,
});
if let Some(pending_request) = maybe_pending_request {
self.cancel_pending_request(pending_request.from);
}
}
Some(transaction_reply::Data::InternalError(_internal_error)) => {
warn!("replica {} reported InternalError on request {}!",
&from_node_id, &transaction_reply.request_id);
let maybe_pending_request: Option<PendingRequest> =
partition.remote_group.handle_reply(&PendingRequestId { id: transaction_reply.request_id });
warn!(
"replica {} reported InternalError on request {}!",
&from_node_id, &transaction_reply.request_id
);
let maybe_pending_request: Option<PendingRequest> = partition.remote_group.handle_reply(&PendingRequestId {
id: transaction_reply.request_id,
});
if let Some(pending_request) = maybe_pending_request {
self.cancel_pending_request(pending_request.from);
}
}
Some(transaction_reply::Data::CreateBackupReply(create_backup_reply)) => {
let maybe_pending_request: Option<PendingRequest> =
partition.remote_group.handle_reply(&PendingRequestId { id: transaction_reply.request_id });
let maybe_pending_request: Option<PendingRequest> = partition.remote_group.handle_reply(&PendingRequestId {
id: transaction_reply.request_id,
});
if let Some(PendingRequestFrom::Untrusted { untrusted_request_id }) =
maybe_pending_request.map(|pending_request| pending_request.from)
{
@ -364,13 +362,13 @@ impl FrontendState {
})),
});
} else {
info!("pending untrusted transaction request {} not found",
transaction_reply.request_id);
info!("pending untrusted transaction request {} not found", transaction_reply.request_id);
}
}
Some(transaction_reply::Data::DeleteBackupReply(delete_backup_reply)) => {
let maybe_pending_request: Option<PendingRequest> =
partition.remote_group.handle_reply(&PendingRequestId { id: transaction_reply.request_id });
let maybe_pending_request: Option<PendingRequest> = partition.remote_group.handle_reply(&PendingRequestId {
id: transaction_reply.request_id,
});
match maybe_pending_request.map(|pending_request| pending_request.from) {
Some(PendingRequestFrom::Client(pending_client_request)) => {
pending_client_request.reply(&kbupd_client::Response {
@ -394,14 +392,21 @@ impl FrontendState {
}
Some(transaction_reply::Data::NotLeader(not_leader_error_data)) => {
let new_leader: Option<NodeId> = not_leader_error_data.leader_node_id.map(NodeId::from);
partition.remote_group.remote_not_leader(not_leader_error_data.term, new_leader.as_ref(), &from_node_id);
verbose!("replica {} reported NotLeader for partition {} with new leader {} at term {}",
replica.remote.id(), &replica.group_id, OptionDisplay(new_leader.as_ref()),
&not_leader_error_data.term);
partition
.remote_group
.remote_not_leader(not_leader_error_data.term, new_leader.as_ref(), &from_node_id);
verbose!(
"replica {} reported NotLeader for partition {} with new leader {} at term {}",
replica.remote.id(),
&replica.group_id,
OptionDisplay(new_leader.as_ref()),
&not_leader_error_data.term
);
}
Some(transaction_reply::Data::WrongPartition(wrong_partition_error_data)) => {
let maybe_pending_request: Option<PendingRequest> =
partition.remote_group.handle_reply(&PendingRequestId { id: transaction_reply.request_id });
let maybe_pending_request: Option<PendingRequest> = partition.remote_group.handle_reply(&PendingRequestId {
id: transaction_reply.request_id,
});
if let Some(range) = &wrong_partition_error_data.range {
match PartitionKeyRange::try_from_pb(range) {
Ok(range) => {
@ -409,8 +414,10 @@ impl FrontendState {
self.key_ranges.update(&replica.group_id, range);
}
Err(()) => {
error!("partition {} reported WrongPartition with invalid range {}",
&replica.group_id, range);
error!(
"partition {} reported WrongPartition with invalid range {}",
&replica.group_id, range
);
}
}
} else {
@ -418,36 +425,49 @@ impl FrontendState {
self.key_ranges.remove(&replica.group_id);
}
if let Some(new_partition) = wrong_partition_error_data.new_partition {
info!("partition {} reported WrongPartition with new partition {} and range {}",
&replica.group_id, ToHex(&new_partition.group_id), OptionDisplay(new_partition.range.as_ref()));
info!(
"partition {} reported WrongPartition with new partition {} and range {}",
&replica.group_id,
ToHex(&new_partition.group_id),
OptionDisplay(new_partition.range.as_ref())
);
self.update_partition(new_partition);
} else {
warn!("partition {} reported WrongPartition but didn't know the right one!", &replica.group_id);
warn!(
"partition {} reported WrongPartition but didn't know the right one!",
&replica.group_id
);
}
if let Some(pending_request) = maybe_pending_request {
self.send_transaction_request(pending_request);
}
}
Some(transaction_reply::Data::ServiceIdMismatch(_service_id_mismatch_data)) => {
let maybe_pending_request: Option<PendingRequest> =
partition.remote_group.handle_reply(&PendingRequestId { id: transaction_reply.request_id });
let maybe_pending_request: Option<PendingRequest> = partition.remote_group.handle_reply(&PendingRequestId {
id: transaction_reply.request_id,
});
if let Some(pending_request) = maybe_pending_request {
warn!("partition {} reported ServiceIdMismatch");
self.cancel_pending_request(pending_request.from);
}
}
Some(transaction_reply::Data::XferInProgress(_xfer_in_progress_data)) => {
let maybe_pending_request: Option<PendingRequest> =
partition.remote_group.handle_reply(&PendingRequestId { id: transaction_reply.request_id });
let maybe_pending_request: Option<PendingRequest> = partition.remote_group.handle_reply(&PendingRequestId {
id: transaction_reply.request_id,
});
if let Some(pending_request) = maybe_pending_request {
info!("partition {} reported XferInProgress for backup id {}",
&replica.group_id, OptionDisplay(pending_request.backup_id()));
info!(
"partition {} reported XferInProgress for backup id {}",
&replica.group_id,
OptionDisplay(pending_request.backup_id())
);
self.cancel_pending_request(pending_request.from);
}
}
None => {
let maybe_pending_request: Option<PendingRequest> =
partition.remote_group.handle_reply(&PendingRequestId { id: transaction_reply.request_id });
let maybe_pending_request: Option<PendingRequest> = partition.remote_group.handle_reply(&PendingRequestId {
id: transaction_reply.request_id,
});
if let Some(pending_request) = maybe_pending_request {
self.cancel_pending_request(pending_request.from);
}
@ -472,10 +492,7 @@ impl FrontendState {
let id = self.last_request_id.clone() + 1;
self.last_request_id = id.clone();
let min_attestation = match &data {
transaction_request::Data::Create(_) |
transaction_request::Data::Delete(_) => {
None
}
transaction_request::Data::Create(_) | transaction_request::Data::Delete(_) => None,
transaction_request::Data::Backup(BackupTransactionRequest { valid_from, .. }) |
transaction_request::Data::Restore(RestoreTransactionRequest { valid_from, .. }) => {
Some(AttestationParameters::new(Duration::from_secs(*valid_from)))
@ -500,25 +517,30 @@ impl FrontendState {
let Self { partitions, .. } = self;
if pending_request.id != self.last_request_id {
pending_request.id = self.last_request_id.clone() + 1;
pending_request.id = self.last_request_id.clone() + 1;
self.last_request_id = pending_request.id.clone();
match &mut Rc::make_mut(&mut pending_request.message).inner {
Some(frontend_to_replica_message::Inner::TransactionRequest(txn_request)) => {
txn_request.request_id = pending_request.id.id;
}
Some(frontend_to_replica_message::Inner::EnclaveGetQuoteRequest(_)) |
None => (),
Some(frontend_to_replica_message::Inner::EnclaveGetQuoteRequest(_)) | None => (),
}
}
if let Some(backup_id) = pending_request.backup_id() {
let maybe_group_id: Option<&RaftGroupId> = self.key_ranges.find(backup_id);
let maybe_group_id: Option<&RaftGroupId> = self.key_ranges.find(backup_id);
let maybe_partition: Option<&mut Partition> = maybe_group_id.and_then(|group_id| partitions.get_mut(group_id));
if let Some(partition) = maybe_partition {
let trimmed = partition.remote_group.trim_to(self.config.pending_request_count.saturating_sub(1).to_usize(),
self.config.pending_request_ttl);
let trimmed = partition.remote_group.trim_to(
self.config.pending_request_count.saturating_sub(1).to_usize(),
self.config.pending_request_ttl,
);
if trimmed.len() != 0 {
info!("dropping {} old pending requests for partition {}", trimmed.len(), partition.remote_group.name());
info!(
"dropping {} old pending requests for partition {}",
trimmed.len(),
partition.remote_group.name()
);
}
if partition.remote_group.pending_len() < self.config.pending_request_count.to_usize() {
@ -564,87 +586,85 @@ impl FrontendState {
}
}
fn get_partition_replica_mut<'a, 'b>(replicas: &'a mut PeerManager<Replica>,
partitions: &'b mut HashMap<RaftGroupId, Partition, DefaultHasher>,
node_id: &NodeId)
-> Option<(&'a mut Replica, &'b mut Partition)> {
let replica = replicas.get_peer_mut(node_id)?;
fn get_partition_replica_mut<'a, 'b>(
replicas: &'a mut PeerManager<Replica>,
partitions: &'b mut HashMap<RaftGroupId, Partition, DefaultHasher>,
node_id: &NodeId,
) -> Option<(&'a mut Replica, &'b mut Partition)>
{
let replica = replicas.get_peer_mut(node_id)?;
let partition = partitions.get_mut(&replica.group_id)?;
Some((replica, partition))
}
pub fn decode_request(&self, request_type: u32, backup_id: Vec<u8>, request_data: &[u8]) -> Result<transaction_request::Data, ()> {
let request = kbupd_client::Request::decode(request_data).map_err(|_| ())?;
let request = kbupd_client::Request::decode(request_data).map_err(|_| ())?;
let backup_id = BackupId::try_from_slice(&backup_id)?;
match request {
kbupd_client::Request {
backup: Some(backup_request),
backup: Some(backup_request),
restore: None,
delete: None,
} => {
match request_type {
KBUPD_REQUEST_TYPE_ANY |
KBUPD_REQUEST_TYPE_BACKUP => {
self.validate_backup_request(backup_id, backup_request)
}
_ => Err(()),
}
}
delete: None,
} => match request_type {
KBUPD_REQUEST_TYPE_ANY | KBUPD_REQUEST_TYPE_BACKUP => self.validate_backup_request(backup_id, backup_request),
_ => Err(()),
},
kbupd_client::Request {
backup: None,
backup: None,
restore: Some(restore_request),
delete: None,
} => {
match request_type {
KBUPD_REQUEST_TYPE_ANY |
KBUPD_REQUEST_TYPE_RESTORE => {
Self::validate_restore_request(backup_id, restore_request)
}
_ => Err(()),
}
}
delete: None,
} => match request_type {
KBUPD_REQUEST_TYPE_ANY | KBUPD_REQUEST_TYPE_RESTORE => Self::validate_restore_request(backup_id, restore_request),
_ => Err(()),
},
kbupd_client::Request {
backup: None,
backup: None,
restore: None,
delete: Some(delete_request),
} => {
match request_type {
KBUPD_REQUEST_TYPE_ANY |
KBUPD_REQUEST_TYPE_DELETE => {
Self::validate_delete_request(backup_id, delete_request)
}
_ => Err(()),
}
}
delete: Some(delete_request),
} => match request_type {
KBUPD_REQUEST_TYPE_ANY | KBUPD_REQUEST_TYPE_DELETE => Self::validate_delete_request(backup_id, delete_request),
_ => Err(()),
},
_ => Err(()),
}
}
fn validate_backup_request(&self, backup_id: BackupId, mut request: kbupd_client::BackupRequest) -> Result<transaction_request::Data, ()> {
fn validate_backup_request(
&self,
backup_id: BackupId,
mut request: kbupd_client::BackupRequest,
) -> Result<transaction_request::Data, ()>
{
if let kbupd_client::BackupRequest {
service_id,
backup_id: Some(request_backup_id),
nonce: Some(nonce),
backup_id: Some(request_backup_id),
nonce: Some(nonce),
valid_from: Some(valid_from),
data: Some(data),
pin: Some(pin),
tries: Some(tries),
} = &mut request {
data: Some(data),
pin: Some(pin),
tries: Some(tries),
} = &mut request
{
if (Self::validate_request_service_id(service_id) &&
request_backup_id == &backup_id.id &&
nonce.len() == 32 &&
data.len() <= self.config.max_backup_data_length.to_usize() &&
pin.len() == 32 &&
*tries != 0 && *tries <= u16::max_value().into())
nonce.len() == 32 &&
data.len() <= self.config.max_backup_data_length.to_usize() &&
pin.len() == 32 &&
*tries != 0 &&
*tries <= u16::max_value().into())
{
Ok(transaction_request::Data::Backup(BackupTransactionRequest {
service_id: service_id.take(),
backup_id,
nonce: std::mem::replace(nonce, Vec::new()),
nonce: std::mem::replace(nonce, Vec::new()),
valid_from: *valid_from,
data: SecretBytes { data: std::mem::replace(data, Vec::new()) },
pin: SecretBytes { data: std::mem::replace(pin, Vec::new()) },
tries: *tries,
data: SecretBytes {
data: std::mem::replace(data, Vec::new()),
},
pin: SecretBytes {
data: std::mem::replace(pin, Vec::new()),
},
tries: *tries,
}))
} else {
Err(())
@ -657,22 +677,22 @@ impl FrontendState {
fn validate_restore_request(backup_id: BackupId, mut request: kbupd_client::RestoreRequest) -> Result<transaction_request::Data, ()> {
if let kbupd_client::RestoreRequest {
service_id,
backup_id: Some(request_backup_id),
nonce: Some(nonce),
backup_id: Some(request_backup_id),
nonce: Some(nonce),
valid_from: Some(valid_from),
pin: Some(pin),
} = &mut request {
if (Self::validate_request_service_id(service_id) &&
request_backup_id == &backup_id.id &&
nonce.len() == 32 &&
pin.len() == 32)
pin: Some(pin),
} = &mut request
{
if (Self::validate_request_service_id(service_id) && request_backup_id == &backup_id.id && nonce.len() == 32 && pin.len() == 32)
{
Ok(transaction_request::Data::Restore(RestoreTransactionRequest {
service_id: service_id.take(),
backup_id,
valid_from: *valid_from,
nonce: std::mem::replace(nonce, Vec::new()),
pin: SecretBytes { data: std::mem::replace(pin, Vec::new()), },
nonce: std::mem::replace(nonce, Vec::new()),
pin: SecretBytes {
data: std::mem::replace(pin, Vec::new()),
},
}))
} else {
Err(())
@ -685,11 +705,10 @@ impl FrontendState {
fn validate_delete_request(backup_id: BackupId, request: kbupd_client::DeleteRequest) -> Result<transaction_request::Data, ()> {
if let kbupd_client::DeleteRequest {
service_id,
backup_id: Some(request_backup_id),
} = request {
if (Self::validate_request_service_id(&service_id) &&
request_backup_id == backup_id.id)
{
backup_id: Some(request_backup_id),
} = request
{
if (Self::validate_request_service_id(&service_id) && request_backup_id == backup_id.id) {
Ok(transaction_request::Data::Delete(DeleteTransactionRequest {
service_id,
backup_id,
@ -714,45 +733,46 @@ impl FrontendState {
fn reject_pending_request_not_yet_valid(pending_request: PendingRequest) {
match pending_request.from {
PendingRequestFrom::Client(pending_client_request) => {
info!("rejecting not yet valid client request {} requiring {}",
&pending_request.id.id, OptionDisplay(pending_request.min_attestation.as_ref()));
info!(
"rejecting not yet valid client request {} requiring {}",
&pending_request.id.id,
OptionDisplay(pending_request.min_attestation.as_ref())
);
match &pending_request.message.as_ref().inner {
Some(frontend_to_replica_message::Inner::TransactionRequest(TransactionRequest { data, .. })) => {
match data {
Some(transaction_request::Data::Backup(_)) => {
pending_client_request.reply(&kbupd_client::Response {
backup: Some(kbupd_client::BackupResponse {
status: Some(kbupd_client::backup_response::Status::NotYetValid.into()),
nonce: None,
}),
restore: None,
delete: None,
});
}
Some(transaction_request::Data::Restore(_)) => {
pending_client_request.reply(&kbupd_client::Response {
backup: None,
restore: Some(kbupd_client::RestoreResponse {
status: Some(kbupd_client::restore_response::Status::NotYetValid.into()),
nonce: None,
data: None,
tries: None,
}),
delete: None,
});
}
Some(transaction_request::Data::Create(_)) |
Some(transaction_request::Data::Delete(_)) |
None => (),
Some(frontend_to_replica_message::Inner::TransactionRequest(TransactionRequest { data, .. })) => match data {
Some(transaction_request::Data::Backup(_)) => {
pending_client_request.reply(&kbupd_client::Response {
backup: Some(kbupd_client::BackupResponse {
status: Some(kbupd_client::backup_response::Status::NotYetValid.into()),
nonce: None,
}),
restore: None,
delete: None,
});
}
}
Some(frontend_to_replica_message::Inner::EnclaveGetQuoteRequest(_)) |
None => (),
Some(transaction_request::Data::Restore(_)) => {
pending_client_request.reply(&kbupd_client::Response {
backup: None,
restore: Some(kbupd_client::RestoreResponse {
status: Some(kbupd_client::restore_response::Status::NotYetValid.into()),
nonce: None,
data: None,
tries: None,
}),
delete: None,
});
}
Some(transaction_request::Data::Create(_)) | Some(transaction_request::Data::Delete(_)) | None => (),
},
Some(frontend_to_replica_message::Inner::EnclaveGetQuoteRequest(_)) | None => (),
}
}
PendingRequestFrom::Untrusted { untrusted_request_id } => {
info!("rejecting not yet valid untrusted request {} requiring {}",
untrusted_request_id, OptionDisplay(pending_request.min_attestation.as_ref()));
info!(
"rejecting not yet valid untrusted request {} requiring {}",
untrusted_request_id,
OptionDisplay(pending_request.min_attestation.as_ref())
);
}
}
}
@ -765,30 +785,31 @@ impl PartitionKeyRanges {
fn range_cmp(one: &PartitionKeyRange, two: &PartitionKeyRange) -> Ordering {
(one.first(), one.last()).cmp(&(two.first(), two.last()))
}
fn entry_cmp(one: &(PartitionKeyRange, RaftGroupId), two: &(PartitionKeyRange, RaftGroupId)) -> Ordering {
Self::range_cmp(&one.0, &two.0)
}
fn key_cmp(range: &PartitionKeyRange, key: &BackupId) -> Ordering {
match range.first().as_ref().cmp(&key.id) {
Ordering::Greater => Ordering::Greater,
Ordering::Less |
Ordering::Equal => {
match range.last().as_ref().cmp(&key.id) {
Ordering::Less => Ordering::Less,
Ordering::Greater |
Ordering::Equal => Ordering::Equal,
}
}
Ordering::Less | Ordering::Equal => match range.last().as_ref().cmp(&key.id) {
Ordering::Less => Ordering::Less,
Ordering::Greater | Ordering::Equal => Ordering::Equal,
},
}
}
fn update(&mut self, update_group_id: &RaftGroupId, update_range: PartitionKeyRange) {
let mut matches =
self.ranges.iter_mut()
.filter(|(_, group_id)| group_id == update_group_id)
.peekable();
let mut matches = self
.ranges
.iter_mut()
.filter(|(_, group_id)| group_id == update_group_id)
.peekable();
if matches.peek().is_none() {
match self.ranges[..].binary_search_by(|(range, _)| Self::range_cmp(range, &update_range))
.and_then(|ranges_index| self.ranges.get_mut(ranges_index).ok_or(ranges_index))
match self.ranges[..]
.binary_search_by(|(range, _)| Self::range_cmp(range, &update_range))
.and_then(|ranges_index| self.ranges.get_mut(ranges_index).ok_or(ranges_index))
{
Ok((_, group_id)) => {
*group_id = update_group_id.clone();
@ -804,9 +825,11 @@ impl PartitionKeyRanges {
}
self.ranges[..].sort_unstable_by(Self::entry_cmp);
}
fn remove(&mut self, remove_group_id: &RaftGroupId) {
self.ranges.retain(|(_range, group_id)| group_id != remove_group_id);
}
fn find<'a>(&'a self, key: &BackupId) -> Option<&'a RaftGroupId> {
self.ranges[..]
.binary_search_by(|(range, _)| Self::key_cmp(range, key))
@ -823,8 +846,11 @@ impl PartitionKeyRanges {
impl Add<u64> for PendingRequestId {
type Output = Self;
fn add(self, inc: u64) -> Self {
Self { id: self.id.checked_add(inc).unwrap_or_else(|| panic!("overflow")) }
Self {
id: self.id.checked_add(inc).unwrap_or_else(|| panic!("overflow")),
}
}
}
@ -834,12 +860,14 @@ impl Add<u64> for PendingRequestId {
impl PendingRequest {
fn backup_id(&self) -> Option<&BackupId> {
if let Some(frontend_to_replica_message::Inner::TransactionRequest(TransactionRequest { data: Some(data), .. })) = &self.message.inner {
if let Some(frontend_to_replica_message::Inner::TransactionRequest(TransactionRequest { data: Some(data), .. })) =
&self.message.inner
{
match data {
transaction_request::Data::Create(create_request) => Some(&create_request.backup_id),
transaction_request::Data::Backup(backup_request) => Some(&backup_request.backup_id),
transaction_request::Data::Create(create_request) => Some(&create_request.backup_id),
transaction_request::Data::Backup(backup_request) => Some(&backup_request.backup_id),
transaction_request::Data::Restore(restore_request) => Some(&restore_request.backup_id),
transaction_request::Data::Delete(delete_request) => Some(&delete_request.backup_id),
transaction_request::Data::Delete(delete_request) => Some(&delete_request.backup_id),
}
} else {
None
@ -848,14 +876,17 @@ impl PendingRequest {
}
impl RemoteGroupPendingRequest for PendingRequest {
type Message = FrontendToReplicaMessage;
type RequestId = PendingRequestId;
type Message = FrontendToReplicaMessage;
fn request_id(&self) -> &Self::RequestId {
&self.id
}
fn message(&self) -> Rc<Self::Message> {
Rc::clone(&self.message)
}
fn min_attestation(&self) -> Option<AttestationParameters> {
self.min_attestation
}
@ -893,12 +924,15 @@ impl PendingClientRequest {
impl Peer for Replica {
type Message = ReplicaToFrontendMessage;
fn remote_mut(&mut self) -> &mut dyn Remote {
&mut self.remote
}
fn recv(&mut self, msg_data: &[u8]) -> Result<Self::Message, RemoteRecvError> {
self.remote.recv(msg_data)
}
fn send_quote_reply(&mut self, _reply: EnclaveGetQuoteReply) -> Result<(), ()> {
Ok(())
}
@ -920,8 +954,9 @@ impl RemoteGroupNode for RemoteSender<FrontendToReplicaMessage> {
// internal
//
fn validate_untrusted_transaction_request(request_data: Option<untrusted_transaction_request::Data>)
-> Result<transaction_request::Data, ()> {
fn validate_untrusted_transaction_request(
request_data: Option<untrusted_transaction_request::Data>,
) -> Result<transaction_request::Data, ()> {
match request_data {
Some(untrusted_transaction_request::Data::CreateBackupRequest(create_backup_request)) => {
if create_backup_request.backup_id.id.len() == 32 {

View File

@ -11,8 +11,8 @@ use sgx_ffi::sgx::*;
use sgxsd_ffi::ecalls::*;
use crate::ffi::ecalls::*;
use crate::protobufs::kbupd::*;
use crate::protobufs::kbupd::untrusted_message;
use crate::protobufs::kbupd::*;
use crate::service::frontend::*;
use crate::service::replica::*;
@ -27,13 +27,11 @@ pub enum ServiceState {
Replica(ReplicaState),
}
pub struct SgxsdState {
}
pub struct SgxsdState {}
#[cfg(not(any(test, feature = "test")))]
pub fn whereis<F, R>(fun: F) -> R
where F: FnOnce(&RefCell<ServiceState>) -> R
{
where F: FnOnce(&RefCell<ServiceState>) -> R {
#[thread_local]
static SERVICE: RefCell<ServiceState> = RefCell::new(ServiceState::NotStarted);
@ -42,8 +40,7 @@ where F: FnOnce(&RefCell<ServiceState>) -> R
#[cfg(any(test, feature = "test"))]
pub fn whereis<F, R>(fun: F) -> R
where F: FnOnce(&RefCell<ServiceState>) -> R
{
where F: FnOnce(&RefCell<ServiceState>) -> R {
thread_local! {
static SERVICE: RefCell<ServiceState> = RefCell::new(ServiceState::NotStarted);
}
@ -72,39 +69,37 @@ impl KbupdService for ServiceState {
warn!("node service already started");
}
}
Some(_) => {
match self {
ServiceState::Replica(replica) => {
replica.untrusted_message(msg);
}
ServiceState::Frontend(frontend) => {
frontend.untrusted_message(msg);
}
ServiceState::NotStarted => {
warn!("node service not started");
}
Some(_) => match self {
ServiceState::Replica(replica) => {
replica.untrusted_message(msg);
}
}
None => {
}
ServiceState::Frontend(frontend) => {
frontend.untrusted_message(msg);
}
ServiceState::NotStarted => {
warn!("node service not started");
}
},
None => {}
}
}
}
impl SgxsdServer for SgxsdState {
type InitArgs = StartArgs;
type HandleCallArgs = CallArgs;
type TerminateArgs = StopArgs;
type InitArgs = StartArgs;
type TerminateArgs = StopArgs;
fn init(_args: Option<&Self::InitArgs>) -> Result<Self, SgxStatus> {
Ok(Self {})
}
fn handle_call(&mut self,
args: Option<&Self::HandleCallArgs>,
request_data: &[u8],
from: SgxsdMsgFrom)
-> Result<(), (SgxStatus, SgxsdMsgFrom)>
fn handle_call(
&mut self,
args: Option<&Self::HandleCallArgs>,
request_data: &[u8],
from: SgxsdMsgFrom,
) -> Result<(), (SgxStatus, SgxsdMsgFrom)>
{
let args = match args {
Some(args) => args,

View File

@ -5,6 +5,6 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//
pub mod main;
pub mod frontend;
pub mod main;
pub mod replica;

View File

@ -11,25 +11,25 @@ mod replica_group;
use crate::prelude::*;
use std::convert::{TryInto};
use std::collections::*;
use std::time::*;
use std::convert::TryInto;
use std::rc::*;
use std::time::*;
use prost::{Message};
use sgx_ffi::util::{SecretValue};
use sgxsd_ffi::{RdRand};
use prost::Message;
use sgx_ffi::util::SecretValue;
use sgxsd_ffi::RdRand;
use crate::ffi::ecalls::{kbupd_send};
use crate::ffi::ecalls::kbupd_send;
use crate::lru::*;
use crate::protobufs::kbupd::*;
use crate::protobufs::kbupd_enclave::*;
use crate::protobufs::kbupd_client;
use crate::protobufs::kbupd_enclave::*;
use crate::protobufs::raft::*;
use crate::storage::*;
use crate::raft::*;
use crate::remote::*;
use crate::remote_group::*;
use crate::storage::*;
use crate::util;
use crate::util::*;
@ -66,7 +66,7 @@ enum PeerState {
Replica {
remote: RemoteState<ReplicaToReplicaMessage, ReplicaToReplicaMessage>,
authorized: bool,
}
},
}
enum PeerMessage {
@ -106,45 +106,32 @@ impl ReplicaState {
pub fn untrusted_message(&mut self, untrusted_message: UntrustedMessage) {
match untrusted_message.inner {
Some(untrusted_message::Inner::StartFrontendRequest(_)) |
Some(untrusted_message::Inner::StartReplicaRequest(_)) =>
(),
Some(untrusted_message::Inner::StartFrontendRequest(_)) | Some(untrusted_message::Inner::StartReplicaRequest(_)) => (),
Some(untrusted_message::Inner::StartReplicaGroupRequest(request)) =>
self.handle_start_replica_group_request(request),
Some(untrusted_message::Inner::UntrustedTransactionRequest(request)) =>
warn!("received untrusted transaction request: {}", request),
Some(untrusted_message::Inner::UntrustedXferRequest(request)) =>
self.handle_untrusted_xfer_request(request),
Some(untrusted_message::Inner::GetEnclaveStatusRequest(request)) =>
self.handle_get_enclave_status_request(request),
Some(untrusted_message::Inner::StartReplicaGroupRequest(request)) => self.handle_start_replica_group_request(request),
Some(untrusted_message::Inner::UntrustedTransactionRequest(request)) => {
warn!("received untrusted transaction request: {}", request)
}
Some(untrusted_message::Inner::UntrustedXferRequest(request)) => self.handle_untrusted_xfer_request(request),
Some(untrusted_message::Inner::GetEnclaveStatusRequest(request)) => self.handle_get_enclave_status_request(request),
Some(untrusted_message::Inner::GetQeInfoReply(reply)) =>
self.handle_get_qe_info_reply(reply),
Some(untrusted_message::Inner::GetQuoteReply(reply)) =>
self.handle_get_quote_reply(reply),
Some(untrusted_message::Inner::GetAttestationReply(reply)) =>
self.handle_get_attestation_reply(reply),
Some(untrusted_message::Inner::GetQeInfoReply(reply)) => self.handle_get_qe_info_reply(reply),
Some(untrusted_message::Inner::GetQuoteReply(reply)) => self.handle_get_quote_reply(reply),
Some(untrusted_message::Inner::GetAttestationReply(reply)) => self.handle_get_attestation_reply(reply),
Some(untrusted_message::Inner::NewMessageSignal(signal)) =>
self.handle_new_message_signal(signal),
Some(untrusted_message::Inner::TimerTickSignal(signal)) =>
self.handle_timer_tick_signal(signal),
Some(untrusted_message::Inner::SetFrontendConfigSignal(_)) =>
(),
Some(untrusted_message::Inner::SetReplicaConfigSignal(signal)) =>
self.handle_set_replica_config_signal(signal),
Some(untrusted_message::Inner::ResetPeerSignal(signal)) =>
self.handle_reset_peer_signal(signal),
Some(untrusted_message::Inner::SetVerboseLoggingSignal(signal)) =>
self.handle_set_verbose_logging_signal(signal),
Some(untrusted_message::Inner::NewMessageSignal(signal)) => self.handle_new_message_signal(signal),
Some(untrusted_message::Inner::TimerTickSignal(signal)) => self.handle_timer_tick_signal(signal),
Some(untrusted_message::Inner::SetFrontendConfigSignal(_)) => (),
Some(untrusted_message::Inner::SetReplicaConfigSignal(signal)) => self.handle_set_replica_config_signal(signal),
Some(untrusted_message::Inner::ResetPeerSignal(signal)) => self.handle_reset_peer_signal(signal),
Some(untrusted_message::Inner::SetVerboseLoggingSignal(signal)) => self.handle_set_verbose_logging_signal(signal),
None => (),
}
}
fn handle_start_replica_group_request(&mut self, request: StartReplicaGroupRequest) {
let group_id = generate_group_id();
let group_id = generate_group_id();
let service_id = if request.source_partition.is_some() {
None
} else {
@ -160,7 +147,7 @@ impl ReplicaState {
service_id,
group_id,
node_ids,
config: request.config,
config: request.config,
source_partition: request.source_partition,
};
let _ignore = self.create_raft_group(create_group_request);
@ -196,16 +183,12 @@ impl ReplicaState {
Some(XferControlCommand::Pause) => {
info!("requesting pause of partitioning process");
let pause_xfer_txn = TransactionData {
inner: Some(transaction_data::Inner::PauseXfer(PauseXferTransaction {
request_id,
})),
inner: Some(transaction_data::Inner::PauseXfer(PauseXferTransaction { request_id })),
};
self.request_transaction(pause_xfer_txn);
None
}
Some(XferControlCommand::Resume) => {
self.resume_partitioning(request_id).err()
}
Some(XferControlCommand::Resume) => self.resume_partitioning(request_id).err(),
Some(XferControlCommand::Finish) => {
info!("requesting finish of partitioning process");
let finish_xfer_txn = TransactionData {
@ -253,10 +236,8 @@ impl ReplicaState {
warn!("Tried to start partitioning as a non-destination replica!");
return Err(UntrustedXferReplyStatus::InvalidState);
}
}
;
let node_ids: Vec<_> =
(partition.group.raft.peers().iter())
};
let node_ids: Vec<_> = (partition.group.raft.peers().iter())
.map(|node_id| node_id.to_vec())
.chain(std::iter::once(self.peers.our_node_id().to_vec()))
.collect();
@ -265,10 +246,10 @@ impl ReplicaState {
info!("requesting xfer of range {} chunk size {}", xfer_source.desired_range(), chunk_size);
let request = PendingXferRequest {
id: PendingXferRequestId::XferRequest,
message: Rc::new(ReplicaToReplicaMessage {
id: PendingXferRequestId::XferRequest,
message: Rc::new(ReplicaToReplicaMessage {
inner: Some(replica_to_replica_message::Inner::XferRequest(XferRequest {
group_id: partition.group.id().clone(),
group_id: partition.group.id().clone(),
chunk_size,
full_range: xfer_source.desired_range().to_pb(),
node_ids,
@ -315,49 +296,47 @@ impl ReplicaState {
let partition = if let Some(partition) = &self.partition {
let mut peers = Vec::new();
for node_id in partition.group.raft.peers() {
let attestation = partition.group.get(node_id).and_then(|peer| peer.attestation());
let replication_status = partition.group.raft.replication_state(node_id).map(|replication: &ReplicationState| {
EnclavePeerReplicationStatus {
let attestation = partition.group.get(node_id).and_then(|peer| peer.attestation());
let replication_status = partition
.group
.raft
.replication_state(node_id)
.map(|replication: &ReplicationState| EnclavePeerReplicationStatus {
next_index: replication.next_idx.id,
match_index: replication.match_idx.id,
inflight_index: replication.inflight.map(|inflight_log_idx: LogIdx| inflight_log_idx.id),
probing: replication.send_probe,
}
});
});
peers.push(EnclavePeerStatus {
node_id: node_id.to_vec(),
node_id: node_id.to_vec(),
attestation,
replication_status,
is_leader: partition.group.raft.leader().0 == Some(node_id),
unsent_requests: Default::default(),
is_leader: partition.group.raft.leader().0 == Some(node_id),
unsent_requests: Default::default(),
inflight_requests: Default::default(),
});
}
Some(EnclaveReplicaPartitionStatus {
group_id: partition.group.id().id.clone(),
service_id: partition.data.service_id().map(|service_id| service_id.id.clone()),
range: partition.data.range().map(PartitionKeyRange::to_pb),
group_id: partition.group.id().id.clone(),
service_id: partition.data.service_id().map(|service_id| service_id.id.clone()),
range: partition.data.range().map(PartitionKeyRange::to_pb),
peers,
min_attestation: partition.group.attestation(),
is_leader: partition.group.raft.is_leader(),
current_term: partition.group.raft.leader().1.id,
prev_log_index: partition.group.raft.log().prev_idx().id,
min_attestation: partition.group.attestation(),
is_leader: partition.group.raft.is_leader(),
current_term: partition.group.raft.leader().1.id,
prev_log_index: partition.group.raft.log().prev_idx().id,
last_applied_index: partition.group.raft.last_applied().id,
commit_index: partition.group.raft.commit_idx().id,
last_log_index: partition.group.raft.log().last_idx().id,
last_log_term: partition.group.raft.log().last_term().id,
log_data_length: partition.group.raft.log().data_len().to_u64(),
backup_count: partition.data.storage_len().to_u64(),
xfer_status: partition.data.xfer_status(),
commit_index: partition.group.raft.commit_idx().id,
last_log_index: partition.group.raft.log().last_idx().id,
last_log_term: partition.group.raft.log().last_term().id,
log_data_length: partition.group.raft.log().data_len().to_u64(),
backup_count: partition.data.storage_len().to_u64(),
xfer_status: partition.data.xfer_status(),
})
} else {
None
};
let memory_status = if request.memory_status {
Some(memory_status())
} else {
None
};
let memory_status = if request.memory_status { Some(memory_status()) } else { None };
kbupd_send(EnclaveMessage {
inner: Some(enclave_message::Inner::GetEnclaveStatusReply(GetEnclaveStatusReply {
inner: Some(get_enclave_status_reply::Inner::ReplicaStatus(EnclaveReplicaStatus {
@ -380,7 +359,7 @@ impl ReplicaState {
match self.peers.get_attestation_reply(reply) {
Some((peer @ PeerState::Replica { .. }, attestation)) => {
let peer_node_id = peer.remote_mut().id().clone();
let _ignore = peer.authorize();
let _ignore = peer.authorize();
self.replica_authorized(attestation, peer_node_id);
}
Some((PeerState::Frontend { .. }, _attestation)) => {
@ -399,10 +378,12 @@ impl ReplicaState {
if let Some(replica) = partition.group.get(&node_id) {
if partition.group.raft.peers().contains(&node_id) {
let create_raft_group_req = Rc::new(ReplicaToReplicaMessage {
inner: Some(replica_to_replica_message::Inner::CreateRaftGroupRequest(partition.create_group_request.clone())),
inner: Some(replica_to_replica_message::Inner::CreateRaftGroupRequest(
partition.create_group_request.clone(),
)),
});
match replica.sender.send(create_raft_group_req) {
Ok(()) => (),
Ok(()) => (),
Err(()) => {
error!("error sending raft group to {}", &node_id);
}
@ -421,7 +402,8 @@ impl ReplicaState {
fn handle_timer_tick_signal(&mut self, signal: TimerTickSignal) {
let now = Duration::from_secs(signal.now_secs);
self.peers.timer_tick(self.config.min_connect_timeout_ticks, self.config.max_connect_timeout_ticks);
self.peers
.timer_tick(self.config.min_connect_timeout_ticks, self.config.max_connect_timeout_ticks);
let remote_group_timeout_ticks = self.config.election_timeout_ticks.saturating_mul(2);
@ -430,7 +412,10 @@ impl ReplicaState {
remote_group.timer_tick(remote_group_timeout_ticks, self.config.request_quote_ticks);
}
if let Some(txn_data) = partition.group.timer_tick(self.config.attestation_expiry_ticks, self.config.request_quote_ticks, now) {
if let Some(txn_data) = partition
.group
.timer_tick(self.config.attestation_expiry_ticks, self.config.request_quote_ticks, now)
{
let mut encoded_txn_data = Vec::with_capacity(txn_data.encoded_len());
if let Ok(()) = txn_data.encode(&mut encoded_txn_data) {
let _ignore = partition.group.raft.client_request(encoded_txn_data);
@ -463,31 +448,33 @@ impl ReplicaState {
self.replica_message(message, from_node_id);
}
Ok(None) => (),
Err(peer_entry) => {
match NodeType::from_i32(peer_entry.connect_request().node_type) {
Some(NodeType::Frontend) => {
info!("accepted frontend connection from {}", &peer_entry.node_id());
let frontends = &mut self.frontends;
let _ignore = peer_entry.accept(|remote| {
PeerState::Frontend {
lru_entry: frontends.push_back(remote.id().clone()),
remote,
}
}, RemoteAuthorizationType::SelfOnly);
if self.frontends.len() > self.config.max_frontend_count.to_usize() {
self.pop_frontend();
}
}
Some(NodeType::Replica) => {
info!("accepting replica connection from {}", peer_entry.node_id());
let _ignore = peer_entry.accept(PeerState::new_replica, RemoteAuthorizationType::Mutual);
}
None | Some(NodeType::None) => {
warn!("bad node type in connect request from {}: {}",
peer_entry.node_id(), peer_entry.connect_request().node_type);
Err(peer_entry) => match NodeType::from_i32(peer_entry.connect_request().node_type) {
Some(NodeType::Frontend) => {
info!("accepted frontend connection from {}", &peer_entry.node_id());
let frontends = &mut self.frontends;
let _ignore = peer_entry.accept(
|remote| PeerState::Frontend {
lru_entry: frontends.push_back(remote.id().clone()),
remote,
},
RemoteAuthorizationType::SelfOnly,
);
if self.frontends.len() > self.config.max_frontend_count.to_usize() {
self.pop_frontend();
}
}
}
Some(NodeType::Replica) => {
info!("accepting replica connection from {}", peer_entry.node_id());
let _ignore = peer_entry.accept(PeerState::new_replica, RemoteAuthorizationType::Mutual);
}
None | Some(NodeType::None) => {
warn!(
"bad node type in connect request from {}: {}",
peer_entry.node_id(),
peer_entry.connect_request().node_type
);
}
},
}
}
@ -537,25 +524,20 @@ impl ReplicaState {
fn replica_message(&mut self, replica_msg: ReplicaToReplicaMessage, from: NodeId) {
match replica_msg.inner {
Some(replica_to_replica_message::Inner::RaftMessage(raft_msg)) =>
self.handle_raft_message(raft_msg, from),
Some(replica_to_replica_message::Inner::CreateRaftGroupRequest(request)) =>
self.handle_create_raft_group_request(request, from),
Some(replica_to_replica_message::Inner::EnclaveGetQuoteRequest(request)) =>
self.handle_enclave_get_quote_request(request, from),
Some(replica_to_replica_message::Inner::EnclaveGetQuoteReply(reply)) =>
self.handle_enclave_get_quote_reply(reply, from),
Some(replica_to_replica_message::Inner::RaftMessage(raft_msg)) => self.handle_raft_message(raft_msg, from),
Some(replica_to_replica_message::Inner::CreateRaftGroupRequest(request)) => {
self.handle_create_raft_group_request(request, from)
}
Some(replica_to_replica_message::Inner::EnclaveGetQuoteRequest(request)) => {
self.handle_enclave_get_quote_request(request, from)
}
Some(replica_to_replica_message::Inner::EnclaveGetQuoteReply(reply)) => self.handle_enclave_get_quote_reply(reply, from),
Some(replica_to_replica_message::Inner::XferRequest(request)) =>
self.handle_xfer_request(request, from),
Some(replica_to_replica_message::Inner::XferReply(reply)) =>
self.handle_xfer_reply(reply, from),
Some(replica_to_replica_message::Inner::XferChunkRequest(request)) =>
self.handle_xfer_chunk_request(request, from),
Some(replica_to_replica_message::Inner::XferChunkReply(reply)) =>
self.handle_xfer_chunk_reply(reply, from),
Some(replica_to_replica_message::Inner::XferErrorNotLeader(xfer_error)) =>
self.handle_xfer_error_not_leader(xfer_error, from),
Some(replica_to_replica_message::Inner::XferRequest(request)) => self.handle_xfer_request(request, from),
Some(replica_to_replica_message::Inner::XferReply(reply)) => self.handle_xfer_reply(reply, from),
Some(replica_to_replica_message::Inner::XferChunkRequest(request)) => self.handle_xfer_chunk_request(request, from),
Some(replica_to_replica_message::Inner::XferChunkReply(reply)) => self.handle_xfer_chunk_reply(reply, from),
Some(replica_to_replica_message::Inner::XferErrorNotLeader(xfer_error)) => self.handle_xfer_error_not_leader(xfer_error, from),
None => (),
}
}
@ -579,13 +561,22 @@ impl ReplicaState {
}
fn create_raft_group(&mut self, create_group_request: CreateRaftGroupRequest) -> Result<(), ()> {
let CreateRaftGroupRequest { group_id, service_id, node_ids, config, source_partition } = create_group_request.clone();
let CreateRaftGroupRequest {
group_id,
service_id,
node_ids,
config,
source_partition,
} = create_group_request.clone();
let node_ids: BTreeSet<NodeId> = node_ids.into_iter().map(|node_id| node_id.into()).collect();
if node_ids.contains(self.node_id()) {
if let Some(partition) = &self.partition {
if partition.group.id() != &group_id {
warn!("tried to start raft group {} on replica already containing partition {}",
&group_id, partition.group.id());
warn!(
"tried to start raft group {} on replica already containing partition {}",
&group_id,
partition.group.id()
);
}
Err(())
} else {
@ -595,12 +586,29 @@ impl ReplicaState {
Some(PartitionKeyRange::new_unbounded())
};
info!("creating replica group {} service {} with range {} and nodes {}",
&group_id, OptionDisplay(service_id.as_ref()), OptionDisplay(range.as_ref()),
ListDisplay(node_ids.iter()));
info!(
"creating replica group {} service {} with range {} and nodes {}",
&group_id,
OptionDisplay(service_id.as_ref()),
OptionDisplay(range.as_ref()),
ListDisplay(node_ids.iter())
);
let raft_log = RaftLogStorage::new(config.raft_log_data_size.to_usize(), config.raft_log_index_size, self.config.raft_log_index_page_cache_size.to_usize())?;
let raft = RaftState::new(group_id, self.node_id().clone(), node_ids, raft_log, RdRand, self.config.election_timeout_ticks, self.config.heartbeat_timeout_ticks, self.config.replication_chunk_size.to_usize());
let raft_log = RaftLogStorage::new(
config.raft_log_data_size.to_usize(),
config.raft_log_index_size,
self.config.raft_log_index_page_cache_size.to_usize(),
)?;
let raft = RaftState::new(
group_id,
self.node_id().clone(),
node_ids,
raft_log,
RdRand,
self.config.election_timeout_ticks,
self.config.heartbeat_timeout_ticks,
self.config.replication_chunk_size.to_usize(),
);
let replica_group = self.connect_to_peers(raft)?;
let xfer_state = if let Some(source_partition) = source_partition {
@ -616,7 +624,7 @@ impl ReplicaState {
self.partition = Some(Partition {
group: replica_group,
data: PartitionData::new(partition_data_config, service_id, range, xfer_state),
data: PartitionData::new(partition_data_config, service_id, range, xfer_state),
create_group_request,
});
@ -629,12 +637,14 @@ impl ReplicaState {
}
}
fn connect_to_peers(&mut self, raft: RaftState<RaftLogStorage, RdRand, NodeId>)
-> Result<ReplicaGroupState, ()> {
fn connect_to_peers(&mut self, raft: RaftState<RaftLogStorage, RdRand, NodeId>) -> Result<ReplicaGroupState, ()> {
let mut remotes = Vec::new();
let our_node_id = self.peers.our_node_id().clone();
for peer_node_id in raft.peers().iter() {
match self.peers.start_peer(peer_node_id.clone(), NodeType::Replica, RemoteAuthorizationType::Mutual) {
match self
.peers
.start_peer(peer_node_id.clone(), NodeType::Replica, RemoteAuthorizationType::Mutual)
{
Ok(peer_entry) => {
let sender = peer_entry.remote().sender().clone();
if *peer_node_id < our_node_id {
@ -642,25 +652,21 @@ impl ReplicaState {
match peer_entry.connect(PeerState::new_replica) {
Ok(_peer) => (),
Err((peer_entry, _mapper)) => {
error!("aborting starting group due to error connecting to {}",
peer_entry.remote().id());
error!("aborting starting group due to error connecting to {}", peer_entry.remote().id());
return Err(());
}
}
} else {
peer_entry.insert(PeerState::new_replica);
}
remotes.push(RemoteReplicaState {
sender,
});
remotes.push(RemoteReplicaState { sender });
}
Err(Some(PeerState::Replica { remote, .. })) => {
remotes.push(RemoteReplicaState {
sender: remote.sender().clone(),
});
}
Err(Some(PeerState::Frontend { .. })) |
Err(None) => {
Err(Some(PeerState::Frontend { .. })) | Err(None) => {
error!("started group with {} when it's already connected as a frontend!", peer_node_id);
return Err(());
}
@ -673,7 +679,10 @@ impl ReplicaState {
let desired_range = match PartitionKeyRange::try_from_pb(&source_partition.range) {
Ok(desired_range) => desired_range,
Err(()) => {
error!("started replica group with source partition config containing invalid range: {}", &source_partition);
error!(
"started replica group with source partition config containing invalid range: {}",
&source_partition
);
return Err(());
}
};
@ -682,7 +691,10 @@ impl ReplicaState {
let source_node_id: NodeId = source_node_id_vec[..].into();
info!("connecting to source replica {}", &source_node_id);
let sender = match self.peers.start_peer(source_node_id.clone(), NodeType::Replica, RemoteAuthorizationType::Mutual) {
let sender = match self
.peers
.start_peer(source_node_id.clone(), NodeType::Replica, RemoteAuthorizationType::Mutual)
{
Ok(peer_entry) => {
let sender = peer_entry.remote().sender().clone();
match peer_entry.connect(PeerState::new_replica) {
@ -694,12 +706,12 @@ impl ReplicaState {
}
sender
}
Err(Some(PeerState::Replica { remote, .. })) => {
remote.sender().clone()
}
Err(Some(PeerState::Frontend { .. })) |
Err(None) => {
error!("source replica {} was already connected as a frontend!", NodeId::from(source_node_id_vec));
Err(Some(PeerState::Replica { remote, .. })) => remote.sender().clone(),
Err(Some(PeerState::Frontend { .. })) | Err(None) => {
error!(
"source replica {} was already connected as a frontend!",
NodeId::from(source_node_id_vec)
);
return Err(());
}
};
@ -725,13 +737,13 @@ impl ReplicaState {
fn handle_xfer_request(&mut self, xfer_request: XferRequest, from: NodeId) {
match &mut self.partition {
Some(_) => (),
None => {
None => {
warn!("received XferRequest from {} without having a partition: {}", &from, &xfer_request);
return;
}
}
match PartitionKeyRange::try_from_pb(&xfer_request.full_range) {
Ok(_) => (),
Ok(_) => (),
Err(()) => {
warn!("received XferRequest from {} with invalid range: {}", &from, &xfer_request);
return;
@ -844,7 +856,7 @@ impl ReplicaState {
}
fn handle_xfer_error_not_leader(&mut self, xfer_error_not_leader: XferErrorNotLeader, from: NodeId) {
let term: TermId = xfer_error_not_leader.term;
let term: TermId = xfer_error_not_leader.term;
let leader: Option<NodeId> = xfer_error_not_leader.leader_node_id.map(NodeId::from);
if let Some(partition) = &mut self.partition {
if let Some(remote_group) = partition.data.xfer_state_mut().remote_group_mut() {
@ -859,10 +871,10 @@ impl ReplicaState {
fn frontend_message(&mut self, msg: FrontendToReplicaMessage, from: NodeId) {
match msg.inner {
Some(frontend_to_replica_message::Inner::TransactionRequest(req)) =>
self.handle_transaction_request(req, from),
Some(frontend_to_replica_message::Inner::EnclaveGetQuoteRequest(request)) =>
self.handle_enclave_get_quote_request(request, from),
Some(frontend_to_replica_message::Inner::TransactionRequest(req)) => self.handle_transaction_request(req, from),
Some(frontend_to_replica_message::Inner::EnclaveGetQuoteRequest(request)) => {
self.handle_enclave_get_quote_request(request, from)
}
None => (),
}
}
@ -898,14 +910,17 @@ impl ReplicaState {
}
}
fn accept_transaction_request(&mut self, request_data: transaction_request::Data)
-> Result<frontend_request_transaction::Transaction, transaction_reply::Data> {
fn accept_transaction_request(
&mut self,
request_data: transaction_request::Data,
) -> Result<frontend_request_transaction::Transaction, transaction_reply::Data>
{
let partition = match &mut self.partition {
Some(partition) => partition,
None => {
None => {
return Err(transaction_reply::Data::NotLeader(TransactionErrorNotLeader {
leader_node_id: None,
term: Default::default(),
leader_node_id: None,
term: Default::default(),
}));
}
};
@ -938,18 +953,16 @@ impl ReplicaState {
}
}
transaction_request::Data::Backup(backup_request) => {
let min_attestation = AttestationParameters::new(Duration::from_secs(backup_request.valid_from));
let our_service_id = partition.data.service_id_bytes();
let min_attestation = AttestationParameters::new(Duration::from_secs(backup_request.valid_from));
let our_service_id = partition.data.service_id_bytes();
let request_service_id = backup_request.service_id.as_ref().map(|service_id: &Vec<u8>| &service_id[..]);
let request_nonce = Self::decode_transaction_request_nonce(backup_request.nonce)?;
let request_nonce = Self::decode_transaction_request_nonce(backup_request.nonce)?;
if (our_service_id.is_none() ||
(request_service_id.is_some() && request_service_id != our_service_id))
{
if (our_service_id.is_none() || (request_service_id.is_some() && request_service_id != our_service_id)) {
Err(transaction_reply::Data::ServiceIdMismatch(TransactionErrorServiceIdMismatch {}))
} else if min_attestation > partition.group.attestation() {
Err(transaction_reply::Data::ClientResponse(kbupd_client::Response {
backup: Some(kbupd_client::BackupResponse {
backup: Some(kbupd_client::BackupResponse {
status: Some(kbupd_client::backup_response::Status::NotYetValid.into()),
nonce: None,
}),
@ -969,14 +982,12 @@ impl ReplicaState {
}
}
transaction_request::Data::Restore(restore_request) => {
let min_attestation = AttestationParameters::new(Duration::from_secs(restore_request.valid_from));
let our_service_id = partition.data.service_id_bytes();
let min_attestation = AttestationParameters::new(Duration::from_secs(restore_request.valid_from));
let our_service_id = partition.data.service_id_bytes();
let request_service_id = restore_request.service_id.as_ref().map(|service_id: &Vec<u8>| &service_id[..]);
let request_nonce = Self::decode_transaction_request_nonce(restore_request.nonce)?;
let request_nonce = Self::decode_transaction_request_nonce(restore_request.nonce)?;
if (our_service_id.is_none() ||
(request_service_id.is_some() && request_service_id != our_service_id))
{
if (our_service_id.is_none() || (request_service_id.is_some() && request_service_id != our_service_id)) {
Err(transaction_reply::Data::ServiceIdMismatch(TransactionErrorServiceIdMismatch {}))
} else if min_attestation > partition.group.attestation() {
Err(transaction_reply::Data::ClientResponse(kbupd_client::Response {
@ -1000,11 +1011,12 @@ impl ReplicaState {
}
}
transaction_request::Data::Delete(delete_backup_request) => {
let our_service_id = partition.data.service_id_bytes();
let request_service_id = delete_backup_request.service_id.as_ref().map(|service_id: &Vec<u8>| &service_id[..]);
if (our_service_id.is_none() ||
(request_service_id.is_some() && request_service_id != our_service_id))
{
let our_service_id = partition.data.service_id_bytes();
let request_service_id = delete_backup_request
.service_id
.as_ref()
.map(|service_id: &Vec<u8>| &service_id[..]);
if (our_service_id.is_none() || (request_service_id.is_some() && request_service_id != our_service_id)) {
Err(transaction_reply::Data::ServiceIdMismatch(TransactionErrorServiceIdMismatch {}))
} else {
Ok(frontend_request_transaction::Transaction::Delete(DeleteBackupTransaction {
@ -1016,8 +1028,8 @@ impl ReplicaState {
}
fn decode_transaction_request_nonce(combined_nonce: Vec<u8>) -> Result<RequestNonce, transaction_reply::Data> {
let combined_nonce: &[u8; 32] = (&combined_nonce[..].try_into())
.map_err(|_| transaction_reply::Data::InvalidRequest(TransactionErrorInvalidRequest {}))?;
let combined_nonce: &[u8; 32] =
(&combined_nonce[..].try_into()).map_err(|_| transaction_reply::Data::InvalidRequest(TransactionErrorInvalidRequest {}))?;
Ok(RequestNonce::from_combined(*combined_nonce))
}
@ -1034,10 +1046,14 @@ impl ReplicaState {
let txn = match TransactionData::decode(&encoded_transaction.data[..]) {
Ok(transaction) => transaction,
Err(_) => panic!("error decoding committed raft transaction"),
Err(_) => panic!("error decoding committed raft transaction"),
};
let txn_info = if let Some(txn_inner) = txn.inner {
Some(partition.data.perform_transaction(txn_inner, &mut self.peers, &mut partition.group, is_leader))
Some(
partition
.data
.perform_transaction(txn_inner, &mut self.peers, &mut partition.group, is_leader),
)
} else {
None
};
@ -1090,7 +1106,7 @@ impl ReplicaState {
fn request_transaction(&mut self, transaction: TransactionData) {
if let Some(partition) = &mut self.partition {
let mut encoded_transaction = SecretValue::new(Vec::with_capacity(transaction.encoded_len()));
let mut encoded_transaction = SecretValue::new(Vec::with_capacity(transaction.encoded_len()));
let request_transaction_result = if let Ok(()) = transaction.encode(encoded_transaction.get_mut()) {
partition.group.raft.client_request(encoded_transaction.into_inner())
} else {
@ -1119,7 +1135,7 @@ impl ReplicaState {
if let Some(from) = self.peers.get_frontend(&from_node_id) {
let transaction_error = match transaction_error {
Some(transaction_error) => transaction_error,
None => transaction_reply::Data::NotLeader(TransactionErrorNotLeader {
None => transaction_reply::Data::NotLeader(TransactionErrorNotLeader {
leader_node_id: leader.and_then(|leader| leader.0).map(|leader| leader.to_vec()),
term: leader.map(|leader| leader.1).cloned().unwrap_or_default(),
}),
@ -1155,10 +1171,8 @@ impl ReplicaState {
info!("cannot finish partitioning process on non-leader");
send_untrusted_xfer_reply(request.request_id, UntrustedXferReplyStatus::NotLeader);
}
Some(transaction_data::Inner::SetTime(_request)) => {
}
None => {
}
Some(transaction_data::Inner::SetTime(_request)) => {}
None => {}
}
}
@ -1181,7 +1195,9 @@ impl ReplicaState {
//
fn generate_group_id() -> RaftGroupId {
RaftGroupId { id: RdRand.rand_bytes(vec![0; 32]) }
RaftGroupId {
id: RdRand.rand_bytes(vec![0; 32]),
}
}
//
@ -1189,7 +1205,9 @@ fn generate_group_id() -> RaftGroupId {
//
fn generate_service_id() -> ServiceId {
ServiceId { id: RdRand.rand_bytes(vec![0; 32]) }
ServiceId {
id: RdRand.rand_bytes(vec![0; 32]),
}
}
//
@ -1199,18 +1217,14 @@ fn generate_service_id() -> ServiceId {
impl PeerState {
fn new_replica(remote: RemoteState<ReplicaToReplicaMessage, ReplicaToReplicaMessage>) -> Self {
let authorized = remote.attestation().is_some();
PeerState::Replica {
remote,
authorized,
}
PeerState::Replica { remote, authorized }
}
#[must_use]
fn authorize(&mut self) -> Option<AttestationParameters> {
match self {
PeerState::Frontend { .. } => {
None
}
PeerState::Replica { remote, authorized } => {
PeerState::Frontend { .. } => None,
PeerState::Replica { remote, authorized } => {
if !*authorized {
let maybe_attestation = remote.attestation();
if maybe_attestation.is_some() {
@ -1227,34 +1241,29 @@ impl PeerState {
impl Peer for PeerState {
type Message = PeerMessage;
fn remote_mut(&mut self) -> &mut dyn Remote {
match self {
PeerState::Frontend { remote, .. } => remote,
PeerState::Replica { remote, .. } => remote,
PeerState::Replica { remote, .. } => remote,
}
}
fn recv(&mut self, msg_data: &[u8]) -> Result<PeerMessage, RemoteRecvError> {
match self {
PeerState::Frontend { remote, .. } => {
remote.recv(msg_data).map(PeerMessage::Frontend)
}
PeerState::Replica { remote, .. } => {
remote.recv(msg_data).map(PeerMessage::Replica)
}
PeerState::Frontend { remote, .. } => remote.recv(msg_data).map(PeerMessage::Frontend),
PeerState::Replica { remote, .. } => remote.recv(msg_data).map(PeerMessage::Replica),
}
}
fn send_quote_reply(&mut self, reply: EnclaveGetQuoteReply) -> Result<(), ()> {
match self {
PeerState::Frontend { remote, .. } => {
remote.send(Rc::new(ReplicaToFrontendMessage {
inner: Some(replica_to_frontend_message::Inner::EnclaveGetQuoteReply(reply))
}))
}
PeerState::Replica { remote, .. } => {
remote.send(Rc::new(ReplicaToReplicaMessage {
inner: Some(replica_to_replica_message::Inner::EnclaveGetQuoteReply(reply))
}))
}
PeerState::Frontend { remote, .. } => remote.send(Rc::new(ReplicaToFrontendMessage {
inner: Some(replica_to_frontend_message::Inner::EnclaveGetQuoteReply(reply)),
})),
PeerState::Replica { remote, .. } => remote.send(Rc::new(ReplicaToReplicaMessage {
inner: Some(replica_to_replica_message::Inner::EnclaveGetQuoteReply(reply)),
})),
}
}
}
@ -1288,10 +1297,11 @@ fn generate_nonce_16() -> Vec<u8> {
RdRand.rand_bytes(vec![0; 16])
}
fn send_transaction_reply(from: &mut dyn RemoteMessageSender<Message = ReplicaToFrontendMessage>,
request_id: u64,
data: transaction_reply::Data)
fn send_transaction_reply(
from: &mut dyn RemoteMessageSender<Message = ReplicaToFrontendMessage>,
request_id: u64,
data: transaction_reply::Data,
)
{
let _ignore = from.send(Rc::new(ReplicaToFrontendMessage {
inner: Some(replica_to_frontend_message::Inner::TransactionReply(TransactionReply {
@ -1317,8 +1327,8 @@ fn send_untrusted_xfer_reply(request_id: u64, status: UntrustedXferReplyStatus)
#[cfg(test)]
mod tests {
use super::*;
use crate::ffi::mocks;
use crate::ffi::ecalls;
use crate::ffi::mocks;
use mockers::*;
fn init(start_replica_req: StartReplicaRequest) -> ReplicaState {
@ -1327,19 +1337,16 @@ mod tests {
fn valid_range() -> PartitionKeyRangePb {
PartitionKeyRangePb {
first: BackupId { id: vec![0x00; 32] },
last: BackupId { id: vec![0xFF; 32] },
first: BackupId { id: vec![0x00; 32] },
last: BackupId { id: vec![0xFF; 32] },
}
}
#[test]
fn init_test() {
let scenario = Scenario::new();
let expected_enclave_messages: Vec<Box<dyn MatchArg<_>>> = vec![
Box::new(arg!(enclave_message::Inner::StartReplicaReply(StartReplicaReply {
..
})))
];
let expected_enclave_messages: Vec<Box<dyn MatchArg<_>>> =
vec![Box::new(arg!(enclave_message::Inner::StartReplicaReply(StartReplicaReply { .. })))];
mocks::expect_enclave_messages(&scenario, expected_enclave_messages);
init(StartReplicaRequest {
config: Default::default(),
@ -1351,14 +1358,14 @@ mod tests {
fn start_new_group_no_peers() {
let scenario = Scenario::new();
let expected_enclave_messages: Vec<Box<dyn MatchArg<_>>> = vec![
Box::new(arg!(enclave_message::Inner::StartReplicaReply(StartReplicaReply {
..
}))),
Box::new(arg!(enclave_message::Inner::StartReplicaGroupReply(StartReplicaGroupReply {
service_id: Some(_),
group_id: Some(_),
..
})))
Box::new(arg!(enclave_message::Inner::StartReplicaReply(StartReplicaReply { .. }))),
Box::new(arg!(
enclave_message::Inner::StartReplicaGroupReply(StartReplicaGroupReply {
service_id: Some(_),
group_id: Some(_),
..
})
)),
];
mocks::expect_enclave_messages(&scenario, expected_enclave_messages);
let mut state = init(StartReplicaRequest {
@ -1376,17 +1383,15 @@ mod tests {
fn start_new_group_with_peer() {
let scenario = Scenario::new();
let expected_enclave_messages: Vec<Box<dyn MatchArg<_>>> = vec![
Box::new(arg!(enclave_message::Inner::StartReplicaReply(StartReplicaReply {
..
}))),
Box::new(arg!(enclave_message::Inner::StartReplicaGroupReply(StartReplicaGroupReply {
service_id: Some(_),
group_id: Some(_),
..
}))),
Box::new(arg!(enclave_message::Inner::GetQeInfoRequest(GetQeInfoRequest {
..
}))),
Box::new(arg!(enclave_message::Inner::StartReplicaReply(StartReplicaReply { .. }))),
Box::new(arg!(
enclave_message::Inner::StartReplicaGroupReply(StartReplicaGroupReply {
service_id: Some(_),
group_id: Some(_),
..
})
)),
Box::new(arg!(enclave_message::Inner::GetQeInfoRequest(GetQeInfoRequest { .. }))),
];
mocks::expect_enclave_messages(&scenario, expected_enclave_messages);
let mut state = init(StartReplicaRequest {
@ -1405,17 +1410,15 @@ mod tests {
fn start_new_xfer_group_no_peers() {
let scenario = Scenario::new();
let expected_enclave_messages: Vec<Box<dyn MatchArg<_>>> = vec![
Box::new(arg!(enclave_message::Inner::StartReplicaReply(StartReplicaReply {
..
}))),
Box::new(arg!(enclave_message::Inner::StartReplicaGroupReply(StartReplicaGroupReply {
service_id: None,
group_id: Some(_),
..
}))),
Box::new(arg!(enclave_message::Inner::GetQeInfoRequest(GetQeInfoRequest {
..
}))),
Box::new(arg!(enclave_message::Inner::StartReplicaReply(StartReplicaReply { .. }))),
Box::new(arg!(
enclave_message::Inner::StartReplicaGroupReply(StartReplicaGroupReply {
service_id: None,
group_id: Some(_),
..
})
)),
Box::new(arg!(enclave_message::Inner::GetQeInfoRequest(GetQeInfoRequest { .. }))),
];
mocks::expect_enclave_messages(&scenario, expected_enclave_messages);
let mut state = init(StartReplicaRequest {
@ -1424,7 +1427,7 @@ mod tests {
state.handle_start_replica_group_request(StartReplicaGroupRequest {
peer_node_ids: vec![vec![0; 32]],
source_partition: Some(SourcePartitionConfig {
range: valid_range(),
range: valid_range(),
node_ids: vec![],
}),
config: Default::default(),
@ -1437,17 +1440,15 @@ mod tests {
fn start_new_xfer_group_with_peer() {
let scenario = Scenario::new();
let expected_enclave_messages: Vec<Box<dyn MatchArg<_>>> = vec![
Box::new(arg!(enclave_message::Inner::StartReplicaReply(StartReplicaReply {
..
}))),
Box::new(arg!(enclave_message::Inner::GetQeInfoRequest(GetQeInfoRequest {
..
}))),
Box::new(arg!(enclave_message::Inner::StartReplicaGroupReply(StartReplicaGroupReply {
service_id: None,
group_id: Some(_),
..
})))
Box::new(arg!(enclave_message::Inner::StartReplicaReply(StartReplicaReply { .. }))),
Box::new(arg!(enclave_message::Inner::GetQeInfoRequest(GetQeInfoRequest { .. }))),
Box::new(arg!(
enclave_message::Inner::StartReplicaGroupReply(StartReplicaGroupReply {
service_id: None,
group_id: Some(_),
..
})
)),
];
mocks::expect_enclave_messages(&scenario, expected_enclave_messages);
let mut state = init(StartReplicaRequest {
@ -1456,7 +1457,7 @@ mod tests {
state.handle_start_replica_group_request(StartReplicaGroupRequest {
peer_node_ids: vec![vec![0; 32]],
source_partition: Some(SourcePartitionConfig {
range: valid_range(),
range: valid_range(),
node_ids: vec![vec![0; 32]],
}),
config: Default::default(),

View File

@ -9,26 +9,26 @@ mod backup_entry;
use crate::prelude::*;
use std::convert::{TryInto};
use std::collections::*;
use std::cmp::*;
use std::collections::*;
use std::convert::TryInto;
use std::mem;
use std::num::*;
use std::time::*;
use std::rc::*;
use std::time::*;
use bytes::*;
use num_traits::{ToPrimitive};
use num_traits::ToPrimitive;
use self::backup_entry::BackupEntrySecrets;
use super::*;
use crate::protobufs::kbupd::*;
use crate::protobufs::kbupd_enclave::*;
use crate::protobufs::kbupd_client;
use crate::protobufs::kbupd_enclave::*;
use crate::protobufs::raft::*;
use crate::remote::*;
use crate::remote_group::*;
use crate::util::*;
use self::backup_entry::BackupEntrySecrets;
use super::*;
pub(super) struct PartitionData {
storage: BTreeMap<PartitionKey, BackupEntry>,
@ -51,20 +51,20 @@ pub(super) enum XferState {
}
pub(super) struct XferSource {
remote_group: RemoteGroupState<ReplicaRemoteSender, PendingXferRequest>,
desired_range: PartitionKeyRange,
remote_group: RemoteGroupState<ReplicaRemoteSender, PendingXferRequest>,
desired_range: PartitionKeyRange,
cur_xfer_chunk_reply: Option<PendingXferRequest>,
}
pub(super) struct XferDestination {
remote_group_id: RaftGroupId,
remote_group: RemoteGroupState<ReplicaRemoteSender, PendingXferRequest>,
full_range: PartitionKeyRange,
chunk_size: u32,
remote_group_id: RaftGroupId,
remote_group: RemoteGroupState<ReplicaRemoteSender, PendingXferRequest>,
full_range: PartitionKeyRange,
chunk_size: u32,
paused: bool,
inflight: Option<PartitionKeyRange>,
paused: bool,
inflight: Option<PartitionKeyRange>,
cur_xfer_reply: Option<PendingXferRequest>,
cur_xfer_chunk_request: Option<PendingXferRequest>,
@ -74,12 +74,8 @@ pub(super) struct XferDestination {
pub(super) enum PendingXferRequestId {
XferRequest,
XferReply,
XferChunkRequest {
new_last: BackupId,
},
XferChunkReply {
new_last: BackupId,
},
XferChunkRequest { new_last: BackupId },
XferChunkReply { new_last: BackupId },
}
#[derive(Clone)]
@ -110,11 +106,12 @@ pub(super) enum FrontendRequestError {
}
impl PartitionData {
pub fn new(config: PartitionDataConfig,
service_id: Option<ServiceId>,
range: Option<PartitionKeyRange>,
xfer_state: XferState)
-> Self
pub fn new(
config: PartitionDataConfig,
service_id: Option<ServiceId>,
range: Option<PartitionKeyRange>,
xfer_state: XferState,
) -> Self
{
Self {
storage: Default::default(),
@ -151,22 +148,25 @@ impl PartitionData {
pub fn xfer_status(&self) -> Option<enclave_replica_partition_status::XferStatus> {
match &self.xfer_state {
XferState::SourcePartition(xfer_destination) => {
Some(enclave_replica_partition_status::XferStatus::OutgoingXferStatus(EnclaveOutgoingXferStatus {
XferState::SourcePartition(xfer_destination) => Some(enclave_replica_partition_status::XferStatus::OutgoingXferStatus(
EnclaveOutgoingXferStatus {
group_id: xfer_destination.remote_group_id.id.clone(),
full_xfer_range: xfer_destination.full_range.to_pb(),
current_chunk_range: xfer_destination.inflight.as_ref().map(PartitionKeyRange::to_pb),
paused: xfer_destination.paused,
min_attestation: xfer_destination.cur_xfer_chunk_request.as_ref().and_then(PendingXferRequest::min_attestation),
min_attestation: xfer_destination
.cur_xfer_chunk_request
.as_ref()
.and_then(PendingXferRequest::min_attestation),
nodes: xfer_destination.remote_group.status(),
}))
}
XferState::DestinationPartition(xfer_source) => {
Some(enclave_replica_partition_status::XferStatus::IncomingXferStatus(EnclaveIncomingXferStatus {
},
)),
XferState::DestinationPartition(xfer_source) => Some(enclave_replica_partition_status::XferStatus::IncomingXferStatus(
EnclaveIncomingXferStatus {
desired_range: xfer_source.desired_range.to_pb(),
nodes: xfer_source.remote_group.status(),
}))
}
},
)),
XferState::None => None,
}
}
@ -198,60 +198,51 @@ impl PartitionData {
}
}
pub fn perform_transaction(&mut self,
txn: transaction_data::Inner,
peers: &mut PeerManager<PeerState>,
group: &mut ReplicaGroupState,
is_leader: bool)
-> enclave_transaction_signal::Transaction
pub fn perform_transaction(
&mut self,
txn: transaction_data::Inner,
peers: &mut PeerManager<PeerState>,
group: &mut ReplicaGroupState,
is_leader: bool,
) -> enclave_transaction_signal::Transaction
{
match txn {
transaction_data::Inner::FrontendRequest(txn) =>
enclave_transaction_signal::Transaction::FrontendRequest(
self.perform_client_transaction(txn, peers, is_leader)
),
transaction_data::Inner::StartXfer(txn) =>
enclave_transaction_signal::Transaction::StartXfer(
self.perform_start_xfer_transaction(txn, peers)
),
transaction_data::Inner::SetSid(txn) =>
enclave_transaction_signal::Transaction::SetSid(
self.perform_set_sid_transaction(txn)
),
transaction_data::Inner::RemoveChunk(txn) =>
enclave_transaction_signal::Transaction::RemoveChunk(
self.perform_remove_chunk_transaction(txn, group)
),
transaction_data::Inner::ApplyChunk(txn) =>
enclave_transaction_signal::Transaction::ApplyChunk(
self.perform_apply_chunk_transaction(txn, group)
),
transaction_data::Inner::PauseXfer(txn) =>
enclave_transaction_signal::Transaction::PauseXfer(
self.perform_pause_xfer_transaction(txn)
),
transaction_data::Inner::ResumeXfer(txn) =>
enclave_transaction_signal::Transaction::ResumeXfer(
self.perform_resume_xfer_transaction(txn, group)
),
transaction_data::Inner::FinishXfer(txn) =>
enclave_transaction_signal::Transaction::FinishXfer(
self.perform_finish_xfer_transaction(txn)
),
transaction_data::Inner::SetTime(txn) =>
enclave_transaction_signal::Transaction::SetTime(
self.perform_set_time_transaction(txn, group)
),
transaction_data::Inner::FrontendRequest(txn) => {
enclave_transaction_signal::Transaction::FrontendRequest(self.perform_client_transaction(txn, peers, is_leader))
}
transaction_data::Inner::StartXfer(txn) => {
enclave_transaction_signal::Transaction::StartXfer(self.perform_start_xfer_transaction(txn, peers))
}
transaction_data::Inner::SetSid(txn) => enclave_transaction_signal::Transaction::SetSid(self.perform_set_sid_transaction(txn)),
transaction_data::Inner::RemoveChunk(txn) => {
enclave_transaction_signal::Transaction::RemoveChunk(self.perform_remove_chunk_transaction(txn, group))
}
transaction_data::Inner::ApplyChunk(txn) => {
enclave_transaction_signal::Transaction::ApplyChunk(self.perform_apply_chunk_transaction(txn, group))
}
transaction_data::Inner::PauseXfer(txn) => {
enclave_transaction_signal::Transaction::PauseXfer(self.perform_pause_xfer_transaction(txn))
}
transaction_data::Inner::ResumeXfer(txn) => {
enclave_transaction_signal::Transaction::ResumeXfer(self.perform_resume_xfer_transaction(txn, group))
}
transaction_data::Inner::FinishXfer(txn) => {
enclave_transaction_signal::Transaction::FinishXfer(self.perform_finish_xfer_transaction(txn))
}
transaction_data::Inner::SetTime(txn) => {
enclave_transaction_signal::Transaction::SetTime(self.perform_set_time_transaction(txn, group))
}
}
}
fn perform_client_transaction(&mut self,
txn: FrontendRequestTransaction,
peers: &mut PeerManager<PeerState>,
is_leader: bool)
-> EnclaveFrontendRequestTransaction
fn perform_client_transaction(
&mut self,
txn: FrontendRequestTransaction,
peers: &mut PeerManager<PeerState>,
is_leader: bool,
) -> EnclaveFrontendRequestTransaction
{
let backup_id = txn.backup_id().cloned();
let backup_id = txn.backup_id().cloned();
let txn_reply_data = if let Some(backup_id) = txn.backup_id() {
if self.check_xfer_in_progress(backup_id) {
transaction_reply::Data::XferInProgress(TransactionErrorXferInProgress {})
@ -260,8 +251,11 @@ impl PartitionData {
if let (Some(our_authoritative_range), Some(inflight)) = (&self.range, &xfer_destination.inflight) {
match PartitionKeyRange::new(*inflight.first(), *our_authoritative_range.last()) {
Ok(our_range) => Some(our_range),
Err(()) => {
error!("our authoritative range {} is less than inflight range {}!", our_authoritative_range, inflight);
Err(()) => {
error!(
"our authoritative range {} is less than inflight range {}!",
our_authoritative_range, inflight
);
None
}
}
@ -281,7 +275,7 @@ impl PartitionData {
if let Some(other_last) = our_range.first().checked_sub(1) {
match PartitionKeyRange::new(*xfer_destination.full_range.first(), other_last) {
Ok(other_range) => Some(other_range.to_pb()),
Err(()) => None,
Err(()) => None,
}
} else {
None
@ -292,7 +286,12 @@ impl PartitionData {
Some(PartitionConfig {
group_id: xfer_destination.remote_group_id.id.clone(),
range: other_range,
node_ids: xfer_destination.remote_group.get_remotes().into_iter().map(|r| r.to_vec()).collect(),
node_ids: xfer_destination
.remote_group
.get_remotes()
.into_iter()
.map(|r| r.to_vec())
.collect(),
})
} else {
info!("Wrong partition for {}, but we don't know the right one!", backup_id);
@ -300,16 +299,14 @@ impl PartitionData {
};
transaction_reply::Data::WrongPartition(TransactionErrorWrongPartition {
range: maybe_our_range.map(|r| r.to_pb()),
range: maybe_our_range.map(|r| r.to_pb()),
new_partition,
})
} else if let Some(txn_data) = txn.transaction {
match self.perform_frontend_request(txn_data) {
Ok(response_data) => response_data,
Err(FrontendRequestError::InvalidRequest) =>
transaction_reply::Data::InvalidRequest(TransactionErrorInvalidRequest {}),
Err(FrontendRequestError::StorageFull) =>
transaction_reply::Data::InternalError(TransactionErrorInternalError {}),
Err(FrontendRequestError::InvalidRequest) => transaction_reply::Data::InvalidRequest(TransactionErrorInvalidRequest {}),
Err(FrontendRequestError::StorageFull) => transaction_reply::Data::InternalError(TransactionErrorInternalError {}),
}
} else {
transaction_reply::Data::InvalidRequest(TransactionErrorInvalidRequest {})
@ -360,28 +357,29 @@ impl PartitionData {
}
fn frontend_request_field_to_array<T>(src: &[u8]) -> Result<T, FrontendRequestError>
where T: AsMut<[u8]> + Default
{
where T: AsMut<[u8]> + Default {
match util::copy_exact(src) {
Ok(array) => Ok(array),
Err(()) => Err(FrontendRequestError::InvalidRequest),
Err(()) => Err(FrontendRequestError::InvalidRequest),
}
}
fn frontend_request_backup_id(backup_id: &BackupId) -> Result<PartitionKey, FrontendRequestError> {
match PartitionKey::try_from_pb(backup_id) {
Ok(backup_id) => Ok(backup_id),
Err(()) => Err(FrontendRequestError::InvalidRequest),
Err(()) => Err(FrontendRequestError::InvalidRequest),
}
}
fn get_or_insert_backup_with(&mut self, backup_id: PartitionKey, create_fun: impl FnOnce() -> BackupEntry)
-> Result<&mut BackupEntry, FrontendRequestError> {
fn get_or_insert_backup_with(
&mut self,
backup_id: PartitionKey,
create_fun: impl FnOnce() -> BackupEntry,
) -> Result<&mut BackupEntry, FrontendRequestError>
{
let storage_len = self.storage.len();
match self.storage.entry(backup_id) {
btree_map::Entry::Occupied(backup_entry) => {
Ok(backup_entry.into_mut())
}
btree_map::Entry::Occupied(backup_entry) => Ok(backup_entry.into_mut()),
btree_map::Entry::Vacant(backup_entry) => {
if storage_len < self.config.capacity {
Ok(backup_entry.insert(create_fun()))
@ -392,16 +390,18 @@ impl PartitionData {
}
}
fn perform_frontend_request(&mut self, txn_data: frontend_request_transaction::Transaction)
-> Result<transaction_reply::Data, FrontendRequestError>
fn perform_frontend_request(
&mut self,
txn_data: frontend_request_transaction::Transaction,
) -> Result<transaction_reply::Data, FrontendRequestError>
{
match txn_data {
frontend_request_transaction::Transaction::Create(create_request) => {
let backup_id = Self::frontend_request_backup_id(&create_request.backup_id)?;
let backup_id = Self::frontend_request_backup_id(&create_request.backup_id)?;
let new_creation_nonce = Self::frontend_request_field_to_array(&create_request.new_creation_nonce)?;
let new_nonce = Self::frontend_request_field_to_array(&create_request.new_nonce)?;
let new_nonce = Self::frontend_request_field_to_array(&create_request.new_nonce)?;
let backup = self.get_or_insert_backup_with(backup_id, || BackupEntry {
nonce: RequestNonce {
nonce: RequestNonce {
creation_nonce: new_creation_nonce,
current_nonce: new_nonce,
},
@ -414,15 +414,19 @@ impl PartitionData {
}))
}
frontend_request_transaction::Transaction::Backup(backup_request) => {
let backup_id = Self::frontend_request_backup_id(&backup_request.backup_id)?;
let backup_id = Self::frontend_request_backup_id(&backup_request.backup_id)?;
let new_creation_nonce = Self::frontend_request_field_to_array(&backup_request.new_creation_nonce)?;
let new_nonce = Self::frontend_request_field_to_array(&backup_request.new_nonce)?;
let tries = backup_request.tries.to_u16().and_then(NonZeroU16::new).ok_or(FrontendRequestError::InvalidRequest)?;
let new_nonce = Self::frontend_request_field_to_array(&backup_request.new_nonce)?;
let tries = backup_request
.tries
.to_u16()
.and_then(NonZeroU16::new)
.ok_or(FrontendRequestError::InvalidRequest)?;
let max_backup_data_length = self.config.max_backup_data_length;
let backup = self.get_or_insert_backup_with(backup_id, || BackupEntry {
nonce: RequestNonce {
nonce: RequestNonce {
creation_nonce: new_creation_nonce,
current_nonce: new_nonce,
},
@ -442,14 +446,14 @@ impl PartitionData {
} else {
let backup_request_pin: &[u8; BackupEntry::PIN_LENGTH] = match backup_request.pin.data[..].try_into() {
Ok(backup_request_pin) => backup_request_pin,
Err(_) => return Err(FrontendRequestError::InvalidRequest),
Err(_) => return Err(FrontendRequestError::InvalidRequest),
};
if backup_request.data.data.len() > max_backup_data_length.to_usize() {
return Err(FrontendRequestError::InvalidRequest);
}
backup.nonce.creation_nonce = backup.nonce.current_nonce;
backup.nonce.current_nonce = new_nonce;
backup.secrets = Some(BackupEntrySecrets::new(tries, backup_request_pin, &backup_request.data.data));
backup.nonce.current_nonce = new_nonce;
backup.secrets = Some(BackupEntrySecrets::new(tries, backup_request_pin, &backup_request.data.data));
kbupd_client::BackupResponse {
status: Some(kbupd_client::backup_response::Status::Ok.into()),
@ -463,9 +467,9 @@ impl PartitionData {
}))
}
frontend_request_transaction::Transaction::Restore(restore_request) => {
let backup_id = Self::frontend_request_backup_id(&restore_request.backup_id)?;
let backup_id = Self::frontend_request_backup_id(&restore_request.backup_id)?;
let restore_response = if let btree_map::Entry::Occupied(mut storage_entry) = self.storage.entry(backup_id) {
let entry = storage_entry.get_mut();
let entry = storage_entry.get_mut();
let new_nonce = Self::frontend_request_field_to_array(&restore_request.new_nonce)?;
if restore_request.creation_nonce != entry.nonce.creation_nonce {
@ -495,9 +499,10 @@ impl PartitionData {
}
} else if let Some(tries_minus_one) = entry_secrets.tries.get().checked_sub(2) {
// decrement tries
entry_secrets.tries = tries_minus_one.checked_add(1)
.and_then(NonZeroU16::new)
.unwrap_or_else(|| unreachable!());
entry_secrets.tries = tries_minus_one
.checked_add(1)
.and_then(NonZeroU16::new)
.unwrap_or_else(|| unreachable!());
kbupd_client::RestoreResponse {
status: Some(kbupd_client::restore_response::Status::PinMismatch.into()),
tries: Some(u32::from(entry_secrets.tries.get())),
@ -546,25 +551,31 @@ impl PartitionData {
}
}
fn perform_start_xfer_transaction(&mut self,
txn: StartXferTransaction,
peers: &mut PeerManager<PeerState>)
-> EnclaveStartXferTransaction
fn perform_start_xfer_transaction(
&mut self,
txn: StartXferTransaction,
peers: &mut PeerManager<PeerState>,
) -> EnclaveStartXferTransaction
{
let from = NodeId::from(&txn.from_node_id);
match &self.xfer_state {
XferState::SourcePartition(xfer_destination) => {
if txn.xfer_request.group_id != xfer_destination.remote_group_id {
warn!("received XferRequest from {} while having xfer destination {}: {}",
&from, &xfer_destination.remote_group_id, &txn.xfer_request);
warn!(
"received XferRequest from {} while having xfer destination {}: {}",
&from, &xfer_destination.remote_group_id, &txn.xfer_request
);
} else {
verbose!("received duplicate XferRequest from {}", &from);
}
return Default::default();
}
XferState::DestinationPartition(_xfer_source) => {
warn!("received XferRequest from {} while having xfer source: {}", &from, &txn.xfer_request);
warn!(
"received XferRequest from {} while having xfer source: {}",
&from, &txn.xfer_request
);
return Default::default();
}
XferState::None => (),
@ -584,23 +595,27 @@ impl PartitionData {
return Default::default();
}
};
if (our_range.first() != full_range.first() ||
!our_range.contains_range(&full_range)) {
warn!("received XferRequest from {} with requested range outside our range {}: {}",
&from, &our_range, &txn.xfer_request);
if (our_range.first() != full_range.first() || !our_range.contains_range(&full_range)) {
warn!(
"received XferRequest from {} with requested range outside our range {}: {}",
&from, &our_range, &txn.xfer_request
);
return Default::default();
}
let service_id = match &self.service_id {
Some(service_id) => service_id,
None => {
warn!("received XferRequest from {} while having no service id: {}", &from, &txn.xfer_request);
warn!(
"received XferRequest from {} while having no service id: {}",
&from, &txn.xfer_request
);
return Default::default();
}
};
let mut attestations: Vec<(NodeId, AttestationParameters)> = Vec::new();
let mut senders: Vec<ReplicaRemoteSender> = Vec::new();
let mut senders: Vec<ReplicaRemoteSender> = Vec::new();
for node_id in &txn.xfer_request.node_ids {
match peers.start_peer(node_id.into(), NodeType::Replica, RemoteAuthorizationType::Mutual) {
Ok(peer_entry) => {
@ -613,24 +628,26 @@ impl PartitionData {
}
senders.push(remote.sender().clone());
}
Err(Some(PeerState::Frontend { .. })) |
Err(None) => {
error!("started xfer to {} when it's already connected as a frontend!", NodeId::from(node_id));
Err(Some(PeerState::Frontend { .. })) | Err(None) => {
error!(
"started xfer to {} when it's already connected as a frontend!",
NodeId::from(node_id)
);
}
}
}
let cur_xfer_reply = PendingXferRequest {
id: PendingXferRequestId::XferReply,
message: Rc::new(ReplicaToReplicaMessage {
id: PendingXferRequestId::XferReply,
message: Rc::new(ReplicaToReplicaMessage {
inner: Some(replica_to_replica_message::Inner::XferReply(XferReply {
service: service_id.clone(),
})),
}),
min_attestation: None,
};
let remote_group_id = txn.xfer_request.group_id.clone();
let group_name = format!("{}", &remote_group_id);
let remote_group_id = txn.xfer_request.group_id.clone();
let group_name = format!("{}", &remote_group_id);
let mut remote_group = RemoteGroupState::new(group_name, senders);
for (replica_node_id, _attestation) in &attestations {
remote_group.remote_authorized(replica_node_id);
@ -638,18 +655,23 @@ impl PartitionData {
let chunk_size = txn.xfer_request.chunk_size;
let display_nodes = txn.xfer_request.node_ids.iter().map(|node| util::ToHex(node));
info!("starting xfer of range {} chunk size {} to group {} with nodes {}",
&full_range, &chunk_size, &remote_group_id, ListDisplay(display_nodes));
info!(
"starting xfer of range {} chunk size {} to group {} with nodes {}",
&full_range,
&chunk_size,
&remote_group_id,
ListDisplay(display_nodes)
);
self.xfer_state = XferState::SourcePartition(XferDestination {
remote_group_id,
remote_group,
full_range,
chunk_size,
paused: true,
paused: true,
inflight: None,
cur_xfer_reply: Some(cur_xfer_reply),
cur_xfer_reply: Some(cur_xfer_reply),
cur_xfer_chunk_request: None,
});
@ -662,8 +684,10 @@ impl PartitionData {
}
if let Some(service_id) = &self.service_id {
if service_id != &txn.service_id {
error!("tried to set service id {} on partition already having service id {}",
&txn.service_id, &service_id);
error!(
"tried to set service id {} on partition already having service id {}",
&txn.service_id, &service_id
);
}
Default::default()
} else {
@ -675,7 +699,12 @@ impl PartitionData {
}
}
fn perform_remove_chunk_transaction(&mut self, txn: RemoveChunkTransaction, group: &mut ReplicaGroupState) -> EnclaveRemoveChunkTransaction {
fn perform_remove_chunk_transaction(
&mut self,
txn: RemoveChunkTransaction,
group: &mut ReplicaGroupState,
) -> EnclaveRemoveChunkTransaction
{
if let XferState::SourcePartition(xfer_destination) = &mut self.xfer_state {
xfer_destination.received_reply(&PendingXferRequestId::XferReply);
xfer_destination.received_reply(&PendingXferRequestId::XferChunkRequest {
@ -686,8 +715,10 @@ impl PartitionData {
if inflight.last() == &txn.xfer_chunk_reply.new_last {
xfer_destination.inflight = None;
} else {
warn!("dropping out of order RemoveChunkTransaction {} expecting {}",
&txn.xfer_chunk_reply.new_last, &inflight);
warn!(
"dropping out of order RemoveChunkTransaction {} expecting {}",
&txn.xfer_chunk_reply.new_last, &inflight
);
}
}
@ -761,11 +792,7 @@ impl PartitionData {
}
};
let maybe_our_new_first = if let Some(range) = &self.range {
Some(range.first())
} else {
None
};
let maybe_our_new_first = if let Some(range) = &self.range { Some(range.first()) } else { None };
let entries = Self::storage_split_to(&mut self.storage, maybe_our_new_first);
@ -784,10 +811,10 @@ impl PartitionData {
let min_attestation = group.attestation_expiration_window();
xfer_destination.cur_xfer_chunk_request = Some(PendingXferRequest {
id: PendingXferRequestId::XferChunkRequest {
id: PendingXferRequestId::XferChunkRequest {
new_last: chunk_last.to_pb(),
},
message: Rc::new(ReplicaToReplicaMessage {
message: Rc::new(ReplicaToReplicaMessage {
inner: Some(replica_to_replica_message::Inner::XferChunkRequest(XferChunkRequest {
data,
chunk_range: chunk_range.to_pb(),
@ -799,8 +826,11 @@ impl PartitionData {
Some(chunk_range)
}
fn storage_split_to(storage: &mut BTreeMap<PartitionKey, BackupEntry>, maybe_split_key: Option<&PartitionKey>)
-> impl DoubleEndedIterator<Item = (PartitionKey, BackupEntry)> + ExactSizeIterator {
fn storage_split_to(
storage: &mut BTreeMap<PartitionKey, BackupEntry>,
maybe_split_key: Option<&PartitionKey>,
) -> impl DoubleEndedIterator<Item = (PartitionKey, BackupEntry)> + ExactSizeIterator
{
let new_map = if let Some(split_key) = maybe_split_key {
storage.split_off(split_key)
} else {
@ -809,7 +839,12 @@ impl PartitionData {
std::mem::replace(storage, new_map).into_iter()
}
fn perform_apply_chunk_transaction(&mut self, txn: ApplyChunkTransaction, group: &mut ReplicaGroupState) -> EnclaveApplyChunkTransaction {
fn perform_apply_chunk_transaction(
&mut self,
txn: ApplyChunkTransaction,
group: &mut ReplicaGroupState,
) -> EnclaveApplyChunkTransaction
{
if let XferState::DestinationPartition(xfer_source) = &mut self.xfer_state {
let request = &txn.xfer_chunk_request;
info!("received xfer chunk {} length {}", &request.chunk_range, request.data.data.len());
@ -826,13 +861,18 @@ impl PartitionData {
if !self.range.map(|range| range.overlaps_range(&chunk_range)).unwrap_or(false) {
chunk_range
} else {
warn!("dropping old xfer chunk {} current range {}",
&chunk_range, OptionDisplay(self.range));
warn!(
"dropping old xfer chunk {} current range {}",
&chunk_range,
OptionDisplay(self.range)
);
return Default::default();
}
} else {
error!("dropping xfer chunk {} not in desired range {}",
&chunk_range, &xfer_source.desired_range);
error!(
"dropping xfer chunk {} not in desired range {}",
&chunk_range, &xfer_source.desired_range
);
return Default::default();
}
}
@ -852,11 +892,11 @@ impl PartitionData {
let old_range = std::mem::replace(&mut self.range, Some(new_range));
let xfer_chunk_reply = PendingXferRequest {
id: PendingXferRequestId::XferChunkReply {
id: PendingXferRequestId::XferChunkReply {
new_last: txn.xfer_chunk_reply.new_last.clone(),
},
message: Rc::new(ReplicaToReplicaMessage {
inner: Some(replica_to_replica_message::Inner::XferChunkReply(txn.xfer_chunk_reply.clone()))
message: Rc::new(ReplicaToReplicaMessage {
inner: Some(replica_to_replica_message::Inner::XferChunkReply(txn.xfer_chunk_reply.clone())),
}),
min_attestation: None,
};
@ -879,8 +919,11 @@ impl PartitionData {
self.storage.insert(entry.id, entry.entry);
entry_count += 1;
} else {
error!("dropping transferred backup id {} within current range {}",
&entry.id, OptionDisplay(self.range));
error!(
"dropping transferred backup id {} within current range {}",
&entry.id,
OptionDisplay(self.range)
);
}
}
@ -914,7 +957,10 @@ impl PartitionData {
}
};
if !xfer_destination.paused {
info!("Pausing partitioning process at {}", OptionDisplay(self.range.as_ref().map(PartitionKeyRange::first)));
info!(
"Pausing partitioning process at {}",
OptionDisplay(self.range.as_ref().map(PartitionKeyRange::first))
);
}
xfer_destination.paused = true;
@ -923,7 +969,12 @@ impl PartitionData {
EnclavePauseXferTransaction {}
}
fn perform_resume_xfer_transaction(&mut self, txn: ResumeXferTransaction, group: &mut ReplicaGroupState) -> EnclaveResumeXferTransaction {
fn perform_resume_xfer_transaction(
&mut self,
txn: ResumeXferTransaction,
group: &mut ReplicaGroupState,
) -> EnclaveResumeXferTransaction
{
let xfer_destination = match &mut self.xfer_state {
XferState::SourcePartition(xfer_destination) => xfer_destination,
_ => {
@ -985,8 +1036,11 @@ impl PartitionData {
fn perform_set_time_transaction(&mut self, txn: SetTimeTransaction, group: &mut ReplicaGroupState) -> EnclaveSetTimeTransaction {
if !group.set_attestation_time_now(Duration::from_secs(txn.now_secs)) {
warn!("tried to set attestation time backward from {} to {}",
group.get_attestation_time_now().as_secs(), txn.now_secs);
warn!(
"tried to set attestation time backward from {} to {}",
group.get_attestation_time_now().as_secs(),
txn.now_secs
);
Default::default()
} else {
EnclaveSetTimeTransaction {
@ -1005,7 +1059,7 @@ impl XferState {
match self {
XferState::DestinationPartition(xfer_source) => Some(&mut xfer_source.remote_group),
XferState::SourcePartition(xfer_destination) => Some(&mut xfer_destination.remote_group),
XferState::None => None,
XferState::None => None,
}
}
}
@ -1015,10 +1069,7 @@ impl XferState {
//
impl XferSource {
pub fn new(remote_group: RemoteGroupState<ReplicaRemoteSender, PendingXferRequest>,
desired_range: PartitionKeyRange)
-> Self
{
pub fn new(remote_group: RemoteGroupState<ReplicaRemoteSender, PendingXferRequest>, desired_range: PartitionKeyRange) -> Self {
Self {
remote_group,
desired_range,
@ -1102,14 +1153,17 @@ impl XferDestination {
//
impl RemoteGroupPendingRequest for PendingXferRequest {
type Message = ReplicaToReplicaMessage;
type RequestId = PendingXferRequestId;
type Message = ReplicaToReplicaMessage;
fn request_id(&self) -> &Self::RequestId {
&self.id
}
fn message(&self) -> Rc<Self::Message> {
Rc::clone(&self.message)
}
fn min_attestation(&self) -> Option<AttestationParameters> {
self.min_attestation
}
@ -1127,13 +1181,14 @@ impl RequestNonce {
current_nonce: *current_nonce,
}
}
pub fn to_combined(&self) -> [u8; 32] {
let mut combined = [0; 32];
let (combined_creation_nonce, combined_current_nonce) = Self::split_mut(&mut combined);
*combined_creation_nonce = self.creation_nonce;
*combined_current_nonce = self.current_nonce;
*combined_current_nonce = self.current_nonce;
combined
}
@ -1159,8 +1214,8 @@ impl RequestNonce {
fn split_mut(combined: &mut [u8; 32]) -> (&mut [u8; 16], &mut [u8; 16]) {
let (creation_nonce, current_nonce) = combined.split_at_mut(16);
let creation_nonce: &mut [u8; 16] = creation_nonce.try_into().unwrap_or_else(|_| static_unreachable!());
let current_nonce: &mut [u8; 16] = current_nonce.try_into().unwrap_or_else(|_| static_unreachable!());
let creation_nonce: &mut [u8; 16] = creation_nonce.try_into().unwrap_or_else(|_| static_unreachable!());
let current_nonce: &mut [u8; 16] = current_nonce.try_into().unwrap_or_else(|_| static_unreachable!());
(creation_nonce, current_nonce)
}
}
@ -1175,10 +1230,12 @@ impl BackupEntry {
const fn encoded_len(data_len: u32) -> u32 {
RequestNonce::encoded_len() + BackupEntrySecrets::encoded_len(data_len)
}
fn encode<B: BufMut>(&self, buf: &mut B) {
self.nonce.encode(buf);
BackupEntrySecrets::encode_opt(self.secrets.as_ref(), buf);
}
fn decode<B: Buf>(buf: &mut B) -> Self {
let nonce = RequestNonce::decode(buf);
let secrets = BackupEntrySecrets::decode(buf);
@ -1205,15 +1262,20 @@ impl XferBackupEntry {
fn encoded_len(data_len: u32) -> u32 {
BackupId::valid_len() + BackupEntry::encoded_len(data_len)
}
fn encode<B: BufMut>(&self, buf: &mut B) {
self.entry.encode(buf);
buf.put_slice(&self.id[..]);
}
fn decode<B: Buf>(buf: &mut B) -> Self {
let entry = BackupEntry::decode(buf);
let mut id = [0; 32];
buf.copy_to_slice(&mut id);
Self { id: PartitionKey::new(id), entry }
Self {
id: PartitionKey::new(id),
entry,
}
}
}

View File

@ -14,7 +14,7 @@ use std::num::NonZeroU16;
use bytes::{Buf, BufMut};
use sgx_ffi::util::{SecretValue, ToUsize};
use super::{BackupEntry};
use super::BackupEntry;
pub struct BackupEntrySecrets {
pub tries: NonZeroU16,

View File

@ -26,11 +26,7 @@ pub struct PartitionKeyRange {
impl PartitionKeyRange {
pub fn new(first: PartitionKey, last: PartitionKey) -> Result<Self, ()> {
if first <= last {
Ok(Self { first, last })
} else {
Err(())
}
if first <= last { Ok(Self { first, last }) } else { Err(()) }
}
pub fn new_unbounded() -> Self {
@ -64,8 +60,7 @@ impl PartitionKeyRange {
}
pub fn contains(&self, key: &[u8; 32]) -> bool {
(key >= &self.first &&
key <= &self.last)
(key >= &self.first && key <= &self.last)
}
pub fn contains_id(&self, key: &BackupId) -> bool {
@ -83,8 +78,7 @@ impl PartitionKeyRange {
}
pub fn overlaps_range(&self, other: &Self) -> bool {
(self.contains(&other.first) || self.contains(&other.last) ||
other.contains(&self.first) || other.contains(&self.last))
(self.contains(&other.first) || self.contains(&other.last) || other.contains(&self.first) || other.contains(&self.last))
}
pub fn split_off_inclusive(&mut self, new_last: &PartitionKey) -> Result<Option<Self>, ()> {
@ -105,6 +99,7 @@ impl RangeBounds<PartitionKey> for &PartitionKeyRange {
fn start_bound(&self) -> Bound<&PartitionKey> {
Bound::Included(&self.first)
}
fn end_bound(&self) -> Bound<&PartitionKey> {
Bound::Included(&self.last)
}
@ -122,7 +117,6 @@ impl fmt::Display for PartitionKeyRange {
}
}
//
// PartitionKey impls
//
@ -143,9 +137,7 @@ impl PartitionKey {
}
pub fn to_pb(&self) -> BackupId {
BackupId {
id: self.0.to_vec(),
}
BackupId { id: self.0.to_vec() }
}
#[allow(clippy::indexing_slicing, clippy::integer_arithmetic)]
@ -157,12 +149,9 @@ impl PartitionKey {
carry = if new_b.1 { 1 } else { 0 };
ret[31 - i] = new_b.0;
}
if carry == 0 {
Some(Self(ret))
} else {
None
}
if carry == 0 { Some(Self(ret)) } else { None }
}
#[allow(clippy::indexing_slicing, clippy::integer_arithmetic)]
pub fn checked_add(&self, rhs: u8) -> Option<Self> {
let mut ret = [0x00; 32];
@ -172,16 +161,13 @@ impl PartitionKey {
carry = if new_b.1 { 1 } else { 0 };
ret[31 - i] = new_b.0;
}
if carry == 0 {
Some(Self(ret))
} else {
None
}
if carry == 0 { Some(Self(ret)) } else { None }
}
}
impl Deref for PartitionKey {
type Target = [u8; 32];
fn deref(&self) -> &Self::Target {
&self.0
}
@ -225,69 +211,55 @@ mod tests {
#[test]
fn valid_full_range() {
PartitionKeyRange::try_from_pb(
&PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 32] },
&BackupId { id: vec![0xFF; 32] })
&PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 32] }, &BackupId { id: vec![0xFF; 32] })
.unwrap()
.to_pb()
).unwrap();
.to_pb(),
)
.unwrap();
}
#[test]
fn valid_empty_range() {
PartitionKeyRange::try_from_pb(
&PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 32] },
&BackupId { id: vec![0x00; 32] })
&PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 32] }, &BackupId { id: vec![0x00; 32] })
.unwrap()
.to_pb()
).unwrap();
.to_pb(),
)
.unwrap();
}
#[test]
fn invalid_inverted_range() {
PartitionKeyRange::from_ids(&BackupId { id: vec![0xFF; 32] },
&BackupId { id: vec![0x00; 32] })
.unwrap_err();
PartitionKeyRange::from_ids(&BackupId { id: vec![0xFF; 32] }, &BackupId { id: vec![0x00; 32] }).unwrap_err();
}
#[test]
fn invalid_range_empty_first() {
PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 0] },
&BackupId { id: vec![0xFF; 32] })
.unwrap_err();
PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 0] }, &BackupId { id: vec![0xFF; 32] }).unwrap_err();
}
#[test]
fn invalid_range_empty_last() {
PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 32] },
&BackupId { id: vec![0xFF; 0] })
.unwrap_err();
PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 32] }, &BackupId { id: vec![0xFF; 0] }).unwrap_err();
}
#[test]
fn invalid_range_short_first() {
PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 31] },
&BackupId { id: vec![0xFF; 32] })
.unwrap_err();
PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 31] }, &BackupId { id: vec![0xFF; 32] }).unwrap_err();
}
#[test]
fn invalid_range_short_last() {
PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 31] },
&BackupId { id: vec![0xFF; 32] })
.unwrap_err();
PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 31] }, &BackupId { id: vec![0xFF; 32] }).unwrap_err();
}
#[test]
fn invalid_range_long_first() {
PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 33] },
&BackupId { id: vec![0xFF; 32] })
.unwrap_err();
PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 33] }, &BackupId { id: vec![0xFF; 32] }).unwrap_err();
}
#[test]
fn invalid_range_long_last() {
PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 32] },
&BackupId { id: vec![0xFF; 33] })
.unwrap_err();
PartitionKeyRange::from_ids(&BackupId { id: vec![0x00; 32] }, &BackupId { id: vec![0xFF; 33] }).unwrap_err();
}
#[test]
@ -323,17 +295,17 @@ mod tests {
}
macro_rules! key {
([$value:expr] + $addend:expr) => ({
PartitionKey::new([$value; 32]).checked_add($addend)
.unwrap_or_else(|| panic!("test key overflow"))
});
([$value:expr] - $subtrahend:expr) => ({
PartitionKey::new([$value; 32]).checked_sub($subtrahend)
.unwrap_or_else(|| panic!("test key underflow"))
});
([$value:expr]) => ({
([$value:expr] + $addend:expr) => {{
PartitionKey::new([$value; 32])
});
.checked_add($addend)
.unwrap_or_else(|| panic!("test key overflow"))
}};
([$value:expr] - $subtrahend:expr) => {{
PartitionKey::new([$value; 32])
.checked_sub($subtrahend)
.unwrap_or_else(|| panic!("test key underflow"))
}};
([$value:expr]) => {{ PartitionKey::new([$value; 32]) }};
}
macro_rules! range {
@ -364,85 +336,85 @@ mod tests {
assert!(range.overlaps_range(&range!([0xFF], [0xFF])));
let range = range!([0x00], [0xBE]);
assert!( range.overlaps_range(&range!()));
assert!( range.overlaps_range(&range!([0x00], [0x00])));
assert!( range.overlaps_range(&range!([0x00], [0xBE]-1)));
assert!( range.overlaps_range(&range!([0x00], [0xBE])));
assert!( range.overlaps_range(&range!([0x00], [0xFF])));
assert!( range.overlaps_range(&range!([0xBA], [0xBA])));
assert!( range.overlaps_range(&range!([0xBA], [0xBE])));
assert!( range.overlaps_range(&range!([0xBA], [0xBE]+1)));
assert!( range.overlaps_range(&range!([0xBE], [0xBE])));
assert!( range.overlaps_range(&range!([0xBE], [0xBE]+1)));
assert!(!range.overlaps_range(&range!([0xBE]+1, [0xEF])));
assert!(!range.overlaps_range(&range!([0xEF], [0xFF])));
assert!(!range.overlaps_range(&range!([0xFF], [0xFF])));
assert!(range.overlaps_range(&range!()));
assert!(range.overlaps_range(&range!([0x00], [0x00])));
assert!(range.overlaps_range(&range!([0x00], [0xBE] - 1)));
assert!(range.overlaps_range(&range!([0x00], [0xBE])));
assert!(range.overlaps_range(&range!([0x00], [0xFF])));
assert!(range.overlaps_range(&range!([0xBA], [0xBA])));
assert!(range.overlaps_range(&range!([0xBA], [0xBE])));
assert!(range.overlaps_range(&range!([0xBA], [0xBE] + 1)));
assert!(range.overlaps_range(&range!([0xBE], [0xBE])));
assert!(range.overlaps_range(&range!([0xBE], [0xBE] + 1)));
assert!(!range.overlaps_range(&range!([0xBE] + 1, [0xEF])));
assert!(!range.overlaps_range(&range!([0xEF], [0xFF])));
assert!(!range.overlaps_range(&range!([0xFF], [0xFF])));
let range = range!([0xBE], [0xEF]);
assert!( range.overlaps_range(&range!()));
assert!(!range.overlaps_range(&range!([0x00], [0x00])));
assert!(!range.overlaps_range(&range!([0x00], [0xBE]-1)));
assert!( range.overlaps_range(&range!([0x00], [0xBE])));
assert!( range.overlaps_range(&range!([0x00], [0xBF])));
assert!( range.overlaps_range(&range!([0x00], [0xEF])));
assert!( range.overlaps_range(&range!([0x00], [0xEF]+1)));
assert!( range.overlaps_range(&range!([0x00], [0xFF])));
assert!( range.overlaps_range(&range!([0xBE], [0xEF])));
assert!( range.overlaps_range(&range!([0xBE], [0xEF]+1)));
assert!( range.overlaps_range(&range!([0xBE]+1, [0xEF]-1)));
assert!( range.overlaps_range(&range!([0xBE]+1, [0xEF])));
assert!( range.overlaps_range(&range!([0xBE]+1, [0xEF]+1)));
assert!( range.overlaps_range(&range!([0xEF], [0xEF])));
assert!( range.overlaps_range(&range!([0xEF], [0xEF]+1)));
assert!(!range.overlaps_range(&range!([0xEF]+1, [0xEF]+1)));
assert!(!range.overlaps_range(&range!([0xEF]+1, [0xFF])));
assert!(range.overlaps_range(&range!()));
assert!(!range.overlaps_range(&range!([0x00], [0x00])));
assert!(!range.overlaps_range(&range!([0x00], [0xBE] - 1)));
assert!(range.overlaps_range(&range!([0x00], [0xBE])));
assert!(range.overlaps_range(&range!([0x00], [0xBF])));
assert!(range.overlaps_range(&range!([0x00], [0xEF])));
assert!(range.overlaps_range(&range!([0x00], [0xEF] + 1)));
assert!(range.overlaps_range(&range!([0x00], [0xFF])));
assert!(range.overlaps_range(&range!([0xBE], [0xEF])));
assert!(range.overlaps_range(&range!([0xBE], [0xEF] + 1)));
assert!(range.overlaps_range(&range!([0xBE] + 1, [0xEF] - 1)));
assert!(range.overlaps_range(&range!([0xBE] + 1, [0xEF])));
assert!(range.overlaps_range(&range!([0xBE] + 1, [0xEF] + 1)));
assert!(range.overlaps_range(&range!([0xEF], [0xEF])));
assert!(range.overlaps_range(&range!([0xEF], [0xEF] + 1)));
assert!(!range.overlaps_range(&range!([0xEF] + 1, [0xEF] + 1)));
assert!(!range.overlaps_range(&range!([0xEF] + 1, [0xFF])));
let range = range!([0xEF], [0xFF]);
assert!(!range.overlaps_range(&range!([0x00], [0x00])));
assert!(!range.overlaps_range(&range!([0x00], [0xEF]-1)));
assert!( range.overlaps_range(&range!([0x00], [0xEF])));
assert!( range.overlaps_range(&range!([0x00], [0xEF]+1)));
assert!( range.overlaps_range(&range!([0x00], [0xFF])));
assert!( range.overlaps_range(&range!([0xEF], [0xEF])));
assert!( range.overlaps_range(&range!([0xEF], [0xEF]+1)));
assert!( range.overlaps_range(&range!([0xEF], [0xFF])));
assert!( range.overlaps_range(&range!([0xEF]+1, [0xFF]-1)));
assert!( range.overlaps_range(&range!([0xEF]+1, [0xFF])));
assert!( range.overlaps_range(&range!([0xFF], [0xFF])));
assert!(!range.overlaps_range(&range!([0x00], [0x00])));
assert!(!range.overlaps_range(&range!([0x00], [0xEF] - 1)));
assert!(range.overlaps_range(&range!([0x00], [0xEF])));
assert!(range.overlaps_range(&range!([0x00], [0xEF] + 1)));
assert!(range.overlaps_range(&range!([0x00], [0xFF])));
assert!(range.overlaps_range(&range!([0xEF], [0xEF])));
assert!(range.overlaps_range(&range!([0xEF], [0xEF] + 1)));
assert!(range.overlaps_range(&range!([0xEF], [0xFF])));
assert!(range.overlaps_range(&range!([0xEF] + 1, [0xFF] - 1)));
assert!(range.overlaps_range(&range!([0xEF] + 1, [0xFF])));
assert!(range.overlaps_range(&range!([0xFF], [0xFF])));
}
#[test]
fn split_off_first() {
let mut range = PartitionKeyRange::new_unbounded();
let split_off = range.split_off_inclusive(&PartitionKey::new([0x00; 32])).unwrap().unwrap();
assert_eq!(range.first(), &[0x00; 32]);
assert_eq!(range.last(), &[0x00; 32]);
assert_eq!(range.first(), &[0x00; 32]);
assert_eq!(range.last(), &[0x00; 32]);
assert_eq!(split_off.first(), &range.last().checked_add(1).unwrap());
assert_eq!(split_off.last(), &[0xff; 32]);
assert_eq!(split_off.last(), &[0xff; 32]);
let mut range = PartitionKeyRange::new(PartitionKey::new([0xbe; 32]), PartitionKey::new([0xef; 32])).unwrap();
let split_off = range.split_off_inclusive(&PartitionKey::new([0xbe; 32])).unwrap().unwrap();
assert_eq!(range.first(), &[0xbe; 32]);
assert_eq!(range.last(), &[0xbe; 32]);
assert_eq!(range.first(), &[0xbe; 32]);
assert_eq!(range.last(), &[0xbe; 32]);
assert_eq!(split_off.first(), &range.last().checked_add(1).unwrap());
assert_eq!(split_off.last(), &[0xef; 32]);
assert_eq!(split_off.last(), &[0xef; 32]);
}
#[test]
fn split_off_mid() {
let mut range = PartitionKeyRange::new_unbounded();
let split_off = range.split_off_inclusive(&PartitionKey::new([0x80; 32])).unwrap().unwrap();
assert_eq!(range.first(), &[0x00; 32]);
assert_eq!(range.last(), &[0x80; 32]);
assert_eq!(range.first(), &[0x00; 32]);
assert_eq!(range.last(), &[0x80; 32]);
assert_eq!(split_off.first(), &range.last().checked_add(1).unwrap());
assert_eq!(split_off.last(), &[0xff; 32]);
assert_eq!(split_off.last(), &[0xff; 32]);
let mut range = PartitionKeyRange::new(PartitionKey::new([0xbe; 32]), PartitionKey::new([0xef; 32])).unwrap();
let split_off = range.split_off_inclusive(&PartitionKey::new([0xdd; 32])).unwrap().unwrap();
assert_eq!(range.first(), &[0xbe; 32]);
assert_eq!(range.last(), &[0xdd; 32]);
assert_eq!(range.first(), &[0xbe; 32]);
assert_eq!(range.last(), &[0xdd; 32]);
assert_eq!(split_off.first(), &range.last().checked_add(1).unwrap());
assert_eq!(split_off.last(), &[0xef; 32]);
assert_eq!(split_off.last(), &[0xef; 32]);
}
#[test]
@ -450,13 +422,13 @@ mod tests {
let mut range = PartitionKeyRange::new_unbounded();
let split_off = range.split_off_inclusive(&PartitionKey::new([0xff; 32])).unwrap();
assert_eq!(range.first(), &[0x00; 32]);
assert_eq!(range.last(), &[0xff; 32]);
assert_eq!(range.last(), &[0xff; 32]);
assert!(split_off.is_none());
let mut range = PartitionKeyRange::new(PartitionKey::new([0xbe; 32]), PartitionKey::new([0xef; 32])).unwrap();
let split_off = range.split_off_inclusive(&PartitionKey::new([0xef; 32])).unwrap();
assert_eq!(range.first(), &[0xbe; 32]);
assert_eq!(range.last(), &[0xef; 32]);
assert_eq!(range.last(), &[0xef; 32]);
assert!(split_off.is_none());
}
@ -465,19 +437,19 @@ mod tests {
let mut range = PartitionKeyRange::new(PartitionKey::new([0x00; 32]), PartitionKey::new([0x00; 32])).unwrap();
let split_off = range.split_off_inclusive(&PartitionKey::new([0x00; 32])).unwrap();
assert_eq!(range.first(), &[0x00; 32]);
assert_eq!(range.last(), &[0x00; 32]);
assert_eq!(range.last(), &[0x00; 32]);
assert!(split_off.is_none());
let mut range = PartitionKeyRange::new(PartitionKey::new([0xbe; 32]), PartitionKey::new([0xbe; 32])).unwrap();
let split_off = range.split_off_inclusive(&PartitionKey::new([0xbe; 32])).unwrap();
assert_eq!(range.first(), &[0xbe; 32]);
assert_eq!(range.last(), &[0xbe; 32]);
assert_eq!(range.last(), &[0xbe; 32]);
assert!(split_off.is_none());
let mut range = PartitionKeyRange::new(PartitionKey::new([0xff; 32]), PartitionKey::new([0xff; 32])).unwrap();
let split_off = range.split_off_inclusive(&PartitionKey::new([0xff; 32])).unwrap();
assert_eq!(range.first(), &[0xff; 32]);
assert_eq!(range.last(), &[0xff; 32]);
assert_eq!(range.last(), &[0xff; 32]);
assert!(split_off.is_none());
}

View File

@ -11,13 +11,13 @@ use std::fmt;
use std::rc::*;
use std::time::*;
use sgxsd_ffi::{RdRand};
use sgxsd_ffi::RdRand;
use crate::protobufs::kbupd_enclave::*;
use crate::raft::*;
use crate::remote::*;
use crate::storage::*;
use crate::util::{ListDisplay};
use crate::util::ListDisplay;
use super::*;
@ -26,12 +26,12 @@ const ATTESTATION_EXPIRATION_WINDOW: Duration = Duration::from_secs(86400);
pub(super) struct ReplicaGroupState {
pub raft: RaftState<RaftLogStorage, RdRand, NodeId>,
remotes: Box<[RemoteReplicaState]>,
remotes: Box<[RemoteReplicaState]>,
attestation_time_ticks: u32,
request_quote_ticks: u32,
attestation_time_now: Duration,
attestation_time_now: Duration,
}
pub(super) struct RemoteReplicaState {
@ -43,16 +43,13 @@ pub(super) struct RemoteReplicaState {
//
impl ReplicaGroupState {
pub fn new(raft: RaftState<RaftLogStorage, RdRand, NodeId>,
remotes: Box<[RemoteReplicaState]>)
-> Self
{
pub fn new(raft: RaftState<RaftLogStorage, RdRand, NodeId>, remotes: Box<[RemoteReplicaState]>) -> Self {
Self {
raft,
remotes,
attestation_time_ticks: 0,
request_quote_ticks: 0,
attestation_time_now: Duration::from_secs(0),
request_quote_ticks: 0,
attestation_time_now: Duration::from_secs(0),
}
}
@ -61,23 +58,23 @@ impl ReplicaGroupState {
}
pub fn set_config(&mut self, config: &EnclaveReplicaConfig) {
self.raft.log_mut().set_index_cache_size(config.raft_log_index_page_cache_size.to_usize());
self.raft
.log_mut()
.set_index_cache_size(config.raft_log_index_page_cache_size.to_usize());
self.raft.set_election_timeout_ticks(config.election_timeout_ticks);
self.raft.set_heartbeat_timeout_ticks(config.heartbeat_timeout_ticks);
self.raft.set_replication_chunk_size(config.replication_chunk_size.to_usize());
}
pub fn get(&self, node_id: &NodeId) -> Option<&RemoteReplicaState> {
self.remotes.iter().find(|replica: &&RemoteReplicaState| replica.sender.id() == node_id)
self.remotes
.iter()
.find(|replica: &&RemoteReplicaState| replica.sender.id() == node_id)
}
pub fn timer_tick(&mut self,
max_attestation_time_ticks: u32,
max_request_quote_ticks: u32,
now: Duration)
-> Option<TransactionData> {
pub fn timer_tick(&mut self, max_attestation_time_ticks: u32, max_request_quote_ticks: u32, now: Duration) -> Option<TransactionData> {
self.attestation_time_ticks = self.attestation_time_ticks.saturating_add(1);
self.request_quote_ticks = self.request_quote_ticks.saturating_add(1);
self.request_quote_ticks = self.request_quote_ticks.saturating_add(1);
if self.request_quote_ticks >= max_request_quote_ticks {
self.request_quote_ticks = Default::default();
@ -93,23 +90,29 @@ impl ReplicaGroupState {
Ok(()) => {
self.attestation_time_ticks = 0;
Some(TransactionData {
inner: Some(transaction_data::Inner::SetTime(SetTimeTransaction {
now_secs: now.as_secs(),
})),
inner: Some(transaction_data::Inner::SetTime(SetTimeTransaction { now_secs: now.as_secs() })),
})
}
Err(invalid) => {
if self.attestation_time_ticks.checked_rem(max_attestation_time_ticks) == Some(0) {
info!("not setting attestation time to {} with invalid attestations for {} of {} replicas: {}",
now.as_secs(), invalid.len(), self.remotes.iter().len(), ListDisplay(invalid));
info!(
"not setting attestation time to {} with invalid attestations for {} of {} replicas: {}",
now.as_secs(),
invalid.len(),
self.remotes.iter().len(),
ListDisplay(invalid)
);
}
None
}
}
} else {
if now < self.attestation_time_now {
warn!("not setting attestation time backward from {} to {}",
self.attestation_time_now.as_secs(), now.as_secs());
warn!(
"not setting attestation time backward from {} to {}",
self.attestation_time_now.as_secs(),
now.as_secs()
);
}
None
}
@ -129,13 +132,20 @@ impl ReplicaGroupState {
for replica in self.remotes.iter() {
if !replica.check_attestation_time(now) {
if let Some(attestation) = replica.attestation() {
warn!("replica {} is now invalid at {}: {}",
replica.sender.id(), now.as_secs(), attestation);
warn!(
"replica {} is now invalid at {}: {}",
replica.sender.id(),
now.as_secs(),
attestation
);
}
}
}
info!("set attestation time from {} to {}",
self.attestation_time_now.as_secs(), now.as_secs());
info!(
"set attestation time from {} to {}",
self.attestation_time_now.as_secs(),
now.as_secs()
);
self.attestation_time_now = now;
true
} else {
@ -152,13 +162,16 @@ impl ReplicaGroupState {
}
pub fn attestation_expiration_window(&self) -> AttestationParameters {
let min_unix_timestamp = self.attestation_time_now.checked_sub(ATTESTATION_EXPIRATION_WINDOW).unwrap_or_default();
let min_unix_timestamp = self
.attestation_time_now
.checked_sub(ATTESTATION_EXPIRATION_WINDOW)
.unwrap_or_default();
AttestationParameters::new(min_unix_timestamp)
}
fn check_quorum_attestation_time(&self, now: Duration) -> Result<(), Vec<&RemoteReplicaState>> {
let replicas_iter = self.remotes.iter();
let replicas_len = replicas_iter.len();
let replicas_iter = self.remotes.iter();
let replicas_len = replicas_iter.len();
let invalid: Vec<_> = replicas_iter
.filter(|replica: &&RemoteReplicaState| !replica.check_attestation_time(now))
.collect();
@ -166,8 +179,13 @@ impl ReplicaGroupState {
if invalid.is_empty() {
info!("setting attestation time to {}", now.as_secs());
} else {
warn!("setting attestation time to {} with invalid attestations for {} of {} replicas: {}",
now.as_secs(), invalid.len(), replicas_len, ListDisplay(invalid));
warn!(
"setting attestation time to {} with invalid attestations for {} of {} replicas: {}",
now.as_secs(),
invalid.len(),
replicas_len,
ListDisplay(invalid)
);
}
Ok(())
} else {
@ -185,8 +203,11 @@ impl ReplicaGroupState {
if replica.check_attestation_time(self.attestation_time_now) {
let _ignore = replica.sender.send(Rc::clone(&r2r_message));
} else {
verbose!("dropping broadcast message to {} with expired attestation {}",
replica.sender.id(), OptionDisplay(replica.attestation()));
verbose!(
"dropping broadcast message to {} with expired attestation {}",
replica.sender.id(),
OptionDisplay(replica.attestation())
);
}
}
}
@ -199,8 +220,12 @@ impl ReplicaGroupState {
let _ignore = replica.sender.send(r2r_message);
} else {
warn!("dropping message to {} with expired attestation {}: {}",
replica.sender.id(), OptionDisplay(replica.attestation()), message);
warn!(
"dropping message to {} with expired attestation {}: {}",
replica.sender.id(),
OptionDisplay(replica.attestation()),
message
);
}
}
}
@ -219,7 +244,7 @@ impl RemoteReplicaState {
fn check_attestation_time(&self, now: Duration) -> bool {
let min_unix_timestamp = now.checked_sub(ATTESTATION_EXPIRATION_WINDOW).unwrap_or_default();
let min_attestation = AttestationParameters::new(min_unix_timestamp);
let min_attestation = AttestationParameters::new(min_unix_timestamp);
self.attestation() >= Some(min_attestation)
}
}

View File

@ -5,12 +5,12 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//
pub mod raft_log;
pub mod storage_array;
pub mod storage_data;
pub mod storage_page_cache;
pub mod raft_log;
pub use self::raft_log::*;
pub use self::storage_array::*;
pub use self::storage_data::*;
pub use self::storage_page_cache::*;
pub use self::raft_log::*;

View File

@ -8,7 +8,7 @@
mod raft_log_data;
mod raft_log_index;
use sgx_ffi::util::{SecretValue};
use sgx_ffi::util::SecretValue;
use crate::prelude::*;
use crate::protobufs::raft::*;
@ -36,12 +36,15 @@ impl RaftLogStorage {
cancelled: Default::default(),
})
}
pub fn take_cancelled(&mut self) -> Vec<SecretValue<Vec<u8>>> {
std::mem::replace(&mut self.cancelled, Default::default())
}
pub fn data_len(&self) -> usize {
self.data.len()
}
pub fn set_index_cache_size(&mut self, index_cache_size: usize) {
self.index.set_cache_size(index_cache_size);
}
@ -66,10 +69,9 @@ impl RaftLog for RaftLogStorage {
Err(RaftLogAppendError::OutOfSpace { log_entry })
}
}
fn pop_front(&mut self, truncate_to: LogIdx) -> Result<(), ()> {
if (!self.index.is_empty() &&
self.index.prev_log_idx() < truncate_to)
{
if (!self.index.is_empty() && self.index.prev_log_idx() < truncate_to) {
if let Some(popped_entry) = self.index.pop_front() {
if let Some(popped_data_entry) = &popped_entry.data {
self.data.pop_front_to(popped_data_entry);
@ -80,20 +82,23 @@ impl RaftLog for RaftLogStorage {
Err(())
}
}
fn cancel_from(&mut self, from_log_idx: LogIdx) -> Result<usize, ()> {
let cancelled_data_entries = self.index.cancel_from(from_log_idx)?;
self.cancelled = self.data.cancel(cancelled_data_entries);
Ok(self.cancelled.len())
}
fn get(&mut self, log_idx: LogIdx) -> Option<LogEntry> {
let index_entry = self.index.get(log_idx)?;
let entry_data = if let Some(data_entry) = &index_entry.data {
let entry_data = if let Some(data_entry) = &index_entry.data {
self.data.read(data_entry)?
} else {
SecretValue::new(Vec::new())
};
Some(LogEntry::new(index_entry.term, entry_data))
}
fn get_len(&mut self, log_idx: LogIdx) -> Option<usize> {
let index_entry = self.index.get(log_idx)?;
if let Some(data_entry) = &index_entry.data {
@ -102,6 +107,7 @@ impl RaftLog for RaftLogStorage {
Some(0)
}
}
fn get_term(&mut self, log_idx: LogIdx) -> Option<TermId> {
if log_idx == self.index.prev_log_idx() {
self.index.prev_log_term()
@ -109,12 +115,15 @@ impl RaftLog for RaftLogStorage {
self.index.get(log_idx).map(|entry| entry.term)
}
}
fn prev_idx(&self) -> LogIdx {
self.index.prev_log_idx()
}
fn last_idx(&self) -> LogIdx {
self.index.last_log_idx()
}
fn last_term(&self) -> TermId {
self.index.last_log_term()
}

View File

@ -7,8 +7,8 @@
use crate::prelude::*;
use num_traits::{ToPrimitive};
use sgx_ffi::util::{SecretValue};
use num_traits::ToPrimitive;
use sgx_ffi::util::SecretValue;
use crate::protobufs::raft::LogEntry;
use crate::raft::*;
@ -42,8 +42,7 @@ impl RaftLogData {
if let Some(len) = self.head.checked_sub(self.tail) {
len
} else {
self.storage.len().saturating_sub(self.tail)
.saturating_add(self.head)
self.storage.len().saturating_sub(self.tail).saturating_add(self.head)
}
}
@ -52,11 +51,15 @@ impl RaftLogData {
let append_len = match append_len.to_u32() {
Some(append_len) if append_len.to_usize() < self.storage.len() => append_len,
Some(_) | None => {
error!("transaction too large at {} bytes (have {} bytes of storage)", log_entry.data.len(), self.storage.len());
error!(
"transaction too large at {} bytes (have {} bytes of storage)",
log_entry.data.len(),
self.storage.len()
);
return Err(RaftLogAppendError::TooLarge { size: append_len });
}
};
let mut offset = self.head;
let mut offset = self.head;
let mut new_head = offset.saturating_add(append_len.to_usize());
if new_head >= self.storage.len() {
if self.tail > 0 {
@ -73,9 +76,11 @@ impl RaftLogData {
self.head = new_head;
match self.storage.write(offset, log_entry.into_data()) {
Ok(nonce) => {
Ok(RaftLogDataEntry { nonce, offset, length: append_len })
}
Ok(nonce) => Ok(RaftLogDataEntry {
nonce,
offset,
length: append_len,
}),
Err(()) => {
error!("wrote out of bounds to raft log at {} len {}", offset, append_len);
Err(RaftLogAppendError::InternalError)
@ -111,7 +116,10 @@ impl RaftLogData {
if let Ok(data) = self.storage.read(data_entry.offset, data_entry.length.to_usize(), data_entry.nonce) {
Some(data)
} else {
error!("error reading raft log entry at offset {} length {}", data_entry.offset, data_entry.length);
error!(
"error reading raft log entry at offset {} length {}",
data_entry.offset, data_entry.length
);
None
}
}

View File

@ -16,7 +16,7 @@ use super::raft_log_data::*;
use std::num::*;
use bytes::*;
use num_traits::{ToPrimitive};
use num_traits::ToPrimitive;
pub struct RaftLogIndex {
storage: StorageArray<RaftLogIndexEntry>,
@ -82,8 +82,10 @@ impl RaftLogIndex {
match self.storage.get(index).unwrap_or_else(|| panic!("overflow")) {
Ok(slot) => slot,
Err(storage_error) => {
error!("storage error reading from raft log index {} (current length {}): {}",
index, self.len, storage_error);
error!(
"storage error reading from raft log index {} (current length {}): {}",
index, self.len, storage_error
);
None
}
}
@ -97,14 +99,16 @@ impl RaftLogIndex {
let index = self.wrap_add(self.len);
match self.storage.get_mut(index).unwrap_or_else(|| panic!("overflow")) {
Ok(slot) => {
self.len = self.len.checked_add(1).unwrap_or_else(|| unreachable!());
self.len = self.len.checked_add(1).unwrap_or_else(|| unreachable!());
self.last_log_term = entry.term;
*slot = Some(entry);
*slot = Some(entry);
Ok(())
}
Err(storage_error) => {
error!("storage error appending to raft log index {} (current length {}): {}",
index, self.len, storage_error);
error!(
"storage error appending to raft log index {} (current length {}): {}",
index, self.len, storage_error
);
Err(())
}
}
@ -118,15 +122,17 @@ impl RaftLogIndex {
let entry = match self.storage.get(self.tail).unwrap_or_else(|| panic!("overflow")) {
Ok(entry) => entry.cloned(),
Err(storage_error) => {
error!("storage error popping from raft log index {} (current length {}): {}",
self.tail, self.len, storage_error);
error!(
"storage error popping from raft log index {} (current length {}): {}",
self.tail, self.len, storage_error
);
None
}
};
self.prev_log_idx = self.prev_log_idx + 1;
self.prev_log_idx = self.prev_log_idx + 1;
self.prev_log_term = entry.as_ref().map(|entry| entry.term);
self.tail = self.wrap_add(1);
self.len = new_len;
self.tail = self.wrap_add(1);
self.len = new_len;
entry
} else {
None
@ -134,7 +140,7 @@ impl RaftLogIndex {
}
pub fn cancel_from(&mut self, from_log_idx: LogIdx) -> Result<Vec<RaftLogDataEntry>, ()> {
let mut cancelled = Vec::new();
let mut cancelled = Vec::new();
let from_offset: u64 = from_log_idx.id.checked_sub((self.prev_log_idx + 1).id).ok_or(())?;
let from_offset: u32 = from_offset.to_u32().ok_or(())?;
if let Some(cancel_len) = self.len.checked_sub(from_offset) {
@ -167,11 +173,15 @@ impl RaftLogIndex {
error!("error reading raft log index entry {} to cancel", cancel_log_idx);
}
}
self.len = from_offset;
self.len = from_offset;
self.last_log_term = new_last_log_term;
}
} else {
warn!("tried to cancel non-existent raft log entries from {} (last is {})", from_log_idx, self.last_log_idx());
warn!(
"tried to cancel non-existent raft log entries from {} (last is {})",
from_log_idx,
self.last_log_idx()
);
}
Ok(cancelled)
}
@ -185,8 +195,9 @@ impl RaftLogIndex {
if addend <= self.storage.len() {
addend - index_rem
} else {
(addend - index_rem).checked_rem(self.storage.len())
.unwrap_or_else(|| unreachable!())
(addend - index_rem)
.checked_rem(self.storage.len())
.unwrap_or_else(|| unreachable!())
}
}
}
@ -196,6 +207,7 @@ impl StorageValue for RaftLogIndexEntry {
fn encoded_len() -> u32 {
8 + 8 + 8 + 4
}
fn encode<B: BufMut>(maybe_value: Option<&Self>, buf: &mut B) {
if let Some(value) = maybe_value {
buf.put_u64_le(value.term.id);
@ -210,6 +222,7 @@ impl StorageValue for RaftLogIndexEntry {
buf.put_u64_le(0);
}
}
fn decode<B: Buf>(buf: &mut B) -> Option<Self> {
let term = buf.get_u64_le();
if term != 0 {
@ -227,10 +240,7 @@ impl StorageValue for RaftLogIndexEntry {
}),
})
} else {
Some(Self {
term,
data: None,
})
Some(Self { term, data: None })
}
} else {
None

View File

@ -16,12 +16,13 @@ pub struct StorageArray<V> {
//
impl<V> StorageArray<V>
where V: StorageValue,
where V: StorageValue
{
pub fn new(length: u32, cache_size: usize) -> Result<Self, ()> {
let page_count = length.saturating_add(V::items_per_page() - 1)
.checked_div(V::items_per_page())
.unwrap_or_else(|| static_unreachable!());
let page_count = length
.saturating_add(V::items_per_page() - 1)
.checked_div(V::items_per_page())
.unwrap_or_else(|| static_unreachable!());
Ok(Self {
cache: StoragePageCache::with_page_count(page_count, cache_size)?,
})

View File

@ -7,14 +7,14 @@
use crate::prelude::*;
use std::convert::{TryInto};
use std::convert::TryInto;
use std::num::*;
use sgx_ffi::untrusted_slice::*;
use sgx_ffi::util::{SecretValue};
use sgx_ffi::util::SecretValue;
use sgxsd_ffi::{AesGcmIv, AesGcmKey, AesGcmMac};
use crate::ffi::ecalls::{kbupd_enclave_alloc_untrusted};
use crate::ffi::ecalls::kbupd_enclave_alloc_untrusted;
const TAG_LENGTH: usize = StorageData::tag_len() as usize;
@ -41,7 +41,7 @@ impl StorageData {
Ok(Self {
cipher: Default::default(),
data,
nonce: NonZeroU64::new(1).unwrap_or_else(|| static_unreachable!()),
nonce: NonZeroU64::new(1).unwrap_or_else(|| static_unreachable!()),
})
}
@ -57,35 +57,35 @@ impl StorageData {
let data_len = read_len.checked_sub(TAG_LENGTH).ok_or(())?;
let mut data = self.data.offset(offset).read_bytes(read_len)?;
let mac_data: &[u8] = data.get(data_len..read_len).unwrap_or_else(|| unreachable!());
let mac_data: &[u8] = data.get(data_len..read_len).unwrap_or_else(|| unreachable!());
let mac_data: &[u8; TAG_LENGTH] = mac_data.try_into().unwrap_or_else(|_| unreachable!());
let mac = AesGcmMac { data: *mac_data };
let mac = AesGcmMac { data: *mac_data };
let mut iv = AesGcmIv::default();
let iv_data: &mut [u8] = iv.data.get_mut(4..).unwrap_or_else(|| static_unreachable!());
let mut iv = AesGcmIv::default();
let iv_data: &mut [u8] = iv.data.get_mut(4..).unwrap_or_else(|| static_unreachable!());
let iv_data: &mut [u8; 8] = iv_data.try_into().unwrap_or_else(|_| static_unreachable!());
*iv_data = nonce.0.get().to_be_bytes();
*iv_data = nonce.0.get().to_be_bytes();
data.truncate(data_len);
let mut data = SecretValue::new(data);
match self.cipher.decrypt(data.get_mut(), &[], &iv, &mac) {
Ok(()) => Ok(data),
Ok(()) => Ok(data),
Err(error) => {
error!("error decrypting storage data at offset {} length {}: {}", offset, read_len, error);
Err(())
}
}
}
pub fn write(&mut self, offset: usize, mut data: SecretValue<Vec<u8>>) -> Result<StorageDataNonce, ()> {
let nonce = self.nonce;
self.nonce = NonZeroU64::new(self.nonce.get().checked_add(1).ok_or(())?)
.unwrap_or_else(|| unreachable!());
let mut iv = AesGcmIv::default();
let iv_data: &mut [u8] = iv.data.get_mut(4..).unwrap_or_else(|| static_unreachable!());
pub fn write(&mut self, offset: usize, mut data: SecretValue<Vec<u8>>) -> Result<StorageDataNonce, ()> {
let nonce = self.nonce;
self.nonce = NonZeroU64::new(self.nonce.get().checked_add(1).ok_or(())?).unwrap_or_else(|| unreachable!());
let mut iv = AesGcmIv::default();
let iv_data: &mut [u8] = iv.data.get_mut(4..).unwrap_or_else(|| static_unreachable!());
let iv_data: &mut [u8; 8] = iv_data.try_into().unwrap_or_else(|_| static_unreachable!());
*iv_data = nonce.get().to_be_bytes();
*iv_data = nonce.get().to_be_bytes();
let mut mac = AesGcmMac { data: Default::default() };
@ -98,7 +98,12 @@ impl StorageData {
Ok(StorageDataNonce(nonce))
}
Err(error) => {
error!("error encrypting storage data at offset {} length {}: {}", offset, data.get().len(), error);
error!(
"error encrypting storage data at offset {} length {}: {}",
offset,
data.get().len(),
error
);
Err(())
}
}
@ -132,20 +137,27 @@ mod test {
fn new(scenario: &Scenario, want_size: usize) -> Self {
if want_size != 0 {
let mut data_vec: Vec<u8> = Vec::with_capacity(want_size);
let data: *mut u8 = data_vec.as_mut_ptr();
let size: usize = data_vec.capacity();
let data: *mut u8 = data_vec.as_mut_ptr();
let size: usize = data_vec.capacity();
std::mem::forget(data_vec);
mocks::expect_kbupd_enclave_ocall_alloc(scenario, size, data as *mut libc::c_void, size);
let storage = StorageData::new(size).unwrap();
assert_eq!(storage.len(), size);
Self { storage: Some(storage), data: Some((data, size)) }
Self {
storage: Some(storage),
data: Some((data, size)),
}
} else {
let storage = StorageData::new(0).unwrap();
assert_eq!(storage.len(), 0);
Self { storage: Some(storage), data: None }
Self {
storage: Some(storage),
data: None,
}
}
}
fn get_mut(&mut self) -> &mut StorageData {
self.storage.as_mut().unwrap()
}
@ -162,8 +174,8 @@ mod test {
#[test]
fn test_storage_data_invalid_empty() {
let scenario = Scenario::new();
let nonce = StorageDataNonce::new(NonZeroU64::new(1).unwrap());
let scenario = Scenario::new();
let nonce = StorageDataNonce::new(NonZeroU64::new(1).unwrap());
let mut storage = TestStorageData::new(&scenario, 0);
for &offset in &[0, 1, usize::max_value()] {
@ -178,14 +190,17 @@ mod test {
#[test]
fn test_storage_data_valid() {
let scenario = Scenario::new();
let nonce = StorageDataNonce::new(NonZeroU64::new(1).unwrap());
let scenario = Scenario::new();
let nonce = StorageDataNonce::new(NonZeroU64::new(1).unwrap());
let mut storage = TestStorageData::new(&scenario, 2 + TAG_LENGTH);
for &offset in &[0, 1] {
for &length in &[0, 1] {
assert!(storage.get_mut().write(offset, SecretValue::new(vec![0; length])).is_ok());
assert_eq!(storage.get_mut().read(offset, length + TAG_LENGTH, nonce).unwrap().get().len(), length);
assert_eq!(
storage.get_mut().read(offset, length + TAG_LENGTH, nonce).unwrap().get().len(),
length
);
}
}
assert!(storage.get_mut().write(2, SecretValue::new(vec![])).is_ok());
@ -194,34 +209,34 @@ mod test {
#[test]
fn test_storage_invalid_overflow() {
let scenario = Scenario::new();
let nonce = StorageDataNonce::new(NonZeroU64::new(1).unwrap());
let scenario = Scenario::new();
let nonce = StorageDataNonce::new(NonZeroU64::new(1).unwrap());
let mut storage = TestStorageData::new(&scenario, TAG_LENGTH - 1);
assert!(storage.get_mut().read(0, 0, nonce).is_err());
assert!(storage.get_mut().read(0, 0, nonce).is_err());
assert!(storage.get_mut().read(0, TAG_LENGTH - 1, nonce).is_err());
assert!(storage.get_mut().write(0, SecretValue::new(vec![])).is_err());
assert!(storage.get_mut().read(0, TAG_LENGTH, nonce).is_err());
let scenario = Scenario::new();
let scenario = Scenario::new();
let mut storage = TestStorageData::new(&scenario, TAG_LENGTH);
assert!(storage.get_mut().write(0, SecretValue::new(vec![])).is_ok());
assert!(storage.get_mut().read(0, TAG_LENGTH, nonce).is_ok());
assert!(storage.get_mut().write(0, SecretValue::new(vec![0; 1])).is_err());
assert!(storage.get_mut().read(0, TAG_LENGTH + 1, nonce).is_err());
assert!(storage.get_mut().read(0, TAG_LENGTH + 1, nonce).is_err());
assert!(storage.get_mut().read(0, usize::max_value(), nonce).is_err());
}
#[test]
fn test_storage_invalid_integer_overflow() {
let scenario = Scenario::new();
let nonce = StorageDataNonce::new(NonZeroU64::new(1).unwrap());
let scenario = Scenario::new();
let nonce = StorageDataNonce::new(NonZeroU64::new(1).unwrap());
let data: *mut u8 = unsafe { std::ptr::NonNull::dangling().as_mut() };
let size: usize = usize::max_value() - 1;
let size: usize = usize::max_value() - 1;
mocks::expect_kbupd_enclave_ocall_alloc(&scenario, size, data as *mut libc::c_void, size);
let storage = StorageData::new(size).unwrap();
@ -229,8 +244,8 @@ mod test {
assert!(storage.read(0, usize::max_value(), nonce).is_err());
assert!(storage.read(size - TAG_LENGTH + 1, TAG_LENGTH, nonce).is_err());
assert!(storage.read(size, TAG_LENGTH, nonce).is_err());
assert!(storage.read(usize::max_value(), TAG_LENGTH, nonce).is_err());
assert!(storage.read(size, TAG_LENGTH, nonce).is_err());
assert!(storage.read(usize::max_value(), TAG_LENGTH, nonce).is_err());
assert!(storage.read(usize::max_value(), usize::max_value(), nonce).is_err());
}

View File

@ -12,7 +12,7 @@ use std::marker::*;
use std::rc::*;
use bytes::*;
use num_traits::{ToPrimitive};
use num_traits::ToPrimitive;
use crate::lru::*;
use crate::storage::storage_data::*;
@ -23,9 +23,9 @@ const PAGE_SIZE: u16 = 4096;
pub struct StoragePageCache<V> {
cache_size: usize,
pages: Box<[StoragePage<V>]>,
cached: Lru<StoragePageIndex>,
data: StorageData,
pages: Box<[StoragePage<V>]>,
cached: Lru<StoragePageIndex>,
data: StorageData,
}
#[derive(Debug)]
@ -79,18 +79,18 @@ struct UncachedStoragePage {
}
impl<V> StoragePageCache<V>
where V: StorageValue,
where V: StorageValue
{
pub fn with_page_count(page_count: u32, cache_size: usize) -> Result<Self, ()> {
let data_size = page_count.to_usize().checked_mul(PAGE_SIZE.into()).ok_or(())?;
Self::new(data_size, cache_size)
}
pub fn new(data_size: usize, cache_size: usize) -> Result<Self, ()> {
let cache_size = cache_size.max(1);
let data = StorageData::new(data_size)?;
let cache_size = cache_size.max(1);
let data = StorageData::new(data_size)?;
let max_page_count = StorageItemIndex::<V>::max_page_count().to_usize();
let page_count = (data.len() / usize::from(PAGE_SIZE)).max(1)
.min(max_page_count);
let page_count = (data.len() / usize::from(PAGE_SIZE)).max(1).min(max_page_count);
let mut pages: Vec<StoragePage<V>> = Vec::with_capacity(page_count);
pages.extend(std::iter::repeat_with(Default::default).take(page_count));
@ -98,7 +98,7 @@ where V: StorageValue,
Ok(Self {
cache_size,
pages: pages.into(),
pages: pages.into(),
cached: Default::default(),
data,
})
@ -109,10 +109,7 @@ where V: StorageValue,
}
pub fn item_index(&self, index: u32) -> Option<StorageItemIndex<V>> {
let entry = StorageItemIndex {
index,
_data: PhantomData,
};
let entry = StorageItemIndex { index, _data: PhantomData };
if u32::from(entry.page_index()) < self.page_count() {
Some(entry)
} else {
@ -136,8 +133,7 @@ where V: StorageValue,
fn read_page(&mut self, page_index: StoragePageIndex) -> Result<&mut CachedStoragePage<V>, StorageError> {
match self.pages.get(usize::from(page_index)) {
Some(StoragePage::Free) |
Some(StoragePage::Uncached(_)) => {
Some(StoragePage::Free) | Some(StoragePage::Uncached(_)) => {
while self.cached.len() >= self.cache_size {
if let Some(evict_lru_entry) = self.cached.pop_front() {
debug!("evicting storage page {}", evict_lru_entry.get());
@ -147,15 +143,12 @@ where V: StorageValue,
}
}
}
None |
Some(StoragePage::Poisoned) |
Some(StoragePage::Cached(_)) => {
}
None | Some(StoragePage::Poisoned) | Some(StoragePage::Cached(_)) => {}
}
if let Some(page) = self.pages.get_mut(usize::from(page_index)) {
let cached_page = match page.read(&self.data, &mut self.cached, page_index) {
Ok(cached_page) => cached_page,
Ok(cached_page) => cached_page,
Err(storage_error) => {
error!("fatal error reading page {}: {}", page_index, storage_error);
return Err(storage_error);
@ -206,38 +199,43 @@ impl From<StoragePageIndex> for usize {
}
impl<V> StorageItemIndex<V>
where V: StorageValue,
where V: StorageValue
{
fn page_index(&self) -> StoragePageIndex {
StoragePageIndex(self.index / V::items_per_page())
}
fn item_index(&self) -> usize {
(self.index % V::items_per_page()).to_usize()
}
fn max_page_count() -> u32 {
u32::max_value() / V::items_per_page()
}
}
impl<V> StoragePage<V>
where V: StorageValue,
where V: StorageValue
{
fn read(&mut self,
data: &StorageData,
cached: &mut Lru<StoragePageIndex>,
page_index: StoragePageIndex)
-> Result<&mut CachedStoragePage<V>, StorageError>
fn read(
&mut self,
data: &StorageData,
cached: &mut Lru<StoragePageIndex>,
page_index: StoragePageIndex,
) -> Result<&mut CachedStoragePage<V>, StorageError>
{
match self {
StoragePage::Cached(cached_page) => {
Ok(cached_page)
}
StoragePage::Cached(cached_page) => Ok(cached_page),
StoragePage::Uncached(uncached_page) => {
debug!("reading storage page {}", page_index);
let lru_entry = cached.push_back(page_index);
let mut cached_page = Box::new(CachedStoragePage::new(CachedStoragePageDirtyState::Clean(uncached_page.nonce), lru_entry));
let offset = usize::from(page_index).checked_mul(PAGE_SIZE.into())
.ok_or(StorageError::InternalError)?;
let lru_entry = cached.push_back(page_index);
let mut cached_page = Box::new(CachedStoragePage::new(
CachedStoragePageDirtyState::Clean(uncached_page.nonce),
lru_entry,
));
let offset = usize::from(page_index)
.checked_mul(PAGE_SIZE.into())
.ok_or(StorageError::InternalError)?;
match data.read(offset, PAGE_SIZE.into(), uncached_page.nonce) {
Ok(decrypted) => {
let items_data = decrypted.get()[..].chunks(V::encoded_len().to_usize());
@ -248,7 +246,7 @@ where V: StorageValue,
*self = StoragePage::Cached(cached_page);
match self {
StoragePage::Cached(cached_page) => Ok(cached_page),
_ => static_unreachable!(),
_ => static_unreachable!(),
}
}
Err(()) => {
@ -257,19 +255,18 @@ where V: StorageValue,
}
}
}
StoragePage::Poisoned => {
Err(StorageError::ReadError)
}
StoragePage::Poisoned => Err(StorageError::ReadError),
StoragePage::Free => {
let lru_entry = cached.push_back(page_index);
*self = StoragePage::Cached(Box::new(CachedStoragePage::new(CachedStoragePageDirtyState::Dirty, lru_entry)));
match self {
StoragePage::Cached(cached_page) => Ok(cached_page),
_ => static_unreachable!(),
_ => static_unreachable!(),
}
}
}
}
fn write(&mut self, data: &mut StorageData, page_index: StoragePageIndex) {
let cached_page = match self {
StoragePage::Cached(cached_page) => cached_page,
@ -277,15 +274,11 @@ where V: StorageValue,
};
*self = match cached_page.dirty {
CachedStoragePageDirtyState::Clean(nonce) => {
StoragePage::Uncached(UncachedStoragePage { nonce })
}
CachedStoragePageDirtyState::Dirty if cached_page.is_empty() => {
StoragePage::Free
}
CachedStoragePageDirtyState::Clean(nonce) => StoragePage::Uncached(UncachedStoragePage { nonce }),
CachedStoragePageDirtyState::Dirty if cached_page.is_empty() => StoragePage::Free,
CachedStoragePageDirtyState::Dirty => {
let mut secret_encoded_vec = SecretValue::new(Vec::with_capacity(PAGE_SIZE.into()));
let encoded: &mut Vec<u8> = secret_encoded_vec.get_mut();
let encoded: &mut Vec<u8> = secret_encoded_vec.get_mut();
for (item_index, item) in cached_page.items.iter().enumerate() {
encoded.resize(item_index.saturating_mul(V::encoded_len().to_usize()), 0);
V::encode(item.as_ref(), encoded);
@ -294,9 +287,7 @@ where V: StorageValue,
let offset_res = usize::from(page_index).checked_mul(PAGE_SIZE.into()).ok_or(());
match offset_res.and_then(|offset: usize| data.write(offset, secret_encoded_vec)) {
Ok(nonce) => {
StoragePage::Uncached(UncachedStoragePage { nonce })
}
Ok(nonce) => StoragePage::Uncached(UncachedStoragePage { nonce }),
Err(()) => {
error!("wrote out of bounds page {}!", page_index);
return;
@ -314,10 +305,10 @@ impl<V> Default for StoragePage<V> {
}
impl<V> CachedStoragePage<V>
where V: StorageValue,
where V: StorageValue
{
fn new(dirty: CachedStoragePageDirtyState, lru_entry: Weak<LruEntry<StoragePageIndex>>) -> Self {
let size = V::items_per_page();
let size = V::items_per_page();
let mut items = Vec::with_capacity(size.to_usize());
items.extend(std::iter::repeat_with(Default::default).take(size.to_usize()));
Self {
@ -326,16 +317,19 @@ where V: StorageValue,
items: items.into(),
}
}
fn get_item(&mut self, item_index: &StorageItemIndex<V>) -> Option<&V> {
self.items.get(item_index.item_index())
.unwrap_or_else(|| panic!("overflow"))
.as_ref()
self.items
.get(item_index.item_index())
.unwrap_or_else(|| panic!("overflow"))
.as_ref()
}
fn get_item_mut(&mut self, item_index: &StorageItemIndex<V>) -> &mut Option<V> {
self.dirty = CachedStoragePageDirtyState::Dirty;
self.items.get_mut(item_index.item_index())
.unwrap_or_else(|| panic!("overflow"))
self.items.get_mut(item_index.item_index()).unwrap_or_else(|| panic!("overflow"))
}
fn is_empty(&self) -> bool {
for item_slot in self.items.iter() {
if item_slot.is_some() {

View File

@ -10,11 +10,11 @@ use crate::prelude::*;
use std::fmt;
use base64;
use serde::de::Error;
use serde::{Deserialize, Deserializer};
use serde::de::{Error};
pub use sgx_ffi::util::*;
pub use crate::protobufs::kbupd::*;
pub use sgx_ffi::util::*;
pub struct ToHex<'a>(pub &'a [u8]);
@ -36,19 +36,19 @@ impl<'a> fmt::Debug for ToHex<'a> {
pub struct OptionDisplay<T>(pub Option<T>);
impl<T> fmt::Display for OptionDisplay<T>
where T: fmt::Display,
where T: fmt::Display
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
let Self(inner) = self;
match inner {
Some(inner) => fmt::Display::fmt(inner, fmt),
None => write!(fmt, "<none>"),
None => write!(fmt, "<none>"),
}
}
}
impl<T> fmt::Debug for OptionDisplay<T>
where T: fmt::Display,
where T: fmt::Display
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
fmt::Display::fmt(self, fmt)
@ -58,20 +58,20 @@ where T: fmt::Display,
pub struct ListDisplay<T>(pub T);
impl<T> fmt::Display for ListDisplay<T>
where T: IntoIterator + Clone,
T::Item: fmt::Display,
where
T: IntoIterator + Clone,
T::Item: fmt::Display,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
let Self(inner) = self;
fmt.debug_list()
.entries(inner.clone().into_iter().map(DisplayAsDebug))
.finish()
fmt.debug_list().entries(inner.clone().into_iter().map(DisplayAsDebug)).finish()
}
}
impl<T> fmt::Debug for ListDisplay<T>
where T: IntoIterator + Clone,
T::Item: fmt::Display,
where
T: IntoIterator + Clone,
T::Item: fmt::Display,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
fmt::Display::fmt(self, fmt)
@ -81,7 +81,7 @@ where T: IntoIterator + Clone,
pub struct DisplayAsDebug<T>(pub T);
impl<T> fmt::Debug for DisplayAsDebug<T>
where T: fmt::Display,
where T: fmt::Display
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
let Self(inner) = self;
@ -91,14 +91,11 @@ where T: fmt::Display,
pub fn deserialize_base64<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Vec<u8>, D::Error> {
Deserialize::deserialize(deserializer)
.and_then(|base64: &[u8]| {
base64::decode(base64).map_err(|error| D::Error::custom(error.to_string()))
})
.and_then(|base64: &[u8]| base64::decode(base64).map_err(|error| D::Error::custom(error.to_string())))
}
pub fn copy_exact<T>(src: &[u8]) -> Result<T, ()>
where T: AsMut<[u8]> + Default
{
where T: AsMut<[u8]> + Default {
let mut dst = T::default();
if src.len() == dst.as_mut().len() {
dst.as_mut().copy_from_slice(src);

View File

@ -0,0 +1,6 @@
//
// Copyright (C) {\d{4}(, \d{4})*} Signal Messenger, LLC.
// All rights reserved.
//
// SPDX-License-Identifier: AGPL-3.0-or-later
//

11
enclave/rustfmt.toml Normal file
View File

@ -0,0 +1,11 @@
binop_separator = "Back"
enum_discrim_align_threshold = 30
license_template_path = "rustfmt.license-template"
max_width = 140
overflow_delimited_expr = true
reorder_impl_items = true
struct_field_align_threshold = 30
unstable_features = true
use_field_init_shorthand = true
version = "Two"
where_single_line = true

File diff suppressed because it is too large Load Diff

View File

@ -6,17 +6,13 @@
//
#![cfg_attr(not(any(test, feature = "test")), no_std)]
#![allow(
unused_parens,
clippy::style,
clippy::large_enum_variant,
)]
#![allow(unused_parens, clippy::style, clippy::large_enum_variant)]
#![warn(
bare_trait_objects,
elided_lifetimes_in_paths,
trivial_numeric_casts,
variant_size_differences,
clippy::integer_arithmetic,
clippy::integer_arithmetic
)]
#![deny(
clippy::cast_possible_truncation,
@ -47,15 +43,17 @@
clippy::use_debug,
clippy::use_self,
clippy::use_underscore_binding,
clippy::wildcard_enum_match_arm,
clippy::wildcard_enum_match_arm
)]
extern crate alloc;
#[allow(dead_code, non_camel_case_types, non_upper_case_globals, non_snake_case, improper_ctypes, clippy::all, clippy::pedantic, clippy::integer_arithmetic)]
#[rustfmt::skip]
mod bindgen_wrapper;
pub mod sgx;
pub mod untrusted_slice;
pub mod util;
#[cfg(any(test, feature = "test"))] pub mod mocks;
#[cfg(any(test, feature = "test"))]
pub mod mocks;

View File

@ -7,24 +7,16 @@
#![allow(clippy::all, clippy::option_unwrap_used, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
use std::cell::{RefCell};
use std::cell::RefCell;
use mockers::*;
use mockers::matchers::*;
use mockers::*;
use mockers_derive::mocked;
use test_ffi::*;
use super::bindgen_wrapper::{
errno_t,
sgx_attributes_t,
sgx_cpu_svn_t,
sgx_key_id_t,
sgx_measurement_t,
sgx_target_info_t,
sgx_report_body_t,
sgx_report_data_t,
sgx_report_t,
sgx_status_t,
errno_t, sgx_attributes_t, sgx_cpu_svn_t, sgx_key_id_t, sgx_measurement_t, sgx_report_body_t, sgx_report_data_t, sgx_report_t,
sgx_status_t, sgx_target_info_t,
};
//
@ -40,15 +32,12 @@ pub trait SgxIsOutsideEnclave {
fn sgx_is_outside_enclave(&self, addr: *const ::std::os::raw::c_void, size: usize) -> bool;
}
pub fn expect_sgx_is_outside_enclave(scenario: &Scenario,
ptr: *const libc::c_void,
size: usize,
res: bool) {
pub fn expect_sgx_is_outside_enclave(scenario: &Scenario, ptr: *const libc::c_void, size: usize, res: bool) {
let mock = mock_for(&SGX_IS_OUTSIDE_ENCLAVE, scenario);
scenario.expect(mock.sgx_is_outside_enclave(
eq(ptr as *const libc::c_void),
eq(size)
).and_return(res));
scenario.expect(
mock.sgx_is_outside_enclave(eq(ptr as *const libc::c_void), eq(size))
.and_return(res),
);
}
//
@ -59,14 +48,9 @@ pub mod impls {
use super::*;
#[no_mangle]
pub extern "C" fn sgx_is_outside_enclave(
addr: *const ::std::os::raw::c_void,
size: usize,
) -> ::std::os::raw::c_int {
let res = SGX_IS_OUTSIDE_ENCLAVE.with(|mock| {
(mock.borrow().as_ref().expect("no mock for sgx_is_outside_enclave"))
.sgx_is_outside_enclave(addr, size)
});
pub extern "C" fn sgx_is_outside_enclave(addr: *const ::std::os::raw::c_void, size: usize) -> ::std::os::raw::c_int {
let res = SGX_IS_OUTSIDE_ENCLAVE
.with(|mock| (mock.borrow().as_ref().expect("no mock for sgx_is_outside_enclave")).sgx_is_outside_enclave(addr, size));
res as i32
}
@ -75,7 +59,8 @@ pub mod impls {
target_info: *const sgx_target_info_t,
report_data: *const sgx_report_data_t,
report: *mut sgx_report_t,
) -> sgx_status_t {
) -> sgx_status_t
{
if !target_info.is_null() {
unsafe { std::ptr::read_volatile(target_info) };
}
@ -83,41 +68,38 @@ pub mod impls {
let mut reserved4 = [0; 42];
read_rand(&mut config_id);
read_rand(&mut reserved4);
unsafe { *report = sgx_report_t {
body: sgx_report_body_t {
cpu_svn: sgx_cpu_svn_t { svn: rand() },
misc_select: rand(),
reserved1: rand(),
isv_ext_prod_id: rand(),
attributes: sgx_attributes_t {
flags: rand(),
xfrm: rand(),
unsafe {
*report = sgx_report_t {
body: sgx_report_body_t {
cpu_svn: sgx_cpu_svn_t { svn: rand() },
misc_select: rand(),
reserved1: rand(),
isv_ext_prod_id: rand(),
attributes: sgx_attributes_t {
flags: rand(),
xfrm: rand(),
},
mr_enclave: sgx_measurement_t { m: rand() },
reserved2: rand(),
mr_signer: sgx_measurement_t { m: rand() },
reserved3: rand(),
config_id,
isv_prod_id: rand(),
isv_svn: rand(),
config_svn: rand(),
reserved4,
isv_family_id: rand(),
report_data: *report_data,
},
mr_enclave: sgx_measurement_t { m: rand() },
reserved2: rand(),
mr_signer: sgx_measurement_t { m: rand() },
reserved3: rand(),
config_id,
isv_prod_id: rand(),
isv_svn: rand(),
config_svn: rand(),
reserved4,
isv_family_id: rand(),
report_data: *report_data,
},
key_id: sgx_key_id_t { id: rand() },
mac: rand(),
}};
key_id: sgx_key_id_t { id: rand() },
mac: rand(),
}
};
0
}
#[no_mangle]
pub extern "C" fn memset_s(s: *mut ::std::os::raw::c_void,
smax: usize,
c: ::std::os::raw::c_int,
n: usize)
-> errno_t
{
pub extern "C" fn memset_s(s: *mut ::std::os::raw::c_void, smax: usize, c: ::std::os::raw::c_int, n: usize) -> errno_t {
assert!(c >= 0);
assert!(n.checked_add(c as usize).unwrap() <= smax);
assert!(!s.is_null());
@ -125,10 +107,11 @@ pub mod impls {
}
#[no_mangle]
pub extern "C" fn consttime_memequal(b1: *const ::std::os::raw::c_void,
b2: *const ::std::os::raw::c_void,
len: usize)
-> ::std::os::raw::c_int
pub extern "C" fn consttime_memequal(
b1: *const ::std::os::raw::c_void,
b2: *const ::std::os::raw::c_void,
len: usize,
) -> ::std::os::raw::c_int
{
unsafe {
let b1 = std::slice::from_raw_parts(b1 as *const u8, len);

View File

@ -5,25 +5,15 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//
use alloc::vec::{Vec};
use core::slice;
use alloc::vec::Vec;
use core::mem;
use core::ptr;
use core::slice;
use super::bindgen_wrapper::{
sgx_create_report,
sgx_attributes_t,
sgx_measurement_t,
sgx_report_data_t,
sgx_target_info_t,
};
use super::bindgen_wrapper::{sgx_attributes_t, sgx_create_report, sgx_measurement_t, sgx_report_data_t, sgx_target_info_t};
pub use super::bindgen_wrapper::{
sgx_status_t as SgxStatus,
sgx_report_t as SgxReport,
sgx_report_t as SgxReport, sgx_status_t as SgxStatus, SGX_ERROR_INVALID_PARAMETER, SGX_ERROR_INVALID_STATE, SGX_ERROR_UNEXPECTED,
SGX_SUCCESS,
SGX_ERROR_INVALID_PARAMETER,
SGX_ERROR_INVALID_STATE,
SGX_ERROR_UNEXPECTED,
};
pub struct SgxTargetInfo<'a> {
@ -35,13 +25,9 @@ pub struct SgxTargetInfo<'a> {
pub config_id: &'a [u8],
}
pub fn create_report(qe_target_info: &SgxTargetInfo<'_>,
report_data_in: &[u8])
-> Result<Vec<u8>, SgxStatus> {
pub fn create_report(qe_target_info: &SgxTargetInfo<'_>, report_data_in: &[u8]) -> Result<Vec<u8>, SgxStatus> {
let mut sgx_qe_target_info = sgx_target_info_t {
mr_enclave: sgx_measurement_t {
m: [0; 32],
},
mr_enclave: sgx_measurement_t { m: [0; 32] },
attributes: sgx_attributes_t {
flags: qe_target_info.flags,
xfrm: qe_target_info.xfrm,
@ -61,18 +47,15 @@ pub fn create_report(qe_target_info: &SgxTargetInfo<'_>,
}
sgx_qe_target_info.mr_enclave.m.copy_from_slice(qe_target_info.mrenclave);
sgx_qe_target_info.config_id.copy_from_slice(qe_target_info.config_id);
let report = create_report_raw(Some(&sgx_qe_target_info), report_data_in)?;
let report = create_report_raw(Some(&sgx_qe_target_info), report_data_in)?;
let report_ref = &report;
unsafe {
let report_slice = slice::from_raw_parts(report_ref as *const SgxReport as *const u8,
mem::size_of::<SgxReport>());
let report_slice = slice::from_raw_parts(report_ref as *const SgxReport as *const u8, mem::size_of::<SgxReport>());
Ok(report_slice.to_vec())
}
}
pub fn create_report_raw(qe_target_info: Option<&sgx_target_info_t>,
report_data_in: &[u8])
-> Result<SgxReport, SgxStatus> {
pub fn create_report_raw(qe_target_info: Option<&sgx_target_info_t>, report_data_in: &[u8]) -> Result<SgxReport, SgxStatus> {
let mut report_data = sgx_report_data_t { d: [0; 64] };
if let Some(()) = report_data.d.get_mut(..report_data_in.len()).map(|report_data_part| {
report_data_part.copy_from_slice(report_data_in);
@ -85,11 +68,7 @@ pub fn create_report_raw(qe_target_info: Option<&sgx_target_info_t>,
sgx_create_report(ptr::null(), &report_data, &mut report)
}
};
if res == SGX_SUCCESS {
return Ok(report)
} else {
return Err(res)
}
if res == SGX_SUCCESS { return Ok(report) } else { return Err(res) }
} else {
return Err(SGX_ERROR_INVALID_PARAMETER);
}
@ -97,17 +76,17 @@ pub fn create_report_raw(qe_target_info: Option<&sgx_target_info_t>,
#[cfg(test)]
pub mod tests {
use test_ffi::{rand_bytes};
use test_ffi::rand_bytes;
use super::*;
fn target_info<'a>(mrenclave: &'a [u8], config_id: &'a [u8]) -> SgxTargetInfo<'a> {
SgxTargetInfo {
mrenclave,
flags: Default::default(),
xfrm: Default::default(),
flags: Default::default(),
xfrm: Default::default(),
misc_select: Default::default(),
config_svn: Default::default(),
config_svn: Default::default(),
config_id,
}
}
@ -116,15 +95,27 @@ pub mod tests {
fn create_report_bad_args() {
let dummy_target_info = sgx_target_info_t::default();
let qe_mrenclave = rand_bytes(vec![0; std::mem::size_of_val(&dummy_target_info.mr_enclave)]);
let qe_config_id = rand_bytes(vec![0; std::mem::size_of_val(&dummy_target_info.config_id)]);
let qe_target_info = target_info(&qe_mrenclave, &qe_config_id);
let qe_mrenclave = rand_bytes(vec![0; std::mem::size_of_val(&dummy_target_info.mr_enclave)]);
let qe_config_id = rand_bytes(vec![0; std::mem::size_of_val(&dummy_target_info.config_id)]);
let qe_target_info = target_info(&qe_mrenclave, &qe_config_id);
let bad_report_data = rand_bytes(vec![0; std::mem::size_of::<sgx_report_data_t>() + 1]);
assert_eq!(Err(SGX_ERROR_INVALID_PARAMETER), create_report(&target_info(&[], &qe_config_id), &[]));
assert_eq!(Err(SGX_ERROR_INVALID_PARAMETER), create_report(&target_info(&[0], &qe_config_id), &[]));
assert_eq!(Err(SGX_ERROR_INVALID_PARAMETER), create_report(&target_info(&qe_mrenclave, &[]), &[]));
assert_eq!(Err(SGX_ERROR_INVALID_PARAMETER), create_report(&target_info(&qe_mrenclave, &[0]), &[]));
assert_eq!(
Err(SGX_ERROR_INVALID_PARAMETER),
create_report(&target_info(&[], &qe_config_id), &[])
);
assert_eq!(
Err(SGX_ERROR_INVALID_PARAMETER),
create_report(&target_info(&[0], &qe_config_id), &[])
);
assert_eq!(
Err(SGX_ERROR_INVALID_PARAMETER),
create_report(&target_info(&qe_mrenclave, &[]), &[])
);
assert_eq!(
Err(SGX_ERROR_INVALID_PARAMETER),
create_report(&target_info(&qe_mrenclave, &[0]), &[])
);
assert_eq!(Err(SGX_ERROR_INVALID_PARAMETER), create_report(&qe_target_info, &bad_report_data));
assert!(create_report(&qe_target_info, &[]).is_ok());
}
@ -133,21 +124,21 @@ pub mod tests {
fn create_report_valid() {
let dummy_target_info = sgx_target_info_t::default();
let qe_mrenclave = rand_bytes(vec![0; std::mem::size_of_val(&dummy_target_info.mr_enclave)]);
let qe_config_id = rand_bytes(vec![0; std::mem::size_of_val(&dummy_target_info.config_id)]);
let qe_mrenclave = rand_bytes(vec![0; std::mem::size_of_val(&dummy_target_info.mr_enclave)]);
let qe_config_id = rand_bytes(vec![0; std::mem::size_of_val(&dummy_target_info.config_id)]);
let qe_target_info = target_info(&qe_mrenclave, &qe_config_id);
let report_data = rand_bytes(vec![0; std::mem::size_of::<sgx_report_data_t>()]);
let report_data = rand_bytes(vec![0; std::mem::size_of::<sgx_report_data_t>()]);
for report_data_len in 0..=report_data.len() {
let res = create_report(&qe_target_info, &report_data[..report_data_len]);
assert!(res.is_ok());
if let Ok(report_bytes) = res {
let report: SgxReport = unsafe {
std::ptr::read_unaligned(report_bytes.as_ptr() as *const SgxReport)
};
let report: SgxReport = unsafe { std::ptr::read_unaligned(report_bytes.as_ptr() as *const SgxReport) };
assert_eq!(&report_data[..report_data_len], &report.body.report_data.d[..report_data_len]);
assert_eq!(&report.body.report_data.d[report_data_len..],
&vec![0; report.body.report_data.d.len() - report_data_len][..]);
assert_eq!(
&report.body.report_data.d[report_data_len..],
&vec![0; report.body.report_data.d.len() - report_data_len][..]
);
}
}
}

View File

@ -5,14 +5,12 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//
use alloc::vec::{Vec};
use alloc::vec::Vec;
use core::marker::*;
use core::num::*;
use core::ptr::*;
use super::bindgen_wrapper::{
sgx_is_outside_enclave,
};
use super::bindgen_wrapper::sgx_is_outside_enclave;
pub enum UntrustedSlice<'a> {
NonEmpty {
@ -46,7 +44,11 @@ impl<'a> UntrustedSlice<'a> {
if data.as_ptr().wrapping_add(size.get()) < data.as_ptr() {
return Err(());
}
Ok(UntrustedSlice::NonEmpty { data, size, _phantom: &PhantomData })
Ok(UntrustedSlice::NonEmpty {
data,
size,
_phantom: &PhantomData,
})
} else {
Ok(UntrustedSlice::Empty)
}
@ -62,7 +64,7 @@ impl<'a> UntrustedSlice<'a> {
pub fn as_ptr(&self) -> *const u8 {
match self {
Self::NonEmpty { data, .. } => data.as_ptr(),
Self::Empty => null(),
Self::Empty => null(),
}
}
@ -72,7 +74,11 @@ impl<'a> UntrustedSlice<'a> {
if let Some(size) = size.get().checked_sub(offset) {
if let Some(size) = NonZeroUsize::new(size) {
let data = unsafe { NonNull::new_unchecked(data.as_ptr().add(offset)) };
UntrustedSlice::NonEmpty { data, size, _phantom: &PhantomData }
UntrustedSlice::NonEmpty {
data,
size,
_phantom: &PhantomData,
}
} else {
UntrustedSlice::Empty
}
@ -140,10 +146,10 @@ impl<'a> Default for UntrustedSlice<'a> {
#[cfg(test)]
mod test {
use mockers::*;
use test_ffi::{rand_bytes};
use test_ffi::rand_bytes;
use super::*;
use super::super::mocks;
use super::*;
struct TestVec {
ptr: *mut u8,
@ -152,8 +158,8 @@ mod test {
impl TestVec {
fn new(size: usize) -> Self {
let mut data_vec: Vec<u8> = rand_bytes(vec![0; size]);
let ptr: *mut u8 = data_vec.as_mut_ptr();
let size: usize = data_vec.capacity();
let ptr: *mut u8 = data_vec.as_mut_ptr();
let size: usize = data_vec.capacity();
std::mem::forget(data_vec);
Self { ptr, size }
}

View File

@ -5,12 +5,12 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//
use core::ffi::{c_void};
use core::ffi::c_void;
use core::mem;
use super::bindgen_wrapper::{dlmallinfo};
use super::bindgen_wrapper::dlmallinfo;
pub use super::bindgen_wrapper::{memset_s, consttime_memequal};
pub use super::bindgen_wrapper::{consttime_memequal, memset_s};
pub fn clear(buf: &mut [u8]) {
let res = unsafe { memset_s(buf.as_ptr() as *mut c_void, buf.len(), 0, buf.len()) };
@ -21,11 +21,7 @@ pub fn consttime_eq(left: impl AsRef<[u8]>, right: impl AsRef<[u8]>) -> bool {
let left = left.as_ref();
let right = right.as_ref();
if left.len() == right.len() {
let res = unsafe {
consttime_memequal(left.as_ptr() as *const c_void,
right.as_ptr() as *const c_void,
left.len())
};
let res = unsafe { consttime_memequal(left.as_ptr() as *const c_void, right.as_ptr() as *const c_void, left.len()) };
res != 0
} else {
false

File diff suppressed because it is too large Load Diff

View File

@ -7,22 +7,16 @@
#![allow(clippy::all, clippy::option_unwrap_used, clippy::cast_sign_loss)]
use alloc::boxed::{Box};
use alloc::boxed::Box;
use core::ptr;
use core::slice;
use num_traits::{ToPrimitive};
use num_traits::ToPrimitive;
pub use super::bindgen_wrapper::{
sgxsd_msg_buf_t,
sgxsd_msg_from_t,
};
use super::bindgen_wrapper::{
sgxsd_enclave_server_noreply,
sgxsd_enclave_server_reply,
};
use super::bindgen_wrapper::{sgxsd_enclave_server_noreply, sgxsd_enclave_server_reply};
pub use super::bindgen_wrapper::{sgxsd_msg_buf_t, sgxsd_msg_from_t};
use sgx_ffi::sgx::*;
use sgx_ffi::util::{clear};
use sgx_ffi::util::clear;
pub trait SgxsdServer: Send + Sized {
type InitArgs;
@ -30,7 +24,12 @@ pub trait SgxsdServer: Send + Sized {
type TerminateArgs;
fn init(_args: Option<&Self::InitArgs>) -> Result<Self, SgxStatus>;
fn handle_call(&mut self, args: Option<&Self::HandleCallArgs>, request_data: &[u8], from: SgxsdMsgFrom) -> Result<(), (SgxStatus, SgxsdMsgFrom)>;
fn handle_call(
&mut self,
args: Option<&Self::HandleCallArgs>,
request_data: &[u8],
from: SgxsdMsgFrom,
) -> Result<(), (SgxStatus, SgxsdMsgFrom)>;
fn terminate(self, _args: Option<&Self::TerminateArgs>) -> Result<(), SgxStatus>;
}
@ -50,18 +49,27 @@ impl SgxsdMsgFrom {
clear(&mut from.server_key.data[..]);
res
}
#[cfg(any(test, feature = "test"))]
pub fn mock() -> Self {
Self::new(&mut sgxsd_msg_from_t { tag: Default::default(), valid: true, server_key: Default::default() })
Self::new(&mut sgxsd_msg_from_t {
tag: Default::default(),
valid: true,
server_key: Default::default(),
})
}
pub fn reply(mut self, msg: &mut [u8]) -> Result<(), SgxStatus> {
if let Some(size) = msg.len().to_u32() {
let msg_buf = sgxsd_msg_buf_t { data: msg.as_mut_ptr(), size };
let msg_buf = sgxsd_msg_buf_t {
data: msg.as_mut_ptr(),
size,
};
if let Some(mut msg_from) = self.0.take() {
let msg_from_ref = &mut *msg_from;
match unsafe { sgxsd_enclave_server_reply(msg_buf, msg_from_ref) } {
0 => Ok(()),
err => Err(err)
err => Err(err),
}
} else {
Err(SGX_ERROR_INVALID_STATE)
@ -70,6 +78,7 @@ impl SgxsdMsgFrom {
Err(SGX_ERROR_UNEXPECTED)
}
}
fn forget(mut self) {
if let Some(mut from) = self.0.take() {
from.valid = false;
@ -86,31 +95,30 @@ impl Drop for SgxsdMsgFrom {
}
}
pub fn sgxsd_enclave_server_init<S>(p_args: *const S::InitArgs,
pp_state: *mut *mut S)
-> SgxStatus
where S: SgxsdServer,
{
pub fn sgxsd_enclave_server_init<S>(p_args: *const S::InitArgs, pp_state: *mut *mut S) -> SgxStatus
where S: SgxsdServer {
let args = unsafe { p_args.as_ref() };
match S::init(args) {
Ok(new_state) => {
unsafe { *pp_state = Box::into_raw(Box::new(new_state)) };
0
}
Err(err) => err
Err(err) => err,
}
}
pub fn sgxsd_enclave_server_handle_call<S>(p_args: *const S::HandleCallArgs,
msg_buf: sgxsd_msg_buf_t,
from: &mut sgxsd_msg_from_t,
pp_state: *mut *mut S)
-> SgxStatus
where S: SgxsdServer,
pub fn sgxsd_enclave_server_handle_call<S>(
p_args: *const S::HandleCallArgs,
msg_buf: sgxsd_msg_buf_t,
from: &mut sgxsd_msg_from_t,
pp_state: *mut *mut S,
) -> SgxStatus
where
S: SgxsdServer,
{
let args = unsafe { p_args.as_ref() };
let args = unsafe { p_args.as_ref() };
let mut state = unsafe { Box::from_raw(*pp_state) };
let msg = ECallSlice(ptr::NonNull::new(msg_buf.data as *mut _), msg_buf.size as usize);
let msg = ECallSlice(ptr::NonNull::new(msg_buf.data as *mut _), msg_buf.size as usize);
match state.handle_call(args, msg.as_ref(), SgxsdMsgFrom::new(from)) {
Ok(()) => {
unsafe { *pp_state = Box::into_raw(state) };
@ -124,16 +132,13 @@ where S: SgxsdServer,
}
}
pub fn sgxsd_enclave_server_terminate<S>(p_args: *const S::TerminateArgs,
p_state: *mut S)
-> SgxStatus
where S: SgxsdServer,
{
pub fn sgxsd_enclave_server_terminate<S>(p_args: *const S::TerminateArgs, p_state: *mut S) -> SgxStatus
where S: SgxsdServer {
let args = unsafe { p_args.as_ref() };
let state = unsafe { Box::from_raw(p_state) };
match state.terminate(args) {
Ok(()) => 0,
Err(err) => err
Err(err) => err,
}
}
@ -154,15 +159,11 @@ impl AsRef<[u8]> for ECallSlice {
#[cfg(test)]
mod tests {
use super::*;
use super::super::mocks;
use mockers::{*, matchers::*};
use super::*;
use mockers::{matchers::*, *};
use super::super::bindgen_wrapper::{
sgxsd_server_init_args_t,
sgxsd_server_handle_call_args_t,
sgxsd_server_terminate_args_t,
};
use super::super::bindgen_wrapper::{sgxsd_server_handle_call_args_t, sgxsd_server_init_args_t, sgxsd_server_terminate_args_t};
fn expect_msg_from_drop(scenario: &Scenario, msg_from: &sgxsd_msg_from_t) {
let msg_from = *msg_from;
@ -213,15 +214,24 @@ mod tests {
struct MockSgxsdServer {}
impl SgxsdServer for MockSgxsdServer {
type InitArgs = sgxsd_server_init_args_t;
type HandleCallArgs = sgxsd_server_handle_call_args_t;
type TerminateArgs = sgxsd_server_terminate_args_t;
type InitArgs = sgxsd_server_init_args_t;
type TerminateArgs = sgxsd_server_terminate_args_t;
fn init(_args: Option<&Self::InitArgs>) -> Result<Self, SgxStatus> {
Ok(Self {})
}
fn handle_call(&mut self, _args: Option<&Self::HandleCallArgs>, _request_data: &[u8], _from: SgxsdMsgFrom) -> Result<(), (SgxStatus, SgxsdMsgFrom)> {
fn handle_call(
&mut self,
_args: Option<&Self::HandleCallArgs>,
_request_data: &[u8],
_from: SgxsdMsgFrom,
) -> Result<(), (SgxStatus, SgxsdMsgFrom)>
{
Ok(())
}
fn terminate(self, _args: Option<&Self::TerminateArgs>) -> Result<(), SgxStatus> {
Ok(())
}
@ -246,11 +256,7 @@ mod tests {
let mut pp_state = mock_sgxsd_server();
expect_msg_from_drop(&scenario, &msg_from);
sgxsd_enclave_server_handle_call(
std::ptr::null(),
mocks::valid_msg_buf(),
&mut msg_from, &mut *pp_state
);
sgxsd_enclave_server_handle_call(std::ptr::null(), mocks::valid_msg_buf(), &mut msg_from, &mut *pp_state);
unsafe { Box::from_raw(*pp_state) };
@ -266,11 +272,18 @@ mod tests {
let mut pp_state = mock_sgxsd_server();
expect_msg_from_drop(&scenario, &msg_from);
assert_eq!(sgxsd_enclave_server_handle_call(
std::ptr::null(),
sgxsd_msg_buf_t { data: std::ptr::null_mut(), size: 0 },
&mut msg_from, &mut *pp_state
), 0);
assert_eq!(
sgxsd_enclave_server_handle_call(
std::ptr::null(),
sgxsd_msg_buf_t {
data: std::ptr::null_mut(),
size: 0,
},
&mut msg_from,
&mut *pp_state
),
0
);
unsafe { Box::from_raw(*pp_state) };
@ -286,11 +299,10 @@ mod tests {
let mut pp_state = mock_sgxsd_server();
expect_msg_from_drop(&scenario, &msg_from);
assert_eq!(sgxsd_enclave_server_handle_call(
std::ptr::null(),
mocks::valid_msg_buf(),
&mut msg_from, &mut *pp_state
), 0);
assert_eq!(
sgxsd_enclave_server_handle_call(std::ptr::null(), mocks::valid_msg_buf(), &mut msg_from, &mut *pp_state),
0
);
unsafe { Box::from_raw(*pp_state) };

View File

@ -6,17 +6,13 @@
//
#![cfg_attr(not(any(test, feature = "test")), no_std)]
#![allow(
unused_parens,
clippy::style,
clippy::large_enum_variant,
)]
#![allow(unused_parens, clippy::style, clippy::large_enum_variant)]
#![warn(
bare_trait_objects,
elided_lifetimes_in_paths,
trivial_numeric_casts,
variant_size_differences,
clippy::integer_arithmetic,
clippy::integer_arithmetic
)]
#![deny(
clippy::cast_possible_truncation,
@ -47,40 +43,31 @@
clippy::use_debug,
clippy::use_self,
clippy::use_underscore_binding,
clippy::wildcard_enum_match_arm,
clippy::wildcard_enum_match_arm
)]
extern crate alloc;
#[allow(dead_code, non_camel_case_types, non_upper_case_globals, non_snake_case, improper_ctypes, clippy::all, clippy::pedantic, clippy::integer_arithmetic)]
#[rustfmt::skip]
mod bindgen_wrapper;
pub mod ecalls;
#[cfg(any(test, feature = "test"))] pub mod mocks;
#[cfg(any(test, feature = "test"))]
pub mod mocks;
use core::ffi::{c_void};
use core::ffi::c_void;
use core::num;
use core::ptr;
use core::sync;
use num_traits::{ToPrimitive};
use num_traits::ToPrimitive;
use rand_core::{CryptoRng, RngCore};
use sgx_ffi::util::{clear, SecretValue};
use bindgen_wrapper::{
br_sha256_context,
br_sha256_out,
br_sha256_init,
br_sha224_update,
br_sha256_SIZE,
curve25519_donna,
sgxsd_aes_gcm_decrypt,
sgxsd_aes_gcm_encrypt,
sgxsd_enclave_read_rand,
sgxsd_rand_buf,
sgx_status_t as SgxStatus,
SGX_SUCCESS,
SGX_ERROR_INVALID_PARAMETER,
br_sha224_update, br_sha256_SIZE, br_sha256_context, br_sha256_init, br_sha256_out, curve25519_donna, sgx_status_t as SgxStatus,
sgxsd_aes_gcm_decrypt, sgxsd_aes_gcm_encrypt, sgxsd_enclave_read_rand, sgxsd_rand_buf, SGX_ERROR_INVALID_PARAMETER, SGX_SUCCESS,
};
//
@ -120,26 +107,29 @@ impl RngCore for RdRand {
let random_bytes = self.rand_bytes([0; 4]);
u32::from_ne_bytes(random_bytes)
}
fn next_u64(&mut self) -> u64 {
let random_bytes = self.rand_bytes([0; 8]);
u64::from_ne_bytes(random_bytes)
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
while let Err(_) = self.try_fill_bytes(dest) {
sync::atomic::spin_loop_hint();
}
}
fn try_fill_bytes(&mut self, mut dest: &mut [u8]) -> Result<(), rand_core::Error> {
let mut rand_buf = sgxsd_rand_buf::default();
while !dest.is_empty() {
match num::NonZeroU32::new(unsafe { sgxsd_enclave_read_rand(&mut rand_buf) }) {
None => (),
None => (),
Some(error) => {
clear(&mut rand_buf.x);
return Err(error.into());
}
}
let dest_part_len = rand_buf.x.len().min(dest.len());
let dest_part_len = rand_buf.x.len().min(dest.len());
let (dest_part, dest_rest) = dest.split_at_mut(dest_part_len);
dest_part.copy_from_slice(rand_buf.x.get(..dest_part_len).unwrap_or_else(|| unreachable!()));
dest = dest_rest;
@ -149,7 +139,8 @@ impl RngCore for RdRand {
}
}
impl RdRand {
pub fn rand_bytes<T>(&mut self, mut buf: T) -> T where T: AsMut<[u8]> {
pub fn rand_bytes<T>(&mut self, mut buf: T) -> T
where T: AsMut<[u8]> {
self.fill_bytes(buf.as_mut());
buf
}
@ -197,49 +188,45 @@ impl AesGcmKey {
self.key.get_mut().data = *data;
}
pub fn decrypt(
&self,
data: &mut [u8],
aad: &[u8],
iv: &AesGcmIv,
mac: &AesGcmMac,
) -> Result<(), SgxStatus>
{
pub fn decrypt(&self, data: &mut [u8], aad: &[u8], iv: &AesGcmIv, mac: &AesGcmMac) -> Result<(), SgxStatus> {
let data_len = data.len().to_u32().ok_or(SGX_ERROR_INVALID_PARAMETER)?;
let aad_len = aad.len().to_u32().ok_or(SGX_ERROR_INVALID_PARAMETER)?;
let aad_len = aad.len().to_u32().ok_or(SGX_ERROR_INVALID_PARAMETER)?;
match unsafe {
sgxsd_aes_gcm_decrypt(self.key.get(), data.as_ptr() as *const c_void, data_len,
data.as_mut_ptr() as *mut c_void,
iv,
aad.as_ptr() as *const c_void, aad_len,
mac)
sgxsd_aes_gcm_decrypt(
self.key.get(),
data.as_ptr() as *const c_void,
data_len,
data.as_mut_ptr() as *mut c_void,
iv,
aad.as_ptr() as *const c_void,
aad_len,
mac,
)
} {
SGX_SUCCESS => Ok(()),
error => Err(error),
error => Err(error),
}
}
pub fn encrypt(
&self,
data: &mut [u8],
aad: &[u8],
iv: &AesGcmIv,
mac: &mut AesGcmMac,
) -> Result<(), SgxStatus>
{
pub fn encrypt(&self, data: &mut [u8], aad: &[u8], iv: &AesGcmIv, mac: &mut AesGcmMac) -> Result<(), SgxStatus> {
let data_len = data.len().to_u32().ok_or(SGX_ERROR_INVALID_PARAMETER)?;
let aad_len = aad.len().to_u32().ok_or(SGX_ERROR_INVALID_PARAMETER)?;
let aad_len = aad.len().to_u32().ok_or(SGX_ERROR_INVALID_PARAMETER)?;
match unsafe {
sgxsd_aes_gcm_encrypt(self.key.get(), data.as_ptr() as *const c_void, data_len,
data.as_mut_ptr() as *mut c_void,
iv,
aad.as_ptr() as *const c_void, aad_len,
mac)
sgxsd_aes_gcm_encrypt(
self.key.get(),
data.as_ptr() as *const c_void,
data_len,
data.as_mut_ptr() as *mut c_void,
iv,
aad.as_ptr() as *const c_void,
aad_len,
mac,
)
} {
SGX_SUCCESS => Ok(()),
error => Err(error),
error => Err(error),
}
}
@ -257,12 +244,15 @@ impl SHA256Context {
pub const fn hash_len() -> usize {
br_sha256_SIZE as usize
}
pub fn reset(&mut self) {
unsafe { br_sha256_init(&mut self.context) };
}
pub fn update(&mut self, data: &[u8]) {
unsafe { br_sha224_update(&mut self.context, data.as_ptr() as *const c_void, data.len()) };
}
pub fn result(&mut self, out: &mut [u8; Self::hash_len()]) {
unsafe { br_sha256_out(&self.context, out.as_mut_ptr() as *mut c_void) }
}
@ -294,21 +284,25 @@ impl Curve25519Key {
*self.privkey.get_mut() = *privkey;
curve25519_base(&mut self.pubkey, self.privkey.get());
}
#[allow(clippy::indexing_slicing)]
pub fn generate(&mut self, mut rng: impl RngCore) {
let privkey = self.privkey.get_mut();
rng.fill_bytes(privkey);
privkey[0] &= 248;
privkey[0] &= 248;
privkey[31] &= 127;
privkey[31] |= 64;
curve25519_base(&mut self.pubkey, self.privkey.get());
}
pub const fn pubkey(&self) -> &[u8; 32] {
&self.pubkey
}
pub fn privkey(&self) -> &[u8; 32] {
self.privkey.get()
}
pub fn dh(&self, pubkey: &[u8; 32], out: &mut [u8; 32]) {
curve25519(out, self.privkey.get(), pubkey);
}
@ -337,14 +331,12 @@ fn curve25519(mypublic: &mut [u8; 32], mysecret: &[u8; 32], basepoint: &[u8; 32]
#[cfg(test)]
pub mod tests {
use super::*;
use super::mocks;
use super::*;
use crate::bindgen_wrapper::{
SGX_ERROR_UNEXPECTED,
};
use crate::bindgen_wrapper::SGX_ERROR_UNEXPECTED;
use mockers::{Scenario};
use mockers::Scenario;
const ASSERT_RANDOM_WINDOW_SIZE: usize = 2;
@ -359,7 +351,7 @@ pub mod tests {
#[test]
#[should_panic]
fn test_assert_random() {
let src = test_ffi::rand_bytes(vec![0; 100]);
let src = test_ffi::rand_bytes(vec![0; 100]);
let mut data = src.clone();
RdRand.fill_bytes(&mut data[..(src.len() - ASSERT_RANDOM_WINDOW_SIZE)]);
assert_random(&src, &data);
@ -367,7 +359,7 @@ pub mod tests {
#[test]
fn fill_bytes_ok() {
let src = test_ffi::rand_bytes(vec![0; 100]);
let src = test_ffi::rand_bytes(vec![0; 100]);
let mut data = src.clone();
RdRand.fill_bytes(&mut data[..0]);
assert_eq!(data[..], src[..]);
@ -421,7 +413,7 @@ pub mod tests {
let read_rand_mock = test_ffi::mock_for(&mocks::SGXSD_ENCLAVE_READ_RAND, &scenario);
scenario.expect(read_rand_mock.sgxsd_enclave_read_rand().and_return_clone(SGX_SUCCESS).times(1));
let src = test_ffi::rand_bytes(vec![0; std::mem::size_of::<sgxsd_rand_buf>()]);
let src = test_ffi::rand_bytes(vec![0; std::mem::size_of::<sgxsd_rand_buf>()]);
let mut data = src.clone();
assert!(RdRand.try_fill_bytes(&mut data).is_ok());
assert_random(&src, &data);
@ -436,7 +428,7 @@ pub mod tests {
let read_rand_mock = test_ffi::mock_for(&mocks::SGXSD_ENCLAVE_READ_RAND, &scenario);
scenario.expect(read_rand_mock.sgxsd_enclave_read_rand().and_return_clone(SGX_SUCCESS).times(4));
let src = test_ffi::rand_bytes(vec![0; std::mem::size_of::<sgxsd_rand_buf>() * 4 - 1]);
let src = test_ffi::rand_bytes(vec![0; std::mem::size_of::<sgxsd_rand_buf>() * 4 - 1]);
let mut data = src.clone();
assert!(RdRand.try_fill_bytes(&mut data).is_ok());
assert_random(&src, &data);
@ -460,7 +452,7 @@ pub mod tests {
#[test]
fn rand_bytes_valid() {
let src = test_ffi::rand_bytes(vec![0; 100]);
let src = test_ffi::rand_bytes(vec![0; 100]);
let data = RdRand.rand_bytes(src.clone());
assert_eq!(data.len(), src.len());
assert_random(&src, &data);

View File

@ -7,29 +7,18 @@
#![allow(clippy::all, clippy::option_unwrap_used, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
use std::cell::{RefCell};
use std::cell::RefCell;
use mockers_derive::mocked;
use rand::*;
use rand::distributions::*;
use rand::*;
use test_ffi::*;
pub use super::bindgen_wrapper::{
sgxsd_aes_gcm_key_t,
sgxsd_msg_buf_t,
sgxsd_msg_from_t,
};
pub use super::bindgen_wrapper::{sgxsd_aes_gcm_key_t, sgxsd_msg_buf_t, sgxsd_msg_from_t};
use super::bindgen_wrapper::{
br_sha224_context,
br_sha256_context,
SGXSD_SHA256_HASH_SIZE,
sgxsd_aes_gcm_iv_t,
sgxsd_aes_gcm_mac_t,
sgxsd_msg_tag_t,
sgxsd_msg_tag__bindgen_ty_1,
sgxsd_rand_buf_t,
sgx_status_t,
br_sha224_context, br_sha256_context, sgx_status_t, sgxsd_aes_gcm_iv_t, sgxsd_aes_gcm_mac_t, sgxsd_msg_tag__bindgen_ty_1,
sgxsd_msg_tag_t, sgxsd_rand_buf_t, SGXSD_SHA256_HASH_SIZE,
};
//
@ -81,7 +70,9 @@ pub trait SgxsdEnclaveReadRand {
impl Distribution<sgxsd_msg_tag_t> for Standard {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> sgxsd_msg_tag_t {
sgxsd_msg_tag_t { __bindgen_anon_1: sgxsd_msg_tag__bindgen_ty_1 { tag: rng.sample(self) } }
sgxsd_msg_tag_t {
__bindgen_anon_1: sgxsd_msg_tag__bindgen_ty_1 { tag: rng.sample(self) },
}
}
}
impl Distribution<sgxsd_aes_gcm_key_t> for Standard {
@ -92,8 +83,8 @@ impl Distribution<sgxsd_aes_gcm_key_t> for Standard {
impl Distribution<sgxsd_msg_from_t> for Standard {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> sgxsd_msg_from_t {
sgxsd_msg_from_t {
tag: rng.sample(self),
valid: true,
tag: rng.sample(self),
valid: true,
server_key: rng.sample(self),
}
}
@ -109,7 +100,10 @@ lazy_static::lazy_static! {
pub fn valid_msg_buf() -> sgxsd_msg_buf_t {
let msg = &VALID_MSG_BUF;
sgxsd_msg_buf_t { data: msg.as_ptr() as *mut _, size: msg.len() as u32 }
sgxsd_msg_buf_t {
data: msg.as_ptr() as *mut _,
size: msg.len() as u32,
}
}
//
@ -122,7 +116,9 @@ pub mod impls {
#[no_mangle]
pub extern "C" fn sgxsd_enclave_server_noreply(from: *mut sgxsd_msg_from_t) -> sgx_status_t {
SGXSD_ENCLAVE_SERVER_NOREPLY.with(|mock| {
mock.borrow().as_ref().expect("no mock for sgxsd_enclave_server_noreply")
mock.borrow()
.as_ref()
.expect("no mock for sgxsd_enclave_server_noreply")
.sgxsd_enclave_server_noreply(unsafe { *from })
})
}
@ -133,12 +129,13 @@ pub mod impls {
assert_ne!(reply_buf.size, 0);
let reply_buf = unsafe { std::slice::from_raw_parts_mut(reply_buf.data, reply_buf.size as usize) };
SGXSD_ENCLAVE_SERVER_REPLY.with(|mock| {
mock.borrow().as_ref().expect("no mock for sgxsd_enclave_server_reply")
mock.borrow()
.as_ref()
.expect("no mock for sgxsd_enclave_server_reply")
.sgxsd_enclave_server_reply(reply_buf, unsafe { *from })
})
}
#[no_mangle]
pub extern "C" fn sgxsd_aes_gcm_encrypt(
p_key: *const sgxsd_aes_gcm_key_t,
@ -149,7 +146,8 @@ pub mod impls {
p_aad: *const ::std::os::raw::c_void,
aad_len: u32,
p_out_mac: *mut sgxsd_aes_gcm_mac_t,
) -> sgx_status_t {
) -> sgx_status_t
{
let key = unsafe { std::ptr::read_volatile(p_key) };
assert_ne!(&key.data[..], &vec![0; key.data.len()][..]);
assert!(!p_iv.is_null());
@ -160,12 +158,16 @@ pub mod impls {
assert!(!p_src.is_null());
assert!(!p_dst.is_null());
let src = unsafe { std::slice::from_raw_parts(p_src as *const u8, src_len as usize) };
src.iter().for_each(|p| unsafe { std::ptr::read_volatile(p); });
src.iter().for_each(|p| unsafe {
std::ptr::read_volatile(p);
});
}
let aad = if aad_len != 0 {
assert!(!p_aad.is_null());
let aad = unsafe { std::slice::from_raw_parts(p_aad as *const u8, aad_len as usize) };
aad.iter().for_each(|p| unsafe { std::ptr::read_volatile(p); });
aad.iter().for_each(|p| unsafe {
std::ptr::read_volatile(p);
});
aad
} else {
&[]
@ -208,7 +210,8 @@ pub mod impls {
p_aad: *const ::std::os::raw::c_void,
aad_len: u32,
p_in_mac: *const sgxsd_aes_gcm_mac_t,
) -> sgx_status_t {
) -> sgx_status_t
{
let key = unsafe { std::ptr::read_volatile(p_key) };
assert_ne!(&key.data[..], &vec![0; key.data.len()][..]);
assert!(!p_iv.is_null());
@ -219,12 +222,16 @@ pub mod impls {
assert!(!p_src.is_null());
assert!(!p_dst.is_null());
let src = unsafe { std::slice::from_raw_parts(p_src as *const u8, src_len as usize) };
src.iter().for_each(|p| unsafe { std::ptr::read_volatile(p); });
src.iter().for_each(|p| unsafe {
std::ptr::read_volatile(p);
});
}
let aad = if aad_len != 0 {
assert!(!p_aad.is_null());
let aad = unsafe { std::slice::from_raw_parts(p_aad as *const u8, aad_len as usize) };
aad.iter().for_each(|p| unsafe { std::ptr::read_volatile(p); });
aad.iter().for_each(|p| unsafe {
std::ptr::read_volatile(p);
});
aad
} else {
&[]
@ -278,7 +285,9 @@ pub mod impls {
if len != 0 {
assert!(!data.is_null());
let data = unsafe { std::slice::from_raw_parts(data as *const u8, len) };
data.iter().for_each(|p| unsafe { std::ptr::read_volatile(p); });
data.iter().for_each(|p| unsafe {
std::ptr::read_volatile(p);
});
}
}

View File

@ -5,13 +5,13 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
//
use std::cell::{RefCell};
use std::thread::{LocalKey};
use std::cell::RefCell;
use std::thread::LocalKey;
use mockers::*;
use rand::*;
use rand::distributions::*;
use rand_chacha::{ChaChaRng};
use rand::*;
use rand_chacha::ChaChaRng;
//
// mock extern "C" functions
@ -25,10 +25,8 @@ pub fn clear<T>(key: &'static LocalKey<RefCell<Option<T>>>) {
key.with(|key| *key.borrow_mut() = None);
}
pub fn mock_for<T>(key: &'static LocalKey<RefCell<Option<T>>>, scenario: &Scenario)
-> <T as Mock>::Handle where
T: Mock,
{
pub fn mock_for<T>(key: &'static LocalKey<RefCell<Option<T>>>, scenario: &Scenario) -> <T as Mock>::Handle
where T: Mock {
let (mock, handle) = scenario.create_mock::<T>();
set(key, mock);
handle
@ -38,11 +36,13 @@ pub fn mock_for<T>(key: &'static LocalKey<RefCell<Option<T>>>, scenario: &Scenar
// random mock values
//
pub fn rand_bytes<T>(mut buf: T) -> T where T: AsMut<[u8]> {
pub fn rand_bytes<T>(mut buf: T) -> T
where T: AsMut<[u8]> {
read_rand(buf.as_mut());
buf
}
pub fn rand<T>() -> T where Standard: Distribution<T> {
pub fn rand<T>() -> T
where Standard: Distribution<T> {
RAND_STATE.with(|rand| rand.borrow_mut().gen())
}
pub fn read_rand(buf: &mut [u8]) {