copy ias_client and kbupd_util from SVR
This copies over the ias_client and its dependency kbupd_util over from the SecureValueRecovery repo. These are from the commit 5725cc27c061dac688feb1ff31ec6027dead718f.
This commit is contained in:
parent
8e686faf83
commit
7d753bbf2f
@ -1,3 +1,7 @@
|
||||
[workspace]
|
||||
|
||||
members = ["sgx_sdk_ffi"]
|
||||
members = [
|
||||
"sgx_sdk_ffi",
|
||||
"ias_client",
|
||||
"kbupd_util",
|
||||
]
|
||||
|
||||
19
ias_client/Cargo.toml
Normal file
19
ias_client/Cargo.toml
Normal file
@ -0,0 +1,19 @@
|
||||
[package]
|
||||
authors = ["Open Whisper Systems"]
|
||||
name = "ias_client"
|
||||
version = "0.1.0"
|
||||
license = "AGPL-3.0-or-later"
|
||||
description = "Intel Attestation Services client"
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
failure = "0.1"
|
||||
futures = "0.1"
|
||||
http = "0.1"
|
||||
hyper = "0.12"
|
||||
kbupd_util = { path = "../kbupd_util" }
|
||||
serde = "1.0"
|
||||
serde_derive = "1.0"
|
||||
serde_json = "1.0"
|
||||
sgx_sdk_ffi = { path = "../sgx_sdk_ffi" }
|
||||
try_future = "0.1"
|
||||
324
ias_client/src/lib.rs
Normal file
324
ias_client/src/lib.rs
Normal file
@ -0,0 +1,324 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Open Whisper Systems
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
use std::fmt;
|
||||
use std::mem;
|
||||
|
||||
use failure::{format_err, ResultExt};
|
||||
use futures::prelude::*;
|
||||
use http::header::HeaderValue;
|
||||
use http::uri::PathAndQuery;
|
||||
use http::{self, HeaderMap, Uri};
|
||||
use hyper::client::connect::Connect;
|
||||
use hyper::{Body, Chunk, Client, Method, Request, Response};
|
||||
use kbupd_util::base64;
|
||||
use serde_derive::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
use sgx_sdk_ffi::SgxQuote;
|
||||
use try_future::{try_future, TryFuture};
|
||||
|
||||
pub struct IasClient<C> {
|
||||
base_uri: Uri,
|
||||
api_key: Option<HeaderValue>,
|
||||
client: Client<C, Body>,
|
||||
}
|
||||
|
||||
#[derive(Debug, failure::Fail)]
|
||||
pub enum GetQuoteSignatureError {
|
||||
#[fail(display = "error fetching signed quote: {:?}", _0)]
|
||||
FetchError(#[cause] failure::Error),
|
||||
#[fail(display = "quote verification error: {:?}", _0)]
|
||||
QuoteVerificationError(#[cause] QuoteVerificationError),
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct SignatureRevocationList(pub Vec<u8>);
|
||||
|
||||
impl<C> IasClient<C>
|
||||
where
|
||||
C: Connect + 'static,
|
||||
C::Transport: 'static,
|
||||
C::Future: 'static,
|
||||
{
|
||||
pub fn new(base_uri: &str, api_key: Option<&str>, connector: C) -> Result<Self, failure::Error> {
|
||||
let base_uri = if api_key.is_some() {
|
||||
uri_path_join(base_uri.parse()?, format_args!("/attestation/v3"))?
|
||||
} else {
|
||||
uri_path_join(base_uri.parse()?, format_args!("/attestation/sgx/v3"))?
|
||||
};
|
||||
let client = Client::builder().build(connector);
|
||||
let api_key = match api_key {
|
||||
Some(api_key) => Some(HeaderValue::from_bytes(api_key.as_bytes()).context("invalid IAS API key value")?),
|
||||
None => None,
|
||||
};
|
||||
Ok(Self { base_uri, api_key, client })
|
||||
}
|
||||
|
||||
pub fn get_signature_revocation_list(&self, gid: u32) -> impl Future<Item = SignatureRevocationList, Error = failure::Error> {
|
||||
let uri = try_future!(self.request_uri(format_args!("/sigrl/{:08x}", gid)));
|
||||
|
||||
let mut hyper_request = Request::new(Body::empty());
|
||||
|
||||
*hyper_request.method_mut() = Method::GET;
|
||||
*hyper_request.uri_mut() = uri;
|
||||
|
||||
if let Some(api_key) = &self.api_key {
|
||||
hyper_request.headers_mut().insert("Ocp-Apim-Subscription-Key", api_key.clone());
|
||||
}
|
||||
|
||||
let response = self.client.request(hyper_request);
|
||||
let response_data = response.from_err().and_then(|response: Response<Body>| {
|
||||
if !response.status().is_success() {
|
||||
return TryFuture::from_error(format_err!("HTTP error: {}", response.status().as_str()));
|
||||
}
|
||||
response.into_body().concat2().from_err().into()
|
||||
});
|
||||
|
||||
let decoded_response =
|
||||
response_data.and_then(|data: Chunk| base64::decode(&data).map(SignatureRevocationList).into_future().from_err());
|
||||
|
||||
decoded_response.into()
|
||||
}
|
||||
|
||||
fn fetch_quote_signature(&self, quote: &[u8]) -> impl Future<Item = (http::response::Parts, Chunk), Error = failure::Error> {
|
||||
let uri = try_future!(self.request_uri(format_args!("/report")));
|
||||
|
||||
let request = QuoteSignatureRequest { isvEnclaveQuote: quote };
|
||||
let encoded_request = try_future!(serde_json::to_vec(&request));
|
||||
let mut hyper_request = Request::new(Body::from(encoded_request));
|
||||
|
||||
*hyper_request.method_mut() = Method::POST;
|
||||
*hyper_request.uri_mut() = uri;
|
||||
hyper_request
|
||||
.headers_mut()
|
||||
.insert("Content-Type", HeaderValue::from_static("application/json"));
|
||||
|
||||
if let Some(api_key) = &self.api_key {
|
||||
hyper_request.headers_mut().insert("Ocp-Apim-Subscription-Key", api_key.clone());
|
||||
}
|
||||
|
||||
let response = self.client.request(hyper_request);
|
||||
let full_response = response.and_then(move |response: Response<Body>| {
|
||||
let (response_parts, response_body) = response.into_parts();
|
||||
|
||||
let response_data = response_body.concat2();
|
||||
|
||||
response_data.map(|response_data| (response_parts, response_data))
|
||||
});
|
||||
full_response.from_err().into()
|
||||
}
|
||||
|
||||
pub fn get_quote_signature(
|
||||
&self,
|
||||
quote: Vec<u8>,
|
||||
accept_group_out_of_date: bool,
|
||||
) -> impl Future<Item = SignedQuote, Error = GetQuoteSignatureError>
|
||||
{
|
||||
let response = self.fetch_quote_signature("e);
|
||||
let signed_quote = response.then(move |response_result: Result<(http::response::Parts, Chunk), failure::Error>| {
|
||||
let (response_parts, response_data) = response_result.map_err(GetQuoteSignatureError::FetchError)?;
|
||||
|
||||
let signed_quote_result = validate_quote_signature(response_parts, response_data, quote, accept_group_out_of_date);
|
||||
signed_quote_result.map_err(GetQuoteSignatureError::QuoteVerificationError)
|
||||
});
|
||||
|
||||
signed_quote
|
||||
}
|
||||
|
||||
fn request_uri(&self, request_path: fmt::Arguments<'_>) -> Result<Uri, failure::Error> {
|
||||
uri_path_join(self.base_uri.clone(), request_path)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Clone for IasClient<C> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
base_uri: self.base_uri.clone(),
|
||||
api_key: self.api_key.clone(),
|
||||
client: self.client.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn uri_path_join(uri: Uri, append_path: fmt::Arguments<'_>) -> Result<Uri, failure::Error> {
|
||||
let mut parts = uri.into_parts();
|
||||
let path_base = parts
|
||||
.path_and_query
|
||||
.as_ref()
|
||||
.map(PathAndQuery::path)
|
||||
.unwrap_or_default()
|
||||
.trim_end_matches('/');
|
||||
parts.path_and_query = Some(format!("{}{}", path_base, append_path).parse::<http::uri::PathAndQuery>()?);
|
||||
let uri = Uri::from_parts(parts)?;
|
||||
Ok(uri)
|
||||
}
|
||||
|
||||
fn validate_quote_signature(
|
||||
response_parts: http::response::Parts,
|
||||
response_body_data: Chunk,
|
||||
quote: Vec<u8>,
|
||||
accept_group_out_of_date: bool,
|
||||
) -> Result<SignedQuote, QuoteVerificationError>
|
||||
{
|
||||
if !response_parts.status.is_success() {
|
||||
let response_body_string = String::from_utf8_lossy(&response_body_data).to_string();
|
||||
return Err(QuoteVerificationError::HttpError(
|
||||
response_parts.status,
|
||||
response_parts,
|
||||
response_body_string,
|
||||
));
|
||||
}
|
||||
|
||||
let base64_signature = get_header_str(&response_parts.headers, "X-IASReport-Signature")?;
|
||||
let pem_certificates = get_header_str(&response_parts.headers, "X-IASReport-Signing-Certificate")?;
|
||||
|
||||
let signature =
|
||||
base64::decode(base64_signature.as_bytes()).map_err(|_| QuoteVerificationError::InvalidSignature(base64_signature.to_string()))?;
|
||||
|
||||
let certificates = kbupd_util::pem::decode(&kbupd_util::percent::decode(pem_certificates.as_bytes()));
|
||||
|
||||
if certificates.is_empty() {
|
||||
return Err(QuoteVerificationError::InvalidCertificates(pem_certificates.to_string()));
|
||||
}
|
||||
|
||||
let body = response_body_data.to_vec();
|
||||
|
||||
let parsed_body: QuoteSignatureResponseBody =
|
||||
serde_json::from_slice(&body).map_err(|parse_error| QuoteVerificationError::InvalidJson(parse_error.into()))?;
|
||||
|
||||
if parsed_body.version != 3 {
|
||||
return Err(QuoteVerificationError::WrongVersion(parsed_body.version));
|
||||
}
|
||||
|
||||
if Some(&parsed_body.isvEnclaveQuoteBody[..]) != quote.get(..mem::size_of::<SgxQuote>() - 4) {
|
||||
return Err(QuoteVerificationError::WrongQuote);
|
||||
}
|
||||
|
||||
match parsed_body.isvEnclaveQuoteStatus.as_str() {
|
||||
"OK" => {}
|
||||
"GROUP_OUT_OF_DATE" | "CONFIGURATION_NEEDED" => {
|
||||
if !accept_group_out_of_date {
|
||||
return Err(QuoteVerificationError::GroupOutOfDate(
|
||||
parsed_body.isvEnclaveQuoteStatus.clone(),
|
||||
parsed_body,
|
||||
));
|
||||
}
|
||||
}
|
||||
"GROUP_REVOKED" => {
|
||||
return Err(QuoteVerificationError::GroupOutOfDate(
|
||||
parsed_body.isvEnclaveQuoteStatus.clone(),
|
||||
parsed_body,
|
||||
));
|
||||
}
|
||||
"SIGRL_VERSION_MISMATCH" => {
|
||||
return Err(QuoteVerificationError::StaleRevocationList);
|
||||
}
|
||||
_ => {
|
||||
return Err(QuoteVerificationError::AttestationError(parsed_body.isvEnclaveQuoteStatus));
|
||||
}
|
||||
}
|
||||
|
||||
// XXX validate timestamp
|
||||
|
||||
Ok(SignedQuote {
|
||||
quote,
|
||||
body,
|
||||
signature,
|
||||
certificates,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_header_str<'a>(headers: &'a HeaderMap, name: &'static str) -> Result<&'a str, QuoteVerificationError> {
|
||||
if let Some(header) = headers.get(name) {
|
||||
match header.to_str() {
|
||||
Ok(header) => Ok(header),
|
||||
Err(_) => Err(QuoteVerificationError::InvalidHeaderValue(name, header.clone())),
|
||||
}
|
||||
} else {
|
||||
Err(QuoteVerificationError::MissingHeader(name))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for SignatureRevocationList {
|
||||
type Target = [u8];
|
||||
|
||||
fn deref(&self) -> &[u8] {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(failure::Fail)]
|
||||
pub enum QuoteVerificationError {
|
||||
#[fail(display = "attestation http error: {}", _0)]
|
||||
HttpError(http::status::StatusCode, http::response::Parts, String),
|
||||
#[fail(display = "missing attestation http header {}", _0)]
|
||||
MissingHeader(&'static str),
|
||||
#[fail(display = "invalid attestation http header value for {}: {:?}", _0, _1)]
|
||||
InvalidHeaderValue(&'static str, HeaderValue),
|
||||
#[fail(display = "invalid attestation signature: {}", _0)]
|
||||
InvalidSignature(String),
|
||||
#[fail(display = "invalid attestation certificates: {}", _0)]
|
||||
InvalidCertificates(String),
|
||||
#[fail(display = "invalid attestation report json: {}", _0)]
|
||||
InvalidJson(#[cause] failure::Error),
|
||||
#[fail(display = "invalid attestation report version: {}", _0)]
|
||||
WrongVersion(u64),
|
||||
#[fail(display = "wrong attestation report quote")]
|
||||
WrongQuote,
|
||||
#[fail(display = "stale attestation revocation list")]
|
||||
StaleRevocationList,
|
||||
#[fail(display = "attestation group out of date: {}", _0)]
|
||||
GroupOutOfDate(String, QuoteSignatureResponseBody),
|
||||
#[fail(display = "attestation error: {}", _0)]
|
||||
AttestationError(String),
|
||||
}
|
||||
|
||||
impl fmt::Debug for QuoteVerificationError {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
fmt::Display::fmt(self, fmt)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SignedQuote {
|
||||
pub quote: Vec<u8>,
|
||||
pub body: Vec<u8>,
|
||||
pub signature: Vec<u8>,
|
||||
pub certificates: Vec<Vec<u8>>,
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
#[derive(Serialize)]
|
||||
pub struct QuoteSignatureRequest<'a> {
|
||||
#[serde(with = "base64")]
|
||||
pub isvEnclaveQuote: &'a [u8],
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct QuoteSignatureResponseBody {
|
||||
pub isvEnclaveQuoteStatus: String,
|
||||
|
||||
#[serde(with = "base64")]
|
||||
pub isvEnclaveQuoteBody: Vec<u8>,
|
||||
|
||||
pub version: u64,
|
||||
|
||||
pub timestamp: String,
|
||||
|
||||
pub platformInfoBlob: Option<String>,
|
||||
}
|
||||
16
kbupd_util/Cargo.toml
Normal file
16
kbupd_util/Cargo.toml
Normal file
@ -0,0 +1,16 @@
|
||||
[package]
|
||||
authors = ["Open Whisper Systems"]
|
||||
name = "kbupd_util"
|
||||
version = "0.1.0"
|
||||
license = "AGPL-3.0-or-later"
|
||||
description = "Key Backup Service Daemon Utility Library"
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
base64 = "0.10"
|
||||
bytes = "0.4"
|
||||
hex = "0.4"
|
||||
rand = "0.6"
|
||||
serde = "1.0"
|
||||
serde_derive = "1.0"
|
||||
regex = "1.1"
|
||||
114
kbupd_util/src/base64.rs
Normal file
114
kbupd_util/src/base64.rs
Normal file
@ -0,0 +1,114 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Open Whisper Systems
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
use std::fmt;
|
||||
use std::marker::{PhantomData};
|
||||
|
||||
use base64;
|
||||
use serde::{Deserializer, Serializer};
|
||||
|
||||
pub fn decode(encoded: &[u8]) -> Result<Vec<u8>, base64::DecodeError> {
|
||||
let space_regex = regex::bytes::Regex::new(r"[ \t\r\n]").unwrap();
|
||||
let base64_data = space_regex.replace_all(encoded, &b""[..]);
|
||||
let config = base64::Config::new(base64::CharacterSet::Standard, true);
|
||||
base64::decode_config(&base64_data, config)
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Vec<u8>, D::Error> {
|
||||
deserializer.deserialize_bytes(Base64Visitor)
|
||||
}
|
||||
|
||||
pub fn serialize<S: Serializer>(data: impl AsRef<[u8]>, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
serializer.serialize_str(&base64::encode(data.as_ref()))
|
||||
}
|
||||
|
||||
//
|
||||
// Base64Visitor impls
|
||||
//
|
||||
|
||||
struct Base64Visitor;
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for Base64Visitor {
|
||||
type Value = Vec<u8>;
|
||||
fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
fmt.write_str("a base64-encoded string")
|
||||
}
|
||||
|
||||
fn visit_bytes<E>(self, base64: &[u8]) -> Result<Self::Value, E>
|
||||
where E: serde::de::Error
|
||||
{
|
||||
decode(base64).map_err(|error| E::custom(error.to_string()))
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, base64: &str) -> Result<Self::Value, E>
|
||||
where E: serde::de::Error
|
||||
{
|
||||
self.visit_bytes(base64.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// FixedLengthBase64Visitor impls
|
||||
//
|
||||
|
||||
struct FixedLengthBase64Visitor<T>(PhantomData<T>);
|
||||
|
||||
impl<'de, T> serde::de::Visitor<'de> for FixedLengthBase64Visitor<T>
|
||||
where T: AsMut<[u8]> + AsRef<[u8]> + Default
|
||||
{
|
||||
type Value = T;
|
||||
fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
fmt.write_str("a base64-encoded string")
|
||||
}
|
||||
|
||||
fn visit_bytes<E>(self, base64: &[u8]) -> Result<Self::Value, E>
|
||||
where E: serde::de::Error
|
||||
{
|
||||
let mut deserialized = T::default();
|
||||
let estimated_length = (base64.len() + 3) / 4 * 3;
|
||||
if estimated_length > deserialized.as_ref().len() + 2 {
|
||||
Err(E::custom(format!("base64 parameter length {} > {}", estimated_length, deserialized.as_ref().len())))
|
||||
} else {
|
||||
let data = decode(base64).map_err(|error| E::custom(error.to_string()))?;
|
||||
if data.len() != deserialized.as_ref().len() {
|
||||
Err(E::custom(format!("base64 parameter length {} != {}", data.len(), deserialized.as_ref().len())))
|
||||
} else {
|
||||
deserialized.as_mut().copy_from_slice(&data[..]);
|
||||
Ok(deserialized)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, base64: &str) -> Result<Self::Value, E>
|
||||
where E: serde::de::Error
|
||||
{
|
||||
self.visit_bytes(base64.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SerdeFixedLengthBase64: Sized + AsMut<[u8]> + AsRef<[u8]> + Default {
|
||||
fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
|
||||
deserializer.deserialize_bytes(FixedLengthBase64Visitor(PhantomData))
|
||||
}
|
||||
|
||||
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
serialize(self, serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]> + Default> SerdeFixedLengthBase64 for T {
|
||||
}
|
||||
69
kbupd_util/src/duration.rs
Normal file
69
kbupd_util/src/duration.rs
Normal file
@ -0,0 +1,69 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Open Whisper Systems
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
use std::time::{Duration};
|
||||
|
||||
use rand::Rng;
|
||||
|
||||
pub const NANOS_PER_SEC: u32 = 1_000_000_000;
|
||||
|
||||
pub fn random(max: Duration) -> Duration {
|
||||
let secs = rand::thread_rng().gen_range(0, max.as_secs().saturating_add(1));
|
||||
let nanos = rand::thread_rng().gen_range(0, max.subsec_nanos().saturating_add(1));
|
||||
Duration::new(secs, nanos)
|
||||
}
|
||||
|
||||
pub fn as_ticks(duration: Duration, tick_interval: Duration) -> u32 {
|
||||
let duration_ms = duration.as_millis();
|
||||
let tick_interval_ms = tick_interval.as_millis();
|
||||
let ticks = duration_ms.saturating_add(tick_interval_ms.saturating_sub(1))
|
||||
.checked_div(tick_interval_ms)
|
||||
.unwrap_or(0);
|
||||
ticks as u32
|
||||
}
|
||||
|
||||
pub fn as_secs_f64(duration: Duration) -> f64 {
|
||||
(duration.as_secs() as f64) + (duration.subsec_nanos() as f64) / (NANOS_PER_SEC as f64)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::time::{Duration};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_as_ticks() {
|
||||
let max_duration = Duration::new(u64::max_value(), NANOS_PER_SEC - 1);
|
||||
|
||||
assert_eq!(as_ticks(Duration::from_secs(10), Duration::from_secs(1)), 10);
|
||||
assert_eq!(as_ticks(Duration::from_secs(10), Duration::from_secs(0)), 0);
|
||||
assert_eq!(as_ticks(Duration::from_millis(100), Duration::from_millis(10)), 10);
|
||||
assert_eq!(as_ticks(Duration::from_millis(100), Duration::from_millis(11)), 10);
|
||||
assert_eq!(as_ticks(Duration::from_millis(100), Duration::from_millis(12)), 9);
|
||||
assert_eq!(as_ticks(Duration::from_millis(100), Duration::from_millis(99)), 2);
|
||||
assert_eq!(as_ticks(Duration::from_millis(100), Duration::from_millis(100)), 1);
|
||||
assert_eq!(as_ticks(Duration::from_millis(100), Duration::from_millis(1000)), 1);
|
||||
|
||||
assert_eq!(as_ticks(max_duration, Duration::from_secs(0)), 0);
|
||||
assert_eq!(as_ticks(max_duration, max_duration), 1);
|
||||
assert_eq!(as_ticks(Duration::from_secs(0), max_duration), 0);
|
||||
assert_eq!(as_ticks(Duration::from_secs(0), Duration::from_secs(0)), 0);
|
||||
|
||||
assert_eq!(as_ticks(Duration::from_millis(1), max_duration), 1);
|
||||
}
|
||||
}
|
||||
154
kbupd_util/src/hex.rs
Normal file
154
kbupd_util/src/hex.rs
Normal file
@ -0,0 +1,154 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Open Whisper Systems
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
use std::fmt;
|
||||
use std::marker::{PhantomData};
|
||||
|
||||
use serde::{Deserializer, Serializer};
|
||||
|
||||
use super::{ToHex};
|
||||
|
||||
pub fn parse(hex: &str) -> Result<Vec<u8>, hex::FromHexError> {
|
||||
hex::decode(hex)
|
||||
}
|
||||
|
||||
pub fn parse_fixed<T>(hex: &str) -> Result<T, hex::FromHexError>
|
||||
where T: Sized + AsMut<[u8]> + AsRef<[u8]> + Default
|
||||
{
|
||||
let mut bytes = T::default();
|
||||
let () = hex::decode_to_slice(hex, bytes.as_mut())?;
|
||||
Ok(bytes)
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Vec<u8>, D::Error> {
|
||||
deserializer.deserialize_str(HexVisitor)
|
||||
}
|
||||
|
||||
pub fn serialize<S: Serializer>(data: impl AsRef<[u8]>, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
serializer.serialize_str(&format!("{}", ToHex(data.as_ref())))
|
||||
}
|
||||
|
||||
//
|
||||
// HexVisitor impls
|
||||
//
|
||||
|
||||
struct HexVisitor;
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for HexVisitor {
|
||||
type Value = Vec<u8>;
|
||||
fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
fmt.write_str("a hexadecimal-encoded string")
|
||||
}
|
||||
fn visit_str<E>(self, hex: &str) -> Result<Vec<u8>, E>
|
||||
where E: serde::de::Error
|
||||
{
|
||||
parse(hex).map_err(|error| E::custom(format!("{}", error)))
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// FixedLengthHexVisitor impls
|
||||
//
|
||||
|
||||
struct FixedLengthHexVisitor<T>(PhantomData<T>);
|
||||
|
||||
impl<'de, T> serde::de::Visitor<'de> for FixedLengthHexVisitor<T>
|
||||
where T: AsMut<[u8]> + AsRef<[u8]> + Default
|
||||
{
|
||||
type Value = T;
|
||||
fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
fmt.write_str("a hexadecimal-encoded string")
|
||||
}
|
||||
fn visit_str<E>(self, hex: &str) -> Result<Self::Value, E>
|
||||
where E: serde::de::Error
|
||||
{
|
||||
parse_fixed(hex).map_err(|error| E::custom(format!("{}", error)))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SerdeFixedLengthHex: Sized + AsMut<[u8]> + AsRef<[u8]> + Default {
|
||||
fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
|
||||
deserializer.deserialize_str(FixedLengthHexVisitor(PhantomData))
|
||||
}
|
||||
|
||||
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
serialize(self, serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsMut<[u8]> + AsRef<[u8]> + Default> SerdeFixedLengthHex for T {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse() {
|
||||
assert_eq!(&parse("").unwrap(), b"");
|
||||
assert_eq!(&parse("616263").unwrap(), b"abc");
|
||||
assert_eq!(&parse("00fF").unwrap(), b"\x00\xFF");
|
||||
|
||||
parse("\n").unwrap_err();
|
||||
parse(" ").unwrap_err();
|
||||
parse(" 00").unwrap_err();
|
||||
parse("00 ").unwrap_err();
|
||||
parse("00\n").unwrap_err();
|
||||
parse(" 00 ").unwrap_err();
|
||||
parse("0 0").unwrap_err();
|
||||
parse("0").unwrap_err();
|
||||
parse("0g").unwrap_err();
|
||||
parse("0\x00").unwrap_err();
|
||||
parse("\x00").unwrap_err();
|
||||
parse("\x00\x00").unwrap_err();
|
||||
parse("FF\x7F").unwrap_err();
|
||||
parse("000").unwrap_err();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_fixed() {
|
||||
assert_eq!(&parse_fixed::<[u8; 0]>("").unwrap(), b"");
|
||||
assert_eq!(&parse_fixed::<[u8; 3]>("616263").unwrap(), b"abc");
|
||||
assert_eq!(&parse_fixed::<[u8; 2]>("00fF").unwrap(), b"\x00\xFF");
|
||||
|
||||
parse_fixed::<[u8; 1]>("").unwrap_err();
|
||||
parse_fixed::<[u8; 0]>("00").unwrap_err();
|
||||
parse_fixed::<[u8; 2]>("00").unwrap_err();
|
||||
|
||||
macro_rules! test_parse_fixed {
|
||||
($n:literal) => ({
|
||||
parse_fixed::<[u8; $n]>("\n").unwrap_err();
|
||||
parse_fixed::<[u8; $n]>(" ").unwrap_err();
|
||||
parse_fixed::<[u8; $n]>(" 00").unwrap_err();
|
||||
parse_fixed::<[u8; $n]>("00 ").unwrap_err();
|
||||
parse_fixed::<[u8; $n]>("00\n").unwrap_err();
|
||||
parse_fixed::<[u8; $n]>(" 00 ").unwrap_err();
|
||||
parse_fixed::<[u8; $n]>("0 0").unwrap_err();
|
||||
parse_fixed::<[u8; $n]>("0").unwrap_err();
|
||||
parse_fixed::<[u8; $n]>("0g").unwrap_err();
|
||||
parse_fixed::<[u8; $n]>("0\x00").unwrap_err();
|
||||
parse_fixed::<[u8; $n]>("\x00").unwrap_err();
|
||||
parse_fixed::<[u8; $n]>("\x00\x00").unwrap_err();
|
||||
parse_fixed::<[u8; $n]>("FF\x7F").unwrap_err();
|
||||
parse_fixed::<[u8; $n]>("000").unwrap_err();
|
||||
})
|
||||
}
|
||||
test_parse_fixed!(0);
|
||||
test_parse_fixed!(1);
|
||||
test_parse_fixed!(2);
|
||||
}
|
||||
}
|
||||
144
kbupd_util/src/lib.rs
Normal file
144
kbupd_util/src/lib.rs
Normal file
@ -0,0 +1,144 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Open Whisper Systems
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
pub mod base64;
|
||||
pub mod duration;
|
||||
pub mod hex;
|
||||
pub mod pem;
|
||||
pub mod percent;
|
||||
pub mod thread;
|
||||
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
|
||||
pub struct ToHex<'a>(pub &'a [u8]);
|
||||
pub struct OptionDisplay<T>(pub Option<T>);
|
||||
pub struct ListDisplay<T>(pub T);
|
||||
pub struct DisplayAsDebug<T>(pub T);
|
||||
|
||||
pub enum Never {}
|
||||
|
||||
pub fn to_socket_addr(address: impl ToSocketAddrs) -> io::Result<SocketAddr> {
|
||||
address
|
||||
.to_socket_addrs()?
|
||||
.next()
|
||||
.ok_or(io::Error::new(io::ErrorKind::Other, "empty listen address"))
|
||||
}
|
||||
|
||||
//
|
||||
// ToHex impls
|
||||
//
|
||||
|
||||
impl<'a> ToHex<'a> {
|
||||
pub fn new<T: AsRef<[u8]>>(bytes: &'a T) -> Self {
|
||||
ToHex(bytes.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> fmt::Display for ToHex<'a> {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let ToHex(data) = self;
|
||||
for byte in *data {
|
||||
write!(fmt, "{:02x}", byte)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
impl<'a> fmt::Debug for ToHex<'a> {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt::Display::fmt(self, fmt)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> fmt::Display for OptionDisplay<T>
|
||||
where T: fmt::Display
|
||||
{
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let OptionDisplay(inner) = self;
|
||||
match inner {
|
||||
Some(inner) => fmt::Display::fmt(inner, fmt),
|
||||
None => write!(fmt, "<none>"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// OptionDisplay impls
|
||||
//
|
||||
|
||||
impl<T> fmt::Debug for OptionDisplay<T>
|
||||
where T: fmt::Display
|
||||
{
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt::Display::fmt(self, fmt)
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// ListDisplay impls
|
||||
//
|
||||
|
||||
impl<T> fmt::Display for ListDisplay<T>
|
||||
where
|
||||
T: IntoIterator + Clone,
|
||||
T::Item: fmt::Display,
|
||||
{
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let ListDisplay(inner) = self;
|
||||
fmt.debug_list().entries(inner.clone().into_iter().map(DisplayAsDebug)).finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> fmt::Debug for ListDisplay<T>
|
||||
where
|
||||
T: IntoIterator + Clone,
|
||||
T::Item: fmt::Display,
|
||||
{
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt::Display::fmt(self, fmt)
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// DisplayAsDebug impls
|
||||
//
|
||||
|
||||
impl<T> fmt::Debug for DisplayAsDebug<T>
|
||||
where T: fmt::Display
|
||||
{
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let DisplayAsDebug(inner) = self;
|
||||
fmt::Display::fmt(inner, fmt)
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Never impls
|
||||
//
|
||||
|
||||
macro_rules! from_never {
|
||||
($type:ty) => {
|
||||
impl From<Never> for $type {
|
||||
fn from(never: Never) -> Self {
|
||||
match never {}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
from_never!(());
|
||||
114
kbupd_util/src/pem.rs
Normal file
114
kbupd_util/src/pem.rs
Normal file
@ -0,0 +1,114 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Open Whisper Systems
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
use std::str;
|
||||
|
||||
pub fn decode(pem: &[u8]) -> Vec<Vec<u8>> {
|
||||
// based on RFC7468
|
||||
let re = regex::bytes::Regex::new(r"(?-u)-----BEGIN (?:[\x21-\x2C\x2E-\x7E](?:[- ]?[\x21-\x2C\x2E-\x7E])*)?-----[ \t]*(?:\r|\n|\r\n)[ \t\r\n]*([ \t\r\n\x2B\x2F\x30-\x39\x3D\x41-\x5A\x61-\x7A]*)(?:\r|\n|\r\n)-----END (?:[\x21-\x2C\x2E-\x7E](?:[- ]?[\x21-\x2C\x2E-\x7E])*)?-----").unwrap();
|
||||
|
||||
let mut certificates: Vec<Vec<u8>> = Vec::new();
|
||||
for captures in re.captures_iter(pem) {
|
||||
if let Ok(der) = crate::base64::decode(&captures[1]) {
|
||||
if !der.is_empty() {
|
||||
certificates.push(der);
|
||||
}
|
||||
}
|
||||
}
|
||||
certificates
|
||||
}
|
||||
|
||||
pub fn encode<T>(tag: &str, certificates_der: impl IntoIterator<Item = T> + Clone) -> String
|
||||
where T: AsRef<[u8]>
|
||||
{
|
||||
const PEM_BEGIN: &'static str = "-----BEGIN ";
|
||||
const PEM_END: &'static str = "-----END ";
|
||||
const PEM_TAG_SUFFIX: &'static str = "-----\n";
|
||||
|
||||
let config = base64::Config::new(base64::CharacterSet::Standard, true);
|
||||
|
||||
let mut approx_len = 0;
|
||||
for certificate_der in certificates_der.clone() {
|
||||
let unwrapped_len = certificate_der.as_ref().len() * 4 / 3 + 4;
|
||||
approx_len += PEM_BEGIN.len() + tag.len() + PEM_TAG_SUFFIX.len();
|
||||
approx_len += unwrapped_len + (unwrapped_len / 64);
|
||||
approx_len += PEM_END.len() + tag.len() + PEM_TAG_SUFFIX.len();
|
||||
}
|
||||
|
||||
let mut encoded = String::with_capacity(approx_len);
|
||||
for certificate_der in certificates_der {
|
||||
encoded.push_str(PEM_BEGIN);
|
||||
encoded.push_str(tag);
|
||||
encoded.push_str(PEM_TAG_SUFFIX);
|
||||
|
||||
let base64: String = base64::encode_config(certificate_der.as_ref(), config);
|
||||
let base64_bytes: &[u8] = base64.as_ref();
|
||||
for base64_line_bytes in base64_bytes.chunks(64) {
|
||||
let base64_line_str = str::from_utf8(base64_line_bytes)
|
||||
.unwrap_or_else(|_| unreachable!("base64 is ascii"));
|
||||
encoded.push_str(base64_line_str);
|
||||
encoded.push_str("\n");
|
||||
}
|
||||
|
||||
encoded.push_str(PEM_END);
|
||||
encoded.push_str(tag);
|
||||
encoded.push_str(PEM_TAG_SUFFIX);
|
||||
}
|
||||
encoded
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_decode() {
|
||||
assert!(decode(b"").is_empty());
|
||||
assert!(decode(b"test").is_empty());
|
||||
// malformed data
|
||||
assert!(decode(b"-----BEGIN TEST-----\n.GVzdA==\n-----END TEST-----").is_empty());
|
||||
// malformed tag
|
||||
assert!(decode(b"-----BEGIN TEST------\ndGVzdA==\n-----END TEST-----").is_empty());
|
||||
assert!(decode(b"-----BEGIN TEST -----\ndGVzdA==\n-----END TEST-----").is_empty());
|
||||
assert!(decode(b"-----BEGIN TEST-----\ndGVzdA==\n-----END TEST -----").is_empty());
|
||||
// missing line break
|
||||
assert!(decode(b"-----BEGIN TEST -----dGVzdA==\n-----END TEST-----").is_empty());
|
||||
assert!(decode(b"-----BEGIN TEST -----\ndGVzdA==-----END TEST-----").is_empty());
|
||||
// missing data
|
||||
assert!(decode(b"-----BEGIN TEST-----\n\n-----END TEST-----").is_empty());
|
||||
assert!(decode(b"-----BEGIN TEST-----\n \n-----END TEST-----").is_empty());
|
||||
// valid tests
|
||||
assert_eq!(decode(b"-----BEGIN -----\ndGVzdA==\n-----END -----"), [b"test"]);
|
||||
assert_eq!(decode(b"-----BEGIN -----\t \r\n dGVzdA== \n\n-----END TEST-----"), [b"test"]);
|
||||
assert_eq!(decode(b"-----BEGIN TEST-----\t \r\n dGVzdA== \n\n-----END TEST-----"), [b"test"]);
|
||||
assert_eq!(decode(b"-----BEGIN TEST1 TEST2-----\t \r\n d\tG VzdA== \n\n-----END TEST1-TEST2-----"), [b"test"]);
|
||||
let test_certs = [b"test"];
|
||||
assert_eq!(decode(encode("", &test_certs).as_bytes()), test_certs);
|
||||
let test_certs = [b"test1", b"test2"];
|
||||
assert_eq!(decode(encode("SOME TAG", &test_certs).as_bytes()), test_certs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode() {
|
||||
let no_certs: [&'static [u8]; 0] = [];
|
||||
assert_eq!(encode("TEST", &no_certs), "");
|
||||
assert_eq!(encode("", &[b"test"]), "-----BEGIN -----\ndGVzdA==\n-----END -----\n");
|
||||
assert_eq!(encode("TEST", &[b""]), "-----BEGIN TEST-----\n-----END TEST-----\n");
|
||||
assert_eq!(encode("TEST", &[b"test1", b"test2"]),
|
||||
"-----BEGIN TEST-----\ndGVzdDE=\n-----END TEST-----\n-----BEGIN TEST-----\ndGVzdDI=\n-----END TEST-----\n");
|
||||
}
|
||||
}
|
||||
65
kbupd_util/src/percent.rs
Normal file
65
kbupd_util/src/percent.rs
Normal file
@ -0,0 +1,65 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Open Whisper Systems
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
use std::str;
|
||||
|
||||
pub fn decode(encoded: &[u8]) -> Vec<u8> {
|
||||
let re = regex::bytes::Regex::new(r"(?-u)%(?:([\x30-\x39\x41-\x46\x61-\x66]{2})|(%))").unwrap();
|
||||
|
||||
let mut decoded: Vec<u8> = Vec::with_capacity(encoded.len());
|
||||
let mut last_match_end = 0;
|
||||
for capture in re.captures_iter(encoded) {
|
||||
if let Some(capture_match) = capture.get(1) {
|
||||
decoded.extend(&encoded[last_match_end..(capture_match.start() - 1)]);
|
||||
decoded.push(u8::from_str_radix(str::from_utf8(capture_match.as_bytes()).unwrap(), 16).unwrap());
|
||||
last_match_end = capture_match.end();
|
||||
} else if let Some(capture_match) = capture.get(2) {
|
||||
decoded.extend(&encoded[last_match_end..capture_match.start()]);
|
||||
last_match_end = capture_match.end();
|
||||
}
|
||||
}
|
||||
decoded.extend(&encoded[last_match_end..]);
|
||||
decoded
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_decode() {
|
||||
// valid tests
|
||||
assert_eq!(decode(b""), b"");
|
||||
assert_eq!(decode(b"%20"), b" ");
|
||||
assert_eq!(decode(b"%%"), b"%");
|
||||
assert_eq!(decode(b"ABCD"), b"ABCD");
|
||||
assert_eq!(decode(b"%41BC"), b"ABC");
|
||||
assert_eq!(decode(b"AB%%%43"), b"AB%C");
|
||||
assert_eq!(decode(b"test1%20test2"), b"test1 test2");
|
||||
assert_eq!(decode(b"test1%20%20test2"), b"test1 test2");
|
||||
assert_eq!(decode(b"%00%FF"), b"\x00\xFF");
|
||||
// invalid hex
|
||||
assert_eq!(decode(b"%"), b"%");
|
||||
assert_eq!(decode(b"%A"), b"%A");
|
||||
assert_eq!(decode(b"%%A"), b"%A");
|
||||
assert_eq!(decode(b"%%%"), b"%%");
|
||||
assert_eq!(decode(b"%AZ"), b"%AZ");
|
||||
assert_eq!(decode(b"AB%C"), b"AB%C");
|
||||
assert_eq!(decode(b"AB%C%D"), b"AB%C%D");
|
||||
assert_eq!(decode(b"AB%C%44"), b"AB%CD");
|
||||
}
|
||||
}
|
||||
125
kbupd_util/src/thread.rs
Normal file
125
kbupd_util/src/thread.rs
Normal file
@ -0,0 +1,125 @@
|
||||
/*
|
||||
* Copyright (C) 2019 Open Whisper Systems
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
use std::thread;
|
||||
use std::thread::{JoinHandle};
|
||||
use std::time::*;
|
||||
use std::sync::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StopJoinHandle<T> {
|
||||
stop_state: Arc<StopState>,
|
||||
join_handle: Arc<Mutex<Option<JoinHandle<T>>>>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct StopState {
|
||||
condvar: Condvar,
|
||||
stopped: Mutex<bool>,
|
||||
}
|
||||
|
||||
impl<T> StopJoinHandle<T> {
|
||||
pub fn new(stop_state: Arc<StopState>, join_handle: JoinHandle<T>) -> Self {
|
||||
Self {
|
||||
stop_state,
|
||||
join_handle: Arc::new(Mutex::new(Some(join_handle))),
|
||||
}
|
||||
}
|
||||
pub fn stop(&self) {
|
||||
let mut stopped_guard = match self.stop_state.stopped.lock() {
|
||||
Ok(guard) => guard,
|
||||
Err(poison) => poison.into_inner(),
|
||||
};
|
||||
*stopped_guard = true;
|
||||
self.stop_state.condvar.notify_all();
|
||||
}
|
||||
pub fn join(&self) -> Option<thread::Result<T>> {
|
||||
let mut join_handle_guard = match self.join_handle.lock() {
|
||||
Ok(guard) => guard,
|
||||
Err(poison) => poison.into_inner(),
|
||||
};
|
||||
if let Some(join_handle) = join_handle_guard.take() {
|
||||
Some(join_handle.join())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
impl StopState {
|
||||
pub fn sleep_while_running(&self, duration: Duration) -> bool {
|
||||
let mut stopped_guard = match self.stopped.lock() {
|
||||
Ok(guard) => guard,
|
||||
Err(poison) => poison.into_inner(),
|
||||
};
|
||||
let start = Instant::now();
|
||||
loop {
|
||||
if *stopped_guard {
|
||||
break false;
|
||||
}
|
||||
let timeout = match duration.checked_sub(start.elapsed()) {
|
||||
Some(timeout) => timeout,
|
||||
None => break true,
|
||||
};
|
||||
stopped_guard = {
|
||||
let (stopped_guard, wait_timeout_result) = match self.condvar.wait_timeout(stopped_guard, timeout) {
|
||||
Ok(result) => result,
|
||||
Err(poison) => poison.into_inner(),
|
||||
};
|
||||
if wait_timeout_result.timed_out() {
|
||||
break !*stopped_guard;
|
||||
} else {
|
||||
stopped_guard
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_stop_join() {
|
||||
let stop_state = Arc::new(StopState::default());
|
||||
let stop_state_2 = stop_state.clone();
|
||||
let join_handle = std::thread::spawn(move || {
|
||||
assert!(!stop_state.sleep_while_running(Duration::from_secs(60)));
|
||||
});
|
||||
let stop_join_handle = StopJoinHandle::new(stop_state_2, join_handle);
|
||||
stop_join_handle.stop();
|
||||
let () = stop_join_handle.join().unwrap().unwrap();
|
||||
assert!(stop_join_handle.join().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sleep_while_running() {
|
||||
let stop_state = Arc::new(StopState::default());
|
||||
let stop_state_2 = stop_state.clone();
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
let join_handle = std::thread::spawn(move || {
|
||||
assert!(stop_state.sleep_while_running(Duration::from_millis(1)));
|
||||
assert!(stop_state.sleep_while_running(Duration::from_millis(1)));
|
||||
let _ = tx.send(());
|
||||
assert!(!stop_state.sleep_while_running(Duration::from_secs(60)));
|
||||
});
|
||||
let stop_join_handle = StopJoinHandle::new(stop_state_2, join_handle);
|
||||
let () = rx.recv().unwrap();
|
||||
stop_join_handle.stop();
|
||||
let () = stop_join_handle.join().unwrap().unwrap();
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user