diff --git a/Cargo.toml b/Cargo.toml index 79e918f..6621eea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,7 @@ [workspace] -members = ["sgx_sdk_ffi"] +members = [ + "sgx_sdk_ffi", + "ias_client", + "kbupd_util", + ] diff --git a/ias_client/Cargo.toml b/ias_client/Cargo.toml new file mode 100644 index 0000000..adcb330 --- /dev/null +++ b/ias_client/Cargo.toml @@ -0,0 +1,19 @@ +[package] +authors = ["Open Whisper Systems"] +name = "ias_client" +version = "0.1.0" +license = "AGPL-3.0-or-later" +description = "Intel Attestation Services client" +edition = "2018" + +[dependencies] +failure = "0.1" +futures = "0.1" +http = "0.1" +hyper = "0.12" +kbupd_util = { path = "../kbupd_util" } +serde = "1.0" +serde_derive = "1.0" +serde_json = "1.0" +sgx_sdk_ffi = { path = "../sgx_sdk_ffi" } +try_future = "0.1" diff --git a/ias_client/src/lib.rs b/ias_client/src/lib.rs new file mode 100644 index 0000000..5dd4f72 --- /dev/null +++ b/ias_client/src/lib.rs @@ -0,0 +1,324 @@ +/* + * Copyright (C) 2019 Open Whisper Systems + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +use std::fmt; +use std::mem; + +use failure::{format_err, ResultExt}; +use futures::prelude::*; +use http::header::HeaderValue; +use http::uri::PathAndQuery; +use http::{self, HeaderMap, Uri}; +use hyper::client::connect::Connect; +use hyper::{Body, Chunk, Client, Method, Request, Response}; +use kbupd_util::base64; +use serde_derive::{Deserialize, Serialize}; +use serde_json; +use sgx_sdk_ffi::SgxQuote; +use try_future::{try_future, TryFuture}; + +pub struct IasClient { + base_uri: Uri, + api_key: Option, + client: Client, +} + +#[derive(Debug, failure::Fail)] +pub enum GetQuoteSignatureError { + #[fail(display = "error fetching signed quote: {:?}", _0)] + FetchError(#[cause] failure::Error), + #[fail(display = "quote verification error: {:?}", _0)] + QuoteVerificationError(#[cause] QuoteVerificationError), +} + +#[derive(Clone, Default)] +pub struct SignatureRevocationList(pub Vec); + +impl IasClient +where + C: Connect + 'static, + C::Transport: 'static, + C::Future: 'static, +{ + pub fn new(base_uri: &str, api_key: Option<&str>, connector: C) -> Result { + let base_uri = if api_key.is_some() { + uri_path_join(base_uri.parse()?, format_args!("/attestation/v3"))? + } else { + uri_path_join(base_uri.parse()?, format_args!("/attestation/sgx/v3"))? + }; + let client = Client::builder().build(connector); + let api_key = match api_key { + Some(api_key) => Some(HeaderValue::from_bytes(api_key.as_bytes()).context("invalid IAS API key value")?), + None => None, + }; + Ok(Self { base_uri, api_key, client }) + } + + pub fn get_signature_revocation_list(&self, gid: u32) -> impl Future { + let uri = try_future!(self.request_uri(format_args!("/sigrl/{:08x}", gid))); + + let mut hyper_request = Request::new(Body::empty()); + + *hyper_request.method_mut() = Method::GET; + *hyper_request.uri_mut() = uri; + + if let Some(api_key) = &self.api_key { + hyper_request.headers_mut().insert("Ocp-Apim-Subscription-Key", api_key.clone()); + } + + let response = self.client.request(hyper_request); + let response_data = response.from_err().and_then(|response: Response| { + if !response.status().is_success() { + return TryFuture::from_error(format_err!("HTTP error: {}", response.status().as_str())); + } + response.into_body().concat2().from_err().into() + }); + + let decoded_response = + response_data.and_then(|data: Chunk| base64::decode(&data).map(SignatureRevocationList).into_future().from_err()); + + decoded_response.into() + } + + fn fetch_quote_signature(&self, quote: &[u8]) -> impl Future { + let uri = try_future!(self.request_uri(format_args!("/report"))); + + let request = QuoteSignatureRequest { isvEnclaveQuote: quote }; + let encoded_request = try_future!(serde_json::to_vec(&request)); + let mut hyper_request = Request::new(Body::from(encoded_request)); + + *hyper_request.method_mut() = Method::POST; + *hyper_request.uri_mut() = uri; + hyper_request + .headers_mut() + .insert("Content-Type", HeaderValue::from_static("application/json")); + + if let Some(api_key) = &self.api_key { + hyper_request.headers_mut().insert("Ocp-Apim-Subscription-Key", api_key.clone()); + } + + let response = self.client.request(hyper_request); + let full_response = response.and_then(move |response: Response| { + let (response_parts, response_body) = response.into_parts(); + + let response_data = response_body.concat2(); + + response_data.map(|response_data| (response_parts, response_data)) + }); + full_response.from_err().into() + } + + pub fn get_quote_signature( + &self, + quote: Vec, + accept_group_out_of_date: bool, + ) -> impl Future + { + let response = self.fetch_quote_signature("e); + let signed_quote = response.then(move |response_result: Result<(http::response::Parts, Chunk), failure::Error>| { + let (response_parts, response_data) = response_result.map_err(GetQuoteSignatureError::FetchError)?; + + let signed_quote_result = validate_quote_signature(response_parts, response_data, quote, accept_group_out_of_date); + signed_quote_result.map_err(GetQuoteSignatureError::QuoteVerificationError) + }); + + signed_quote + } + + fn request_uri(&self, request_path: fmt::Arguments<'_>) -> Result { + uri_path_join(self.base_uri.clone(), request_path) + } +} + +impl Clone for IasClient { + fn clone(&self) -> Self { + Self { + base_uri: self.base_uri.clone(), + api_key: self.api_key.clone(), + client: self.client.clone(), + } + } +} + +fn uri_path_join(uri: Uri, append_path: fmt::Arguments<'_>) -> Result { + let mut parts = uri.into_parts(); + let path_base = parts + .path_and_query + .as_ref() + .map(PathAndQuery::path) + .unwrap_or_default() + .trim_end_matches('/'); + parts.path_and_query = Some(format!("{}{}", path_base, append_path).parse::()?); + let uri = Uri::from_parts(parts)?; + Ok(uri) +} + +fn validate_quote_signature( + response_parts: http::response::Parts, + response_body_data: Chunk, + quote: Vec, + accept_group_out_of_date: bool, +) -> Result +{ + if !response_parts.status.is_success() { + let response_body_string = String::from_utf8_lossy(&response_body_data).to_string(); + return Err(QuoteVerificationError::HttpError( + response_parts.status, + response_parts, + response_body_string, + )); + } + + let base64_signature = get_header_str(&response_parts.headers, "X-IASReport-Signature")?; + let pem_certificates = get_header_str(&response_parts.headers, "X-IASReport-Signing-Certificate")?; + + let signature = + base64::decode(base64_signature.as_bytes()).map_err(|_| QuoteVerificationError::InvalidSignature(base64_signature.to_string()))?; + + let certificates = kbupd_util::pem::decode(&kbupd_util::percent::decode(pem_certificates.as_bytes())); + + if certificates.is_empty() { + return Err(QuoteVerificationError::InvalidCertificates(pem_certificates.to_string())); + } + + let body = response_body_data.to_vec(); + + let parsed_body: QuoteSignatureResponseBody = + serde_json::from_slice(&body).map_err(|parse_error| QuoteVerificationError::InvalidJson(parse_error.into()))?; + + if parsed_body.version != 3 { + return Err(QuoteVerificationError::WrongVersion(parsed_body.version)); + } + + if Some(&parsed_body.isvEnclaveQuoteBody[..]) != quote.get(..mem::size_of::() - 4) { + return Err(QuoteVerificationError::WrongQuote); + } + + match parsed_body.isvEnclaveQuoteStatus.as_str() { + "OK" => {} + "GROUP_OUT_OF_DATE" | "CONFIGURATION_NEEDED" => { + if !accept_group_out_of_date { + return Err(QuoteVerificationError::GroupOutOfDate( + parsed_body.isvEnclaveQuoteStatus.clone(), + parsed_body, + )); + } + } + "GROUP_REVOKED" => { + return Err(QuoteVerificationError::GroupOutOfDate( + parsed_body.isvEnclaveQuoteStatus.clone(), + parsed_body, + )); + } + "SIGRL_VERSION_MISMATCH" => { + return Err(QuoteVerificationError::StaleRevocationList); + } + _ => { + return Err(QuoteVerificationError::AttestationError(parsed_body.isvEnclaveQuoteStatus)); + } + } + + // XXX validate timestamp + + Ok(SignedQuote { + quote, + body, + signature, + certificates, + }) +} + +fn get_header_str<'a>(headers: &'a HeaderMap, name: &'static str) -> Result<&'a str, QuoteVerificationError> { + if let Some(header) = headers.get(name) { + match header.to_str() { + Ok(header) => Ok(header), + Err(_) => Err(QuoteVerificationError::InvalidHeaderValue(name, header.clone())), + } + } else { + Err(QuoteVerificationError::MissingHeader(name)) + } +} + +impl std::ops::Deref for SignatureRevocationList { + type Target = [u8]; + + fn deref(&self) -> &[u8] { + &self.0 + } +} + +#[derive(failure::Fail)] +pub enum QuoteVerificationError { + #[fail(display = "attestation http error: {}", _0)] + HttpError(http::status::StatusCode, http::response::Parts, String), + #[fail(display = "missing attestation http header {}", _0)] + MissingHeader(&'static str), + #[fail(display = "invalid attestation http header value for {}: {:?}", _0, _1)] + InvalidHeaderValue(&'static str, HeaderValue), + #[fail(display = "invalid attestation signature: {}", _0)] + InvalidSignature(String), + #[fail(display = "invalid attestation certificates: {}", _0)] + InvalidCertificates(String), + #[fail(display = "invalid attestation report json: {}", _0)] + InvalidJson(#[cause] failure::Error), + #[fail(display = "invalid attestation report version: {}", _0)] + WrongVersion(u64), + #[fail(display = "wrong attestation report quote")] + WrongQuote, + #[fail(display = "stale attestation revocation list")] + StaleRevocationList, + #[fail(display = "attestation group out of date: {}", _0)] + GroupOutOfDate(String, QuoteSignatureResponseBody), + #[fail(display = "attestation error: {}", _0)] + AttestationError(String), +} + +impl fmt::Debug for QuoteVerificationError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self, fmt) + } +} + +#[derive(Debug)] +pub struct SignedQuote { + pub quote: Vec, + pub body: Vec, + pub signature: Vec, + pub certificates: Vec>, +} + +#[allow(non_snake_case)] +#[derive(Serialize)] +pub struct QuoteSignatureRequest<'a> { + #[serde(with = "base64")] + pub isvEnclaveQuote: &'a [u8], +} + +#[allow(non_snake_case)] +#[derive(Deserialize, Debug)] +pub struct QuoteSignatureResponseBody { + pub isvEnclaveQuoteStatus: String, + + #[serde(with = "base64")] + pub isvEnclaveQuoteBody: Vec, + + pub version: u64, + + pub timestamp: String, + + pub platformInfoBlob: Option, +} diff --git a/kbupd_util/Cargo.toml b/kbupd_util/Cargo.toml new file mode 100644 index 0000000..bd1a718 --- /dev/null +++ b/kbupd_util/Cargo.toml @@ -0,0 +1,16 @@ +[package] +authors = ["Open Whisper Systems"] +name = "kbupd_util" +version = "0.1.0" +license = "AGPL-3.0-or-later" +description = "Key Backup Service Daemon Utility Library" +edition = "2018" + +[dependencies] +base64 = "0.10" +bytes = "0.4" +hex = "0.4" +rand = "0.6" +serde = "1.0" +serde_derive = "1.0" +regex = "1.1" diff --git a/kbupd_util/src/base64.rs b/kbupd_util/src/base64.rs new file mode 100644 index 0000000..f3c430f --- /dev/null +++ b/kbupd_util/src/base64.rs @@ -0,0 +1,114 @@ +/* + * Copyright (C) 2019 Open Whisper Systems + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +use std::fmt; +use std::marker::{PhantomData}; + +use base64; +use serde::{Deserializer, Serializer}; + +pub fn decode(encoded: &[u8]) -> Result, base64::DecodeError> { + let space_regex = regex::bytes::Regex::new(r"[ \t\r\n]").unwrap(); + let base64_data = space_regex.replace_all(encoded, &b""[..]); + let config = base64::Config::new(base64::CharacterSet::Standard, true); + base64::decode_config(&base64_data, config) +} + +pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result, D::Error> { + deserializer.deserialize_bytes(Base64Visitor) +} + +pub fn serialize(data: impl AsRef<[u8]>, serializer: S) -> Result { + serializer.serialize_str(&base64::encode(data.as_ref())) +} + +// +// Base64Visitor impls +// + +struct Base64Visitor; + +impl<'de> serde::de::Visitor<'de> for Base64Visitor { + type Value = Vec; + fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.write_str("a base64-encoded string") + } + + fn visit_bytes(self, base64: &[u8]) -> Result + where E: serde::de::Error + { + decode(base64).map_err(|error| E::custom(error.to_string())) + } + + fn visit_str(self, base64: &str) -> Result + where E: serde::de::Error + { + self.visit_bytes(base64.as_bytes()) + } +} + +// +// FixedLengthBase64Visitor impls +// + +struct FixedLengthBase64Visitor(PhantomData); + +impl<'de, T> serde::de::Visitor<'de> for FixedLengthBase64Visitor +where T: AsMut<[u8]> + AsRef<[u8]> + Default +{ + type Value = T; + fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.write_str("a base64-encoded string") + } + + fn visit_bytes(self, base64: &[u8]) -> Result + where E: serde::de::Error + { + let mut deserialized = T::default(); + let estimated_length = (base64.len() + 3) / 4 * 3; + if estimated_length > deserialized.as_ref().len() + 2 { + Err(E::custom(format!("base64 parameter length {} > {}", estimated_length, deserialized.as_ref().len()))) + } else { + let data = decode(base64).map_err(|error| E::custom(error.to_string()))?; + if data.len() != deserialized.as_ref().len() { + Err(E::custom(format!("base64 parameter length {} != {}", data.len(), deserialized.as_ref().len()))) + } else { + deserialized.as_mut().copy_from_slice(&data[..]); + Ok(deserialized) + } + } + } + + fn visit_str(self, base64: &str) -> Result + where E: serde::de::Error + { + self.visit_bytes(base64.as_bytes()) + } +} + +pub trait SerdeFixedLengthBase64: Sized + AsMut<[u8]> + AsRef<[u8]> + Default { + fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result { + deserializer.deserialize_bytes(FixedLengthBase64Visitor(PhantomData)) + } + + fn serialize(&self, serializer: S) -> Result { + serialize(self, serializer) + } +} + +impl + AsRef<[u8]> + Default> SerdeFixedLengthBase64 for T { +} diff --git a/kbupd_util/src/duration.rs b/kbupd_util/src/duration.rs new file mode 100644 index 0000000..206140f --- /dev/null +++ b/kbupd_util/src/duration.rs @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2019 Open Whisper Systems + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +use std::time::{Duration}; + +use rand::Rng; + +pub const NANOS_PER_SEC: u32 = 1_000_000_000; + +pub fn random(max: Duration) -> Duration { + let secs = rand::thread_rng().gen_range(0, max.as_secs().saturating_add(1)); + let nanos = rand::thread_rng().gen_range(0, max.subsec_nanos().saturating_add(1)); + Duration::new(secs, nanos) +} + +pub fn as_ticks(duration: Duration, tick_interval: Duration) -> u32 { + let duration_ms = duration.as_millis(); + let tick_interval_ms = tick_interval.as_millis(); + let ticks = duration_ms.saturating_add(tick_interval_ms.saturating_sub(1)) + .checked_div(tick_interval_ms) + .unwrap_or(0); + ticks as u32 +} + +pub fn as_secs_f64(duration: Duration) -> f64 { + (duration.as_secs() as f64) + (duration.subsec_nanos() as f64) / (NANOS_PER_SEC as f64) +} + +#[cfg(test)] +mod test { + use std::time::{Duration}; + + use super::*; + + #[test] + fn test_as_ticks() { + let max_duration = Duration::new(u64::max_value(), NANOS_PER_SEC - 1); + + assert_eq!(as_ticks(Duration::from_secs(10), Duration::from_secs(1)), 10); + assert_eq!(as_ticks(Duration::from_secs(10), Duration::from_secs(0)), 0); + assert_eq!(as_ticks(Duration::from_millis(100), Duration::from_millis(10)), 10); + assert_eq!(as_ticks(Duration::from_millis(100), Duration::from_millis(11)), 10); + assert_eq!(as_ticks(Duration::from_millis(100), Duration::from_millis(12)), 9); + assert_eq!(as_ticks(Duration::from_millis(100), Duration::from_millis(99)), 2); + assert_eq!(as_ticks(Duration::from_millis(100), Duration::from_millis(100)), 1); + assert_eq!(as_ticks(Duration::from_millis(100), Duration::from_millis(1000)), 1); + + assert_eq!(as_ticks(max_duration, Duration::from_secs(0)), 0); + assert_eq!(as_ticks(max_duration, max_duration), 1); + assert_eq!(as_ticks(Duration::from_secs(0), max_duration), 0); + assert_eq!(as_ticks(Duration::from_secs(0), Duration::from_secs(0)), 0); + + assert_eq!(as_ticks(Duration::from_millis(1), max_duration), 1); + } +} diff --git a/kbupd_util/src/hex.rs b/kbupd_util/src/hex.rs new file mode 100644 index 0000000..ce25a42 --- /dev/null +++ b/kbupd_util/src/hex.rs @@ -0,0 +1,154 @@ +/* + * Copyright (C) 2019 Open Whisper Systems + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +use std::fmt; +use std::marker::{PhantomData}; + +use serde::{Deserializer, Serializer}; + +use super::{ToHex}; + +pub fn parse(hex: &str) -> Result, hex::FromHexError> { + hex::decode(hex) +} + +pub fn parse_fixed(hex: &str) -> Result +where T: Sized + AsMut<[u8]> + AsRef<[u8]> + Default +{ + let mut bytes = T::default(); + let () = hex::decode_to_slice(hex, bytes.as_mut())?; + Ok(bytes) +} + +pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result, D::Error> { + deserializer.deserialize_str(HexVisitor) +} + +pub fn serialize(data: impl AsRef<[u8]>, serializer: S) -> Result { + serializer.serialize_str(&format!("{}", ToHex(data.as_ref()))) +} + +// +// HexVisitor impls +// + +struct HexVisitor; + +impl<'de> serde::de::Visitor<'de> for HexVisitor { + type Value = Vec; + fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.write_str("a hexadecimal-encoded string") + } + fn visit_str(self, hex: &str) -> Result, E> + where E: serde::de::Error + { + parse(hex).map_err(|error| E::custom(format!("{}", error))) + } +} + +// +// FixedLengthHexVisitor impls +// + +struct FixedLengthHexVisitor(PhantomData); + +impl<'de, T> serde::de::Visitor<'de> for FixedLengthHexVisitor +where T: AsMut<[u8]> + AsRef<[u8]> + Default +{ + type Value = T; + fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.write_str("a hexadecimal-encoded string") + } + fn visit_str(self, hex: &str) -> Result + where E: serde::de::Error + { + parse_fixed(hex).map_err(|error| E::custom(format!("{}", error))) + } +} + +pub trait SerdeFixedLengthHex: Sized + AsMut<[u8]> + AsRef<[u8]> + Default { + fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result { + deserializer.deserialize_str(FixedLengthHexVisitor(PhantomData)) + } + + fn serialize(&self, serializer: S) -> Result { + serialize(self, serializer) + } +} + +impl + AsRef<[u8]> + Default> SerdeFixedLengthHex for T { +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_parse() { + assert_eq!(&parse("").unwrap(), b""); + assert_eq!(&parse("616263").unwrap(), b"abc"); + assert_eq!(&parse("00fF").unwrap(), b"\x00\xFF"); + + parse("\n").unwrap_err(); + parse(" ").unwrap_err(); + parse(" 00").unwrap_err(); + parse("00 ").unwrap_err(); + parse("00\n").unwrap_err(); + parse(" 00 ").unwrap_err(); + parse("0 0").unwrap_err(); + parse("0").unwrap_err(); + parse("0g").unwrap_err(); + parse("0\x00").unwrap_err(); + parse("\x00").unwrap_err(); + parse("\x00\x00").unwrap_err(); + parse("FF\x7F").unwrap_err(); + parse("000").unwrap_err(); + } + + #[test] + fn test_parse_fixed() { + assert_eq!(&parse_fixed::<[u8; 0]>("").unwrap(), b""); + assert_eq!(&parse_fixed::<[u8; 3]>("616263").unwrap(), b"abc"); + assert_eq!(&parse_fixed::<[u8; 2]>("00fF").unwrap(), b"\x00\xFF"); + + parse_fixed::<[u8; 1]>("").unwrap_err(); + parse_fixed::<[u8; 0]>("00").unwrap_err(); + parse_fixed::<[u8; 2]>("00").unwrap_err(); + + macro_rules! test_parse_fixed { + ($n:literal) => ({ + parse_fixed::<[u8; $n]>("\n").unwrap_err(); + parse_fixed::<[u8; $n]>(" ").unwrap_err(); + parse_fixed::<[u8; $n]>(" 00").unwrap_err(); + parse_fixed::<[u8; $n]>("00 ").unwrap_err(); + parse_fixed::<[u8; $n]>("00\n").unwrap_err(); + parse_fixed::<[u8; $n]>(" 00 ").unwrap_err(); + parse_fixed::<[u8; $n]>("0 0").unwrap_err(); + parse_fixed::<[u8; $n]>("0").unwrap_err(); + parse_fixed::<[u8; $n]>("0g").unwrap_err(); + parse_fixed::<[u8; $n]>("0\x00").unwrap_err(); + parse_fixed::<[u8; $n]>("\x00").unwrap_err(); + parse_fixed::<[u8; $n]>("\x00\x00").unwrap_err(); + parse_fixed::<[u8; $n]>("FF\x7F").unwrap_err(); + parse_fixed::<[u8; $n]>("000").unwrap_err(); + }) + } + test_parse_fixed!(0); + test_parse_fixed!(1); + test_parse_fixed!(2); + } +} diff --git a/kbupd_util/src/lib.rs b/kbupd_util/src/lib.rs new file mode 100644 index 0000000..e747917 --- /dev/null +++ b/kbupd_util/src/lib.rs @@ -0,0 +1,144 @@ +/* + * Copyright (C) 2019 Open Whisper Systems + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +pub mod base64; +pub mod duration; +pub mod hex; +pub mod pem; +pub mod percent; +pub mod thread; + +use std::fmt; +use std::io; +use std::net::{SocketAddr, ToSocketAddrs}; + +pub struct ToHex<'a>(pub &'a [u8]); +pub struct OptionDisplay(pub Option); +pub struct ListDisplay(pub T); +pub struct DisplayAsDebug(pub T); + +pub enum Never {} + +pub fn to_socket_addr(address: impl ToSocketAddrs) -> io::Result { + address + .to_socket_addrs()? + .next() + .ok_or(io::Error::new(io::ErrorKind::Other, "empty listen address")) +} + +// +// ToHex impls +// + +impl<'a> ToHex<'a> { + pub fn new>(bytes: &'a T) -> Self { + ToHex(bytes.as_ref()) + } +} + +impl<'a> fmt::Display for ToHex<'a> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let ToHex(data) = self; + for byte in *data { + write!(fmt, "{:02x}", byte)?; + } + Ok(()) + } +} +impl<'a> fmt::Debug for ToHex<'a> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(self, fmt) + } +} + +impl fmt::Display for OptionDisplay +where T: fmt::Display +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let OptionDisplay(inner) = self; + match inner { + Some(inner) => fmt::Display::fmt(inner, fmt), + None => write!(fmt, ""), + } + } +} + +// +// OptionDisplay impls +// + +impl fmt::Debug for OptionDisplay +where T: fmt::Display +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(self, fmt) + } +} + +// +// ListDisplay impls +// + +impl fmt::Display for ListDisplay +where + T: IntoIterator + Clone, + T::Item: fmt::Display, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let ListDisplay(inner) = self; + fmt.debug_list().entries(inner.clone().into_iter().map(DisplayAsDebug)).finish() + } +} + +impl fmt::Debug for ListDisplay +where + T: IntoIterator + Clone, + T::Item: fmt::Display, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(self, fmt) + } +} + +// +// DisplayAsDebug impls +// + +impl fmt::Debug for DisplayAsDebug +where T: fmt::Display +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let DisplayAsDebug(inner) = self; + fmt::Display::fmt(inner, fmt) + } +} + +// +// Never impls +// + +macro_rules! from_never { + ($type:ty) => { + impl From for $type { + fn from(never: Never) -> Self { + match never {} + } + } + }; +} + +from_never!(()); diff --git a/kbupd_util/src/pem.rs b/kbupd_util/src/pem.rs new file mode 100644 index 0000000..e581b31 --- /dev/null +++ b/kbupd_util/src/pem.rs @@ -0,0 +1,114 @@ +/* + * Copyright (C) 2019 Open Whisper Systems + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +use std::str; + +pub fn decode(pem: &[u8]) -> Vec> { + // based on RFC7468 + let re = regex::bytes::Regex::new(r"(?-u)-----BEGIN (?:[\x21-\x2C\x2E-\x7E](?:[- ]?[\x21-\x2C\x2E-\x7E])*)?-----[ \t]*(?:\r|\n|\r\n)[ \t\r\n]*([ \t\r\n\x2B\x2F\x30-\x39\x3D\x41-\x5A\x61-\x7A]*)(?:\r|\n|\r\n)-----END (?:[\x21-\x2C\x2E-\x7E](?:[- ]?[\x21-\x2C\x2E-\x7E])*)?-----").unwrap(); + + let mut certificates: Vec> = Vec::new(); + for captures in re.captures_iter(pem) { + if let Ok(der) = crate::base64::decode(&captures[1]) { + if !der.is_empty() { + certificates.push(der); + } + } + } + certificates +} + +pub fn encode(tag: &str, certificates_der: impl IntoIterator + Clone) -> String +where T: AsRef<[u8]> +{ + const PEM_BEGIN: &'static str = "-----BEGIN "; + const PEM_END: &'static str = "-----END "; + const PEM_TAG_SUFFIX: &'static str = "-----\n"; + + let config = base64::Config::new(base64::CharacterSet::Standard, true); + + let mut approx_len = 0; + for certificate_der in certificates_der.clone() { + let unwrapped_len = certificate_der.as_ref().len() * 4 / 3 + 4; + approx_len += PEM_BEGIN.len() + tag.len() + PEM_TAG_SUFFIX.len(); + approx_len += unwrapped_len + (unwrapped_len / 64); + approx_len += PEM_END.len() + tag.len() + PEM_TAG_SUFFIX.len(); + } + + let mut encoded = String::with_capacity(approx_len); + for certificate_der in certificates_der { + encoded.push_str(PEM_BEGIN); + encoded.push_str(tag); + encoded.push_str(PEM_TAG_SUFFIX); + + let base64: String = base64::encode_config(certificate_der.as_ref(), config); + let base64_bytes: &[u8] = base64.as_ref(); + for base64_line_bytes in base64_bytes.chunks(64) { + let base64_line_str = str::from_utf8(base64_line_bytes) + .unwrap_or_else(|_| unreachable!("base64 is ascii")); + encoded.push_str(base64_line_str); + encoded.push_str("\n"); + } + + encoded.push_str(PEM_END); + encoded.push_str(tag); + encoded.push_str(PEM_TAG_SUFFIX); + } + encoded +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_decode() { + assert!(decode(b"").is_empty()); + assert!(decode(b"test").is_empty()); + // malformed data + assert!(decode(b"-----BEGIN TEST-----\n.GVzdA==\n-----END TEST-----").is_empty()); + // malformed tag + assert!(decode(b"-----BEGIN TEST------\ndGVzdA==\n-----END TEST-----").is_empty()); + assert!(decode(b"-----BEGIN TEST -----\ndGVzdA==\n-----END TEST-----").is_empty()); + assert!(decode(b"-----BEGIN TEST-----\ndGVzdA==\n-----END TEST -----").is_empty()); + // missing line break + assert!(decode(b"-----BEGIN TEST -----dGVzdA==\n-----END TEST-----").is_empty()); + assert!(decode(b"-----BEGIN TEST -----\ndGVzdA==-----END TEST-----").is_empty()); + // missing data + assert!(decode(b"-----BEGIN TEST-----\n\n-----END TEST-----").is_empty()); + assert!(decode(b"-----BEGIN TEST-----\n \n-----END TEST-----").is_empty()); + // valid tests + assert_eq!(decode(b"-----BEGIN -----\ndGVzdA==\n-----END -----"), [b"test"]); + assert_eq!(decode(b"-----BEGIN -----\t \r\n dGVzdA== \n\n-----END TEST-----"), [b"test"]); + assert_eq!(decode(b"-----BEGIN TEST-----\t \r\n dGVzdA== \n\n-----END TEST-----"), [b"test"]); + assert_eq!(decode(b"-----BEGIN TEST1 TEST2-----\t \r\n d\tG VzdA== \n\n-----END TEST1-TEST2-----"), [b"test"]); + let test_certs = [b"test"]; + assert_eq!(decode(encode("", &test_certs).as_bytes()), test_certs); + let test_certs = [b"test1", b"test2"]; + assert_eq!(decode(encode("SOME TAG", &test_certs).as_bytes()), test_certs); + } + + #[test] + fn test_encode() { + let no_certs: [&'static [u8]; 0] = []; + assert_eq!(encode("TEST", &no_certs), ""); + assert_eq!(encode("", &[b"test"]), "-----BEGIN -----\ndGVzdA==\n-----END -----\n"); + assert_eq!(encode("TEST", &[b""]), "-----BEGIN TEST-----\n-----END TEST-----\n"); + assert_eq!(encode("TEST", &[b"test1", b"test2"]), + "-----BEGIN TEST-----\ndGVzdDE=\n-----END TEST-----\n-----BEGIN TEST-----\ndGVzdDI=\n-----END TEST-----\n"); + } +} diff --git a/kbupd_util/src/percent.rs b/kbupd_util/src/percent.rs new file mode 100644 index 0000000..745a93c --- /dev/null +++ b/kbupd_util/src/percent.rs @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2019 Open Whisper Systems + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +use std::str; + +pub fn decode(encoded: &[u8]) -> Vec { + let re = regex::bytes::Regex::new(r"(?-u)%(?:([\x30-\x39\x41-\x46\x61-\x66]{2})|(%))").unwrap(); + + let mut decoded: Vec = Vec::with_capacity(encoded.len()); + let mut last_match_end = 0; + for capture in re.captures_iter(encoded) { + if let Some(capture_match) = capture.get(1) { + decoded.extend(&encoded[last_match_end..(capture_match.start() - 1)]); + decoded.push(u8::from_str_radix(str::from_utf8(capture_match.as_bytes()).unwrap(), 16).unwrap()); + last_match_end = capture_match.end(); + } else if let Some(capture_match) = capture.get(2) { + decoded.extend(&encoded[last_match_end..capture_match.start()]); + last_match_end = capture_match.end(); + } + } + decoded.extend(&encoded[last_match_end..]); + decoded +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_decode() { + // valid tests + assert_eq!(decode(b""), b""); + assert_eq!(decode(b"%20"), b" "); + assert_eq!(decode(b"%%"), b"%"); + assert_eq!(decode(b"ABCD"), b"ABCD"); + assert_eq!(decode(b"%41BC"), b"ABC"); + assert_eq!(decode(b"AB%%%43"), b"AB%C"); + assert_eq!(decode(b"test1%20test2"), b"test1 test2"); + assert_eq!(decode(b"test1%20%20test2"), b"test1 test2"); + assert_eq!(decode(b"%00%FF"), b"\x00\xFF"); + // invalid hex + assert_eq!(decode(b"%"), b"%"); + assert_eq!(decode(b"%A"), b"%A"); + assert_eq!(decode(b"%%A"), b"%A"); + assert_eq!(decode(b"%%%"), b"%%"); + assert_eq!(decode(b"%AZ"), b"%AZ"); + assert_eq!(decode(b"AB%C"), b"AB%C"); + assert_eq!(decode(b"AB%C%D"), b"AB%C%D"); + assert_eq!(decode(b"AB%C%44"), b"AB%CD"); + } +} diff --git a/kbupd_util/src/thread.rs b/kbupd_util/src/thread.rs new file mode 100644 index 0000000..ae03edd --- /dev/null +++ b/kbupd_util/src/thread.rs @@ -0,0 +1,125 @@ +/* + * Copyright (C) 2019 Open Whisper Systems + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +use std::thread; +use std::thread::{JoinHandle}; +use std::time::*; +use std::sync::*; + +#[derive(Clone)] +pub struct StopJoinHandle { + stop_state: Arc, + join_handle: Arc>>>, +} + +#[derive(Default)] +pub struct StopState { + condvar: Condvar, + stopped: Mutex, +} + +impl StopJoinHandle { + pub fn new(stop_state: Arc, join_handle: JoinHandle) -> Self { + Self { + stop_state, + join_handle: Arc::new(Mutex::new(Some(join_handle))), + } + } + pub fn stop(&self) { + let mut stopped_guard = match self.stop_state.stopped.lock() { + Ok(guard) => guard, + Err(poison) => poison.into_inner(), + }; + *stopped_guard = true; + self.stop_state.condvar.notify_all(); + } + pub fn join(&self) -> Option> { + let mut join_handle_guard = match self.join_handle.lock() { + Ok(guard) => guard, + Err(poison) => poison.into_inner(), + }; + if let Some(join_handle) = join_handle_guard.take() { + Some(join_handle.join()) + } else { + None + } + } +} +impl StopState { + pub fn sleep_while_running(&self, duration: Duration) -> bool { + let mut stopped_guard = match self.stopped.lock() { + Ok(guard) => guard, + Err(poison) => poison.into_inner(), + }; + let start = Instant::now(); + loop { + if *stopped_guard { + break false; + } + let timeout = match duration.checked_sub(start.elapsed()) { + Some(timeout) => timeout, + None => break true, + }; + stopped_guard = { + let (stopped_guard, wait_timeout_result) = match self.condvar.wait_timeout(stopped_guard, timeout) { + Ok(result) => result, + Err(poison) => poison.into_inner(), + }; + if wait_timeout_result.timed_out() { + break !*stopped_guard; + } else { + stopped_guard + } + }; + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_stop_join() { + let stop_state = Arc::new(StopState::default()); + let stop_state_2 = stop_state.clone(); + let join_handle = std::thread::spawn(move || { + assert!(!stop_state.sleep_while_running(Duration::from_secs(60))); + }); + let stop_join_handle = StopJoinHandle::new(stop_state_2, join_handle); + stop_join_handle.stop(); + let () = stop_join_handle.join().unwrap().unwrap(); + assert!(stop_join_handle.join().is_none()); + } + + #[test] + fn test_sleep_while_running() { + let stop_state = Arc::new(StopState::default()); + let stop_state_2 = stop_state.clone(); + let (tx, rx) = std::sync::mpsc::channel(); + let join_handle = std::thread::spawn(move || { + assert!(stop_state.sleep_while_running(Duration::from_millis(1))); + assert!(stop_state.sleep_while_running(Duration::from_millis(1))); + let _ = tx.send(()); + assert!(!stop_state.sleep_while_running(Duration::from_secs(60))); + }); + let stop_join_handle = StopJoinHandle::new(stop_state_2, join_handle); + let () = rx.recv().unwrap(); + stop_join_handle.stop(); + let () = stop_join_handle.join().unwrap().unwrap(); + } +}