Co-authored-by: Alan Evans <alan@signal.org> Co-authored-by: Jim Gustafson <jim@signal.org> Co-authored-by: Jordan Rose <jrose@signal.org> Co-authored-by: Peter Thatcher <peter@signal.org>
This commit is contained in:
commit
3b91b08a5c
5
.cargo/audit.toml
Normal file
5
.cargo/audit.toml
Normal file
@ -0,0 +1,5 @@
|
||||
# https://docs.rs/crate/cargo-audit/0.14.1/source/audit.toml.example
|
||||
|
||||
[output]
|
||||
deny = ["unmaintained", "unsound", "yanked"]
|
||||
quiet = false
|
||||
3
.dockerignore
Normal file
3
.dockerignore
Normal file
@ -0,0 +1,3 @@
|
||||
target
|
||||
.git
|
||||
.github
|
||||
4
.github/FUNDING.yml
vendored
Normal file
4
.github/FUNDING.yml
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright 2021 Signal Messenger, LLC
|
||||
# SPDX-License-Identifier: AGPL-3.0-only
|
||||
|
||||
custom: https://signal.org/donate/
|
||||
41
.github/workflows/ci.yml
vendored
Normal file
41
.github/workflows/ci.yml
vendored
Normal file
@ -0,0 +1,41 @@
|
||||
name: Continuous integration
|
||||
|
||||
on: # rebuild any PRs and main branch changes
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
ci:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
override: true
|
||||
profile: minimal
|
||||
components: rustfmt, clippy
|
||||
- name: Environment
|
||||
run: rustup --version && cargo --version
|
||||
- name: Build
|
||||
run: cargo build
|
||||
- name: Format
|
||||
run: cargo fmt -- --check
|
||||
- uses: actions-rs/clippy-check@v1
|
||||
with:
|
||||
name: Clippy
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
args: --all-targets -- -D warnings
|
||||
- uses: actions-rs/clippy-check@v1
|
||||
with:
|
||||
name: Clippy (generic UDP)
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
args: --no-default-features -- -D warnings
|
||||
- name: Clippy (fuzz targets)
|
||||
run: cargo clippy --all-targets -- -D warnings
|
||||
working-directory: fuzz
|
||||
env:
|
||||
RUSTFLAGS: --cfg fuzzing
|
||||
- name: Test
|
||||
run: cargo test
|
||||
17
.github/workflows/security_audit.yml
vendored
Normal file
17
.github/workflows/security_audit.yml
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
# https://github.com/actions-rs/audit-check
|
||||
name: Security audit
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- '**/Cargo.toml'
|
||||
- '**/Cargo.lock'
|
||||
schedule:
|
||||
- cron: '0 0 * * *'
|
||||
jobs:
|
||||
audit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions-rs/audit-check@v1
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
/target
|
||||
.DS_Store
|
||||
.idea
|
||||
bin
|
||||
6
.idea/copyright/CoreTeamRust.xml
generated
Normal file
6
.idea/copyright/CoreTeamRust.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="CopyrightManager">
|
||||
<copyright>
|
||||
<option name="notice" value=" Copyright 2021 Signal Messenger, LLC SPDX-License-Identifier: AGPL-3.0-only " />
|
||||
<option name="myName" value="CoreTeamRust" />
|
||||
</copyright>
|
||||
</component>
|
||||
9
.idea/copyright/profiles_settings.xml
generated
Normal file
9
.idea/copyright/profiles_settings.xml
generated
Normal file
@ -0,0 +1,9 @@
|
||||
<component name="CopyrightManager">
|
||||
<settings default="CoreTeamRust">
|
||||
<LanguageOptions name="Rust">
|
||||
<option name="fileTypeOverride" value="3" />
|
||||
<option name="block" value="false" />
|
||||
</LanguageOptions>
|
||||
<LanguageOptions name="__TEMPLATE__" />
|
||||
</settings>
|
||||
</component>
|
||||
103
BUILDING.md
Normal file
103
BUILDING.md
Normal file
@ -0,0 +1,103 @@
|
||||
# Building the Calling Server
|
||||
|
||||
## For Development & Debugging
|
||||
|
||||
cargo run
|
||||
|
||||
You can specify a variety of command line arguments. See the [config.rs file](/src/config.rs) file for
|
||||
more details or run:
|
||||
|
||||
cargo run -- --help
|
||||
|
||||
A common example for debugging would be:
|
||||
|
||||
cargo run -- --binding-ip 192.168.1.100 --ice-candidate-ip 192.168.1.100 --diagnostics-interval-secs 1
|
||||
|
||||
where ```--binding-ip``` sets the IP address that the servers will listen on and ```--ice-candidate-ip```
|
||||
is the IP address that will accept media packets from clients. Use the IP addresses specific to your
|
||||
environment. ```--diagnostics-interval-secs``` sets the metrics gathering interval, here to be every
|
||||
second.
|
||||
|
||||
The configuration shown is for debugging and uses the internal http_server for testing, to which clients
|
||||
can connect directly. Usually this is achieved through a TLS veneer such as [ngrok](https://ngrok.com/).
|
||||
|
||||
## For Running Tests
|
||||
|
||||
cargo test
|
||||
|
||||
or
|
||||
|
||||
cargo test --release
|
||||
|
||||
## For Release Builds and Performance Testing
|
||||
|
||||
Release builds and all performance testing should use the ```--release``` build option:
|
||||
|
||||
cargo run --release
|
||||
|
||||
For best performance, the target CPU should also be specified. In this example, ```native``` is used
|
||||
to instruct the compiler to optimize for the CPU that is performing the build itself:
|
||||
|
||||
RUSTFLAGS="-C target-cpu=native" cargo run --release
|
||||
|
||||
## For Deployment
|
||||
|
||||
Signal uses the provided Dockerfile to build images for deployment. This uses a multi-stage process,
|
||||
creating a builder image, a minimal image for delivery, and a runnable image for testing.
|
||||
|
||||
### Building the Docker Image
|
||||
|
||||
Images currently run on AWS EC2 instances supporting the Intel Skylake architecture. When building
|
||||
the image, we can target that specific CPU (or choose any other that matches the platform where the
|
||||
container will be run):
|
||||
|
||||
docker build --build-arg rust_flags=-Ctarget-cpu=skylake -t signal-calling-server .
|
||||
|
||||
The ```build-arg``` can also be omitted to maintain maximum compatibility.
|
||||
|
||||
_Note: At the time of this writing, the skylake-avx512 target is not compatible with some dependencies._
|
||||
|
||||
### Deploying the Docker Image
|
||||
|
||||
The deployment is specific to the type of service or registry being used. For testing, the
|
||||
image can be saved and copied somewhere for running. To save:
|
||||
|
||||
docker save signal-calling-server:latest | gzip > signal-calling-server-latest.tar.gz
|
||||
|
||||
### Running the Docker Container
|
||||
|
||||
To run the container, the following docker command can be used:
|
||||
|
||||
docker run -d --rm -p 8080:8080 -p 10000:10000/udp signal-calling-server:latest
|
||||
|
||||
- ```-d``` runs the container in detached mode (can be omitted for easier testing)
|
||||
- ```--rm``` will clean up the container when it is stopped
|
||||
- ```-p 8080:8080``` connects the TCP port 8080 to the same one on the host
|
||||
- ```-p 10000:10000/udp``` connects the UDP port 10000 to the same one on the host
|
||||
|
||||
#### Configuration Options
|
||||
|
||||
Certain configuration options can be passed when running the container. Currently, any of the
|
||||
following can be specified:
|
||||
- ICE_CANDIDATE_IP
|
||||
- SIGNALING_IP
|
||||
- DIAGNOSTICS_INTERVAL_SECS
|
||||
|
||||
For example:
|
||||
|
||||
docker run --rm -p 8080:8080 -p 10000:10000/udp \
|
||||
-e ICE_CANDIDATE_IP=192.168.1.100 \
|
||||
-e DIAGNOSTICS_INTERVAL_SECS=1 \
|
||||
signal-calling-server:latest
|
||||
|
||||
The host will listen on port 8080 for requests and publish 192.168.1.100:10000 to clients for media
|
||||
access. Packets to that address will be routed to the running container.
|
||||
|
||||
### Binary Deployment
|
||||
|
||||
The docker file can also be used to obtain a binary file:
|
||||
|
||||
docker build --build-arg rust_flags=-Ctarget-cpu=skylake -t signal-calling-server --target export-stage -o bin .
|
||||
|
||||
This will build the calling_server binary executable for Linux and copy it to the ./bin directory of
|
||||
the host. The command will stop at the export-stage and not create the runnable docker image.
|
||||
2040
Cargo.lock
generated
Normal file
2040
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
94
Cargo.toml
Normal file
94
Cargo.toml
Normal file
@ -0,0 +1,94 @@
|
||||
#
|
||||
# Copyright 2019-2021 Signal Messenger, LLC
|
||||
# SPDX-License-Identifier: AGPL-3.0-only
|
||||
#
|
||||
|
||||
[package]
|
||||
name = "calling_server"
|
||||
version = "1.0.0"
|
||||
authors = ["Peter Thatcher <peter@signal.org>", "Jim Gustafson <jim@signal.org>", "Jordan Rose <jrose@signal.org>", "Alan Evans <alan@signal.org>"]
|
||||
edition = "2018"
|
||||
description = "Media forwarding server for group calls."
|
||||
license = "AGPL-3.0-only"
|
||||
|
||||
[dependencies]
|
||||
# For error handling
|
||||
anyhow = "1.0"
|
||||
thiserror = "1.0"
|
||||
|
||||
# For logging and command line operations
|
||||
log = "0.4"
|
||||
env_logger = "0.8"
|
||||
structopt = "0.3"
|
||||
|
||||
# For runtime and threading
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
parking_lot = "0.11"
|
||||
lazy_static = "1.4"
|
||||
futures = "0.3"
|
||||
num_cpus = "1.13"
|
||||
|
||||
# For http
|
||||
warp = "0.3"
|
||||
|
||||
# For general conversions
|
||||
base64 = "0.13"
|
||||
hex = { version = "0.4", features = ["serde"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
rcgen = "0.8"
|
||||
prost = "0.7"
|
||||
|
||||
# For common
|
||||
sha2 = "0.9"
|
||||
|
||||
# For ICE
|
||||
crc = "1.8"
|
||||
hmac = "0.11"
|
||||
sha-1 = "0.9"
|
||||
|
||||
# For DTLS
|
||||
aes = "0.7"
|
||||
aes-gcm = "0.9"
|
||||
rand = "0.8"
|
||||
p256 = { version = "0.8", features = ["ecdh", "ecdsa", "pkcs8"] }
|
||||
ecdsa = { version = "0.11", features = ["sign"] }
|
||||
# Already needed by p256. But we also need it for certificate parsing.
|
||||
der = "0.3"
|
||||
|
||||
# For DTLS and RTP
|
||||
zeroize = "1.4.1"
|
||||
|
||||
# For congestion control
|
||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||
async-stream = "0.3"
|
||||
# For congestion-control-specific helpers
|
||||
pin-project = "1.0"
|
||||
|
||||
# For low-level UDP sockets
|
||||
nix = { version = "0.20", optional = true }
|
||||
|
||||
# For current process memory stats
|
||||
psutil = { version = "3.2", default-features = false, features = ["process"] }
|
||||
|
||||
[dev-dependencies]
|
||||
unzip3 = "1.0"
|
||||
|
||||
# For simulating passage of time in timing tests
|
||||
mock_instant = { version = "0.2" }
|
||||
hex-literal = "0.3.2"
|
||||
|
||||
# For matching WebRTC's randomness
|
||||
rand_distr = "0.4.1"
|
||||
|
||||
# For testing warp responses
|
||||
serde_json = "1.0"
|
||||
|
||||
[profile.release]
|
||||
debug = 1
|
||||
panic = "abort"
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
|
||||
[features]
|
||||
default = ["epoll"]
|
||||
epoll = ["nix"]
|
||||
43
Dockerfile
Normal file
43
Dockerfile
Normal file
@ -0,0 +1,43 @@
|
||||
#
|
||||
# Copyright 2019-2021 Signal Messenger, LLC
|
||||
# SPDX-License-Identifier: AGPL-3.0-only
|
||||
#
|
||||
|
||||
# Use the current rust environment for building.
|
||||
FROM rust:1.53.0 AS build-stage
|
||||
RUN apt-get update
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
# Create a stub version of the project to cache dependencies.
|
||||
RUN USER=root cargo new calling-server
|
||||
WORKDIR /usr/src/calling-server
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
|
||||
# Take in a build argument to specify RUSTFLAGS environment, usually a target-cpu.
|
||||
ARG rust_flags
|
||||
ENV RUSTFLAGS=$rust_flags
|
||||
|
||||
# Do the initial stub build.
|
||||
RUN cargo build --release
|
||||
|
||||
# Copy the source and build the calling-server proper.
|
||||
COPY src ./src
|
||||
RUN cargo build --release
|
||||
|
||||
# Export the calling-server executable if the '-o' option is specified.
|
||||
FROM scratch AS export-stage
|
||||
|
||||
COPY --from=build-stage /usr/src/calling-server/target/release/calling_server calling_server
|
||||
|
||||
# Create a minimal container to deploy and run the calling-server.
|
||||
FROM debian:buster-slim AS run-stage
|
||||
|
||||
# Expose http and udp server access ports to this container.
|
||||
EXPOSE 8080
|
||||
EXPOSE 10000/udp
|
||||
|
||||
COPY --from=build-stage /usr/src/calling-server/target/release/calling_server .
|
||||
USER 1000
|
||||
|
||||
ENTRYPOINT ["./calling_server"]
|
||||
619
LICENSE
Normal file
619
LICENSE
Normal file
@ -0,0 +1,619 @@
|
||||
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||
Version 3, 19 November 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU Affero General Public License is a free, copyleft license for
|
||||
software and other kinds of works, specifically designed to ensure
|
||||
cooperation with the community in the case of network server software.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
our General Public Licenses are intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
Developers that use our General Public Licenses protect your rights
|
||||
with two steps: (1) assert copyright on the software, and (2) offer
|
||||
you this License which gives you legal permission to copy, distribute
|
||||
and/or modify the software.
|
||||
|
||||
A secondary benefit of defending all users' freedom is that
|
||||
improvements made in alternate versions of the program, if they
|
||||
receive widespread use, become available for other developers to
|
||||
incorporate. Many developers of free software are heartened and
|
||||
encouraged by the resulting cooperation. However, in the case of
|
||||
software used on network servers, this result may fail to come about.
|
||||
The GNU General Public License permits making a modified version and
|
||||
letting the public access it on a server without ever releasing its
|
||||
source code to the public.
|
||||
|
||||
The GNU Affero General Public License is designed specifically to
|
||||
ensure that, in such cases, the modified source code becomes available
|
||||
to the community. It requires the operator of a network server to
|
||||
provide the source code of the modified version running there to the
|
||||
users of that server. Therefore, public use of a modified version, on
|
||||
a publicly accessible server, gives the public access to the source
|
||||
code of the modified version.
|
||||
|
||||
An older license, called the Affero General Public License and
|
||||
published by Affero, was designed to accomplish similar goals. This is
|
||||
a different license, not a version of the Affero GPL, but Affero has
|
||||
released a new version of the Affero GPL which permits relicensing under
|
||||
this license.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, if you modify the
|
||||
Program, your modified version must prominently offer all users
|
||||
interacting with it remotely through a computer network (if your version
|
||||
supports such interaction) an opportunity to receive the Corresponding
|
||||
Source of your version by providing access to the Corresponding Source
|
||||
from a network server at no charge, through some standard or customary
|
||||
means of facilitating copying of software. This Corresponding Source
|
||||
shall include the Corresponding Source for any work covered by version 3
|
||||
of the GNU General Public License that is incorporated pursuant to the
|
||||
following paragraph.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the work with which it is combined will remain governed by version
|
||||
3 of the GNU General Public License.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU Affero General Public License from time to time. Such new versions
|
||||
will be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU Affero General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU Affero General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU Affero General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
22
README.md
Normal file
22
README.md
Normal file
@ -0,0 +1,22 @@
|
||||
# Calling Service
|
||||
|
||||
Forwards media from 1 group call device to N group call devices.
|
||||
|
||||
# Thanks
|
||||
|
||||
We thank WebRTC for the "googcc" congestion control algorithm (see googcc.rs for more details).
|
||||
|
||||
We thank Ilana Volfin, Israel Cohen, and Jitsi for the "Dominant Speaker Identification" algorithm (see audio.rs for more details).
|
||||
|
||||
# Legal things
|
||||
## Cryptography Notice
|
||||
|
||||
This distribution includes cryptographic software. The country in which you currently reside may have restrictions on the import, possession, use, and/or re-export to another country, of encryption software. BEFORE using any encryption software, please check your country's laws, regulations and policies concerning the import, possession, or use, and re-export of encryption software, to see if this is permitted. See <http://www.wassenaar.org/> for more information.
|
||||
|
||||
The U.S. Government Department of Commerce, Bureau of Industry and Security (BIS), has classified this software as Export Commodity Control Number (ECCN) 5D002.C.1, which includes information security software using or performing cryptographic functions with asymmetric algorithms. The form and manner of this distribution makes it eligible for export under the License Exception ENC Technology Software Unrestricted (TSU) exception (see the BIS Export Administration Regulations, Section 740.13) for both object code and source code.
|
||||
|
||||
## License
|
||||
|
||||
Copyright 2019-2021 Signal Messenger, LLC<br/>
|
||||
|
||||
Licensed under [AGPLv3](https://www.gnu.org/licenses/agpl-3.0.html) only.
|
||||
1
fuzz/.gitattributes
vendored
Normal file
1
fuzz/.gitattributes
vendored
Normal file
@ -0,0 +1 @@
|
||||
seeds/*/* binary
|
||||
5
fuzz/.gitignore
vendored
Normal file
5
fuzz/.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
Cargo.lock
|
||||
target
|
||||
corpus
|
||||
artifacts
|
||||
coverage
|
||||
59
fuzz/Cargo.toml
Normal file
59
fuzz/Cargo.toml
Normal file
@ -0,0 +1,59 @@
|
||||
|
||||
[package]
|
||||
name = "calling_server-fuzz"
|
||||
version = "0.0.0"
|
||||
authors = ["Automatically generated"]
|
||||
publish = false
|
||||
edition = "2018"
|
||||
|
||||
[package.metadata]
|
||||
cargo-fuzz = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
hex-literal = "0.3.2"
|
||||
libfuzzer-sys = "0.4"
|
||||
rand = "0.8"
|
||||
|
||||
[dependencies.calling_server]
|
||||
path = ".."
|
||||
|
||||
# Prevent this from interfering with workspaces
|
||||
[workspace]
|
||||
members = ["."]
|
||||
|
||||
[[bin]]
|
||||
name = "ice"
|
||||
path = "fuzz_targets/ice.rs"
|
||||
test = false
|
||||
doc = false
|
||||
|
||||
[[bin]]
|
||||
name = "vp8"
|
||||
path = "fuzz_targets/vp8.rs"
|
||||
test = false
|
||||
doc = false
|
||||
|
||||
[[bin]]
|
||||
name = "rtcp"
|
||||
path = "fuzz_targets/rtcp.rs"
|
||||
test = false
|
||||
doc = false
|
||||
|
||||
[[bin]]
|
||||
name = "transportcc"
|
||||
path = "fuzz_targets/transportcc.rs"
|
||||
test = false
|
||||
doc = false
|
||||
|
||||
[[bin]]
|
||||
name = "rtp"
|
||||
path = "fuzz_targets/rtp.rs"
|
||||
test = false
|
||||
doc = false
|
||||
|
||||
[[bin]]
|
||||
name = "dtls"
|
||||
path = "fuzz_targets/dtls.rs"
|
||||
test = false
|
||||
doc = false
|
||||
16
fuzz/README.md
Normal file
16
fuzz/README.md
Normal file
@ -0,0 +1,16 @@
|
||||
This directory contains fuzz targets used with `cargo fuzz`.
|
||||
|
||||
```
|
||||
// In the top-level source directory
|
||||
cargo install cargo-fuzz
|
||||
cargo fuzz list
|
||||
cargo +nightly fuzz run <fuzz-target>
|
||||
|
||||
// If you have custom seed inputs
|
||||
cargo +nightly fuzz run <fuzz-target> fuzz/corpus/<fuzz-target> fuzz/seeds/<fuzz-target>
|
||||
|
||||
// If you find a crash
|
||||
RUST_BACKTRACE=1 cargo +nightly fuzz run -D <fuzz-target> <crash-artifact>
|
||||
```
|
||||
|
||||
For more information, including how to check the coverage of the explored corpus, see <https://rust-fuzz.github.io>.
|
||||
BIN
fuzz/fuzz_targets/dtls-server-private-key.der
Normal file
BIN
fuzz/fuzz_targets/dtls-server-private-key.der
Normal file
Binary file not shown.
56
fuzz/fuzz_targets/dtls.rs
Normal file
56
fuzz/fuzz_targets/dtls.rs
Normal file
@ -0,0 +1,56 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
#![no_main]
|
||||
|
||||
use calling_server::*;
|
||||
use hex_literal::hex;
|
||||
use libfuzzer_sys::fuzz_target;
|
||||
use rand::{rngs::StdRng, CryptoRng, Rng, SeedableRng};
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
/// Generated DER from the sample private key in RFC 6979 A.2.5.
|
||||
static SERVER_PRIVATE_KEY_DER: &[u8] = include_bytes!("dtls-server-private-key.der");
|
||||
/// Client-provided entropy that matches the packet in seeds/dtls/certificate-with-prefixed-rng-seed
|
||||
static CLIENT_RANDOM: [u8; 32] =
|
||||
hex!("c84bd1aeaa4d7456e6c30aa3d514feb9f33afce48cab77a53d83d7d0b4bdb0b7");
|
||||
/// Client-provided fingerprint that matches the packet in seeds/dtls/certificate-with-prefixed-rng-seed
|
||||
static CLIENT_FINGERPRINT: [u8; 32] =
|
||||
hex!("74daf1fa3bbd8705527b5045ba20348d626d9cf002a7b468c30faf600f40f5f4");
|
||||
/// The timestamp that matches the packet in seeds/dtls/certificate-with-prefixed-rng-seed
|
||||
static NOW_MILLIS: u64 = 1625082261580;
|
||||
|
||||
fn random_dtls_state(
|
||||
now: SystemTime,
|
||||
rng_seed: u64,
|
||||
) -> (dtls::HandshakeState, impl Rng + CryptoRng) {
|
||||
// Make sure no one uses this RNG before HandshakeState::after_hello,
|
||||
// to make sure it's consistent with what's in seeds/dtls/certificate-with-prefixed-rng-seed.
|
||||
let mut rng = StdRng::seed_from_u64(rng_seed);
|
||||
// A seed of 0 was used to generate seeds/dtls/certificate-with-prefixed-rng-seed,
|
||||
// so we need to make sure that that seed tests certificate packet parsing.
|
||||
if rng_seed & 1 == 0 {
|
||||
let state = dtls::HandshakeState::after_hello(CLIENT_RANDOM, now, &mut rng);
|
||||
(state, rng)
|
||||
} else {
|
||||
(dtls::HandshakeState::new(), rng)
|
||||
}
|
||||
}
|
||||
|
||||
fuzz_target!(|input: (u64, &[u8])| {
|
||||
let (seed, data) = input;
|
||||
if dtls::looks_like_packet(data) {
|
||||
let now = SystemTime::UNIX_EPOCH + Duration::from_millis(NOW_MILLIS);
|
||||
let (mut state, mut rng) = random_dtls_state(now, seed);
|
||||
state.process_packet(
|
||||
data,
|
||||
&[],
|
||||
SERVER_PRIVATE_KEY_DER,
|
||||
&CLIENT_FINGERPRINT,
|
||||
now,
|
||||
&mut rng,
|
||||
);
|
||||
}
|
||||
});
|
||||
29
fuzz/fuzz_targets/ice.rs
Normal file
29
fuzz/fuzz_targets/ice.rs
Normal file
@ -0,0 +1,29 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
#![no_main]
|
||||
|
||||
use calling_server::{
|
||||
common::try_scoped,
|
||||
ice::{BindingRequest, VerifiedBindingRequest},
|
||||
};
|
||||
use libfuzzer_sys::fuzz_target;
|
||||
|
||||
fuzz_target!(|data: Vec<u8>| {
|
||||
let _ = try_scoped(|| {
|
||||
if BindingRequest::looks_like_header(&data) {
|
||||
let ice_binding_request = BindingRequest::parse(&data)?;
|
||||
let ice_request_username = ice_binding_request.username();
|
||||
let pwd = &[0u8; 20];
|
||||
|
||||
let _ = ice_binding_request.verify_hmac(pwd);
|
||||
|
||||
VerifiedBindingRequest::new_for_fuzzing(&ice_binding_request)
|
||||
.to_binding_response(&ice_request_username, pwd);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
});
|
||||
});
|
||||
16
fuzz/fuzz_targets/rtcp.rs
Normal file
16
fuzz/fuzz_targets/rtcp.rs
Normal file
@ -0,0 +1,16 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
#![no_main]
|
||||
|
||||
use calling_server::*;
|
||||
use libfuzzer_sys::fuzz_target;
|
||||
|
||||
fuzz_target!(|data: Vec<u8>| {
|
||||
let mut data = data;
|
||||
if rtp::looks_like_rtcp(&data) {
|
||||
rtp::parse_rtcp(&mut data);
|
||||
}
|
||||
});
|
||||
15
fuzz/fuzz_targets/rtp.rs
Normal file
15
fuzz/fuzz_targets/rtp.rs
Normal file
@ -0,0 +1,15 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
#![no_main]
|
||||
|
||||
use calling_server::*;
|
||||
use libfuzzer_sys::fuzz_target;
|
||||
|
||||
fuzz_target!(|data: Vec<u8>| {
|
||||
if rtp::looks_like_rtp(&data) {
|
||||
let _ = rtp::parse_and_forward_rtp_for_fuzzing(data);
|
||||
}
|
||||
});
|
||||
19
fuzz/fuzz_targets/transportcc.rs
Normal file
19
fuzz/fuzz_targets/transportcc.rs
Normal file
@ -0,0 +1,19 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
#![no_main]
|
||||
|
||||
use calling_server::*;
|
||||
use libfuzzer_sys::fuzz_target;
|
||||
|
||||
fuzz_target!(|input: (&[u8], transportcc::FullSequenceNumber)| {
|
||||
let data = input.0;
|
||||
let mut seqnum = input.1;
|
||||
// Assume we never get close to the max seqnum
|
||||
if seqnum > u64::MAX - u32::MAX as u64 {
|
||||
return
|
||||
}
|
||||
let _ = transportcc::read_feedback(data, &mut seqnum);
|
||||
});
|
||||
13
fuzz/fuzz_targets/vp8.rs
Normal file
13
fuzz/fuzz_targets/vp8.rs
Normal file
@ -0,0 +1,13 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
#![no_main]
|
||||
|
||||
use calling_server::*;
|
||||
use libfuzzer_sys::fuzz_target;
|
||||
|
||||
fuzz_target!(|data: &[u8]| {
|
||||
vp8::ParsedHeader::read(data).ok();
|
||||
});
|
||||
BIN
fuzz/seeds/dtls/certificate-with-prefixed-rng-seed
Normal file
BIN
fuzz/seeds/dtls/certificate-with-prefixed-rng-seed
Normal file
Binary file not shown.
BIN
fuzz/seeds/dtls/hello-with-prefixed-rng-seed
Normal file
BIN
fuzz/seeds/dtls/hello-with-prefixed-rng-seed
Normal file
Binary file not shown.
BIN
fuzz/seeds/rtcp/bye
Normal file
BIN
fuzz/seeds/rtcp/bye
Normal file
Binary file not shown.
BIN
fuzz/seeds/rtcp/generic-feedback-nack
Normal file
BIN
fuzz/seeds/rtcp/generic-feedback-nack
Normal file
Binary file not shown.
BIN
fuzz/seeds/rtcp/generic-feedback-transportcc
Normal file
BIN
fuzz/seeds/rtcp/generic-feedback-transportcc
Normal file
Binary file not shown.
BIN
fuzz/seeds/rtcp/receiver-report
Normal file
BIN
fuzz/seeds/rtcp/receiver-report
Normal file
Binary file not shown.
BIN
fuzz/seeds/rtcp/sdes
Normal file
BIN
fuzz/seeds/rtcp/sdes
Normal file
Binary file not shown.
BIN
fuzz/seeds/rtcp/sender-report
Normal file
BIN
fuzz/seeds/rtcp/sender-report
Normal file
Binary file not shown.
BIN
fuzz/seeds/rtcp/specific-feedback-format-loss
Normal file
BIN
fuzz/seeds/rtcp/specific-feedback-format-loss
Normal file
Binary file not shown.
BIN
fuzz/seeds/rtcp/specific-feedback-pli
Normal file
BIN
fuzz/seeds/rtcp/specific-feedback-pli
Normal file
Binary file not shown.
BIN
fuzz/seeds/rtp/ext-audio-level
Normal file
BIN
fuzz/seeds/rtp/ext-audio-level
Normal file
Binary file not shown.
BIN
fuzz/seeds/rtp/ext-tcc-seqnum
Normal file
BIN
fuzz/seeds/rtp/ext-tcc-seqnum
Normal file
Binary file not shown.
2
fuzz/seeds/rtp/rtx-vp8
Normal file
2
fuzz/seeds/rtp/rtx-vp8
Normal file
@ -0,0 +1,2 @@
|
||||
€v
|
||||
|
||||
BIN
fuzz/seeds/rtp/rtx-vp8-ext-tcc-seqnum
Normal file
BIN
fuzz/seeds/rtp/rtx-vp8-ext-tcc-seqnum
Normal file
Binary file not shown.
1
rust-toolchain
Normal file
1
rust-toolchain
Normal file
@ -0,0 +1 @@
|
||||
1.53.0
|
||||
0
rustfmt.toml
Normal file
0
rustfmt.toml
Normal file
308
src/audio.rs
Normal file
308
src/audio.rs
Normal file
@ -0,0 +1,308 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::{cmp::min, default::Default};
|
||||
|
||||
use log::*;
|
||||
|
||||
use crate::common::{count_in_chunks_exact, Duration, RingBuffer};
|
||||
|
||||
// Higher is louder
|
||||
pub type Level = u8;
|
||||
|
||||
// TODO: Consider rewriting this in the googcc async stream style.
|
||||
#[derive(Clone, Copy, Default)]
|
||||
struct LevelFloorTracker {
|
||||
floor: Option<Level>,
|
||||
floor_since_reset: Option<Level>,
|
||||
samples_since_reset: u32,
|
||||
}
|
||||
|
||||
impl LevelFloorTracker {
|
||||
const RECALCULCATION_INTERVAL: Duration = Duration::from_secs(15);
|
||||
const ASSUMED_SAMPLE_DURATION: Duration = Duration::from_millis(20);
|
||||
const SAMPLES_PER_RECALCULATION: u32 = (Self::RECALCULCATION_INTERVAL.as_millis()
|
||||
/ Self::ASSUMED_SAMPLE_DURATION.as_millis())
|
||||
as u32;
|
||||
|
||||
fn reset(floor: Level) -> Self {
|
||||
Self {
|
||||
floor: Some(floor),
|
||||
..Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn get(self) -> Option<Level> {
|
||||
self.floor
|
||||
}
|
||||
|
||||
fn update(self, sample: Level) -> Self {
|
||||
if sample == 0 {
|
||||
// We ignore 0 levels for calculation.
|
||||
// The input is likely muted and would throw off the unmute value.
|
||||
return self;
|
||||
}
|
||||
|
||||
if self.floor.is_none() {
|
||||
// We treat our first sample as the initial floor.
|
||||
return Self::reset(sample);
|
||||
}
|
||||
let floor = self.floor.unwrap();
|
||||
|
||||
if sample < floor {
|
||||
// Any time we get a sample below the floor, immediately drop to
|
||||
// that level as if it were the first value.
|
||||
return Self::reset(sample);
|
||||
}
|
||||
|
||||
if self.floor_since_reset.is_none() {
|
||||
// Our first value since a reset becomes our new floor_since_reset.
|
||||
return Self {
|
||||
floor: Some(floor),
|
||||
floor_since_reset: Some(sample),
|
||||
samples_since_reset: 1,
|
||||
};
|
||||
}
|
||||
let floor_since_reset = min(sample, self.floor_since_reset.unwrap());
|
||||
let samples_since_reset = self.samples_since_reset + 1;
|
||||
|
||||
// We have enough samples to trigger an average and reset.
|
||||
// This slowly creeps up the floor if it increases over time.
|
||||
if samples_since_reset >= Self::SAMPLES_PER_RECALCULATION {
|
||||
let average_floor = ((floor as f32) * (floor_since_reset as f32)).sqrt() as Level;
|
||||
return Self::reset(average_floor);
|
||||
}
|
||||
|
||||
// We don't have enough samples to recalculate, so track the state until we do.
|
||||
Self {
|
||||
floor: Some(floor),
|
||||
floor_since_reset: Some(floor_since_reset),
|
||||
samples_since_reset,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Based on the "dominant speaker identification" algorithm found at
|
||||
// https://github.com/jitsi/jitsi-utils/blob/master/src/main/java/org/jitsi/utils/dsi/DominantSpeakerIdentification.java
|
||||
// which is based on the paper "Dominant Speaker Identification for Multipoint Videoconferencing"
|
||||
// by Ilana Volfin and Israel Cohen found at
|
||||
// https://israelcohen.com/wp-content/uploads/2018/05/IEEEI2012_Volfin.pdf
|
||||
// Although this code does much less math, it should produce the same results,
|
||||
// at least for the range of audio levels 0-127.
|
||||
pub struct LevelsTracker {
|
||||
floor: LevelFloorTracker,
|
||||
levels: RingBuffer<Level>,
|
||||
}
|
||||
|
||||
impl Default for LevelsTracker {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
floor: LevelFloorTracker::default(),
|
||||
levels: RingBuffer::new(50),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LevelsTracker {
|
||||
pub fn push(&mut self, mut sample: Level) {
|
||||
self.floor = self.floor.update(sample);
|
||||
let threshold = self.floor.get().unwrap_or(0) + 10;
|
||||
// Treat anything near the floor as 0
|
||||
// for the purposes of tracking levels
|
||||
// and which levels are more active.
|
||||
if sample <= threshold {
|
||||
sample = 0;
|
||||
}
|
||||
self.levels.push(sample);
|
||||
}
|
||||
|
||||
fn iter_latest_first(&self) -> impl Iterator<Item = Level> + '_ {
|
||||
self.levels.iter().rev().copied()
|
||||
}
|
||||
|
||||
fn latest(&self) -> Option<Level> {
|
||||
self.iter_latest_first().next()
|
||||
}
|
||||
|
||||
fn count_latest_chunk_above_threshold(&self, n: usize, threshold: Level) -> usize {
|
||||
self.iter_latest_first()
|
||||
.take(n)
|
||||
.filter(|level| *level > threshold)
|
||||
.count()
|
||||
}
|
||||
|
||||
fn count_chunks_above_threshold(&self, chunk_size: usize, threshold: Level) -> usize {
|
||||
// Here is how it would read if we could group iterators or copy:
|
||||
// self.iter_latest_first().collect::<Vec<_>>()
|
||||
// .chunks_exact(chunk_size)
|
||||
// .filter(|chunk| chunk.iter().all(|level| *level > threshold))
|
||||
// .count()
|
||||
count_in_chunks_exact(
|
||||
self.iter_latest_first().map(|level| level > threshold),
|
||||
chunk_size,
|
||||
)
|
||||
.filter(|high_count| *high_count == chunk_size)
|
||||
.count()
|
||||
}
|
||||
|
||||
pub fn more_active_than_most_active(&self, most_active: &LevelsTracker) -> bool {
|
||||
const HIGH: Level = 70;
|
||||
const LOW: Level = 40;
|
||||
const CHUNK_SIZE: usize = 5;
|
||||
|
||||
if self.latest().unwrap_or(0) <= LOW {
|
||||
trace!(
|
||||
"The contender isn't active enough (latest sample = {:?})",
|
||||
self.latest()
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
if most_active.latest().unwrap_or(0) >= LOW {
|
||||
trace!("The most active is still active (latest sample)");
|
||||
return false;
|
||||
}
|
||||
|
||||
let self_first_chunk = self.count_latest_chunk_above_threshold(CHUNK_SIZE, HIGH);
|
||||
|
||||
if self_first_chunk < CHUNK_SIZE {
|
||||
trace!("The contender isn't active enough (latest chunk)");
|
||||
// We're not active enough.
|
||||
return false;
|
||||
}
|
||||
|
||||
let most_active_first_chunk =
|
||||
most_active.count_latest_chunk_above_threshold(CHUNK_SIZE, HIGH);
|
||||
|
||||
if most_active_first_chunk > 0 {
|
||||
trace!("The most active is still active (latest chunk)");
|
||||
// The most active is too active.
|
||||
return false;
|
||||
}
|
||||
|
||||
if self_first_chunk < most_active_first_chunk {
|
||||
trace!("The most active is more active (first chunk)");
|
||||
// We can't compete with the most active
|
||||
return false;
|
||||
}
|
||||
|
||||
let self_high_chunks = self.count_chunks_above_threshold(CHUNK_SIZE, HIGH);
|
||||
let most_active_high_chunks = most_active.count_chunks_above_threshold(CHUNK_SIZE, HIGH);
|
||||
|
||||
if self_high_chunks <= most_active_high_chunks {
|
||||
trace!("The most active is more active (number of active chunk)");
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_audio_noise_floor_tracker() {
|
||||
let floor = LevelFloorTracker::default();
|
||||
assert_eq!(None, floor.get());
|
||||
|
||||
let floor = floor.update(10);
|
||||
assert_eq!(Some(10), floor.get());
|
||||
|
||||
// Not enough to do another reset
|
||||
let mut floor = floor;
|
||||
for i in 0..(749u16) {
|
||||
floor = floor.update(20 + (i % 20) as Level);
|
||||
}
|
||||
assert_eq!(Some(10), floor.get());
|
||||
|
||||
// Now enough to do another reset
|
||||
let floor = floor.update(20);
|
||||
assert_eq!(Some(14), floor.get());
|
||||
|
||||
// And another
|
||||
let mut floor = floor;
|
||||
for i in 0..(750u16) {
|
||||
floor = floor.update(20 + (i % 20) as Level);
|
||||
}
|
||||
assert_eq!(Some(16), floor.get());
|
||||
|
||||
// Another low value pushes it back down
|
||||
let floor = floor.update(12);
|
||||
assert_eq!(Some(12), floor.get());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_audio_activity() {
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
|
||||
let mut most_active = LevelsTracker::default();
|
||||
let mut contender = LevelsTracker::default();
|
||||
|
||||
assert!(!contender.more_active_than_most_active(&most_active));
|
||||
|
||||
// Establishes the noise floor
|
||||
contender.push(60);
|
||||
// Shows activity
|
||||
for _ in 0..4 {
|
||||
contender.push(80);
|
||||
}
|
||||
// Not quite active enough yet
|
||||
assert!(!contender.more_active_than_most_active(&most_active));
|
||||
|
||||
// OK, now we have enough
|
||||
contender.push(80);
|
||||
assert!(contender.more_active_than_most_active(&most_active));
|
||||
|
||||
// Not any more, though
|
||||
contender.push(70);
|
||||
assert!(!contender.more_active_than_most_active(&most_active));
|
||||
|
||||
// OK, active enough again
|
||||
for _ in 0..5 {
|
||||
contender.push(80);
|
||||
}
|
||||
assert!(contender.more_active_than_most_active(&most_active));
|
||||
|
||||
// But not if the most active is active again
|
||||
most_active.push(50); // Establishes noise floor
|
||||
assert!(contender.more_active_than_most_active(&most_active));
|
||||
most_active.push(60); // Not yet above noise floor
|
||||
assert!(contender.more_active_than_most_active(&most_active));
|
||||
most_active.push(80); // Now it is
|
||||
assert!(!contender.more_active_than_most_active(&most_active));
|
||||
|
||||
// If it goes inactive a little, that's not enough.
|
||||
most_active.push(50);
|
||||
assert!(!contender.more_active_than_most_active(&most_active));
|
||||
|
||||
// But if it's inactive a long time, we win.
|
||||
for _ in 0..4 {
|
||||
most_active.push(50);
|
||||
}
|
||||
assert!(contender.more_active_than_most_active(&most_active));
|
||||
|
||||
// Unless it was also active even longer ago. Then it's harder to dislodge.
|
||||
for _ in 0..5 {
|
||||
most_active.push(80);
|
||||
}
|
||||
for _ in 0..5 {
|
||||
most_active.push(50);
|
||||
}
|
||||
assert!(!contender.more_active_than_most_active(&most_active));
|
||||
|
||||
assert_eq!(1, most_active.count_chunks_above_threshold(5, 70));
|
||||
assert_eq!(1, contender.count_chunks_above_threshold(5, 70));
|
||||
|
||||
// But it is possible with enough activity
|
||||
for _ in 0..5 {
|
||||
contender.push(80);
|
||||
}
|
||||
assert_eq!(1, most_active.count_chunks_above_threshold(5, 70));
|
||||
assert_eq!(2, contender.count_chunks_above_threshold(5, 70));
|
||||
assert!(contender.more_active_than_most_active(&most_active));
|
||||
}
|
||||
}
|
||||
3301
src/call.rs
Normal file
3301
src/call.rs
Normal file
File diff suppressed because it is too large
Load Diff
444
src/common.rs
Normal file
444
src/common.rs
Normal file
@ -0,0 +1,444 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
//! Common functionality for ice, rtp, rtcp, dtls, or googcc.
|
||||
|
||||
mod bits;
|
||||
mod bytes_reader;
|
||||
mod collections;
|
||||
mod counters;
|
||||
mod data_rate;
|
||||
mod integers;
|
||||
mod math;
|
||||
mod serialize;
|
||||
mod time;
|
||||
|
||||
use std::{cmp::PartialEq, convert::TryInto, fmt::Write};
|
||||
|
||||
use anyhow::Result;
|
||||
pub use bits::*;
|
||||
pub use bytes_reader::*;
|
||||
pub use collections::*;
|
||||
pub use counters::*;
|
||||
pub use data_rate::*;
|
||||
use hex::FromHex;
|
||||
pub use integers::*;
|
||||
pub use math::*;
|
||||
use rand::{thread_rng, Rng};
|
||||
pub use serialize::*;
|
||||
pub use time::*;
|
||||
|
||||
// It's (value, rest)
|
||||
// TODO: Change to Result
|
||||
pub type ReadOption<'a, T> = Option<(T, &'a [u8])>;
|
||||
|
||||
pub fn read_u16_len_prefixed_u16s(input: &[u8]) -> ReadOption<Vec<u16>> {
|
||||
let (bytes, rest) = read_u16_len_prefixed(input)?;
|
||||
let (values, _) = read_n(bytes, bytes.len() / 2, read_u16)?;
|
||||
Some((values, rest))
|
||||
}
|
||||
|
||||
pub fn read_u24_len_prefixed(input: &[u8]) -> ReadOption<&[u8]> {
|
||||
let (len, rest) = read_u24(input)?;
|
||||
let (bytes, rest) = read_bytes(rest, len.into())?;
|
||||
Some((bytes, rest))
|
||||
}
|
||||
|
||||
pub fn read_u16_len_prefixed(input: &[u8]) -> ReadOption<&[u8]> {
|
||||
let (len, rest) = read_u16(input)?;
|
||||
let (bytes, rest) = read_bytes(rest, len as usize)?;
|
||||
Some((bytes, rest))
|
||||
}
|
||||
|
||||
pub fn read_u8_len_prefixed(input: &[u8]) -> ReadOption<&[u8]> {
|
||||
let (len, rest) = read_u8(input)?;
|
||||
let (bytes, rest) = read_bytes(rest, len as usize)?;
|
||||
Some((bytes, rest))
|
||||
}
|
||||
|
||||
pub fn read_u48(input: &[u8]) -> ReadOption<U48> {
|
||||
let (bytes, rest) = read_bytes(input, 6)?;
|
||||
Some((parse_u48(bytes), rest))
|
||||
}
|
||||
|
||||
pub fn read_u32(input: &[u8]) -> ReadOption<u32> {
|
||||
let (bytes, rest) = read_bytes(input, 4)?;
|
||||
Some((parse_u32(bytes), rest))
|
||||
}
|
||||
|
||||
pub fn read_u24(input: &[u8]) -> ReadOption<U24> {
|
||||
let (bytes, rest) = read_bytes(input, 3)?;
|
||||
Some((parse_u24(bytes), rest))
|
||||
}
|
||||
|
||||
pub fn read_u16(input: &[u8]) -> ReadOption<u16> {
|
||||
let (bytes, rest) = read_bytes(input, 2)?;
|
||||
Some((parse_u16(bytes), rest))
|
||||
}
|
||||
|
||||
pub fn read_i16(input: &[u8]) -> ReadOption<i16> {
|
||||
let (bytes, rest) = read_bytes(input, 2)?;
|
||||
Some((parse_i16(bytes), rest))
|
||||
}
|
||||
|
||||
pub fn read_u8(input: &[u8]) -> ReadOption<u8> {
|
||||
let (bytes, rest) = read_bytes(input, 1)?;
|
||||
Some((bytes[0], rest))
|
||||
}
|
||||
|
||||
pub fn read_n<'a, T>(
|
||||
mut input: &'a [u8],
|
||||
n: usize,
|
||||
read_one: impl Fn(&'a [u8]) -> ReadOption<'a, T>,
|
||||
) -> ReadOption<'a, Vec<T>> {
|
||||
let mut values = Vec::with_capacity(n);
|
||||
for _ in 0..n {
|
||||
let (val, rest) = read_one(input)?;
|
||||
values.push(val);
|
||||
input = rest;
|
||||
}
|
||||
Some((values, input))
|
||||
}
|
||||
|
||||
pub fn read_as_many_as_possible<'a, T>(
|
||||
mut input: &'a [u8],
|
||||
read_one: impl Fn(&'a [u8]) -> ReadOption<'a, T>,
|
||||
) -> ReadOption<'a, Vec<T>> {
|
||||
let mut values = vec![];
|
||||
while !input.is_empty() {
|
||||
let (val, rest) = read_one(input)?;
|
||||
values.push(val);
|
||||
input = rest;
|
||||
}
|
||||
Some((values, input))
|
||||
}
|
||||
|
||||
// Returns (read, rest)
|
||||
pub fn read_bytes(input: &[u8], len: usize) -> ReadOption<&[u8]> {
|
||||
let bytes = input.get(0..len)?;
|
||||
let rest = &input[len..];
|
||||
Some((bytes, rest))
|
||||
}
|
||||
|
||||
pub fn read_from_end(input: &[u8], len: usize) -> ReadOption<&[u8]> {
|
||||
if input.len() < len {
|
||||
return None;
|
||||
}
|
||||
let (before, after) = read_bytes(input, input.len() - len)?;
|
||||
Some((after, before))
|
||||
}
|
||||
|
||||
pub fn parse_u16(bytes: &[u8]) -> u16 {
|
||||
u16::from_be_bytes(bytes[0..2].try_into().unwrap())
|
||||
}
|
||||
|
||||
pub fn parse_u16_le(bytes: &[u8]) -> u16 {
|
||||
u16::from_le_bytes(bytes[0..2].try_into().unwrap())
|
||||
}
|
||||
|
||||
pub fn parse_i16(bytes: &[u8]) -> i16 {
|
||||
i16::from_be_bytes(bytes[0..2].try_into().unwrap())
|
||||
}
|
||||
|
||||
pub fn parse_u24(bytes: &[u8]) -> U24 {
|
||||
U24::from_be_bytes(bytes[0..U24::SIZE].try_into().unwrap())
|
||||
}
|
||||
|
||||
pub fn parse_u32(bytes: &[u8]) -> u32 {
|
||||
u32::from_be_bytes(bytes[0..4].try_into().unwrap())
|
||||
}
|
||||
|
||||
pub fn parse_u48(bytes: &[u8]) -> U48 {
|
||||
U48::from_be_bytes(bytes[0..U48::SIZE].try_into().unwrap())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod parse_tests {
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_48() {
|
||||
assert_eq!(
|
||||
U48::try_from(0x010203040506u64).unwrap(),
|
||||
parse_u48(vec![1, 2, 3, 4, 5, 6].as_slice())
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_24() {
|
||||
assert_eq!(
|
||||
U24::try_from(0x010203u32).unwrap(),
|
||||
parse_u24(vec![1, 2, 3].as_slice())
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn random_hex_string(n: usize) -> String {
|
||||
const HEXCHARSET: &[u8] = b"abcdef0123456789";
|
||||
|
||||
let mut rng = thread_rng();
|
||||
let string: String = (0..n)
|
||||
.map(|_| {
|
||||
let index = rng.gen_range(0..HEXCHARSET.len());
|
||||
HEXCHARSET[index] as char
|
||||
})
|
||||
.collect();
|
||||
string
|
||||
}
|
||||
|
||||
/// Const generic expressions may replace this in future, but for now we must have a macro
|
||||
macro_rules! random_base64_string_of_length {
|
||||
($string_length:expr) => {{
|
||||
base64::encode(thread_rng().gen::<[u8; $string_length * 6 / 8]>())
|
||||
}};
|
||||
}
|
||||
|
||||
/// Create a random Base64 string of length 32.
|
||||
/// ```
|
||||
/// # use calling_server::common::random_base64_string_of_length_32;
|
||||
///
|
||||
/// let string = random_base64_string_of_length_32();
|
||||
/// assert_eq!(32, string.len());
|
||||
///
|
||||
/// # let string2 = random_base64_string_of_length_32();
|
||||
/// # assert_ne!(string, string2);
|
||||
/// ```
|
||||
pub fn random_base64_string_of_length_32() -> String {
|
||||
random_base64_string_of_length!(32)
|
||||
}
|
||||
|
||||
/// Create a random Base64 string of length 4.
|
||||
/// ```
|
||||
/// # use calling_server::common::random_base64_string_of_length_4;
|
||||
///
|
||||
/// let string = random_base64_string_of_length_4();
|
||||
/// assert_eq!(4, string.len());
|
||||
///
|
||||
/// # let string2 = random_base64_string_of_length_4();
|
||||
/// # assert_ne!(string, string2);
|
||||
/// ```
|
||||
pub fn random_base64_string_of_length_4() -> String {
|
||||
random_base64_string_of_length!(4)
|
||||
}
|
||||
|
||||
/// Encodes a slice of bytes as a hexadecimal fingerprint string.
|
||||
///
|
||||
/// ```
|
||||
/// use calling_server::common::bytes_to_colon_separated_hexstring;
|
||||
///
|
||||
/// assert_eq!(bytes_to_colon_separated_hexstring(&[]), "");
|
||||
/// assert_eq!(bytes_to_colon_separated_hexstring(&[0x01, 0xAB, 0xCD]), "01:AB:CD");
|
||||
/// ```
|
||||
pub fn bytes_to_colon_separated_hexstring(bytes: &[u8]) -> String {
|
||||
let mut result = String::with_capacity(bytes.len() * 3);
|
||||
for byte in bytes {
|
||||
write!(&mut result, "{:02X}:", byte).expect("Should be safe to write to String");
|
||||
}
|
||||
if !result.is_empty() {
|
||||
// Remove the extra colon.
|
||||
result.pop();
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Decodes hexadecimal fingerprint string to a u8 array of 32 bytes.
|
||||
///
|
||||
/// ```
|
||||
/// use calling_server::common::colon_separated_hexstring_to_array;
|
||||
///
|
||||
/// assert_eq!(colon_separated_hexstring_to_array(
|
||||
/// "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00").unwrap(),
|
||||
/// [0u8; 32]);
|
||||
/// assert_eq!(colon_separated_hexstring_to_array(
|
||||
/// "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:").unwrap(),
|
||||
/// [0u8; 32]);
|
||||
/// assert!(colon_separated_hexstring_to_array("").is_err());
|
||||
/// assert!(colon_separated_hexstring_to_array(":").is_err());
|
||||
/// assert!(colon_separated_hexstring_to_array("01").is_err());
|
||||
/// assert!(colon_separated_hexstring_to_array("01:AB:CD").is_err());
|
||||
/// assert!(colon_separated_hexstring_to_array("01:AB:CD:").is_err());
|
||||
/// assert!(colon_separated_hexstring_to_array("1:A:B").is_err());
|
||||
/// assert!(colon_separated_hexstring_to_array("01:AB:B").is_err());
|
||||
/// assert!(colon_separated_hexstring_to_array("01:AB:B:").is_err());
|
||||
/// ```
|
||||
pub fn colon_separated_hexstring_to_array(string: &str) -> Result<[u8; 32]> {
|
||||
let string = string.replace(":", "");
|
||||
let result = <[u8; 32]>::from_hex(string)?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Allows using `?` syntax in a scope and collecting failures in a `Result`.
|
||||
pub fn try_scoped<T>(call: impl FnOnce() -> anyhow::Result<T>) -> anyhow::Result<T> {
|
||||
call()
|
||||
}
|
||||
|
||||
// Can be used for video resolution
|
||||
#[derive(Clone, Copy, Default, Debug, PartialEq, Eq)]
|
||||
pub struct PixelSize {
|
||||
pub width: u16,
|
||||
pub height: u16,
|
||||
}
|
||||
|
||||
/// Number of pixels
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Copy, Ord, PartialOrd, Default)]
|
||||
pub struct VideoHeight(u16);
|
||||
|
||||
impl From<u16> for VideoHeight {
|
||||
fn from(height: u16) -> Self {
|
||||
VideoHeight(height)
|
||||
}
|
||||
}
|
||||
|
||||
impl VideoHeight {
|
||||
pub fn as_u16(self) -> u16 {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
// Values beyond a multiple of chunk_size are ignored.
|
||||
// Panics with a chunk_size of 0.
|
||||
// Just like https://doc.rust-lang.org/beta/std/primitive.slice.html#method.chunks_exact
|
||||
pub fn count_in_chunks_exact(
|
||||
inputs: impl Iterator<Item = bool>,
|
||||
chunk_size: usize,
|
||||
) -> impl Iterator<Item = usize> {
|
||||
fold_in_chunks_exact(
|
||||
inputs,
|
||||
chunk_size,
|
||||
|| 0usize,
|
||||
|count, bit| count + (bit as usize),
|
||||
)
|
||||
}
|
||||
|
||||
// Values beyond a multiple of chunk_size are ignored.
|
||||
// Panics with a chunk_size of 0.
|
||||
// Just like https://doc.rust-lang.org/beta/std/primitive.slice.html#method.chunks_exact
|
||||
pub fn fold_in_chunks_exact<Input, Output, Init, Acc>(
|
||||
mut inputs: impl Iterator<Item = Input>,
|
||||
chunk_size: usize,
|
||||
init: Init,
|
||||
acc: Acc,
|
||||
) -> impl Iterator<Item = Output>
|
||||
where
|
||||
Init: Fn() -> Output,
|
||||
Acc: Fn(Output, Input) -> Output,
|
||||
{
|
||||
assert!(chunk_size != 0);
|
||||
std::iter::from_fn(move || {
|
||||
let mut output = init();
|
||||
for _ in 0..chunk_size {
|
||||
let input = inputs.next()?;
|
||||
output = acc(output, input);
|
||||
}
|
||||
Some(output)
|
||||
})
|
||||
}
|
||||
|
||||
pub trait CheckedSplitAt {
|
||||
fn checked_split_at(&self, mid: usize) -> Option<(&[u8], &[u8])>;
|
||||
}
|
||||
|
||||
impl CheckedSplitAt for [u8] {
|
||||
fn checked_split_at(&self, mid: usize) -> Option<(&[u8], &[u8])> {
|
||||
if self.len() < mid {
|
||||
None
|
||||
} else {
|
||||
Some(self.split_at(mid))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CheckedSplitAtMut {
|
||||
fn checked_split_at_mut(&mut self, mid: usize) -> Option<(&mut [u8], &mut [u8])>;
|
||||
}
|
||||
|
||||
impl CheckedSplitAtMut for [u8] {
|
||||
fn checked_split_at_mut(&mut self, mid: usize) -> Option<(&mut [u8], &mut [u8])> {
|
||||
if self.len() < mid {
|
||||
None
|
||||
} else {
|
||||
Some(self.split_at_mut(mid))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
lazy_static::lazy_static! {
|
||||
pub(crate) static ref RANDOM_SEED_FOR_TESTS: u64 = {
|
||||
let seed = match std::env::var("RANDOM_SEED") {
|
||||
Ok(v) => v.parse().unwrap(),
|
||||
Err(_) => thread_rng().gen(),
|
||||
};
|
||||
|
||||
println!("\n*** Using RANDOM_SEED={}", seed);
|
||||
seed
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_count_in_groups_exact() {
|
||||
let vals: Vec<bool> = vec![];
|
||||
assert_eq!(
|
||||
vec![0usize; 0],
|
||||
count_in_chunks_exact(vals.iter().copied(), 1).collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
let vals = [true, false, false, true, true, false, false];
|
||||
assert_eq!(
|
||||
vec![1, 0, 0, 1, 1, 0, 0],
|
||||
count_in_chunks_exact(vals.iter().copied(), 1).collect::<Vec<_>>()
|
||||
);
|
||||
assert_eq!(
|
||||
vec![1, 1, 1],
|
||||
count_in_chunks_exact(vals.iter().copied(), 2).collect::<Vec<_>>()
|
||||
);
|
||||
assert_eq!(
|
||||
vec![1, 2],
|
||||
count_in_chunks_exact(vals.iter().copied(), 3).collect::<Vec<_>>()
|
||||
);
|
||||
assert_eq!(
|
||||
vec![3],
|
||||
count_in_chunks_exact(vals.iter().copied(), 5).collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checked_split_at() {
|
||||
assert_eq!(Some((&b""[..], &b"ab"[..])), b"ab".checked_split_at(0));
|
||||
assert_eq!(Some((&b"a"[..], &b"b"[..])), b"ab".checked_split_at(1));
|
||||
assert_eq!(Some((&b"ab"[..], &b""[..])), b"ab".checked_split_at(2));
|
||||
assert_eq!(None, b"ab".checked_split_at(3));
|
||||
assert_eq!(None, b"ab".checked_split_at(30));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::clone_on_copy)]
|
||||
fn test_checked_split_at_mut() {
|
||||
let mut empty = [];
|
||||
let mut zero = [0];
|
||||
let mut one = [1];
|
||||
let mut zero_one = [0, 1];
|
||||
assert_eq!(
|
||||
Some((&mut empty[..], &mut zero_one.clone()[..])),
|
||||
zero_one.checked_split_at_mut(0)
|
||||
);
|
||||
assert_eq!(
|
||||
Some((&mut zero[..], &mut one[..])),
|
||||
zero_one.checked_split_at_mut(1)
|
||||
);
|
||||
assert_eq!(
|
||||
Some((&mut zero_one.clone()[..], &mut empty[..])),
|
||||
zero_one.checked_split_at_mut(2)
|
||||
);
|
||||
assert_eq!(None, zero_one.checked_split_at_mut(3));
|
||||
assert_eq!(None, zero_one.checked_split_at_mut(30));
|
||||
}
|
||||
}
|
||||
239
src/common/bits.rs
Normal file
239
src/common/bits.rs
Normal file
@ -0,0 +1,239 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::ops::{BitAnd, BitOr, Shl, Shr};
|
||||
|
||||
pub trait Bits: Sized + Copy {
|
||||
const BIT_WIDTH: u8 = (std::mem::size_of::<Self>() * 8) as u8;
|
||||
|
||||
/// Returns true iff the bit at the index is one.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `index` - The 0 based index starting at the most significant bit.
|
||||
fn ms_bit(self, index: u8) -> bool;
|
||||
|
||||
/// Sets the bit to one at the index.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `index` - The 0 based index starting at the most significant bit.
|
||||
fn set_ms_bit(self, index: u8) -> Self;
|
||||
|
||||
/// Returns true iff the bit at the index is one.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `index` - The 0 based index starting at the least significant bit.
|
||||
fn ls_bit(self, index: u8) -> bool;
|
||||
|
||||
/// Sets the bit to one at the index.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `index` - The 0 based index starting at the least significant bit.
|
||||
fn set_ls_bit(self, index: u8) -> Self;
|
||||
}
|
||||
|
||||
impl<T> Bits for T
|
||||
where
|
||||
T: Copy
|
||||
+ Shr<u8, Output = T>
|
||||
+ Shl<u8, Output = T>
|
||||
+ BitAnd<T, Output = T>
|
||||
+ BitOr<T, Output = T>
|
||||
+ From<u8>
|
||||
+ Eq,
|
||||
{
|
||||
fn ms_bit(self, index: u8) -> bool {
|
||||
assert!(index < Self::BIT_WIDTH);
|
||||
|
||||
self >> (Self::BIT_WIDTH - index - 1) & T::from(1) == T::from(1)
|
||||
}
|
||||
|
||||
fn set_ms_bit(self, index: u8) -> Self {
|
||||
assert!(index < Self::BIT_WIDTH);
|
||||
|
||||
self | T::from(1) << (Self::BIT_WIDTH - index - 1)
|
||||
}
|
||||
|
||||
fn ls_bit(self, index: u8) -> bool {
|
||||
assert!(index < Self::BIT_WIDTH);
|
||||
|
||||
self >> index & T::from(1) == T::from(1)
|
||||
}
|
||||
|
||||
fn set_ls_bit(self, index: u8) -> Self {
|
||||
assert!(index < Self::BIT_WIDTH);
|
||||
|
||||
self | T::from(1) << index
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod msb_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn is_set_leading_ms_bit_u8() {
|
||||
assert!(0b1111_1111u8.ms_bit(0));
|
||||
assert!(!0b0111_1111u8.ms_bit(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_set_trailing_ms_bit_u8() {
|
||||
assert!(0b1111_1111u8.ms_bit(7));
|
||||
assert!(!0b1111_1110u8.ms_bit(7));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn get_panics_when_over_bit_length_u8() {
|
||||
0u8.ms_bit(8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn set_panics_when_over_bit_length_u8() {
|
||||
0u8.set_ms_bit(8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn get_panics_when_over_bit_length_u32() {
|
||||
0u32.ms_bit(32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_set_leading_ms_bit_u32() {
|
||||
assert!(0b1111_1111_1111_1111_1111_1111_1111_1111u32.ms_bit(0));
|
||||
assert!(!0b0111_1111_1111_1111_1111_1111_1111_1111u32.ms_bit(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_set_trailing_ms_bit_u32() {
|
||||
assert!(0b1111_1111_1111_1111_1111_1111_1111_1111u32.ms_bit(31));
|
||||
assert!(!0b1111_1111_1111_1111_1111_1111_1111_1110u32.ms_bit(31));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_same_bit_twice() {
|
||||
let byte = 0b0000_0000u8.set_ms_bit(0);
|
||||
assert_eq!(0b1000_0000, byte);
|
||||
let byte = byte.set_ms_bit(0);
|
||||
assert_eq!(0b1000_0000, byte);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_bits_u8() {
|
||||
let byte = 0b0000_0000u8.set_ms_bit(0);
|
||||
assert_eq!(0b1000_0000, byte);
|
||||
let byte = byte.set_ms_bit(2);
|
||||
assert_eq!(0b1010_0000, byte);
|
||||
let byte = byte.set_ms_bit(4);
|
||||
assert_eq!(0b1010_1000, byte);
|
||||
let byte = byte.set_ms_bit(6);
|
||||
assert_eq!(0b1010_1010, byte);
|
||||
let byte = byte.set_ms_bit(7);
|
||||
assert_eq!(0b1010_1011, byte);
|
||||
let byte = byte.set_ms_bit(5);
|
||||
assert_eq!(0b1010_1111, byte);
|
||||
let byte = byte.set_ms_bit(3);
|
||||
assert_eq!(0b1011_1111, byte);
|
||||
let byte = byte.set_ms_bit(1);
|
||||
assert_eq!(0b1111_1111, byte);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_bits_u16() {
|
||||
let byte = 0b0000_0000_0000_0000u16.set_ms_bit(0);
|
||||
assert_eq!(0b1000_0000_0000_0000, byte);
|
||||
let byte = byte.set_ms_bit(15);
|
||||
assert_eq!(0b1000_0000_0000_0001, byte);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod lsb_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn is_set_trailing_ls_bit_u8() {
|
||||
assert!(0b1111_1111u8.ls_bit(0));
|
||||
assert!(!0b1111_1110u8.ls_bit(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_set_leading_ls_bit_u8() {
|
||||
assert!(0b1111_1111u8.ls_bit(7));
|
||||
assert!(!0b0111_1111u8.ls_bit(7));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn get_panics_when_over_bit_length_u8() {
|
||||
0u8.ls_bit(8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn set_panics_when_over_bit_length_u8() {
|
||||
0u8.set_ls_bit(8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn get_panics_when_over_bit_length_u32() {
|
||||
0u32.ls_bit(32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_set_trailing_ls_bit_u32() {
|
||||
assert!(0b1111_1111_1111_1111_1111_1111_1111_1111u32.ls_bit(0));
|
||||
assert!(!0b1111_1111_1111_1111_1111_1111_1111_1110u32.ls_bit(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_set_leading_ls_bit_u32() {
|
||||
assert!(0b1111_1111_1111_1111_1111_1111_1111_1111u32.ls_bit(31));
|
||||
assert!(!0b0111_1111_1111_1111_1111_1111_1111_1111u32.ls_bit(31));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_same_bit_twice() {
|
||||
let byte = 0b0000_0000u8.set_ls_bit(0);
|
||||
assert_eq!(0b0000_0001, byte);
|
||||
let byte = byte.set_ls_bit(0);
|
||||
assert_eq!(0b0000_0001, byte);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_bits_u8() {
|
||||
let byte = 0b0000_0000u8.set_ls_bit(0);
|
||||
assert_eq!(0b0000_0001, byte);
|
||||
let byte = byte.set_ls_bit(2);
|
||||
assert_eq!(0b0000_0101, byte);
|
||||
let byte = byte.set_ls_bit(4);
|
||||
assert_eq!(0b0001_0101, byte);
|
||||
let byte = byte.set_ls_bit(6);
|
||||
assert_eq!(0b0101_0101, byte);
|
||||
let byte = byte.set_ls_bit(7);
|
||||
assert_eq!(0b1101_0101, byte);
|
||||
let byte = byte.set_ls_bit(5);
|
||||
assert_eq!(0b1111_0101, byte);
|
||||
let byte = byte.set_ls_bit(3);
|
||||
assert_eq!(0b1111_1101, byte);
|
||||
let byte = byte.set_ls_bit(1);
|
||||
assert_eq!(0b1111_1111, byte);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_bits_u16() {
|
||||
let byte = 0b0000_0000_0000_0000u16.set_ls_bit(0);
|
||||
assert_eq!(0b0000_0000_0000_0001, byte);
|
||||
let byte = byte.set_ls_bit(15);
|
||||
assert_eq!(0b1000_0000_0000_0001, byte);
|
||||
}
|
||||
}
|
||||
285
src/common/bytes_reader.rs
Normal file
285
src/common/bytes_reader.rs
Normal file
@ -0,0 +1,285 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::{convert::TryInto, fmt::Debug};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::common::{U24, U48};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct BytesReader<'a> {
|
||||
data: &'a [u8],
|
||||
}
|
||||
|
||||
impl<'a> BytesReader<'a> {
|
||||
pub fn from_slice(slice: &'a [u8]) -> Self {
|
||||
Self { data: slice }
|
||||
}
|
||||
|
||||
pub fn read_u8(&mut self) -> ReadResult<u8> {
|
||||
Ok(self.read_bytes(1)?[0])
|
||||
}
|
||||
|
||||
/// Gets the next `n` bytes from the stream.
|
||||
///
|
||||
/// If there are not at least `n` bytes remaining it reads them all and returns the Error `End`.
|
||||
pub fn read_bytes(&mut self, n: usize) -> ReadResult<&'a [u8]> {
|
||||
if self.data.len() < n {
|
||||
self.data = &self.data[self.data.len()..];
|
||||
return Err(ReadError::End);
|
||||
}
|
||||
let result = &self.data[..n];
|
||||
self.data = &self.data[n..];
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Reads the next `n` bytes as an independent `BytesReader`.
|
||||
///
|
||||
/// If there are not at least `n` bytes remaining it reads them all and returns the Error `End`.
|
||||
pub fn read(&mut self, n: usize) -> ReadResult<BytesReader<'a>> {
|
||||
Ok(Self::from_slice(self.read_bytes(n)?))
|
||||
}
|
||||
|
||||
/// Reads all remaining bytes.
|
||||
pub fn read_all(&mut self) -> &[u8] {
|
||||
let result = self.data;
|
||||
self.data = &self.data[self.data.len()..];
|
||||
result
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
|
||||
/// Reads a `T` using the supplied function.
|
||||
///
|
||||
/// This will only stop reading when the stream end is hit, this means that once read_one starts,
|
||||
/// it must complete a whole read.
|
||||
///
|
||||
/// Panics if `read_one` reads nothing.
|
||||
pub fn read_until_end_exactly<T>(
|
||||
&mut self,
|
||||
read_one: impl Fn(&mut Self) -> ReadResult<T>,
|
||||
) -> ReadResult<Vec<T>> {
|
||||
let mut values = vec![];
|
||||
while !self.is_empty() {
|
||||
let data_length = self.data.len();
|
||||
values.push(
|
||||
read_one(self)
|
||||
.map_err(|_| ReadError::EndReachedInLoop(values.len(), data_length))?,
|
||||
);
|
||||
assert!(
|
||||
data_length > self.data.len(),
|
||||
"The function did not read anything"
|
||||
);
|
||||
}
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! read_implementation {
|
||||
($T:ty, $be_fn_name:ident, $le_fn_name:ident) => {
|
||||
impl<'a> BytesReader<'a> {
|
||||
pub fn $be_fn_name(&mut self) -> ReadResult<$T> {
|
||||
let size = <$T>::BITS as usize / 8;
|
||||
let bytes = self.read_bytes(size)?;
|
||||
Ok(<$T>::from_be_bytes(bytes[0..size].try_into().unwrap()))
|
||||
}
|
||||
|
||||
pub fn $le_fn_name(&mut self) -> ReadResult<$T> {
|
||||
let size = <$T>::BITS as usize / 8;
|
||||
let bytes = self.read_bytes(size)?;
|
||||
Ok(<$T>::from_le_bytes(bytes[0..size].try_into().unwrap()))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
read_implementation! { u16, read_u16_be, read_u16_le }
|
||||
read_implementation! { i16, read_i16_be, read_i16_le }
|
||||
read_implementation! { U24, read_u24_be, read_u24_le }
|
||||
read_implementation! { u32, read_u32_be, read_u32_le }
|
||||
read_implementation! { i32, read_i32_be, read_i32_le }
|
||||
read_implementation! { U48, read_u48_be, read_u48_le }
|
||||
|
||||
pub type ReadResult<T> = Result<T, ReadError>;
|
||||
|
||||
#[derive(Error, Eq, PartialEq, Debug, Copy, Clone)]
|
||||
pub enum ReadError {
|
||||
#[error("Reached end of stream")]
|
||||
End,
|
||||
#[error("Reached end of stream in loop. After {0} successful loops, there were only {1} bytes available at the next loop start.")]
|
||||
EndReachedInLoop(usize, usize),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use hex_literal::hex;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn read_u8() {
|
||||
let data = &hex!("01 02 03");
|
||||
let mut reader = BytesReader::from_slice(data);
|
||||
assert_eq!(Ok(1), reader.read_u8());
|
||||
assert_eq!(Ok(2), reader.read_u8());
|
||||
assert_eq!(Ok(3), reader.read_u8());
|
||||
assert_eq!(Err(ReadError::End), reader.read_u8());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_u16_be() {
|
||||
let data = &hex!("0102 0304");
|
||||
let mut reader = BytesReader::from_slice(data);
|
||||
assert_eq!(Ok(0x0102), reader.read_u16_be());
|
||||
assert_eq!(Ok(0x0304), reader.read_u16_be());
|
||||
assert_eq!(Err(ReadError::End), reader.read_u16_be());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_u16_le() {
|
||||
let data = &hex!("0102 0304");
|
||||
let mut reader = BytesReader::from_slice(data);
|
||||
assert_eq!(Ok(0x0201), reader.read_u16_le());
|
||||
assert_eq!(Ok(0x0403), reader.read_u16_le());
|
||||
assert_eq!(Err(ReadError::End), reader.read_u16_le());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_i16_be() {
|
||||
let data = &hex!("0102 ffff fffe 7fff 8000");
|
||||
let mut reader = BytesReader::from_slice(data);
|
||||
assert_eq!(Ok(0x0102), reader.read_i16_be());
|
||||
assert_eq!(Ok(-1), reader.read_i16_be());
|
||||
assert_eq!(Ok(-2), reader.read_i16_be());
|
||||
assert_eq!(Ok(i16::MAX), reader.read_i16_be());
|
||||
assert_eq!(Ok(i16::MIN), reader.read_i16_be());
|
||||
assert_eq!(Err(ReadError::End), reader.read_u16_be());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_u8_after_u16_fail_to_get_all_bytes() {
|
||||
let data = &hex!("0102 0304 05");
|
||||
let mut reader = BytesReader::from_slice(data);
|
||||
assert_eq!(Ok(0x0102), reader.read_u16_be());
|
||||
assert_eq!(Ok(0x0304), reader.read_u16_be());
|
||||
assert_eq!(Err(ReadError::End), reader.read_u16_be());
|
||||
assert_eq!(Err(ReadError::End), reader.read_u8());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_u32_be() {
|
||||
let data = &hex!("01020304 05060708");
|
||||
let mut reader = BytesReader::from_slice(data);
|
||||
assert_eq!(Ok(0x01020304), reader.read_u32_be());
|
||||
assert_eq!(Ok(0x05060708), reader.read_u32_be());
|
||||
assert_eq!(Err(ReadError::End), reader.read_u32_be());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_u24_be() {
|
||||
let data = &hex!("010203 040506");
|
||||
let mut reader = BytesReader::from_slice(data);
|
||||
assert_eq!(Ok(U24::truncate(0x010203)), reader.read_u24_be());
|
||||
assert_eq!(Ok(U24::truncate(0x040506)), reader.read_u24_be());
|
||||
assert_eq!(Err(ReadError::End), reader.read_u24_be());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_u48_be() {
|
||||
let data = &hex!("010203040506 0708090a0b0c");
|
||||
let mut reader = BytesReader::from_slice(data);
|
||||
assert_eq!(Ok(U48::truncate(0x010203040506u64)), reader.read_u48_be());
|
||||
assert_eq!(Ok(U48::truncate(0x0708090a0b0cu64)), reader.read_u48_be());
|
||||
assert_eq!(Err(ReadError::End), reader.read_u48_be());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_as_many_as_possible() {
|
||||
let data = &hex!("010203 040506");
|
||||
let mut reader = BytesReader::from_slice(data);
|
||||
let result = reader.read_until_end_exactly(|s| Ok((s.read_u8()?, s.read_u16_be()?)));
|
||||
assert_eq!(Ok(vec![(1u8, 0x0203u16), (4u8, 0x0506u16)]), result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_as_many_as_possible_but_empty() {
|
||||
let data = &hex!("");
|
||||
let mut reader = BytesReader::from_slice(data);
|
||||
let result = reader.read_until_end_exactly(|s| Ok((s.read_u8()?, s.read_u16_be()?)));
|
||||
assert_eq!(Ok(vec![]), result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_as_many_as_possible_length_mismatch_two_loops_one_byte_remains() {
|
||||
let data = &hex!("010203 040506 07");
|
||||
let mut reader = BytesReader::from_slice(data);
|
||||
let result = reader.read_until_end_exactly(|s| Ok((s.read_u8()?, s.read_u16_be()?)));
|
||||
assert_eq!(Err(ReadError::EndReachedInLoop(2, 1)), result);
|
||||
assert!(reader.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_as_many_as_possible_length_mismatch_three_loops_two_bytes_remain() {
|
||||
let data = &hex!("010203 040506 070809 0a0b");
|
||||
let mut reader = BytesReader::from_slice(data);
|
||||
let result = reader.read_until_end_exactly(|s| Ok((s.read_u8()?, s.read_u16_be()?)));
|
||||
assert_eq!(Err(ReadError::EndReachedInLoop(3, 2)), result);
|
||||
assert!(reader.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "The function did not read anything")]
|
||||
fn read_as_many_as_possible_no_movement() {
|
||||
let data = &hex!("010203");
|
||||
let mut reader = BytesReader::from_slice(data);
|
||||
let _result = reader.read_until_end_exactly(|_| Ok(()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_as_many_as_possible_no_movement_but_empty_anyway() {
|
||||
let data = &hex!("");
|
||||
let mut reader = BytesReader::from_slice(data);
|
||||
let result = reader.read_until_end_exactly(|_| Ok(()));
|
||||
assert_eq!(Ok(vec![]), result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clone() {
|
||||
let data = &hex!("010203");
|
||||
let mut reader1 = BytesReader::from_slice(data);
|
||||
assert_eq!(Ok(1), reader1.read_u8());
|
||||
let mut reader2 = BytesReader::clone(&reader1);
|
||||
assert_eq!(Ok(2), reader2.read_u8());
|
||||
assert_eq!(Ok(2), reader1.read_u8());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_all() {
|
||||
let data = &hex!("010203");
|
||||
let mut byte_array = BytesReader::from_slice(data);
|
||||
let read = byte_array.read_all();
|
||||
assert_eq!(data, read);
|
||||
let read = byte_array.read_all();
|
||||
assert_eq!(vec![0u8; 0], read);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_all_via_read() {
|
||||
let data = &hex!("0102 03");
|
||||
let mut byte_array = BytesReader::from_slice(data);
|
||||
let mut sub_reader = byte_array.read(2).unwrap();
|
||||
let read = sub_reader.read_all();
|
||||
assert_eq!(&hex!("0102"), read);
|
||||
let read = sub_reader.read_all();
|
||||
assert_eq!(vec![0u8; 0], read);
|
||||
let read = byte_array.read_all();
|
||||
assert_eq!(&hex!("03"), read);
|
||||
let read = byte_array.read_all();
|
||||
assert_eq!(vec![0u8; 0], read);
|
||||
}
|
||||
}
|
||||
12
src/common/collections.rs
Normal file
12
src/common/collections.rs
Normal file
@ -0,0 +1,12 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
mod key_sorted_cache;
|
||||
mod ring_buffer;
|
||||
mod two_generation_cache;
|
||||
|
||||
pub use key_sorted_cache::*;
|
||||
pub use ring_buffer::*;
|
||||
pub use two_generation_cache::*;
|
||||
182
src/common/collections/key_sorted_cache.rs
Normal file
182
src/common/collections/key_sorted_cache.rs
Normal file
@ -0,0 +1,182 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
/// A (Key, Value) cache, that keeps the largest keys (by Ord trait) up to the size limit specified
|
||||
/// dropping the smallest key on insert if full.
|
||||
pub struct KeySortedCache<K, V> {
|
||||
limit: usize,
|
||||
value_by_key: BTreeMap<K, V>,
|
||||
}
|
||||
|
||||
impl<K: Ord + Clone, V> KeySortedCache<K, V> {
|
||||
pub fn new(limit: usize) -> Self {
|
||||
Self {
|
||||
limit,
|
||||
value_by_key: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, key: K, value: V) {
|
||||
self.value_by_key.insert(key, value);
|
||||
// TODO: use pop_first when out of experimental
|
||||
if self.value_by_key.len() > self.limit {
|
||||
let smallest_key = self.value_by_key.keys().next().unwrap().clone();
|
||||
self.value_by_key.remove(&smallest_key);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
|
||||
self.value_by_key.iter()
|
||||
}
|
||||
|
||||
pub fn iter_mut(&mut self) -> impl Iterator<Item = (&K, &mut V)> + '_ {
|
||||
self.value_by_key.iter_mut()
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, key: &K) {
|
||||
self.value_by_key.remove(key);
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.value_by_key.is_empty()
|
||||
}
|
||||
|
||||
pub fn retain(&mut self, f: impl FnMut(&K, &mut V) -> bool) {
|
||||
self.value_by_key.retain(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn fill_buffer() {
|
||||
let mut buffer = KeySortedCache::new(2);
|
||||
assert!(buffer.is_empty());
|
||||
buffer.insert(1, "A");
|
||||
assert!(!buffer.is_empty());
|
||||
buffer.insert(2, "B");
|
||||
assert!(!buffer.is_empty());
|
||||
|
||||
assert_eq!(
|
||||
vec![(&1, &"A"), (&2, &"B")],
|
||||
buffer.iter().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn overfill_buffer() {
|
||||
let mut buffer = KeySortedCache::new(2);
|
||||
buffer.insert(1, "A");
|
||||
buffer.insert(2, "B");
|
||||
buffer.insert(3, "C");
|
||||
|
||||
assert_eq!(
|
||||
vec![(&2, &"B"), (&3, &"C")],
|
||||
buffer.iter().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn overfill_buffer_with_lower_key() {
|
||||
let mut buffer = KeySortedCache::new(2);
|
||||
buffer.insert(2, "B");
|
||||
buffer.insert(3, "C");
|
||||
buffer.insert(1, "A");
|
||||
|
||||
assert_eq!(
|
||||
vec![(&2, &"B"), (&3, &"C")],
|
||||
buffer.iter().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn overfill_buffer_with_middle_key() {
|
||||
let mut buffer = KeySortedCache::new(2);
|
||||
buffer.insert(1, "A");
|
||||
buffer.insert(3, "C");
|
||||
buffer.insert(2, "B");
|
||||
|
||||
assert_eq!(
|
||||
vec![(&2, &"B"), (&3, &"C")],
|
||||
buffer.iter().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn replace_key() {
|
||||
let mut buffer = KeySortedCache::new(3);
|
||||
buffer.insert(1, "A");
|
||||
buffer.insert(2, "B");
|
||||
buffer.insert(1, "C");
|
||||
|
||||
assert_eq!(
|
||||
vec![(&1, &"C"), (&2, &"B")],
|
||||
buffer.iter().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn replace_key_once_full() {
|
||||
let mut buffer = KeySortedCache::new(2);
|
||||
buffer.insert(1, "A");
|
||||
buffer.insert(2, "B");
|
||||
buffer.insert(1, "C");
|
||||
|
||||
assert_eq!(
|
||||
vec![(&1, &"C"), (&2, &"B")],
|
||||
buffer.iter().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_key() {
|
||||
let mut buffer = KeySortedCache::new(3);
|
||||
buffer.insert(1, "A");
|
||||
buffer.insert(2, "B");
|
||||
buffer.remove(&1);
|
||||
|
||||
assert_eq!(vec![(&2, &"B")], buffer.iter().collect::<Vec<_>>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iter_mut_update_value() {
|
||||
let mut buffer = KeySortedCache::new(3);
|
||||
buffer.insert(1, "A".to_string());
|
||||
buffer.insert(2, "B".to_string());
|
||||
buffer.insert(3, "C".to_string());
|
||||
|
||||
for (_k, v) in buffer.iter_mut() {
|
||||
*v = format!("{}x{}", v, v);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
vec![
|
||||
(&1, &"AxA".to_string()),
|
||||
(&2, &"BxB".to_string()),
|
||||
(&3, &"CxC".to_string())
|
||||
],
|
||||
buffer.iter().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retain() {
|
||||
let mut buffer = KeySortedCache::new(4);
|
||||
buffer.insert(1, "A");
|
||||
buffer.insert(2, "B");
|
||||
buffer.insert(3, "C");
|
||||
buffer.insert(4, "D");
|
||||
buffer.retain(|key, _value| *key > 2);
|
||||
|
||||
assert_eq!(
|
||||
vec![(&3, &"C"), (&4, &"D")],
|
||||
buffer.iter().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
}
|
||||
120
src/common/collections/ring_buffer.rs
Normal file
120
src/common/collections/ring_buffer.rs
Normal file
@ -0,0 +1,120 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::collections::VecDeque;
|
||||
|
||||
/// A fixed size RingBuffer. On insert drops the oldest inserted item iff full.
|
||||
pub struct RingBuffer<T> {
|
||||
limit: usize,
|
||||
values: VecDeque<T>,
|
||||
}
|
||||
|
||||
impl<T> RingBuffer<T> {
|
||||
pub fn new(limit: usize) -> Self {
|
||||
Self {
|
||||
limit,
|
||||
// + 1 as we push before pop
|
||||
values: VecDeque::with_capacity(limit + 1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Iff pushing popped off an old value, return it.
|
||||
pub fn push(&mut self, value: T) -> Option<T> {
|
||||
self.values.push_back(value);
|
||||
if self.values.len() > self.limit {
|
||||
self.values.pop_front()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl DoubleEndedIterator<Item = &T> + ExactSizeIterator + Clone + '_ {
|
||||
self.values.iter()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.values.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.values.is_empty()
|
||||
}
|
||||
|
||||
pub fn is_full(&self) -> bool {
|
||||
self.len() == self.limit
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::RingBuffer;
|
||||
|
||||
#[test]
|
||||
fn is_empty() {
|
||||
let mut buffer = RingBuffer::new(3);
|
||||
assert!(buffer.is_empty());
|
||||
buffer.push(1);
|
||||
assert!(!buffer.is_empty());
|
||||
buffer.push(2);
|
||||
assert!(!buffer.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn len() {
|
||||
let mut buffer = RingBuffer::new(3);
|
||||
assert_eq!(0, buffer.len());
|
||||
buffer.push(1);
|
||||
assert_eq!(1, buffer.len());
|
||||
buffer.push(1);
|
||||
assert_eq!(2, buffer.len());
|
||||
buffer.push(1);
|
||||
assert_eq!(3, buffer.len());
|
||||
buffer.push(1);
|
||||
assert_eq!(3, buffer.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full() {
|
||||
let mut buffer = RingBuffer::new(3);
|
||||
assert!(!buffer.is_full());
|
||||
buffer.push(1);
|
||||
assert!(!buffer.is_full());
|
||||
buffer.push(1);
|
||||
assert!(!buffer.is_full());
|
||||
buffer.push(1);
|
||||
assert!(buffer.is_full());
|
||||
buffer.push(2);
|
||||
assert!(buffer.is_full());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_to_limit() {
|
||||
let mut buffer = RingBuffer::new(3);
|
||||
assert_eq!(None, buffer.push(1));
|
||||
assert_eq!(None, buffer.push(3));
|
||||
assert_eq!(None, buffer.push(5));
|
||||
assert_eq!(vec![1, 3, 5], buffer.iter().copied().collect::<Vec<_>>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_beyond_limit() {
|
||||
let mut buffer = RingBuffer::new(2);
|
||||
assert_eq!(None, buffer.push(1));
|
||||
assert_eq!(None, buffer.push(3));
|
||||
assert_eq!(Some(1), buffer.push(5));
|
||||
assert_eq!(vec![3, 5], buffer.iter().copied().collect::<Vec<_>>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_beyond_twice_limit() {
|
||||
let mut buffer = RingBuffer::new(2);
|
||||
assert_eq!(None, buffer.push(1));
|
||||
assert_eq!(None, buffer.push(3));
|
||||
assert_eq!(Some(1), buffer.push(5));
|
||||
assert_eq!(Some(3), buffer.push(7));
|
||||
assert_eq!(Some(5), buffer.push(9));
|
||||
assert_eq!(vec![7, 9], buffer.iter().copied().collect::<Vec<_>>());
|
||||
}
|
||||
}
|
||||
356
src/common/collections/two_generation_cache.rs
Normal file
356
src/common/collections/two_generation_cache.rs
Normal file
@ -0,0 +1,356 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::{borrow::Borrow, collections::HashMap, hash::Hash, mem};
|
||||
|
||||
use crate::common::{Duration, Instant};
|
||||
|
||||
/// A cache that keeps values for at least the specified generation lifetime, and at most around 2x
|
||||
/// generation lifetime assuming regular invocations of [insert]. Note ejections are only done
|
||||
/// on [insert] so it is possible that an entry can be returned after 2x generation_lifetime has
|
||||
/// passed.
|
||||
///
|
||||
/// Users should consider the implications of dropping potentially large numbers of items at once.
|
||||
///
|
||||
/// [insert]: TwoGenerationCache::insert
|
||||
pub struct TwoGenerationCache<K, V>(TwoGenerationCacheWithManualRemoveOld<K, V>)
|
||||
where
|
||||
K: Hash + Eq;
|
||||
|
||||
impl<K, V> TwoGenerationCache<K, V>
|
||||
where
|
||||
K: Hash + Eq,
|
||||
{
|
||||
pub fn new(generation_lifetime: Duration, now: Instant) -> Self {
|
||||
Self(TwoGenerationCacheWithManualRemoveOld::new(
|
||||
generation_lifetime,
|
||||
now,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, key: K, value: V, now: Instant) {
|
||||
self.0.swap_generations_if_its_been_too_long(now);
|
||||
self.0.insert_without_removing_old(key, value);
|
||||
}
|
||||
|
||||
pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<&V>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq,
|
||||
{
|
||||
self.0.get(key)
|
||||
}
|
||||
|
||||
pub fn remove<Q: ?Sized>(&mut self, key: &Q) -> Option<V>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq,
|
||||
{
|
||||
self.0.remove(key)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn generation0_len(&self) -> usize {
|
||||
self.0.generation0.len()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn generation1_len(&self) -> usize {
|
||||
self.0.generation1.len()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TwoGenerationCacheWithManualRemoveOld<K, V>
|
||||
where
|
||||
K: Hash + Eq,
|
||||
{
|
||||
generation_lifetime: Duration,
|
||||
generation0: HashMap<K, V>,
|
||||
generation1: HashMap<K, V>,
|
||||
generation1_expires: Instant,
|
||||
}
|
||||
|
||||
impl<K, V> TwoGenerationCacheWithManualRemoveOld<K, V>
|
||||
where
|
||||
K: Hash + Eq,
|
||||
{
|
||||
pub fn new(generation_lifetime: Duration, now: Instant) -> Self {
|
||||
Self {
|
||||
generation_lifetime,
|
||||
generation0: Default::default(),
|
||||
generation1: Default::default(),
|
||||
generation1_expires: now + generation_lifetime,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert_without_removing_old(&mut self, key: K, value: V) {
|
||||
self.generation0.insert(key, value);
|
||||
}
|
||||
|
||||
pub fn remove_old(&mut self, now: Instant) -> Vec<K> {
|
||||
if let Some(mut removed_gen) = self.swap_generations_if_its_been_too_long(now) {
|
||||
removed_gen
|
||||
.drain()
|
||||
.map(|(k, _v)| k)
|
||||
.filter(|k| !self.generation1.contains_key(&k))
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
fn swap_generations_if_its_been_too_long(&mut self, now: Instant) -> Option<HashMap<K, V>> {
|
||||
if now >= self.generation1_expires {
|
||||
let removed_gen = mem::replace(&mut self.generation1, mem::take(&mut self.generation0));
|
||||
self.generation1_expires = now + self.generation_lifetime;
|
||||
Some(removed_gen)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<&V>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq,
|
||||
{
|
||||
self.generation0
|
||||
.get(key)
|
||||
.or_else(|| self.generation1.get(key))
|
||||
}
|
||||
|
||||
pub fn remove<Q: ?Sized>(&mut self, key: &Q) -> Option<V>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq,
|
||||
{
|
||||
let removed0 = self.generation0.remove(key);
|
||||
let removed1 = self.generation1.remove(key);
|
||||
removed0.or(removed1)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod two_generation_cache_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn insert_and_get() {
|
||||
let mut lru = TwoGenerationCache::new(Duration::from_secs(1), Instant::now());
|
||||
|
||||
lru.insert("K", "V", Instant::now());
|
||||
|
||||
assert_eq!(Some(&"V"), lru.get("K"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_still_read_from_second_generation() {
|
||||
let lifetime = Duration::from_secs(1);
|
||||
let mut now = Instant::now();
|
||||
let mut lru = TwoGenerationCache::new(lifetime, now);
|
||||
|
||||
lru.insert("K1", "V1", now);
|
||||
now += lifetime;
|
||||
lru.insert("K2", "V2", now);
|
||||
|
||||
assert_eq!(Some(&"V1"), lru.get("K1"));
|
||||
assert_eq!(Some(&"V2"), lru.get("K2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_present_after_two_generations() {
|
||||
let lifetime = Duration::from_secs(1);
|
||||
let mut now = Instant::now();
|
||||
let mut lru = TwoGenerationCache::new(lifetime, now);
|
||||
|
||||
lru.insert("K1", "V1", now);
|
||||
now += lifetime;
|
||||
lru.insert("K2", "V2", now);
|
||||
now += lifetime;
|
||||
lru.insert("K3", "V3", now);
|
||||
|
||||
assert_eq!(None, lru.get("K1"));
|
||||
assert_eq!(Some(&"V2"), lru.get("K2"));
|
||||
assert_eq!(Some(&"V3"), lru.get("K3"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_present_after_three_generations() {
|
||||
let lifetime = Duration::from_secs(1);
|
||||
let mut now = Instant::now();
|
||||
let mut lru = TwoGenerationCache::new(lifetime, now);
|
||||
|
||||
lru.insert("K1", "V1", now);
|
||||
now += lifetime;
|
||||
lru.insert("K2", "V2", now);
|
||||
now += lifetime;
|
||||
lru.insert("K3", "V3", now);
|
||||
now += lifetime;
|
||||
lru.insert("K4", "V4", now);
|
||||
|
||||
assert_eq!(None, lru.get("K1"));
|
||||
assert_eq!(None, lru.get("K2"));
|
||||
assert_eq!(Some(&"V3"), lru.get("K3"));
|
||||
assert_eq!(Some(&"V4"), lru.get("K4"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generation_expiry_resets_to_last_inserted_time_plus_duration() {
|
||||
let lifetime = Duration::from_secs(2);
|
||||
let mut now = Instant::now();
|
||||
let mut lru = TwoGenerationCache::new(lifetime, now);
|
||||
|
||||
lru.insert("K1", "V1", now);
|
||||
now += lifetime;
|
||||
lru.insert("K2", "V2", now);
|
||||
now += lifetime;
|
||||
lru.insert("K3", "V3", now);
|
||||
now += Duration::from_millis(1999);
|
||||
lru.insert("K4", "V4", now);
|
||||
|
||||
assert_eq!(None, lru.get("K1"));
|
||||
|
||||
// Generation 1
|
||||
assert_eq!(Some(&"V2"), lru.get("K2"));
|
||||
|
||||
// Generation 0
|
||||
assert_eq!(Some(&"V3"), lru.get("K3"));
|
||||
assert_eq!(Some(&"V4"), lru.get("K4"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn longest_and_shortest_lifetime_is_in_range_1_2x_lifetime() {
|
||||
let lifetime = Duration::from_secs(1);
|
||||
let mut now = Instant::now();
|
||||
let mut lru = TwoGenerationCache::new(lifetime, now);
|
||||
|
||||
let longest_insert_time = now;
|
||||
lru.insert("KLongest", "VLongest", longest_insert_time);
|
||||
now += lifetime;
|
||||
let shortest_insert_time = now - Duration::from_nanos(1);
|
||||
lru.insert("KShortest", "VShortest", shortest_insert_time);
|
||||
|
||||
assert_eq!((2, 0), (lru.generation0_len(), lru.generation1_len()));
|
||||
|
||||
// Cause the generation to shift
|
||||
lru.insert("KTrigger", "VTrigger", now);
|
||||
assert_eq!((1, 2), (lru.generation0_len(), lru.generation1_len()));
|
||||
|
||||
assert_eq!(Some(&"VLongest"), lru.get("KLongest"));
|
||||
assert_eq!(Some(&"VShortest"), lru.get("KShortest"));
|
||||
|
||||
// Insert, just before the next eviction time
|
||||
now += lifetime;
|
||||
now -= Duration::from_nanos(1);
|
||||
lru.insert("KTrigger2", "VTrigger2", now);
|
||||
assert_eq!((2, 2), (lru.generation0_len(), lru.generation1_len()));
|
||||
|
||||
// These are the longest lifetimes we will see for these, [1..2) * lifetime.
|
||||
assert_eq!(
|
||||
(lifetime * 2).checked_sub(Duration::from_nanos(1)),
|
||||
now.checked_duration_since(longest_insert_time)
|
||||
);
|
||||
assert_eq!(
|
||||
lifetime,
|
||||
now.checked_duration_since(shortest_insert_time).unwrap()
|
||||
);
|
||||
|
||||
// Cause ejection of generation1
|
||||
now += Duration::from_nanos(1);
|
||||
lru.insert("KTrigger3", "VTrigger3", now);
|
||||
assert_eq!((1, 2), (lru.generation0_len(), lru.generation1_len()));
|
||||
|
||||
assert_eq!(None, lru.get("KLongest"));
|
||||
assert_eq!(None, lru.get("KShortest"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_when_in_gen_0() {
|
||||
let lifetime = Duration::from_secs(1);
|
||||
let now = Instant::now();
|
||||
let mut lru = TwoGenerationCache::new(lifetime, now);
|
||||
|
||||
lru.insert("R1", "V1", now);
|
||||
|
||||
assert_eq!(Some(&"V1"), lru.get("R1"));
|
||||
assert_eq!(Some("V1"), lru.remove("R1"));
|
||||
assert_eq!(None, lru.get("R1"));
|
||||
assert_eq!(None, lru.remove("R1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_when_in_gen_1() {
|
||||
let lifetime = Duration::from_secs(1);
|
||||
let mut now = Instant::now();
|
||||
let mut lru = TwoGenerationCache::new(lifetime, now);
|
||||
|
||||
lru.insert("R1", "V1", now);
|
||||
now += lifetime;
|
||||
lru.insert("K1", "V2", now);
|
||||
assert_eq!((1, 1), (lru.generation0_len(), lru.generation1_len()));
|
||||
|
||||
assert_eq!(Some(&"V1"), lru.get("R1"));
|
||||
assert_eq!(Some("V1"), lru.remove("R1"));
|
||||
assert_eq!(None, lru.get("R1"));
|
||||
assert_eq!(None, lru.remove("R1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_when_in_both_generations() {
|
||||
let lifetime = Duration::from_secs(1);
|
||||
let mut now = Instant::now();
|
||||
let mut lru = TwoGenerationCache::new(lifetime, now);
|
||||
|
||||
lru.insert("R1", "V1", now);
|
||||
now += lifetime;
|
||||
lru.insert("R1", "V2", now);
|
||||
assert_eq!((1, 1), (lru.generation0_len(), lru.generation1_len()));
|
||||
|
||||
assert_eq!(Some(&"V2"), lru.get("R1"));
|
||||
assert_eq!(Some("V2"), lru.remove("R1"));
|
||||
assert_eq!(None, lru.get("R1"));
|
||||
assert_eq!(None, lru.remove("R1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_old() {
|
||||
let lifetime = Duration::from_secs(1);
|
||||
let now = Instant::now();
|
||||
let mut lru: TwoGenerationCacheWithManualRemoveOld<u32, String> =
|
||||
TwoGenerationCacheWithManualRemoveOld::new(lifetime, now);
|
||||
|
||||
lru.insert_without_removing_old(1, "a".to_owned());
|
||||
lru.insert_without_removing_old(2, "b".to_owned());
|
||||
lru.insert_without_removing_old(3, "c".to_owned());
|
||||
|
||||
assert_eq!(
|
||||
vec![0u32; 0],
|
||||
lru.remove_old(now + Duration::from_millis(999))
|
||||
);
|
||||
assert_eq!(Some(&"a".to_owned()), lru.get(&1));
|
||||
assert_eq!(Some(&"b".to_owned()), lru.get(&2));
|
||||
assert_eq!(Some(&"c".to_owned()), lru.get(&3));
|
||||
|
||||
assert_eq!(
|
||||
vec![0u32; 0],
|
||||
lru.remove_old(now + Duration::from_millis(1001))
|
||||
);
|
||||
assert_eq!(Some(&"a".to_owned()), lru.get(&1));
|
||||
assert_eq!(Some(&"b".to_owned()), lru.get(&2));
|
||||
assert_eq!(Some(&"c".to_owned()), lru.get(&3));
|
||||
|
||||
lru.insert_without_removing_old(4, "d".to_owned());
|
||||
|
||||
let mut removed = lru.remove_old(now + Duration::from_millis(2001));
|
||||
removed.sort_unstable();
|
||||
assert_eq!(vec![1, 2, 3], removed,);
|
||||
assert_eq!(None, lru.get(&1));
|
||||
assert_eq!(None, lru.get(&2));
|
||||
assert_eq!(None, lru.get(&3));
|
||||
assert_eq!(Some(&"d".to_owned()), lru.get(&4));
|
||||
|
||||
assert_eq!(vec![4], lru.remove_old(now + Duration::from_millis(3001)));
|
||||
}
|
||||
}
|
||||
112
src/common/counters.rs
Normal file
112
src/common/counters.rs
Normal file
@ -0,0 +1,112 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::{
|
||||
convert::{TryFrom, TryInto},
|
||||
ops::Sub,
|
||||
};
|
||||
|
||||
/// Expands a truncated counter value to the full length by using the previous largest value as
|
||||
/// guide to rollover/rollunder. Updates this maximum.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `truncated` - The truncated counter value.
|
||||
/// * `max` - The previously return value from this function.
|
||||
/// * `width` - The bit width the supplied value has been truncated to.
|
||||
pub fn expand_truncated_counter<Truncated>(truncated: Truncated, max: &mut u64, width: usize) -> u64
|
||||
where
|
||||
Truncated: TryFrom<u64> + Into<u64> + Sub<Truncated, Output = Truncated> + Ord + Copy,
|
||||
<Truncated as TryFrom<u64>>::Error: std::fmt::Debug,
|
||||
{
|
||||
let mask: u64 = (1 << width) - 1;
|
||||
let really_big: Truncated = (1 << (width - 1)).try_into().unwrap();
|
||||
|
||||
let truncated_max = (*max & mask).try_into().unwrap();
|
||||
let max_roc = *max >> width;
|
||||
let roc: u64 = if truncated_max > truncated && truncated_max - truncated > really_big {
|
||||
// Truncated is a lot smaller than the max; It's likely a rollover.
|
||||
max_roc + 1
|
||||
} else if max_roc > 0 && truncated > truncated_max && truncated - truncated_max > really_big {
|
||||
// Truncated is a lot bigger than the max; It's likely a rollunder.
|
||||
max_roc - 1
|
||||
} else {
|
||||
// Truncated is close to the max, so it's neither rollover nor rollunder.
|
||||
max_roc
|
||||
};
|
||||
let full = (roc << width) | (truncated.into() & mask);
|
||||
if full > *max {
|
||||
*max = full;
|
||||
}
|
||||
full
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn zero_max() {
|
||||
let mut max = 0u64;
|
||||
let expanded = expand_truncated_counter(0x3u16, &mut max, 16);
|
||||
assert_eq!(0x3u64, expanded);
|
||||
assert_eq!(0x3u64, max);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roll_over() {
|
||||
let mut max = 0xffffu64;
|
||||
let expanded = expand_truncated_counter(0x0001u16, &mut max, 16);
|
||||
assert_eq!(0x10001u64, expanded);
|
||||
assert_eq!(0x10001u64, max);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roll_over_larger_roc() {
|
||||
let mut max = 0x3ffffu64;
|
||||
let expanded = expand_truncated_counter(0x0001u16, &mut max, 16);
|
||||
assert_eq!(0x40001u64, expanded);
|
||||
assert_eq!(0x40001u64, max);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roll_under() {
|
||||
let mut max = 0x10001u64;
|
||||
let expanded = expand_truncated_counter(0xffffu16, &mut max, 16);
|
||||
assert_eq!(0xffffu64, expanded);
|
||||
assert_eq!(0x10001u64, max);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roll_under_larger_roc() {
|
||||
let mut max = 0x30001u64;
|
||||
let expanded = expand_truncated_counter(0xffffu16, &mut max, 16);
|
||||
assert_eq!(0x2ffffu64, expanded);
|
||||
assert_eq!(0x30001u64, max);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_8_multiple_bits() {
|
||||
let mut max = 0b0011_1111;
|
||||
let expanded = expand_truncated_counter(0b0000u8, &mut max, 4);
|
||||
assert_eq!(0b0100_0000u64, expanded);
|
||||
assert_eq!(0b0100_0000u64, max);
|
||||
let expanded = expand_truncated_counter(0b1000u8, &mut max, 4);
|
||||
assert_eq!(0b0100_1000u64, expanded);
|
||||
assert_eq!(0b0100_1000u64, max);
|
||||
let expanded = expand_truncated_counter(0b0100u8, &mut max, 4);
|
||||
assert_eq!(0b0100_0100u64, expanded);
|
||||
assert_eq!(0b0100_1000u64, max);
|
||||
let expanded = expand_truncated_counter(0b1101u8, &mut max, 4);
|
||||
assert_eq!(0b0100_1101u64, expanded);
|
||||
assert_eq!(0b0100_1101u64, max);
|
||||
let expanded = expand_truncated_counter(0b0001u8, &mut max, 4);
|
||||
assert_eq!(0b0101_0001u64, expanded);
|
||||
assert_eq!(0b0101_0001u64, max);
|
||||
let expanded = expand_truncated_counter(0b1101u8, &mut max, 4);
|
||||
assert_eq!(0b0100_1101u64, expanded);
|
||||
assert_eq!(0b0101_0001u64, max);
|
||||
}
|
||||
}
|
||||
534
src/common/data_rate.rs
Normal file
534
src/common/data_rate.rs
Normal file
@ -0,0 +1,534 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::{
|
||||
fmt::{self, Display, Formatter},
|
||||
iter::Sum,
|
||||
ops::{Add, AddAssign, Div, Mul, Sub, SubAssign},
|
||||
};
|
||||
|
||||
use crate::common::Duration;
|
||||
|
||||
#[derive(Debug, Copy, Clone, Eq, Ord, PartialEq, PartialOrd)]
|
||||
pub struct DataSize {
|
||||
bits: u64,
|
||||
}
|
||||
|
||||
impl Default for DataSize {
|
||||
fn default() -> Self {
|
||||
Self::from_bits(0)
|
||||
}
|
||||
}
|
||||
|
||||
impl DataSize {
|
||||
const BITS_PER_BYTE: u64 = 8;
|
||||
const BITS_PER_KILO_BIT: u64 = 1000;
|
||||
const BITS_PER_MEGA_BIT: u64 = Self::BITS_PER_KILO_BIT * Self::BITS_PER_KILO_BIT;
|
||||
|
||||
pub fn from_bits(bits: u64) -> Self {
|
||||
Self { bits }
|
||||
}
|
||||
|
||||
pub fn as_bits(&self) -> u64 {
|
||||
self.bits
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> u64 {
|
||||
self.bits / Self::BITS_PER_BYTE
|
||||
}
|
||||
|
||||
pub fn from_bytes(bytes: u64) -> Self {
|
||||
Self::from_bits(bytes * Self::BITS_PER_BYTE)
|
||||
}
|
||||
|
||||
pub fn from_kilobits(kbits: u64) -> Self {
|
||||
Self::from_bits(kbits * Self::BITS_PER_KILO_BIT)
|
||||
}
|
||||
|
||||
pub fn saturating_sub(self, other: Self) -> Self {
|
||||
if self > other {
|
||||
self - other
|
||||
} else {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<DataSize> for DataSize {
|
||||
type Output = DataSize;
|
||||
|
||||
fn add(self, other: DataSize) -> DataSize {
|
||||
DataSize::from_bits(self.bits + other.bits)
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssign<DataSize> for DataSize {
|
||||
fn add_assign(&mut self, rhs: DataSize) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Sum for DataSize {
|
||||
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
Self::from_bits(iter.map(|size| size.bits).sum())
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<DataSize> for DataSize {
|
||||
type Output = DataSize;
|
||||
|
||||
fn sub(self, other: DataSize) -> DataSize {
|
||||
DataSize::from_bits(self.bits - other.bits)
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAssign<DataSize> for DataSize {
|
||||
fn sub_assign(&mut self, rhs: DataSize) {
|
||||
*self = *self - rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<f64> for DataSize {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, x: f64) -> Self {
|
||||
Self::from_bits((self.bits as f64 * x as f64) as u64)
|
||||
}
|
||||
}
|
||||
|
||||
impl Div<DataSize> for DataSize {
|
||||
type Output = f64;
|
||||
|
||||
fn div(self, other: DataSize) -> f64 {
|
||||
self.bits as f64 / other.bits as f64
|
||||
}
|
||||
}
|
||||
|
||||
impl Div<f64> for DataSize {
|
||||
type Output = Self;
|
||||
|
||||
fn div(self, x: f64) -> Self {
|
||||
Self::from_bits((self.bits as f64 / x as f64) as u64)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod data_size_tests {
|
||||
use super::DataSize;
|
||||
|
||||
#[test]
|
||||
fn default() {
|
||||
assert_eq!(DataSize::from_bits(0), Default::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_bits() {
|
||||
assert_eq!(1, DataSize::from_bits(1).as_bits());
|
||||
assert_eq!(8, DataSize::from_bits(8).as_bits());
|
||||
assert_eq!(16, DataSize::from_bits(16).as_bits());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_bytes() {
|
||||
assert_eq!(8, DataSize::from_bytes(1).as_bits());
|
||||
assert_eq!(16, DataSize::from_bytes(2).as_bits());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_kilobits() {
|
||||
assert_eq!(1_000, DataSize::from_kilobits(1).as_bits());
|
||||
assert_eq!(2_000, DataSize::from_kilobits(2).as_bits());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn as_bytes_rounds_down() {
|
||||
assert_eq!(0, DataSize::from_bits(1).as_bytes());
|
||||
assert_eq!(0, DataSize::from_bits(7).as_bytes());
|
||||
assert_eq!(1, DataSize::from_bits(8).as_bytes());
|
||||
assert_eq!(1, DataSize::from_bits(15).as_bytes());
|
||||
assert_eq!(2, DataSize::from_bits(16).as_bytes());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ordinal_comparisons() {
|
||||
assert!(DataSize::from_bits(2) > DataSize::from_bits(1));
|
||||
assert!(DataSize::from_bits(1) < DataSize::from_bits(2));
|
||||
assert!(DataSize::from_bits(2) >= DataSize::from_bits(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn addition() {
|
||||
assert_eq!(
|
||||
DataSize::from_bits(1_008),
|
||||
DataSize::from_kilobits(1) + DataSize::from_bytes(1)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_assign() {
|
||||
let mut size = DataSize::from_kilobits(1);
|
||||
size += DataSize::from_bytes(1);
|
||||
assert_eq!(DataSize::from_bits(1_008), size);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subtraction() {
|
||||
assert_eq!(
|
||||
DataSize::from_bits(992),
|
||||
DataSize::from_kilobits(1) - DataSize::from_bytes(1)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub_assign() {
|
||||
let mut size = DataSize::from_kilobits(1);
|
||||
size -= DataSize::from_bytes(1);
|
||||
assert_eq!(DataSize::from_bits(992), size);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn saturating_subtraction() {
|
||||
assert_eq!(
|
||||
DataSize::from_bits(901),
|
||||
DataSize::from_kilobits(1).saturating_sub(DataSize::from_bits(99))
|
||||
);
|
||||
assert_eq!(
|
||||
DataSize::from_bits(1),
|
||||
DataSize::from_bits(4).saturating_sub(DataSize::from_bits(3))
|
||||
);
|
||||
assert_eq!(
|
||||
DataSize::from_bits(0),
|
||||
DataSize::from_bits(4).saturating_sub(DataSize::from_bits(4))
|
||||
);
|
||||
assert_eq!(
|
||||
DataSize::from_bits(0),
|
||||
DataSize::from_bits(4).saturating_sub(DataSize::from_bits(5))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiplication_by_scalar() {
|
||||
assert_eq!(DataSize::from_bytes(56), DataSize::from_bytes(8) * 7.0f64);
|
||||
assert_eq!(DataSize::from_bytes(60), DataSize::from_bytes(8) * 7.5f64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::float_cmp)]
|
||||
fn division_by_data_size() {
|
||||
assert_eq!(7.0f64, DataSize::from_bytes(56) / DataSize::from_bytes(8));
|
||||
assert_eq!(7.5f64, DataSize::from_bytes(60) / DataSize::from_bytes(8));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn division_by_scalar() {
|
||||
assert_eq!(DataSize::from_bytes(8), DataSize::from_bytes(56) / 7.0f64);
|
||||
assert_eq!(DataSize::from_bytes(8), DataSize::from_bytes(60) / 7.5f64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sum() {
|
||||
let data_sizes = vec![
|
||||
DataSize::from_bits(1),
|
||||
DataSize::from_bits(2),
|
||||
DataSize::from_bits(5),
|
||||
];
|
||||
assert_eq!(DataSize::from_bits(8), data_sizes.into_iter().sum());
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Eq, Ord, PartialEq, PartialOrd)]
|
||||
pub struct DataRate {
|
||||
size_per_second: DataSize,
|
||||
}
|
||||
|
||||
impl Default for DataRate {
|
||||
fn default() -> Self {
|
||||
Self::from_bps(0)
|
||||
}
|
||||
}
|
||||
|
||||
impl DataRate {
|
||||
pub fn per_second(size_per_second: DataSize) -> Self {
|
||||
Self { size_per_second }
|
||||
}
|
||||
|
||||
pub fn from_bps(bps: u64) -> Self {
|
||||
Self::per_second(DataSize::from_bits(bps))
|
||||
}
|
||||
|
||||
pub fn from_kbps(kbps: u64) -> Self {
|
||||
Self::per_second(DataSize::from_kilobits(kbps))
|
||||
}
|
||||
|
||||
pub fn as_bps(&self) -> u64 {
|
||||
self.size_per_second.as_bits()
|
||||
}
|
||||
|
||||
pub fn as_kbps(&self) -> u64 {
|
||||
self.as_bps() / DataSize::BITS_PER_KILO_BIT
|
||||
}
|
||||
|
||||
pub fn saturating_sub(self, other: Self) -> Self {
|
||||
Self::per_second(self.size_per_second.saturating_sub(other.size_per_second))
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for DataRate {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
let bits = self.size_per_second.bits;
|
||||
if bits < DataSize::BITS_PER_KILO_BIT {
|
||||
write!(f, "{} bps", bits)
|
||||
} else if bits < DataSize::BITS_PER_MEGA_BIT {
|
||||
write!(
|
||||
f,
|
||||
"{:.1} Kbps",
|
||||
(bits * 10 / DataSize::BITS_PER_KILO_BIT) as f64 / 10f64
|
||||
)
|
||||
} else {
|
||||
write!(
|
||||
f,
|
||||
"{:.1} Mbps",
|
||||
(bits * 10 / DataSize::BITS_PER_MEGA_BIT) as f64 / 10f64
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<DataRate> for DataRate {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, other: Self) -> Self {
|
||||
DataRate::per_second(self.size_per_second + other.size_per_second)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sum for DataRate {
|
||||
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
Self::per_second(iter.map(|rate| rate.size_per_second).sum())
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<DataRate> for DataRate {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, other: Self) -> Self {
|
||||
DataRate::per_second(self.size_per_second - other.size_per_second)
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<f64> for DataRate {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, x: f64) -> Self {
|
||||
Self::per_second(self.size_per_second * x)
|
||||
}
|
||||
}
|
||||
|
||||
impl Div<DataRate> for DataRate {
|
||||
type Output = f64;
|
||||
|
||||
fn div(self, other: Self) -> f64 {
|
||||
self.size_per_second / other.size_per_second
|
||||
}
|
||||
}
|
||||
|
||||
impl Div<f64> for DataRate {
|
||||
type Output = Self;
|
||||
|
||||
fn div(self, x: f64) -> Self {
|
||||
Self::per_second(self.size_per_second / x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod data_rate_tests {
|
||||
use super::DataRate;
|
||||
|
||||
#[test]
|
||||
fn default() {
|
||||
assert_eq!(DataRate::from_bps(0), Default::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_bps() {
|
||||
assert_eq!(1, DataRate::from_bps(1).as_bps());
|
||||
assert_eq!(8, DataRate::from_bps(8).as_bps());
|
||||
assert_eq!(16, DataRate::from_bps(16).as_bps());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_kbps() {
|
||||
assert_eq!(1_000, DataRate::from_kbps(1).as_bps());
|
||||
assert_eq!(8_000, DataRate::from_kbps(8).as_bps());
|
||||
assert_eq!(16_000, DataRate::from_kbps(16).as_bps());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn as_kbps_rounds_down() {
|
||||
assert_eq!(0, DataRate::from_bps(1).as_kbps());
|
||||
assert_eq!(0, DataRate::from_bps(999).as_kbps());
|
||||
assert_eq!(1, DataRate::from_bps(1_000).as_kbps());
|
||||
assert_eq!(2, DataRate::from_bps(2_999).as_kbps());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ordinal_comparisons() {
|
||||
assert!(DataRate::from_bps(2) > DataRate::from_bps(1));
|
||||
assert!(DataRate::from_bps(1) < DataRate::from_bps(2));
|
||||
assert!(DataRate::from_bps(2) >= DataRate::from_bps(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn addition() {
|
||||
assert_eq!(
|
||||
DataRate::from_bps(1_099),
|
||||
DataRate::from_kbps(1) + DataRate::from_bps(99)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subtraction() {
|
||||
assert_eq!(
|
||||
DataRate::from_bps(901),
|
||||
DataRate::from_kbps(1) - DataRate::from_bps(99)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn saturating_subtraction() {
|
||||
assert_eq!(
|
||||
DataRate::from_bps(901),
|
||||
DataRate::from_kbps(1).saturating_sub(DataRate::from_bps(99))
|
||||
);
|
||||
assert_eq!(
|
||||
DataRate::from_bps(1),
|
||||
DataRate::from_bps(4).saturating_sub(DataRate::from_bps(3))
|
||||
);
|
||||
assert_eq!(
|
||||
DataRate::from_bps(0),
|
||||
DataRate::from_bps(4).saturating_sub(DataRate::from_bps(4))
|
||||
);
|
||||
assert_eq!(
|
||||
DataRate::from_bps(0),
|
||||
DataRate::from_bps(4).saturating_sub(DataRate::from_bps(5))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiplication_by_scalar() {
|
||||
assert_eq!(DataRate::from_bps(56), DataRate::from_bps(8) * 7.0f64);
|
||||
assert_eq!(DataRate::from_bps(60), DataRate::from_bps(8) * 7.5f64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::float_cmp)]
|
||||
fn division_by_data_rate() {
|
||||
assert_eq!(7.0f64, DataRate::from_bps(56) / DataRate::from_bps(8));
|
||||
assert_eq!(7.5f64, DataRate::from_bps(60) / DataRate::from_bps(8));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn division_by_scalar() {
|
||||
assert_eq!(DataRate::from_bps(8), DataRate::from_bps(56) / 7.0f64);
|
||||
assert_eq!(DataRate::from_bps(8), DataRate::from_bps(60) / 7.5f64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sum() {
|
||||
let data_rates = vec![
|
||||
DataRate::from_bps(1),
|
||||
DataRate::from_bps(2),
|
||||
DataRate::from_bps(5),
|
||||
];
|
||||
assert_eq!(DataRate::from_bps(8), data_rates.into_iter().sum());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_rounds_down_to_1_decimal_point() {
|
||||
assert_eq!("0 bps", format!("{}", DataRate::default()));
|
||||
assert_eq!("1 bps", format!("{}", DataRate::from_bps(1)));
|
||||
assert_eq!("999 bps", format!("{}", DataRate::from_bps(999)));
|
||||
assert_eq!("1.0 Kbps", format!("{}", DataRate::from_bps(1_000)));
|
||||
assert_eq!("1.5 Kbps", format!("{}", DataRate::from_bps(1_550)));
|
||||
assert_eq!("1.9 Kbps", format!("{}", DataRate::from_bps(1_999)));
|
||||
assert_eq!("999.9 Kbps", format!("{}", DataRate::from_bps(999_999)));
|
||||
assert_eq!("1.0 Mbps", format!("{}", DataRate::from_bps(1_000_000)));
|
||||
assert_eq!("2.3 Mbps", format!("{}", DataRate::from_bps(2_350_000)));
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<Duration> for DataRate {
|
||||
type Output = DataSize;
|
||||
|
||||
fn mul(self, duration: Duration) -> DataSize {
|
||||
DataSize::from_bits(((self.as_bps() as f64) * duration.as_secs_f64()) as u64)
|
||||
}
|
||||
}
|
||||
|
||||
impl Div<Duration> for DataSize {
|
||||
type Output = DataRate;
|
||||
|
||||
fn div(self, duration: Duration) -> DataRate {
|
||||
DataRate::from_bps((self.as_bits() as f64 / duration.as_secs_f64()) as u64)
|
||||
}
|
||||
}
|
||||
|
||||
impl Div<DataRate> for DataSize {
|
||||
type Output = Duration;
|
||||
|
||||
fn div(self, rate: DataRate) -> Duration {
|
||||
Duration::from_secs_f64((self.as_bits() as f64) / (rate.as_bps() as f64))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod data_rate_and_data_size_interaction_tests {
|
||||
use super::{DataRate, DataSize};
|
||||
use crate::common::Duration;
|
||||
|
||||
#[test]
|
||||
fn per_second() {
|
||||
assert_eq!(
|
||||
DataRate::from_bps(8),
|
||||
DataRate::per_second(DataSize::from_bytes(1))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn data_rate_multiplication_by_duration_gives_data_size() {
|
||||
assert_eq!(
|
||||
DataSize::from_bits(56),
|
||||
DataRate::from_bps(8) * Duration::from_secs(7)
|
||||
);
|
||||
assert_eq!(
|
||||
DataSize::from_bits(61_455),
|
||||
DataRate::from_bps(8_194) * Duration::from_secs_f64(7.5f64)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn data_size_division_by_duration_gives_data_rate() {
|
||||
assert_eq!(
|
||||
DataRate::from_bps(8),
|
||||
DataSize::from_bits(56) / Duration::from_secs(7)
|
||||
);
|
||||
assert_eq!(
|
||||
DataRate::from_bps(8_194),
|
||||
DataSize::from_bits(61_455) / Duration::from_secs_f64(7.5f64)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn data_size_division_by_data_rate_gives_duration() {
|
||||
assert_eq!(
|
||||
Duration::from_secs(7),
|
||||
DataSize::from_bits(56) / DataRate::from_bps(8)
|
||||
);
|
||||
assert_eq!(
|
||||
Duration::from_secs_f64(7.5f64),
|
||||
DataSize::from_bits(61_455) / DataRate::from_bps(8_194)
|
||||
);
|
||||
}
|
||||
}
|
||||
355
src/common/integers.rs
Normal file
355
src/common/integers.rs
Normal file
@ -0,0 +1,355 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
//! Contains non-standard integer lengths.
|
||||
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
fmt,
|
||||
fmt::{Debug, Display, Formatter},
|
||||
};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug, Copy, Clone, PartialEq, Eq)]
|
||||
#[error("out of range integral type conversion attempted")]
|
||||
pub struct TryFromIntError(());
|
||||
|
||||
macro_rules! non_standard_unsigned_int {
|
||||
(pub struct $T:ident($UNDERLYING_TYPE:ty) { bytes = $BYTE_WIDTH:literal }) => {
|
||||
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
|
||||
pub struct $T($UNDERLYING_TYPE);
|
||||
|
||||
impl $T {
|
||||
#[doc = "Number of bytes the type uses"]
|
||||
pub const SIZE: usize = $BYTE_WIDTH;
|
||||
|
||||
#[doc = "Number of bits the type uses"]
|
||||
pub const BITS: u32 = $BYTE_WIDTH * 8;
|
||||
|
||||
pub const MAX: $T = $T((1 << ($BYTE_WIDTH * 8)) - 1);
|
||||
pub const MIN: $T = $T(0);
|
||||
pub const ZERO: $T = $T(0);
|
||||
|
||||
pub fn truncate(value: $UNDERLYING_TYPE) -> Self {
|
||||
Self(value & Self::MAX.0)
|
||||
}
|
||||
|
||||
pub fn wrapping_add(self, other: Self) -> Self {
|
||||
Self::truncate(self.0 + other.0)
|
||||
}
|
||||
|
||||
pub fn from_be_bytes(bytes: [u8; Self::SIZE]) -> Self {
|
||||
let mut r: $UNDERLYING_TYPE = 0;
|
||||
for b in bytes.iter() {
|
||||
r = r << 8 | *b as $UNDERLYING_TYPE;
|
||||
}
|
||||
Self(r)
|
||||
}
|
||||
|
||||
pub fn from_le_bytes(bytes: [u8; Self::SIZE]) -> Self {
|
||||
let mut r: $UNDERLYING_TYPE = 0;
|
||||
for b in bytes.iter().rev() {
|
||||
r = r << 8 | *b as $UNDERLYING_TYPE;
|
||||
}
|
||||
Self(r)
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for $T {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
concat![stringify!($T), "({:#0width$x})"],
|
||||
self.0,
|
||||
width = $BYTE_WIDTH * 2 + 2
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for $T {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<$T> for $UNDERLYING_TYPE {
|
||||
fn from(value: $T) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
non_standard_unsigned_int! {
|
||||
pub struct U48(u64) { bytes = 6 }
|
||||
}
|
||||
|
||||
impl TryFrom<u64> for U48 {
|
||||
type Error = TryFromIntError;
|
||||
|
||||
fn try_from(value: u64) -> Result<Self, Self::Error> {
|
||||
if value > Self::MAX.0 {
|
||||
Err(TryFromIntError(()))
|
||||
} else {
|
||||
Ok(U48(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u16> for U48 {
|
||||
fn from(value: u16) -> Self {
|
||||
U48(value as u64)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u32> for U48 {
|
||||
fn from(value: u32) -> Self {
|
||||
U48(value as u64)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_pointer_width = "64")]
|
||||
impl From<U48> for usize {
|
||||
fn from(value: U48) -> Self {
|
||||
value.0 as usize
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod u48_tests {
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
|
||||
use hex_literal::hex;
|
||||
|
||||
use super::U48;
|
||||
|
||||
#[test]
|
||||
fn zero() {
|
||||
assert_eq!(Ok(U48(0)), 0u16.try_into());
|
||||
assert_eq!(U48(0), 0u16.into());
|
||||
assert_eq!(U48::ZERO, 0u16.into());
|
||||
assert_eq!(U48::MIN, 0u16.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one() {
|
||||
assert_eq!(Ok(U48(1)), 1u16.try_into());
|
||||
assert_eq!(U48(1), 1u16.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn largest_value() {
|
||||
assert_eq!(Ok(U48(0xffffffffffff)), 0xffffffffffffu64.try_into());
|
||||
assert_eq!(Ok(U48::MAX), 0xffffffffffffu64.try_into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn too_large() {
|
||||
assert!(U48::try_from(0x1000000000000u64).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_u64() {
|
||||
let u48: U48 = 0x101112131415u64.try_into().unwrap();
|
||||
assert_eq!(0x101112131415u64, u64::from(u48));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(target_pointer_width = "64")]
|
||||
fn to_usize() {
|
||||
let u48: U48 = 0x101112131415u64.try_into().unwrap();
|
||||
assert_eq!(0x101112131415usize, usize::from(u48));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn debug_print() {
|
||||
assert_eq!("U48(0x000000000064)", format!("{:?}", U48::from(100u16)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn print() {
|
||||
assert_eq!("100", format!("{}", U48::from(100u16)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_u16() {
|
||||
assert_eq!(U48(0xffff), 0xffffu16.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_u32() {
|
||||
assert_eq!(U48(0xffff0102), 0xffff0102u32.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate() {
|
||||
assert_eq!(U48(0xf7123456789a), U48::truncate(0xffeef7123456789au64));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrapping_add() {
|
||||
assert_eq!(U48::from(1u16), U48::ZERO.wrapping_add(U48::from(1u16)));
|
||||
assert_eq!(
|
||||
U48::from(3u16),
|
||||
U48::from(2u32).wrapping_add(U48::from(1u32))
|
||||
);
|
||||
assert_eq!(U48::from(9u16), U48::MAX.wrapping_add(U48::from(10u32)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compare() {
|
||||
assert!(U48(1) > U48::ZERO);
|
||||
assert!(U48(2) > U48(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_be_bytes() {
|
||||
assert_eq!(
|
||||
U48(0x111213141516),
|
||||
U48::from_be_bytes(hex!("11 12 13 14 15 16"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_le_bytes() {
|
||||
assert_eq!(
|
||||
U48(0x161514131211),
|
||||
U48::from_le_bytes(hex!("11 12 13 14 15 16"))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
non_standard_unsigned_int! {
|
||||
pub struct U24(u32) { bytes = 3 }
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for U24 {
|
||||
type Error = TryFromIntError;
|
||||
|
||||
fn try_from(value: u32) -> Result<Self, Self::Error> {
|
||||
if value > Self::MAX.0 {
|
||||
Err(TryFromIntError(()))
|
||||
} else {
|
||||
Ok(U24(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u16> for U24 {
|
||||
fn from(value: u16) -> Self {
|
||||
U24(value as u32)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<U24> for u64 {
|
||||
fn from(value: U24) -> Self {
|
||||
value.0 as u64
|
||||
}
|
||||
}
|
||||
|
||||
impl From<U24> for usize {
|
||||
fn from(value: U24) -> Self {
|
||||
value.0 as usize
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod u24_tests {
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
|
||||
use hex_literal::hex;
|
||||
|
||||
use super::U24;
|
||||
|
||||
#[test]
|
||||
fn zero() {
|
||||
assert_eq!(Ok(U24(0)), 0u16.try_into());
|
||||
assert_eq!(U24(0), 0.into());
|
||||
assert_eq!(U24::ZERO, 0.into());
|
||||
assert_eq!(U24::MIN, 0.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one() {
|
||||
assert_eq!(Ok(U24(1)), 1u16.try_into());
|
||||
assert_eq!(U24(1), 1.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn largest_value() {
|
||||
assert_eq!(Ok(U24(0xffffff)), 0xffffffu32.try_into());
|
||||
assert_eq!(Ok(U24::MAX), 0xffffffu32.try_into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn too_large() {
|
||||
assert!(U24::try_from(0x1000000u32).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_u32() {
|
||||
let u24: U24 = 0x121314u32.try_into().unwrap();
|
||||
assert_eq!(0x121314u32, u32::from(u24));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_u64() {
|
||||
let u24: U24 = 0x121314u32.try_into().unwrap();
|
||||
assert_eq!(0x121314u64, u64::from(u24));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_usize() {
|
||||
let u24: U24 = 0x121314u32.try_into().unwrap();
|
||||
assert_eq!(0x121314usize, usize::from(u24));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn debug_print() {
|
||||
assert_eq!("U24(0x000064)", format!("{:?}", U24::from(100)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn print() {
|
||||
assert_eq!("100", format!("{}", U24::from(100)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_u16() {
|
||||
assert_eq!(U24(0xffff), 0xffff.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate() {
|
||||
assert_eq!(U24(0xf71234), U24::truncate(0xfef71234u32));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrapping_add() {
|
||||
assert_eq!(U24::from(1u16), U24::ZERO.wrapping_add(U24::from(1u16)));
|
||||
assert_eq!(
|
||||
U24::from(3u16),
|
||||
U24::from(2u16).wrapping_add(U24::from(1u16))
|
||||
);
|
||||
assert_eq!(U24::from(9u16), U24::MAX.wrapping_add(U24::from(10u16)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compare() {
|
||||
assert!(U24(1) > U24::ZERO);
|
||||
assert!(U24(2) > U24(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_be_bytes() {
|
||||
assert_eq!(U24(0x141516), U24::from_be_bytes(hex!("14 15 16")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_le_bytes() {
|
||||
assert_eq!(U24(0x161514), U24::from_le_bytes(hex!("14 15 16")));
|
||||
}
|
||||
}
|
||||
133
src/common/math.rs
Normal file
133
src/common/math.rs
Normal file
@ -0,0 +1,133 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::ops::{Add, Mul, Sub};
|
||||
|
||||
pub fn round_up_to_multiple_of<const M: usize>(n: usize) -> usize {
|
||||
(n + (M - 1)) / M * M
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod round_up_multiple_of_n_tests {
|
||||
use super::round_up_to_multiple_of;
|
||||
|
||||
#[test]
|
||||
fn round_up_multiple_4() {
|
||||
assert_eq!(0, round_up_to_multiple_of::<4>(0));
|
||||
assert_eq!(4, round_up_to_multiple_of::<4>(1));
|
||||
assert_eq!(4, round_up_to_multiple_of::<4>(2));
|
||||
assert_eq!(4, round_up_to_multiple_of::<4>(3));
|
||||
assert_eq!(4, round_up_to_multiple_of::<4>(4));
|
||||
assert_eq!(8, round_up_to_multiple_of::<4>(5));
|
||||
assert_eq!(8, round_up_to_multiple_of::<4>(6));
|
||||
assert_eq!(8, round_up_to_multiple_of::<4>(7));
|
||||
assert_eq!(8, round_up_to_multiple_of::<4>(8));
|
||||
assert_eq!(12, round_up_to_multiple_of::<4>(9));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_up_multiple_5() {
|
||||
assert_eq!(0, round_up_to_multiple_of::<5>(0));
|
||||
assert_eq!(5, round_up_to_multiple_of::<5>(1));
|
||||
assert_eq!(5, round_up_to_multiple_of::<5>(2));
|
||||
assert_eq!(5, round_up_to_multiple_of::<5>(3));
|
||||
assert_eq!(5, round_up_to_multiple_of::<5>(4));
|
||||
assert_eq!(5, round_up_to_multiple_of::<5>(5));
|
||||
assert_eq!(10, round_up_to_multiple_of::<5>(6));
|
||||
assert_eq!(10, round_up_to_multiple_of::<5>(10));
|
||||
assert_eq!(15, round_up_to_multiple_of::<5>(11));
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Square: Copy + Mul + Sized {
|
||||
fn square(self) -> Self::Output {
|
||||
self * self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Copy + Mul + Sized> Square for T {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod square_tests {
|
||||
use super::Square;
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::float_cmp)]
|
||||
fn simple() {
|
||||
assert_eq!(0, 0.square());
|
||||
assert_eq!(4, 2.square());
|
||||
assert_eq!(4, (-2).square());
|
||||
|
||||
assert_eq!(0.25, 0.5.square());
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AbsDiff: PartialOrd + Sub + Sized {
|
||||
fn abs_diff(self, other: Self) -> Self::Output {
|
||||
if self > other {
|
||||
self - other
|
||||
} else {
|
||||
other - self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialOrd + Sub + Sized> AbsDiff for T {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod abs_diff_tests {
|
||||
use super::AbsDiff;
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::float_cmp)]
|
||||
fn positive() {
|
||||
assert_eq!(0, 1.abs_diff(1));
|
||||
assert_eq!(5, 10.abs_diff(15));
|
||||
assert_eq!(5, 15.abs_diff(10));
|
||||
|
||||
assert_eq!(5.0, 15.0.abs_diff(10.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::float_cmp)]
|
||||
fn negative() {
|
||||
assert_eq!(0, (-1).abs_diff(-1));
|
||||
assert_eq!(5, (-10).abs_diff(-15));
|
||||
assert_eq!(5, (-15).abs_diff(-10));
|
||||
|
||||
assert_eq!(5.0, (-15.0).abs_diff(-10.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_numeric() {
|
||||
// Use the std Instant because ours deliberately doesn't implement Sub.
|
||||
let now = std::time::Instant::now();
|
||||
let interval = std::time::Duration::from_millis(25);
|
||||
|
||||
assert_eq!(interval, now.abs_diff(now + interval));
|
||||
assert_eq!(interval, (now + interval).abs_diff(now));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn exponential_moving_average<T: Mul<f64, Output = T> + Add<T, Output = T>>(
|
||||
average: T,
|
||||
alpha: f64,
|
||||
update: T,
|
||||
) -> T {
|
||||
(update * alpha) + (average * (1.0 - alpha))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod exponential_moving_average_tests {
|
||||
use crate::common::exponential_moving_average;
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::float_cmp)]
|
||||
fn interpolation() {
|
||||
assert_eq!(10.0, exponential_moving_average(10.0, 0.0, 20.0));
|
||||
assert_eq!(15.0, exponential_moving_average(10.0, 0.5, 20.0));
|
||||
assert_eq!(20.0, exponential_moving_average(10.0, 1.0, 20.0));
|
||||
}
|
||||
}
|
||||
471
src/common/serialize.rs
Normal file
471
src/common/serialize.rs
Normal file
@ -0,0 +1,471 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
//! Allows the serialization of datastructures to Vec<u8>.
|
||||
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
use crate::common::integers::{U24, U48};
|
||||
|
||||
pub fn write_u8_len_prefixed(val: impl Writer) -> impl Writer {
|
||||
let len_prefix = u8::try_from(val.written_len()).expect("Length exceeds 8 bits");
|
||||
([len_prefix], val)
|
||||
}
|
||||
|
||||
pub fn write_u16_len_prefixed(val: impl Writer) -> impl Writer {
|
||||
let len_prefix = u16::try_from(val.written_len()).expect("Length exceeds 16 bits");
|
||||
(len_prefix, val)
|
||||
}
|
||||
|
||||
pub fn write_u24_len_prefixed(val: impl Writer) -> impl Writer {
|
||||
let len_prefix = U24::try_from(val.written_len() as u32).expect("The length exceeded 24 bits");
|
||||
(len_prefix, val)
|
||||
}
|
||||
|
||||
pub trait Writer {
|
||||
fn written_len(&self) -> usize;
|
||||
fn write(&self, out: &mut dyn Writable);
|
||||
fn to_vec(&self) -> Vec<u8> {
|
||||
let mut vec = Vec::with_capacity(self.written_len());
|
||||
self.write(&mut vec);
|
||||
vec
|
||||
}
|
||||
fn to_sha256(&self) -> Sha256 {
|
||||
let mut digest = Sha256::new();
|
||||
self.write(&mut digest);
|
||||
digest
|
||||
}
|
||||
}
|
||||
|
||||
// Like std::io::Write but can't fail or only do partial writes.
|
||||
pub trait Writable {
|
||||
fn write(&mut self, input: &[u8]);
|
||||
}
|
||||
|
||||
impl Writable for Vec<u8> {
|
||||
fn write(&mut self, input: &[u8]) {
|
||||
self.extend_from_slice(input);
|
||||
}
|
||||
}
|
||||
|
||||
impl Writable for Sha256 {
|
||||
fn write(&mut self, input: &[u8]) {
|
||||
sha2::Digest::update(self, input);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Empty {}
|
||||
|
||||
impl Writer for Empty {
|
||||
fn written_len(&self) -> usize {
|
||||
0
|
||||
}
|
||||
fn write(&self, _out: &mut dyn Writable) {}
|
||||
}
|
||||
|
||||
impl<T: Writer> Writer for Option<T> {
|
||||
fn written_len(&self) -> usize {
|
||||
match self {
|
||||
None => 0,
|
||||
Some(writer) => writer.written_len(),
|
||||
}
|
||||
}
|
||||
fn write(&self, out: &mut dyn Writable) {
|
||||
match self {
|
||||
None => {}
|
||||
Some(writer) => writer.write(out),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We don't impl u8 directly so as to avoid a conflict between [u8] and [T: Writer]
|
||||
impl<const N: usize> Writer for [u8; N] {
|
||||
fn written_len(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
fn write(&self, out: &mut dyn Writable) {
|
||||
out.write(&self[..]);
|
||||
}
|
||||
}
|
||||
|
||||
impl Writer for [u8] {
|
||||
fn written_len(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
fn write(&self, out: &mut dyn Writable) {
|
||||
out.write(self);
|
||||
}
|
||||
}
|
||||
|
||||
impl Writer for Vec<u8> {
|
||||
fn written_len(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
fn write(&self, out: &mut dyn Writable) {
|
||||
out.write(&self[..]);
|
||||
}
|
||||
}
|
||||
|
||||
impl Writer for u16 {
|
||||
fn written_len(&self) -> usize {
|
||||
2
|
||||
}
|
||||
fn write(&self, out: &mut dyn Writable) {
|
||||
self.to_be_bytes().write(out)
|
||||
}
|
||||
}
|
||||
|
||||
impl Writer for U24 {
|
||||
fn written_len(&self) -> usize {
|
||||
3
|
||||
}
|
||||
fn write(&self, out: &mut dyn Writable) {
|
||||
(&(u32::from(*self)).to_be_bytes()[1..4]).write(out);
|
||||
}
|
||||
}
|
||||
|
||||
impl Writer for u32 {
|
||||
fn written_len(&self) -> usize {
|
||||
4
|
||||
}
|
||||
fn write(&self, out: &mut dyn Writable) {
|
||||
self.to_be_bytes().write(out)
|
||||
}
|
||||
}
|
||||
|
||||
impl Writer for U48 {
|
||||
fn written_len(&self) -> usize {
|
||||
6
|
||||
}
|
||||
fn write(&self, out: &mut dyn Writable) {
|
||||
(&u64::from(*self).to_be_bytes()[2..8]).write(out);
|
||||
}
|
||||
}
|
||||
|
||||
impl Writer for Box<dyn Writer> {
|
||||
fn written_len(&self) -> usize {
|
||||
(**self).written_len()
|
||||
}
|
||||
fn write(&self, out: &mut dyn Writable) {
|
||||
(**self).write(out)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_writer_tuple {
|
||||
($($name:ident)+) => (
|
||||
impl<$($name: Writer),+> Writer for ($($name,)+) {
|
||||
#[allow(non_snake_case)]
|
||||
fn written_len(&self) -> usize {
|
||||
let ($(ref $name,)+) = *self;
|
||||
let mut len = 0;
|
||||
$(len += $name.written_len();)+
|
||||
len
|
||||
}
|
||||
#[allow(non_snake_case)]
|
||||
fn write(&self, out: &mut dyn Writable) {
|
||||
let ($(ref $name,)+) = *self;
|
||||
$($name.write(out);)+
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
impl_writer_tuple! { A }
|
||||
impl_writer_tuple! { A B }
|
||||
impl_writer_tuple! { A B C }
|
||||
impl_writer_tuple! { A B C D }
|
||||
impl_writer_tuple! { A B C D E }
|
||||
|
||||
impl<T: Writer, const N: usize> Writer for [T; N] {
|
||||
fn written_len(&self) -> usize {
|
||||
self.iter().map(|writable| writable.written_len()).sum()
|
||||
}
|
||||
fn write(&self, out: &mut dyn Writable) {
|
||||
for writable in self {
|
||||
writable.write(out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Writer> Writer for [T] {
|
||||
fn written_len(&self) -> usize {
|
||||
self.iter().map(|writable| writable.written_len()).sum()
|
||||
}
|
||||
fn write(&self, out: &mut dyn Writable) {
|
||||
for writable in self {
|
||||
writable.write(out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Writer> Writer for Vec<T> {
|
||||
fn written_len(&self) -> usize {
|
||||
self.iter().map(|writable| writable.written_len()).sum()
|
||||
}
|
||||
fn write(&self, out: &mut dyn Writable) {
|
||||
for writable in self {
|
||||
writable.write(out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Necessary for composition with other impls (such as tuples).
|
||||
impl<T: Writer + ?Sized> Writer for &T {
|
||||
fn written_len(&self) -> usize {
|
||||
T::written_len(self)
|
||||
}
|
||||
|
||||
fn write(&self, out: &mut dyn Writable) {
|
||||
T::write(self, out)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn u16() {
|
||||
assert_eq!("0064", hex::encode(100u16.to_vec()));
|
||||
assert_eq!("2778", hex::encode(10104u16.to_vec()));
|
||||
assert_eq!(2, 100u16.written_len());
|
||||
assert_eq!(
|
||||
"b413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2",
|
||||
hex::encode(1u16.to_sha256().finalize())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn u32() {
|
||||
assert_eq!("00000064", hex::encode(100u32.to_vec()));
|
||||
assert_eq!("00002778", hex::encode(10104u32.to_vec()));
|
||||
assert_eq!("7e8a6925", hex::encode(2_123_000_101u32.to_vec()));
|
||||
assert_eq!(4, 100u32.written_len());
|
||||
assert_eq!(
|
||||
"b40711a88c7039756fb8a73827eabe2c0fe5a0346ca7e0a104adc0fc764f528d",
|
||||
hex::encode(1u32.to_sha256().finalize())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn u24() {
|
||||
assert_eq!("000064", hex::encode(U24::from(100u16).to_vec()));
|
||||
assert_eq!("002778", hex::encode(U24::from(10104u16).to_vec()));
|
||||
assert_eq!(
|
||||
"800000",
|
||||
hex::encode(U24::try_from(1u32 << 23).unwrap().to_vec())
|
||||
);
|
||||
assert_eq!(
|
||||
"8a6925",
|
||||
hex::encode(U24::try_from(0x8a6925u32).unwrap().to_vec())
|
||||
);
|
||||
assert_eq!(3, U24::from(100u16).written_len());
|
||||
assert_eq!(
|
||||
"cf7605ed1bc735f6c825554154627467e1cac9df54cee8699218ed434603c568",
|
||||
hex::encode(U24::from(1u16).to_sha256().finalize())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn u48() {
|
||||
assert_eq!("000000000064", hex::encode(U48::from(100u16).to_vec()));
|
||||
assert_eq!("000000002778", hex::encode(U48::from(10104u32).to_vec()));
|
||||
assert_eq!(
|
||||
"800000000000",
|
||||
hex::encode(U48::try_from(1u64 << 47).unwrap().to_vec())
|
||||
);
|
||||
assert_eq!(6, U48::from(100u16).written_len());
|
||||
assert_eq!(
|
||||
"186128bf8a4d60eb4b51102ae2a2cb6a0b80011977582480395a454454bec7e1",
|
||||
hex::encode(U48::from(1u16).to_sha256().finalize())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn single_element_slice_of_writable() {
|
||||
let element = U48::from(100u16);
|
||||
let slice_of_1 = [element];
|
||||
assert_eq!(element.written_len(), slice_of_1.written_len());
|
||||
assert_eq!(element.to_vec(), slice_of_1.to_vec());
|
||||
assert_eq!(
|
||||
element.to_sha256().finalize(),
|
||||
slice_of_1.to_sha256().finalize()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn double_element_slice_of_writable() {
|
||||
let element_1 = U48::from(100u16);
|
||||
let element_2 = U48::from(203u16);
|
||||
let slice_of_2 = [element_1, element_2];
|
||||
assert_eq!(
|
||||
element_1.written_len() + element_2.written_len(),
|
||||
slice_of_2.written_len()
|
||||
);
|
||||
let mut vec = element_1.to_vec();
|
||||
vec.append(&mut element_2.to_vec());
|
||||
assert_eq!(vec, slice_of_2.to_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn single_element_slice_of_u8() {
|
||||
let element = 100u8;
|
||||
let slice_of_1 = [element];
|
||||
assert_eq!(1, slice_of_1.written_len());
|
||||
assert_eq!(vec![element], slice_of_1.to_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn double_element_slice_of_u8() {
|
||||
let element_1 = 100u8;
|
||||
let element_2 = 234u8;
|
||||
let slice_of_2 = [element_1, element_2];
|
||||
assert_eq!(2, slice_of_2.written_len());
|
||||
assert_eq!(vec![element_1, element_2], slice_of_2.to_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_of_u8() {
|
||||
let vec = vec![1u8, 2u8, 255u8];
|
||||
assert_eq!("0102ff", hex::encode(vec.to_vec()));
|
||||
assert_eq!(
|
||||
"0526d0e18ea19dfaad9d79166bec1e18d6221ef6b1830385fe9bf67022ed5f96",
|
||||
hex::encode(vec.to_sha256().finalize())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tuple2() {
|
||||
let tuple = (100u16, 2_123_000_101u32);
|
||||
assert_eq!("00647e8a6925", hex::encode(tuple.to_vec()));
|
||||
assert_eq!(
|
||||
"dffc18faa457d5aa0a27c5bc8cd065d837cf997bb37940abcf5cef505b31b725",
|
||||
hex::encode(tuple.to_sha256().finalize())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tuple3() {
|
||||
let tuple = ([255u8], 100u16, 2_123_000_101u32);
|
||||
assert_eq!("ff00647e8a6925", hex::encode(tuple.to_vec()));
|
||||
assert_eq!(
|
||||
"52ab2cba6473730d0e6a0f7feba988e59b9cc83ca04e0343cd34e0f27a924ee0",
|
||||
hex::encode(tuple.to_sha256().finalize())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tuple4() {
|
||||
let tuple = ([255u8], 100u16, [127u8], 2_123_000_101u32);
|
||||
assert_eq!("ff00647f7e8a6925", hex::encode(tuple.to_vec()));
|
||||
assert_eq!(
|
||||
"8ec96a4f46c07b2d92a9009278fa6675ff89dec47985dcb7e76acc30b621a685",
|
||||
hex::encode(tuple.to_sha256().finalize())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tuple5() {
|
||||
let tuple = (
|
||||
[127u8],
|
||||
65535u16,
|
||||
[1u8],
|
||||
1_000_000_000u32,
|
||||
U24::try_from(1u32 << 23).unwrap(),
|
||||
);
|
||||
assert_eq!("7fffff013b9aca00800000", hex::encode(tuple.to_vec()));
|
||||
assert_eq!(
|
||||
"521e73b5adde4db9a2ddf9c8cc263a327dfecf9e54f78114deb286ed65574a21",
|
||||
hex::encode(tuple.to_sha256().finalize())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn static_vec() {
|
||||
let vec = vec![1u32, 1 << 31];
|
||||
assert_eq!("0000000180000000", hex::encode(vec.to_vec()));
|
||||
assert_eq!(
|
||||
"3c258dec7ff9182db1c9ceac940453011fcc3ce440309a310f2a2c8475509c8a",
|
||||
hex::encode(vec.to_sha256().finalize())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tuple_and_vec_u8() {
|
||||
let vec1 = vec![1u8, 1 << 7];
|
||||
let tuple = (1u16, vec1);
|
||||
assert_eq!("00010180", hex::encode(tuple.to_vec()));
|
||||
assert_eq!(
|
||||
"ba57af15c8d49bc2d673bdbcc15b9761dddf25386be00890bb7cf56cc02b0dba",
|
||||
hex::encode(tuple.to_sha256().finalize())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dynamic_vec() {
|
||||
let mut vec: Vec<&dyn Writer> = vec![&[127u8]];
|
||||
vec.push(&65535u16);
|
||||
vec.push(&[1u8]);
|
||||
vec.push(&1_000_000_000u32);
|
||||
let u24 = U24::try_from(1u32 << 23).unwrap();
|
||||
vec.push(&u24);
|
||||
assert_eq!("7fffff013b9aca00800000", hex::encode(vec.to_vec()));
|
||||
assert_eq!(
|
||||
"521e73b5adde4db9a2ddf9c8cc263a327dfecf9e54f78114deb286ed65574a21",
|
||||
hex::encode(vec.to_sha256().finalize())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefix_with_length_as_u8() {
|
||||
assert_eq!(
|
||||
"0600647e8a6925",
|
||||
hex::encode(write_u8_len_prefixed((100u16, 2_123_000_101u32)).to_vec())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefix_with_length_as_u8_max() {
|
||||
let data: &[u8] = &[1u8; u8::MAX as usize];
|
||||
assert!(hex::encode(write_u8_len_prefixed(data).to_vec()).starts_with("ff0101"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Length exceeds 8 bits")]
|
||||
fn prefix_with_length_as_u8_max_plus_1_panics() {
|
||||
let data: &[u8] = &[1u8; u8::MAX as usize + 1];
|
||||
write_u8_len_prefixed(data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefix_with_length_as_u16() {
|
||||
assert_eq!(
|
||||
"000600647e8a6925",
|
||||
hex::encode(write_u16_len_prefixed((100u16, 2_123_000_101u32)).to_vec())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefix_with_length_as_u16_max() {
|
||||
let data: &[u8] = &[1u8; u16::MAX as usize];
|
||||
assert!(hex::encode(write_u16_len_prefixed(data).to_vec()).starts_with("ffff0101"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Length exceeds 16 bits")]
|
||||
fn prefix_with_length_as_u16_max_plus_1_panics() {
|
||||
let data: &[u8] = &[1u8; u16::MAX as usize + 1];
|
||||
write_u16_len_prefixed(data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefix_with_length_as_u24() {
|
||||
assert_eq!(
|
||||
"00000600647e8a6925",
|
||||
hex::encode(write_u24_len_prefixed((100u16, 2_123_000_101u32)).to_vec())
|
||||
);
|
||||
}
|
||||
}
|
||||
320
src/common/time.rs
Normal file
320
src/common/time.rs
Normal file
@ -0,0 +1,320 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
iter::Sum,
|
||||
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign},
|
||||
};
|
||||
|
||||
/// A wrapper around [`std::time::Instant`] that does not expose panicking `duration_since` operations.
|
||||
///
|
||||
/// Instead of subtraction, use `checked_duration_since` or `saturating_duration_since`.
|
||||
///
|
||||
/// Note that addition and subtraction of durations that would result in distant-future or
|
||||
/// distant-past Instants may still panic. Only operations that might result in a negative
|
||||
/// duration have been forbidden.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct Instant(std::time::Instant);
|
||||
|
||||
impl Instant {
|
||||
pub fn checked_duration_since(&self, earlier: Instant) -> Option<Duration> {
|
||||
self.0.checked_duration_since(earlier.0).map(Duration)
|
||||
}
|
||||
|
||||
pub fn saturating_duration_since(&self, earlier: Instant) -> Duration {
|
||||
Duration(self.0.saturating_duration_since(earlier.0))
|
||||
}
|
||||
|
||||
pub fn now() -> Instant {
|
||||
Instant(std::time::Instant::now())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::time::Instant> for Instant {
|
||||
fn from(instant: std::time::Instant) -> Self {
|
||||
Self(instant)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Instant> for std::time::Instant {
|
||||
fn from(instant: Instant) -> Self {
|
||||
instant.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for Instant {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<Duration> for Instant {
|
||||
type Output = Instant;
|
||||
|
||||
fn add(self, rhs: Duration) -> Self::Output {
|
||||
Self(self.0 + rhs.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssign<Duration> for Instant {
|
||||
fn add_assign(&mut self, rhs: Duration) {
|
||||
self.0 += rhs.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<Duration> for Instant {
|
||||
type Output = Instant;
|
||||
|
||||
fn sub(self, rhs: Duration) -> Self::Output {
|
||||
Self(self.0 - rhs.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAssign<Duration> for Instant {
|
||||
fn sub_assign(&mut self, rhs: Duration) {
|
||||
self.0 -= rhs.0
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper around [`std::time::Duration`] that does not expose panicking difference operations.
|
||||
///
|
||||
/// Instead of subtraction, use `checked_sub` or `saturating_sub`.
|
||||
///
|
||||
/// Note that addition or multiplication of durations that would result in an overly large duration
|
||||
/// may still panic. Only operations that could result in a negative duration have been forbidden.
|
||||
///
|
||||
/// Only methods of `std::time::Duration` that are used in the project are exposed here.
|
||||
/// If you need another method of `std::time::Duration`,
|
||||
/// add a wrapper here rather than converting to the underlying value.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
|
||||
pub struct Duration(std::time::Duration);
|
||||
|
||||
impl Duration {
|
||||
pub const ZERO: Duration = Duration::from_secs(0);
|
||||
pub const MILLISECOND: Duration = Duration::from_millis(1);
|
||||
pub const SECOND: Duration = Duration::from_secs(1);
|
||||
|
||||
pub const fn from_secs(secs: u64) -> Duration {
|
||||
Duration(std::time::Duration::from_secs(secs))
|
||||
}
|
||||
|
||||
pub fn as_secs(&self) -> u64 {
|
||||
self.0.as_secs()
|
||||
}
|
||||
|
||||
pub fn from_secs_f64(secs: f64) -> Duration {
|
||||
Duration(std::time::Duration::from_secs_f64(secs))
|
||||
}
|
||||
|
||||
pub fn as_secs_f64(&self) -> f64 {
|
||||
self.0.as_secs_f64()
|
||||
}
|
||||
|
||||
pub const fn from_millis(millis: u64) -> Duration {
|
||||
Duration(std::time::Duration::from_millis(millis))
|
||||
}
|
||||
|
||||
pub const fn as_millis(&self) -> u128 {
|
||||
self.0.as_millis()
|
||||
}
|
||||
|
||||
pub const fn from_micros(micros: u64) -> Duration {
|
||||
Duration(std::time::Duration::from_micros(micros))
|
||||
}
|
||||
|
||||
pub const fn as_micros(&self) -> u128 {
|
||||
self.0.as_micros()
|
||||
}
|
||||
|
||||
pub const fn from_nanos(nanos: u64) -> Duration {
|
||||
Duration(std::time::Duration::from_nanos(nanos))
|
||||
}
|
||||
|
||||
pub const fn as_nanos(&self) -> u128 {
|
||||
self.0.as_nanos()
|
||||
}
|
||||
|
||||
pub fn checked_sub(&self, rhs: Duration) -> Option<Duration> {
|
||||
self.0.checked_sub(rhs.0).map(Duration)
|
||||
}
|
||||
|
||||
pub fn saturating_sub(&self, rhs: Duration) -> Duration {
|
||||
// TODO: use std::time::Duration::saturating_sub when stabilized
|
||||
self.checked_sub(rhs).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::time::Duration> for Duration {
|
||||
fn from(duration: std::time::Duration) -> Self {
|
||||
Self(duration)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Duration> for std::time::Duration {
|
||||
fn from(duration: Duration) -> Self {
|
||||
duration.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for Duration {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<Duration> for Duration {
|
||||
type Output = Duration;
|
||||
|
||||
fn add(self, rhs: Duration) -> Self::Output {
|
||||
Duration(self.0 + rhs.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssign<Duration> for Duration {
|
||||
fn add_assign(&mut self, rhs: Duration) {
|
||||
self.0 += rhs.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<u32> for Duration {
|
||||
type Output = Duration;
|
||||
|
||||
fn mul(self, rhs: u32) -> Self::Output {
|
||||
Duration(self.0 * rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<Duration> for u32 {
|
||||
type Output = Duration;
|
||||
|
||||
fn mul(self, rhs: Duration) -> Self::Output {
|
||||
rhs * self
|
||||
}
|
||||
}
|
||||
|
||||
impl MulAssign<u32> for Duration {
|
||||
fn mul_assign(&mut self, rhs: u32) {
|
||||
self.0 *= rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl Div<u32> for Duration {
|
||||
type Output = Duration;
|
||||
|
||||
fn div(self, rhs: u32) -> Self::Output {
|
||||
Duration(self.0 / rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl DivAssign<u32> for Duration {
|
||||
fn div_assign(&mut self, rhs: u32) {
|
||||
self.0 /= rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl Sum for Duration {
|
||||
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
Duration(iter.map(|x| x.0).sum())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Sum<&'a Duration> for Duration {
|
||||
fn sum<I: Iterator<Item = &'a Duration>>(iter: I) -> Self {
|
||||
Duration(iter.map(|x| x.0).sum())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn wrap_unwrap() {
|
||||
let now = std::time::Instant::now();
|
||||
assert_eq!(now, Instant::from(now).into());
|
||||
|
||||
let duration = std::time::Duration::new(5, 10);
|
||||
assert_eq!(duration, Duration::from(duration).into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transparent_debug() {
|
||||
let now = std::time::Instant::now();
|
||||
assert_eq!(format!("{:?}", now), format!("{:?}", Instant::from(now)));
|
||||
|
||||
let duration = std::time::Duration::new(5, 10);
|
||||
assert_eq!(
|
||||
format!("{:?}", duration),
|
||||
format!("{:?}", Duration::from(duration))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::float_cmp)]
|
||||
fn duration_from_as() {
|
||||
assert_eq!(2.5, Duration::from_secs_f64(2.5).as_secs_f64());
|
||||
assert_eq!(2, Duration::from_millis(2).as_millis());
|
||||
assert_eq!(2, Duration::from_micros(2).as_micros());
|
||||
assert_eq!(2, Duration::from_nanos(2).as_nanos());
|
||||
|
||||
assert_eq!(2.0, Duration::from_secs(2).as_secs_f64());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn duration_default() {
|
||||
assert_eq!(Duration::from_secs(0), Duration::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn duration_arithmetic() {
|
||||
let short = Duration::from_millis(2);
|
||||
let long = Duration::from_secs(5);
|
||||
let sum = Duration::from_millis(5002);
|
||||
|
||||
assert_eq!(sum, short + long);
|
||||
assert_eq!(Some(long), sum.checked_sub(short));
|
||||
assert_eq!(None, short.checked_sub(sum));
|
||||
assert_eq!(long, sum.saturating_sub(short));
|
||||
assert_eq!(Duration::default(), short.saturating_sub(sum));
|
||||
assert_eq!(sum, [short, long].iter().sum());
|
||||
assert_eq!(sum, vec![short, long].into_iter().sum());
|
||||
|
||||
let mut manual_sum = short;
|
||||
manual_sum += long;
|
||||
assert_eq!(sum, manual_sum);
|
||||
|
||||
assert_eq!(long, short * 2500);
|
||||
let mut manual_product = short;
|
||||
manual_product *= 2500;
|
||||
assert_eq!(long, manual_product);
|
||||
|
||||
assert_eq!(short, long / 2500);
|
||||
let mut manual_quotient = long;
|
||||
manual_quotient /= 2500;
|
||||
assert_eq!(short, manual_quotient);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn instant_arithmetic() {
|
||||
let now = Instant::now();
|
||||
let duration = Duration::from_millis(2);
|
||||
let soon = now + duration;
|
||||
|
||||
assert_eq!(now, soon - duration);
|
||||
assert_eq!(Some(duration), soon.checked_duration_since(now));
|
||||
assert_eq!(None, now.checked_duration_since(soon));
|
||||
assert_eq!(duration, soon.saturating_duration_since(now));
|
||||
assert_eq!(Duration::default(), now.saturating_duration_since(soon));
|
||||
|
||||
let mut manual_sum = now;
|
||||
manual_sum += duration;
|
||||
assert_eq!(soon, manual_sum);
|
||||
|
||||
let mut manual_difference = soon;
|
||||
manual_difference -= duration;
|
||||
assert_eq!(now, manual_difference);
|
||||
}
|
||||
}
|
||||
147
src/config.rs
Normal file
147
src/config.rs
Normal file
@ -0,0 +1,147 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
//! Configuration options for the calling server.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use serde::Deserialize;
|
||||
use structopt::StructOpt;
|
||||
|
||||
/// General configuration options, set by command line arguments or
|
||||
/// falls back to default or environment variables (in some cases).
|
||||
#[derive(Default, StructOpt, Debug, Clone)]
|
||||
#[structopt(name = "calling_server")]
|
||||
pub struct Config {
|
||||
/// The IP address to bind to for all servers.
|
||||
#[structopt(long, default_value = "0.0.0.0")]
|
||||
pub binding_ip: String,
|
||||
|
||||
/// The IP address to share for for ICE candidates. Clients will connect
|
||||
/// to the calling_server using this IP.
|
||||
#[structopt(long, env = "ICE_CANDIDATE_IP")]
|
||||
pub ice_candidate_ip: Option<String>,
|
||||
|
||||
/// The port to use for ICE candidates. Clients will connect to the
|
||||
/// calling_server using this port.
|
||||
#[structopt(long, default_value = "10000")]
|
||||
pub ice_candidate_port: u16,
|
||||
|
||||
/// The IP address to share for direct access to the signaling_server. If
|
||||
/// defined, then the signaling_server will be used, otherwise the
|
||||
/// http_server will be used for testing.
|
||||
#[structopt(long, env = "SIGNALING_IP")]
|
||||
pub signaling_ip: Option<String>,
|
||||
|
||||
/// The port to use for the signaling interface.
|
||||
#[structopt(long, default_value = "8080")]
|
||||
pub signaling_port: u16,
|
||||
|
||||
/// Maximum clients per call, if using the http_server for testing.
|
||||
#[structopt(long, default_value = "8")]
|
||||
pub max_clients_per_call: u32,
|
||||
|
||||
#[structopt(long)]
|
||||
pub udp_threads: Option<usize>,
|
||||
|
||||
/// The initial bitrate target for sending. In a 16-person call with
|
||||
/// each base layer at 200kbps you'd need 3.2mbps to send them all.
|
||||
/// With an increase of 8% per second it will take 10 seconds to
|
||||
/// increase from 1.5mbps to 3.2.
|
||||
#[structopt(long, default_value = "1500")]
|
||||
pub initial_target_send_rate_kbps: u64,
|
||||
|
||||
/// Timer tick period for operating on the Sfu state (ms).
|
||||
#[structopt(long, default_value = "100")]
|
||||
pub tick_interval_ms: u64,
|
||||
|
||||
/// Optional interval used to post diagnostics to the log. If not defined
|
||||
/// then no periodic information about calls will be posted to the log.
|
||||
#[structopt(long, env = "DIAGNOSTICS_INTERVAL_SECS")]
|
||||
pub diagnostics_interval_secs: Option<u64>,
|
||||
|
||||
/// Interval for sending active speaker messages (ms). The amount of time
|
||||
/// to wait between sending messages to the clients to remind them of the
|
||||
/// current active speaker for the call. Using milliseconds in case sub-
|
||||
/// second resolution is needed.
|
||||
#[structopt(long, default_value = "1000")]
|
||||
pub active_speaker_message_interval_ms: u64,
|
||||
|
||||
/// Inactivity check interval (seconds). The amount of time to wait between
|
||||
/// iterating structures for inactive calls and clients.
|
||||
#[structopt(long, default_value = "5")]
|
||||
pub inactivity_check_interval_secs: u64,
|
||||
|
||||
/// Amount of time to wait before dropping a call or client due to inactivity (seconds).
|
||||
#[structopt(long, default_value = "30")]
|
||||
pub inactivity_timeout_secs: u64,
|
||||
|
||||
#[structopt(flatten)]
|
||||
pub metrics: MetricsOptions,
|
||||
|
||||
// For DTLS, not a command line argument, storage for the der private key.
|
||||
#[structopt(skip)]
|
||||
pub server_private_key_der: Vec<u8>,
|
||||
|
||||
// For DTLS, not a command line argument, storage for the der certificate.
|
||||
#[structopt(skip)]
|
||||
pub server_certificate_der: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(StructOpt, Clone, Debug, Default)]
|
||||
pub struct MetricsOptions {
|
||||
/// Host and port of Datadog StatsD agent. Typically 127.0.0.1:8125.
|
||||
#[structopt(long)]
|
||||
pub datadog: Option<String>,
|
||||
|
||||
/// Region appears as a tag in metrics and logging.
|
||||
#[structopt(long = "metrics-region", default_value = "unspecified")]
|
||||
pub region: String,
|
||||
|
||||
/// Deployment version appears as a tag in metrics and in logging if specified.
|
||||
#[structopt(long = "metrics-version")]
|
||||
pub version: Option<String>,
|
||||
}
|
||||
|
||||
/// Deployment configuration options, used to set sensitive information
|
||||
/// at runtime from a configuration file.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DeploymentConfig {
|
||||
#[serde(rename = "authenticationKey")]
|
||||
pub authentication_key: String,
|
||||
}
|
||||
|
||||
/// Returns the public address of the server for media/UDP as per configuration.
|
||||
pub fn get_server_media_address(config: &'static Config) -> SocketAddr {
|
||||
let ip = config
|
||||
.ice_candidate_ip
|
||||
.as_ref()
|
||||
.unwrap_or(&config.binding_ip)
|
||||
.parse()
|
||||
.expect("ice_candidate_ip should parse");
|
||||
SocketAddr::new(ip, config.ice_candidate_port)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn default_test_config() -> Config {
|
||||
Config {
|
||||
binding_ip: "127.0.0.1".to_string(),
|
||||
ice_candidate_ip: Some("127.0.0.1".to_string()),
|
||||
signaling_ip: Some("127.0.0.1".to_string()),
|
||||
signaling_port: 8080,
|
||||
ice_candidate_port: 10000,
|
||||
max_clients_per_call: 8,
|
||||
udp_threads: Some(1),
|
||||
initial_target_send_rate_kbps: 1500,
|
||||
tick_interval_ms: 100,
|
||||
diagnostics_interval_secs: None,
|
||||
active_speaker_message_interval_ms: 1000,
|
||||
inactivity_check_interval_secs: 5,
|
||||
inactivity_timeout_secs: 30,
|
||||
metrics: Default::default(),
|
||||
server_private_key_der: vec![],
|
||||
server_certificate_der: vec![],
|
||||
}
|
||||
}
|
||||
1240
src/connection.rs
Normal file
1240
src/connection.rs
Normal file
File diff suppressed because it is too large
Load Diff
1172
src/dtls.rs
Normal file
1172
src/dtls.rs
Normal file
File diff suppressed because it is too large
Load Diff
1010
src/googcc.rs
Normal file
1010
src/googcc.rs
Normal file
File diff suppressed because it is too large
Load Diff
332
src/googcc/ack_rates.rs
Normal file
332
src/googcc/ack_rates.rs
Normal file
@ -0,0 +1,332 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use async_stream::stream;
|
||||
use futures::{pin_mut, Stream, StreamExt};
|
||||
|
||||
use crate::{
|
||||
common::{AbsDiff, DataRate, DataSize, Duration, Square},
|
||||
transportcc::Ack,
|
||||
};
|
||||
|
||||
// Break up the series of acks into groups of accumulated (size, duration).
|
||||
// To be passed into estimate_acked_rates for estimateing the rate over time.
|
||||
fn accumulate_acked_sizes(
|
||||
acks: impl Stream<Item = Ack>,
|
||||
) -> impl Stream<Item = (DataSize, Duration)> {
|
||||
// TODO: Maybe make some of these configurable
|
||||
let initial_ack_group_duration = Duration::from_millis(500);
|
||||
let subsequent_ack_group_duration = Duration::from_millis(150);
|
||||
|
||||
stream! {
|
||||
pin_mut!(acks);
|
||||
if let Some(mut ack1) = acks.next().await {
|
||||
let mut accumulated_size = ack1.size;
|
||||
let mut accumulated_duration = Duration::default();
|
||||
let mut target_ack_group_duration = initial_ack_group_duration;
|
||||
while let Some(ack2) = acks.next().await {
|
||||
if ack2.arrival < ack1.arrival {
|
||||
// Reset when we hit out-of-order packets
|
||||
accumulated_size = DataSize::default();
|
||||
accumulated_duration = Duration::default();
|
||||
} else {
|
||||
let arrival_delta = ack2.arrival.saturating_duration_since(ack1.arrival);
|
||||
accumulated_duration += arrival_delta;
|
||||
if arrival_delta > target_ack_group_duration {
|
||||
// Reset if it's been too long since we've received an ACK
|
||||
accumulated_size = DataSize::default();
|
||||
accumulated_duration = Duration::from_micros(
|
||||
accumulated_duration.as_micros() as u64
|
||||
% target_ack_group_duration.as_micros() as u64,
|
||||
);
|
||||
} else if accumulated_duration >= target_ack_group_duration {
|
||||
yield (accumulated_size, target_ack_group_duration);
|
||||
|
||||
// Use what's "left over" for the next group.
|
||||
accumulated_size = Default::default();
|
||||
accumulated_duration =
|
||||
accumulated_duration.saturating_sub(target_ack_group_duration);
|
||||
|
||||
// Now that we have a group, we can use a smaller window.
|
||||
target_ack_group_duration = subsequent_ack_group_duration;
|
||||
}
|
||||
}
|
||||
accumulated_size += ack2.size;
|
||||
ack1 = ack2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod accumulate_acked_sizes_tests {
|
||||
use futures::FutureExt;
|
||||
|
||||
use super::*;
|
||||
use crate::{common::Instant, transportcc::RemoteInstant};
|
||||
|
||||
/// Creates an `Ack` for each duration with a size of 10.
|
||||
///
|
||||
/// The departure and feedback-arrival times should be ignored.
|
||||
fn acks_from_arrival_durations(
|
||||
durations: impl IntoIterator<Item = u64>,
|
||||
) -> impl Stream<Item = Ack> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
futures::stream::iter(durations.into_iter().map(move |duration| Ack {
|
||||
size: DataSize::from_bytes(10),
|
||||
departure: start_time,
|
||||
arrival: RemoteInstant::from_millis(duration),
|
||||
feedback_arrival: start_time,
|
||||
}))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn every_millisecond() {
|
||||
let acks = acks_from_arrival_durations(0..1000);
|
||||
let stream = accumulate_acked_sizes(acks);
|
||||
pin_mut!(stream);
|
||||
assert_eq!(
|
||||
&[(5000, 500), (1500, 150), (1500, 150), (1500, 150)],
|
||||
&stream
|
||||
.map(|(size, duration)| (size.as_bytes(), duration.as_millis()))
|
||||
.collect::<Vec<_>>()
|
||||
.now_or_never()
|
||||
.unwrap()[..]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn every_hundred_ms() {
|
||||
let acks = acks_from_arrival_durations((0..20).map(|x| x * 100));
|
||||
let stream = accumulate_acked_sizes(acks);
|
||||
pin_mut!(stream);
|
||||
assert_eq!(
|
||||
&[
|
||||
(50, 500),
|
||||
(20, 150),
|
||||
(10, 150),
|
||||
(20, 150),
|
||||
(10, 150),
|
||||
(20, 150),
|
||||
(10, 150),
|
||||
(20, 150),
|
||||
(10, 150),
|
||||
(20, 150)
|
||||
],
|
||||
&stream
|
||||
.map(|(size, duration)| (size.as_bytes(), duration.as_millis()))
|
||||
.collect::<Vec<_>>()
|
||||
.now_or_never()
|
||||
.unwrap()[..]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn start_time_does_not_matter() {
|
||||
let acks = acks_from_arrival_durations(1000..2000);
|
||||
let stream = accumulate_acked_sizes(acks);
|
||||
pin_mut!(stream);
|
||||
assert_eq!(
|
||||
&[(5000, 500), (1500, 150), (1500, 150), (1500, 150)],
|
||||
&stream
|
||||
.map(|(size, duration)| (size.as_bytes(), duration.as_millis()))
|
||||
.collect::<Vec<_>>()
|
||||
.now_or_never()
|
||||
.unwrap()[..]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_on_out_of_order() {
|
||||
let acks = acks_from_arrival_durations(vec![
|
||||
0, 1, // reset!
|
||||
0, // first group
|
||||
500, 600, // reset!
|
||||
550, 600, 650, // second group
|
||||
700, // force the second group to be emitted
|
||||
]);
|
||||
let stream = accumulate_acked_sizes(acks);
|
||||
pin_mut!(stream);
|
||||
assert_eq!(
|
||||
&[(10, 500), (30, 150)],
|
||||
&stream
|
||||
.map(|(size, duration)| (size.as_bytes(), duration.as_millis()))
|
||||
.collect::<Vec<_>>()
|
||||
.now_or_never()
|
||||
.unwrap()[..]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reset_on_large_gap() {
|
||||
let acks = acks_from_arrival_durations(vec![
|
||||
0, // reset!
|
||||
1001, 1500, // first group
|
||||
// reset!
|
||||
1651, 1700, 1750, // second group
|
||||
1800, // force the second group to be emitted
|
||||
]);
|
||||
let stream = accumulate_acked_sizes(acks);
|
||||
pin_mut!(stream);
|
||||
assert_eq!(
|
||||
&[(10, 500), (30, 150)],
|
||||
&stream
|
||||
.map(|(size, duration)| (size.as_bytes(), duration.as_millis()))
|
||||
.collect::<Vec<_>>()
|
||||
.now_or_never()
|
||||
.unwrap()[..]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Make initial variance and other variance numbers (10.0 and 5.0) below configurable
|
||||
fn estimate_acked_rates_from_groups(
|
||||
ack_groups: impl Stream<Item = (DataSize, Duration)>,
|
||||
) -> impl Stream<Item = DataRate> {
|
||||
stream! {
|
||||
pin_mut!(ack_groups);
|
||||
if let Some((size, duration)) = ack_groups.next().await {
|
||||
let mut estimate: DataRate = size / duration;
|
||||
let mut variance: f64 = 50.0;
|
||||
|
||||
yield estimate;
|
||||
|
||||
while let Some((size, duration)) = ack_groups.next().await {
|
||||
let sample: DataRate = size / duration;
|
||||
let sample_variance = ((sample.abs_diff(estimate) / estimate) * 10.0).square();
|
||||
let pred_variance = variance + 5.0;
|
||||
estimate = ((estimate * sample_variance) + (sample * pred_variance))
|
||||
/ (sample_variance + pred_variance);
|
||||
variance = (sample_variance * pred_variance) / (sample_variance + pred_variance);
|
||||
|
||||
yield estimate;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod estimate_acked_rates_from_groups_tests {
|
||||
use std::{cmp::Ordering, future::ready};
|
||||
|
||||
use futures::FutureExt;
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
use super::*;
|
||||
use crate::common::RANDOM_SEED_FOR_TESTS;
|
||||
|
||||
/// Creates a stream of size groups with the given bits-per-second ratio.
|
||||
fn size_groups_from_bps(
|
||||
ratios: impl IntoIterator<Item = u64>,
|
||||
) -> impl Stream<Item = (DataSize, Duration)> {
|
||||
futures::stream::iter(ratios).map(|bps| (DataSize::from_bits(bps), Duration::from_secs(1)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_result_is_simple_division() {
|
||||
let stream = estimate_acked_rates_from_groups(stream! {
|
||||
yield (DataSize::from_bits(100), Duration::from_secs(2))
|
||||
});
|
||||
pin_mut!(stream);
|
||||
assert_eq!(
|
||||
&[50],
|
||||
&stream
|
||||
.map(|rate| rate.as_bps())
|
||||
.collect::<Vec<_>>()
|
||||
.now_or_never()
|
||||
.unwrap()[..]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reference_rates() {
|
||||
let stream = estimate_acked_rates_from_groups(size_groups_from_bps(vec![
|
||||
500, 1000, 1000, 500, 2000, 2000, 2000, 2000, 2000,
|
||||
]));
|
||||
pin_mut!(stream);
|
||||
// These values came from running the test and seeing the output.
|
||||
assert_eq!(
|
||||
&[500, 677, 883, 687, 737, 813, 928, 1100, 1355],
|
||||
&stream
|
||||
.map(|rate| rate.as_bps())
|
||||
.collect::<Vec<_>>()
|
||||
.now_or_never()
|
||||
.unwrap()[..]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eventually_converges_upward() {
|
||||
let stream = estimate_acked_rates_from_groups(size_groups_from_bps(
|
||||
std::iter::once(0).chain(std::iter::repeat(2000)),
|
||||
));
|
||||
pin_mut!(stream);
|
||||
assert!(stream
|
||||
.take(20_000)
|
||||
.take_while(|rate| ready(rate < &DataRate::from_bps(1990)))
|
||||
.next()
|
||||
.now_or_never()
|
||||
.unwrap()
|
||||
.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eventually_converges_downward() {
|
||||
let stream = estimate_acked_rates_from_groups(size_groups_from_bps(
|
||||
std::iter::once(10_000).chain(std::iter::repeat(2000)),
|
||||
));
|
||||
pin_mut!(stream);
|
||||
assert!(stream
|
||||
.take(20_000)
|
||||
.take_while(|rate| ready(rate > &DataRate::from_bps(2010)))
|
||||
.next()
|
||||
.now_or_never()
|
||||
.unwrap()
|
||||
.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn direction_follows_samples() {
|
||||
let mut rng = StdRng::seed_from_u64(*RANDOM_SEED_FOR_TESTS);
|
||||
let rates = std::iter::from_fn(move || Some(rng.gen_range(0..100_000)));
|
||||
let stream = estimate_acked_rates_from_groups(size_groups_from_bps(rates.clone()));
|
||||
pin_mut!(stream);
|
||||
stream
|
||||
.zip(futures::stream::iter(rates))
|
||||
.take(10_000)
|
||||
.fold(
|
||||
DataRate::default(),
|
||||
|previous_estimate, (current_estimate, previous_sample)| {
|
||||
// If the previous sample went up, the estimate goes up.
|
||||
// If it went down, the estimate goes down.
|
||||
// Except...we're doing this in floating-point math,
|
||||
// so we could have rounding errors when we go back to integers.
|
||||
if previous_estimate.as_bps().abs_diff(previous_sample) > 1 {
|
||||
let change = previous_estimate.cmp(¤t_estimate);
|
||||
// And estimate_acked_rates weights by variance to avoid outliers,
|
||||
// so a sample can end up not making a change.
|
||||
if change != Ordering::Equal {
|
||||
assert_eq!(
|
||||
previous_estimate.cmp(&DataRate::from_bps(previous_sample)),
|
||||
change,
|
||||
"pe: {:?}, ce: {:?}, ps: {:?}",
|
||||
previous_estimate,
|
||||
current_estimate,
|
||||
previous_sample
|
||||
);
|
||||
}
|
||||
}
|
||||
ready(current_estimate)
|
||||
},
|
||||
)
|
||||
.now_or_never()
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn estimate_acked_rates(acks: impl Stream<Item = Ack>) -> impl Stream<Item = DataRate> {
|
||||
estimate_acked_rates_from_groups(accumulate_acked_sizes(acks))
|
||||
}
|
||||
1810
src/googcc/delay_directions.rs
Normal file
1810
src/googcc/delay_directions.rs
Normal file
File diff suppressed because it is too large
Load Diff
153
src/googcc/delay_directions/time_delta.rs
Normal file
153
src/googcc/delay_directions/time_delta.rs
Normal file
@ -0,0 +1,153 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::ops::{Add, Mul, Sub};
|
||||
|
||||
use crate::common::{Duration, Instant};
|
||||
|
||||
/// Like [Duration], but can be negative, and may not have the same precision.
|
||||
#[derive(Debug, Copy, Clone, PartialOrd, PartialEq)]
|
||||
pub struct TimeDelta {
|
||||
secs: f64,
|
||||
}
|
||||
|
||||
impl TimeDelta {
|
||||
pub fn from_secs(secs: f64) -> Self {
|
||||
Self { secs }
|
||||
}
|
||||
|
||||
pub fn from_millis(millis: f64) -> Self {
|
||||
Self::from_secs(millis / 1000.0)
|
||||
}
|
||||
|
||||
pub fn as_secs(self) -> f64 {
|
||||
self.secs
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<TimeDelta> for TimeDelta {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, other: Self) -> Self {
|
||||
Self::from_secs(self.as_secs() + other.as_secs())
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<Duration> for TimeDelta {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, other: Duration) -> Self {
|
||||
Self::from_secs(self.as_secs() + other.as_secs_f64())
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<TimeDelta> for TimeDelta {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, other: TimeDelta) -> Self {
|
||||
Self::from_secs(self.as_secs() - other.as_secs())
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<Duration> for TimeDelta {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, other: Duration) -> Self {
|
||||
Self::from_secs(self.as_secs() - other.as_secs_f64())
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<f64> for TimeDelta {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, other: f64) -> Self {
|
||||
Self::from_secs(self.as_secs() * other)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait TimeDeltaSince {
|
||||
fn time_delta_since(self, other: Self) -> TimeDelta;
|
||||
}
|
||||
|
||||
impl TimeDeltaSince for Instant {
|
||||
fn time_delta_since(self, other: Self) -> TimeDelta {
|
||||
if self > other {
|
||||
TimeDelta::from_secs(self.checked_duration_since(other).unwrap().as_secs_f64())
|
||||
} else {
|
||||
TimeDelta::from_secs(-other.checked_duration_since(self).unwrap().as_secs_f64())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TimeDeltaSince for crate::transportcc::RemoteInstant {
|
||||
fn time_delta_since(self, other: Self) -> TimeDelta {
|
||||
if self > other {
|
||||
TimeDelta::from_secs(self.saturating_duration_since(other).as_secs_f64())
|
||||
} else {
|
||||
TimeDelta::from_secs(-other.saturating_duration_since(self).as_secs_f64())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::transportcc::RemoteInstant;
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::float_cmp)]
|
||||
fn from_and_as() {
|
||||
assert_eq!(TimeDelta::from_secs(2.5).as_secs(), 2.5);
|
||||
assert_eq!(TimeDelta::from_millis(2.5).as_secs(), 0.0025);
|
||||
assert_eq!(TimeDelta::from_secs(-2.5).as_secs(), -2.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::float_cmp)]
|
||||
fn time_delta_since() {
|
||||
assert_eq!(
|
||||
RemoteInstant::from_millis(2500)
|
||||
.time_delta_since(RemoteInstant::from_millis(250))
|
||||
.as_secs(),
|
||||
2.25
|
||||
);
|
||||
assert_eq!(
|
||||
RemoteInstant::from_millis(250)
|
||||
.time_delta_since(RemoteInstant::from_millis(2500))
|
||||
.as_secs(),
|
||||
-2.25
|
||||
);
|
||||
|
||||
let now = Instant::now();
|
||||
let duration = Duration::from_millis(2500);
|
||||
assert_eq!((now + duration).time_delta_since(now).as_secs(), 2.5);
|
||||
assert_eq!(now.time_delta_since(now + duration).as_secs(), -2.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_and_sub() {
|
||||
assert_eq!(
|
||||
TimeDelta::from_secs(2.5) + TimeDelta::from_secs(5.25),
|
||||
TimeDelta::from_secs(7.75)
|
||||
);
|
||||
assert_eq!(
|
||||
TimeDelta::from_secs(2.5) + Duration::from_millis(5250),
|
||||
TimeDelta::from_secs(7.75)
|
||||
);
|
||||
assert_eq!(
|
||||
TimeDelta::from_secs(2.25) - TimeDelta::from_secs(2.5),
|
||||
TimeDelta::from_secs(-0.25)
|
||||
);
|
||||
assert_eq!(
|
||||
TimeDelta::from_secs(2.25) - Duration::from_millis(2500),
|
||||
TimeDelta::from_secs(-0.25)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mul() {
|
||||
assert_eq!(TimeDelta::from_secs(2.5) * 3.0, TimeDelta::from_secs(7.5));
|
||||
}
|
||||
}
|
||||
148
src/googcc/feedback_rtts.rs
Normal file
148
src/googcc/feedback_rtts.rs
Normal file
@ -0,0 +1,148 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use async_stream::stream;
|
||||
use futures::{pin_mut, Stream, StreamExt};
|
||||
|
||||
use crate::{
|
||||
common::{Duration, RingBuffer},
|
||||
transportcc::Ack,
|
||||
};
|
||||
|
||||
// TODO: Consider making this configurable
|
||||
const FEEDBACK_RTTS_HISTORY_LEN: usize = 32;
|
||||
|
||||
pub fn estimate_feedback_rtts(
|
||||
ack_reports: impl Stream<Item = Vec<Ack>>,
|
||||
) -> impl Stream<Item = Duration> {
|
||||
stream! {
|
||||
let mut history: RingBuffer<Duration> = RingBuffer::new(FEEDBACK_RTTS_HISTORY_LEN);
|
||||
pin_mut!(ack_reports);
|
||||
while let Some(acks) = ack_reports.next().await {
|
||||
if let Some(max_feedback_rtt) = acks
|
||||
.iter()
|
||||
.map(|ack| {
|
||||
ack.feedback_arrival
|
||||
.saturating_duration_since(ack.departure)
|
||||
})
|
||||
.max()
|
||||
{
|
||||
history.push(max_feedback_rtt);
|
||||
let mean_feedback_rtt: Duration =
|
||||
history.iter().sum::<Duration>() / (history.len() as u32);
|
||||
yield mean_feedback_rtt;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::FutureExt;
|
||||
|
||||
use super::*;
|
||||
use crate::{common::Instant, transportcc::RemoteInstant};
|
||||
|
||||
/// Creates an `Ack` for each RTT, setting the departure and feedback-arrival times.
|
||||
///
|
||||
/// The size and arrival duration should be ignored.
|
||||
fn acks_from_rtts(rtts: &[u64]) -> Vec<Ack> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
rtts.iter()
|
||||
.map(|rtt| Ack {
|
||||
size: Default::default(),
|
||||
departure: start_time,
|
||||
arrival: RemoteInstant::from_millis(0),
|
||||
feedback_arrival: start_time + Duration::from_millis(*rtt),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn running_average_of_max() {
|
||||
let stream = estimate_feedback_rtts(stream! {
|
||||
yield acks_from_rtts(&[10, 20, 30]);
|
||||
yield acks_from_rtts(&[60, 50, 40]);
|
||||
yield acks_from_rtts(&[60, 50]);
|
||||
});
|
||||
pin_mut!(stream);
|
||||
assert_eq!(
|
||||
&[30, 45, 50],
|
||||
&stream
|
||||
.map(|d| d.as_millis())
|
||||
.collect::<Vec<_>>()
|
||||
.now_or_never()
|
||||
.unwrap()[..]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_update_for_no_acks() {
|
||||
let stream = estimate_feedback_rtts(stream! {
|
||||
yield vec![];
|
||||
yield acks_from_rtts(&[20]);
|
||||
yield vec![];
|
||||
yield acks_from_rtts(&[30]);
|
||||
yield vec![];
|
||||
});
|
||||
pin_mut!(stream);
|
||||
assert_eq!(
|
||||
&[20, 25],
|
||||
&stream
|
||||
.map(|d| d.as_millis())
|
||||
.collect::<Vec<_>>()
|
||||
.now_or_never()
|
||||
.unwrap()[..]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bounded_history() {
|
||||
let stream = estimate_feedback_rtts(stream! {
|
||||
yield acks_from_rtts(&[1000]);
|
||||
loop {
|
||||
yield acks_from_rtts(&[1]);
|
||||
}
|
||||
});
|
||||
pin_mut!(stream);
|
||||
assert_eq!(
|
||||
1,
|
||||
stream
|
||||
.skip(FEEDBACK_RTTS_HISTORY_LEN)
|
||||
.next()
|
||||
.now_or_never()
|
||||
.expect("stream is ready")
|
||||
.expect("and not complete")
|
||||
.as_millis(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handles_negative_rtt() {
|
||||
let start_time = Instant::now();
|
||||
|
||||
let acks = vec![Ack {
|
||||
size: Default::default(),
|
||||
departure: start_time,
|
||||
arrival: RemoteInstant::from_millis(0),
|
||||
feedback_arrival: start_time - Duration::from_millis(1),
|
||||
}];
|
||||
|
||||
let stream = estimate_feedback_rtts(stream! {
|
||||
yield acks;
|
||||
});
|
||||
pin_mut!(stream);
|
||||
assert_eq!(
|
||||
0,
|
||||
stream
|
||||
.next()
|
||||
.now_or_never()
|
||||
.expect("stream is ready")
|
||||
.expect("and not complete")
|
||||
.as_millis(),
|
||||
);
|
||||
}
|
||||
}
|
||||
255
src/googcc/stream.rs
Normal file
255
src/googcc/stream.rs
Normal file
@ -0,0 +1,255 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures::{stream::FusedStream, Stream};
|
||||
use pin_project::pin_project;
|
||||
|
||||
/// The type of [`StreamExt::last`].
|
||||
///
|
||||
/// Manually expanded once so you never have to type it again.
|
||||
#[allow(dead_code)] // Silence the warning about 'pub'; it affects the docs for StreamExt.
|
||||
pub type Last<S> = futures::stream::Fold<
|
||||
S,
|
||||
futures::future::Ready<Option<<S as Stream>::Item>>,
|
||||
Option<<S as Stream>::Item>,
|
||||
fn(
|
||||
Option<<S as Stream>::Item>,
|
||||
<S as Stream>::Item,
|
||||
) -> futures::future::Ready<Option<<S as Stream>::Item>>,
|
||||
>;
|
||||
|
||||
/// Additional adapters for [`Stream`] in the style of [`futures::StreamExt`].
|
||||
pub trait StreamExt: futures::StreamExt {
|
||||
fn last(self) -> Last<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.fold(None, |_, val| futures::future::ready(Some(val)))
|
||||
}
|
||||
|
||||
fn latest_only(self) -> LatestOnly<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
LatestOnly(Some(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Stream> StreamExt for S {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod last_tests {
|
||||
use async_stream::stream;
|
||||
use futures::{
|
||||
future::FutureExt,
|
||||
pin_mut,
|
||||
stream::{empty, iter, pending, StreamExt},
|
||||
};
|
||||
|
||||
use super::StreamExt as OurStreamExt;
|
||||
|
||||
#[test]
|
||||
fn empty_stream() {
|
||||
let stream = empty::<i32>();
|
||||
assert_eq!(None, stream.last().now_or_never().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pending_stream() {
|
||||
let stream = pending::<i32>();
|
||||
assert_eq!(None, stream.last().now_or_never());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn single() {
|
||||
let stream = iter([1]);
|
||||
assert_eq!(Some(1), stream.last().now_or_never().unwrap());
|
||||
|
||||
let unfinished_stream = iter([1]).chain(pending());
|
||||
assert_eq!(None, unfinished_stream.last().now_or_never());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn last() {
|
||||
let stream = iter([1, 2, 3]).latest_only();
|
||||
assert_eq!(Some(3), stream.last().now_or_never().unwrap());
|
||||
|
||||
let unfinished_stream = iter([1, 2, 3]).chain(pending());
|
||||
assert_eq!(None, unfinished_stream.last().now_or_never());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn groups() {
|
||||
let (sender1, receiver1) = futures::channel::oneshot::channel::<()>();
|
||||
let (sender2, receiver2) = futures::channel::oneshot::channel::<()>();
|
||||
let last = stream! {
|
||||
yield 1i32;
|
||||
yield 2;
|
||||
receiver1.await.unwrap();
|
||||
yield 10;
|
||||
yield 20;
|
||||
receiver2.await.unwrap();
|
||||
yield 100;
|
||||
yield 200;
|
||||
}
|
||||
.last();
|
||||
pin_mut!(last);
|
||||
|
||||
assert_eq!(None, last.as_mut().now_or_never());
|
||||
|
||||
sender1.send(()).unwrap();
|
||||
// Deliberately send the second signal too; now the groups will be coalesced.
|
||||
sender2.send(()).unwrap();
|
||||
|
||||
assert_eq!(Some(200), last.now_or_never().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
pub struct LatestOnly<S>(#[pin] Option<S>);
|
||||
|
||||
impl<S: Stream> Stream for LatestOnly<S> {
|
||||
type Item = S::Item;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<S::Item>> {
|
||||
match self.as_mut().project().0.as_pin_mut() {
|
||||
None => Poll::Ready(None),
|
||||
Some(mut inner) => {
|
||||
let mut last_val = None;
|
||||
loop {
|
||||
match inner.as_mut().poll_next(cx) {
|
||||
Poll::Ready(Some(val)) => {
|
||||
last_val = Some(val);
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
self.project().0.set(None);
|
||||
return Poll::Ready(last_val);
|
||||
}
|
||||
Poll::Pending => {
|
||||
return if last_val.is_some() {
|
||||
Poll::Ready(last_val)
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
match &self.0 {
|
||||
None => (0, Some(0)),
|
||||
Some(stream) => {
|
||||
let (original_min, max) = stream.size_hint();
|
||||
let min = if original_min == 0 { 0 } else { 1 };
|
||||
(min, max)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Stream> FusedStream for LatestOnly<S> {
|
||||
fn is_terminated(&self) -> bool {
|
||||
self.0.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod latest_only_tests {
|
||||
use async_stream::stream;
|
||||
use futures::{
|
||||
future::FutureExt,
|
||||
pin_mut,
|
||||
stream::{empty, iter, pending, StreamExt},
|
||||
};
|
||||
|
||||
use super::{StreamExt as OurStreamExt, *};
|
||||
|
||||
#[test]
|
||||
fn empty_stream() {
|
||||
let stream = empty::<i32>().latest_only();
|
||||
assert_eq!((0, Some(0)), stream.size_hint());
|
||||
assert_eq!(
|
||||
&[] as &[i32],
|
||||
&stream.collect::<Vec<_>>().now_or_never().unwrap()[..]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pending_stream() {
|
||||
let mut stream = pending::<i32>().latest_only();
|
||||
// The upper bound here is provided by pending().
|
||||
assert_eq!((0, Some(0)), stream.size_hint());
|
||||
assert_eq!(None, stream.next().now_or_never());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn single() {
|
||||
let stream = iter([1]).latest_only();
|
||||
assert_eq!((1, Some(1)), stream.size_hint());
|
||||
assert_eq!(
|
||||
&[1],
|
||||
&stream.collect::<Vec<_>>().now_or_never().unwrap()[..]
|
||||
);
|
||||
|
||||
let mut unfinished_stream = iter([1]).chain(pending()).latest_only();
|
||||
assert_eq!(Some(1), unfinished_stream.next().now_or_never().unwrap());
|
||||
assert_eq!(None, unfinished_stream.next().now_or_never());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn last() {
|
||||
let stream = iter([1, 2, 3]).latest_only();
|
||||
assert_eq!((1, Some(3)), stream.size_hint());
|
||||
assert_eq!(
|
||||
&[3],
|
||||
&stream.collect::<Vec<_>>().now_or_never().unwrap()[..]
|
||||
);
|
||||
|
||||
let mut unfinished_stream = iter([1, 2, 3]).chain(pending()).latest_only();
|
||||
assert_eq!(Some(3), unfinished_stream.next().now_or_never().unwrap());
|
||||
assert_eq!(None, unfinished_stream.next().now_or_never());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn groups() {
|
||||
let (sender1, receiver1) = futures::channel::oneshot::channel::<()>();
|
||||
let (sender2, receiver2) = futures::channel::oneshot::channel::<()>();
|
||||
let stream = stream! {
|
||||
yield 1i32;
|
||||
yield 2;
|
||||
receiver1.await.unwrap();
|
||||
yield 10;
|
||||
yield 20;
|
||||
receiver2.await.unwrap();
|
||||
yield 100;
|
||||
yield 200;
|
||||
}
|
||||
.latest_only();
|
||||
pin_mut!(stream);
|
||||
|
||||
assert_eq!((0, None), stream.size_hint());
|
||||
assert_eq!(Some(2), stream.next().now_or_never().unwrap());
|
||||
assert_eq!((0, None), stream.size_hint());
|
||||
assert_eq!(None, stream.next().now_or_never());
|
||||
assert_eq!((0, None), stream.size_hint());
|
||||
|
||||
sender1.send(()).unwrap();
|
||||
// Deliberately send the second signal too; now the groups will be coalesced.
|
||||
sender2.send(()).unwrap();
|
||||
|
||||
assert_eq!(Some(200), stream.next().now_or_never().unwrap());
|
||||
assert_eq!((0, Some(0)), stream.size_hint());
|
||||
assert_eq!(None, stream.next().now_or_never().unwrap());
|
||||
assert_eq!((0, Some(0)), stream.size_hint());
|
||||
}
|
||||
}
|
||||
873
src/http_server.rs
Normal file
873
src/http_server.rs
Normal file
@ -0,0 +1,873 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
//! Implementation of the http server. This version is based on warp.
|
||||
//! Supported APIs:
|
||||
//! GET /health
|
||||
//! GET /metrics
|
||||
//! GET /v1/conference/participants
|
||||
//! PUT /v1/conference/participants
|
||||
//! DELETE /v1/conference/participants/endpoint_id
|
||||
|
||||
use std::{
|
||||
convert::TryInto,
|
||||
net::IpAddr,
|
||||
str,
|
||||
str::FromStr,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::UNIX_EPOCH,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use hex::{FromHex, ToHex};
|
||||
use log::*;
|
||||
use parking_lot::Mutex;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tokio::sync::oneshot::Receiver;
|
||||
use warp::{http::StatusCode, Filter, Reply};
|
||||
|
||||
use crate::{
|
||||
common,
|
||||
common::Instant,
|
||||
config, ice,
|
||||
sfu::{self, Sfu},
|
||||
};
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ParticipantsResponse {
|
||||
#[serde(rename = "conferenceId")]
|
||||
pub era_id: String,
|
||||
#[serde(rename = "maxConferenceSize")]
|
||||
pub max_devices: u32,
|
||||
pub participants: Vec<SfuParticipant>,
|
||||
// TODO: Make this with hex too.
|
||||
pub creator: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct Candidate {
|
||||
pub port: u16,
|
||||
pub ip: String,
|
||||
#[serde(rename = "type")]
|
||||
pub candidate_type: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct Fingerprint {
|
||||
pub fingerprint: String,
|
||||
pub hash: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct Transport {
|
||||
pub candidates: Vec<Candidate>,
|
||||
pub fingerprints: Vec<Fingerprint>,
|
||||
pub ufrag: String,
|
||||
pub pwd: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct PayloadParameters {
|
||||
pub minptime: Option<u32>,
|
||||
pub useinbandfec: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct RtcpFbs {
|
||||
#[serde(rename = "type")]
|
||||
pub fbs_type: String,
|
||||
pub subtype: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct PayloadType {
|
||||
pub id: u8,
|
||||
pub name: String,
|
||||
pub clockrate: u32,
|
||||
pub channels: u32,
|
||||
pub parameters: Option<PayloadParameters>,
|
||||
#[serde(rename = "rtcp-fbs")]
|
||||
pub rtcp_fbs: Option<Vec<RtcpFbs>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct HeaderExtension {
|
||||
pub id: u32,
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct SsrcGroup {
|
||||
pub semantics: String,
|
||||
pub sources: Vec<u32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct JoinRequest {
|
||||
pub transport: Transport,
|
||||
#[serde(rename = "audioPayloadType")]
|
||||
pub audio_payload_type: PayloadType,
|
||||
#[serde(rename = "videoPayloadType")]
|
||||
pub video_payload_type: PayloadType,
|
||||
#[serde(rename = "dataPayloadType")]
|
||||
pub data_payload_type: PayloadType,
|
||||
#[serde(rename = "audioHeaderExtensions")]
|
||||
pub audio_header_extensions: Vec<HeaderExtension>,
|
||||
#[serde(rename = "videoHeaderExtensions")]
|
||||
pub video_header_extensions: Vec<HeaderExtension>,
|
||||
#[serde(rename = "audioSsrcs")]
|
||||
pub audio_ssrcs: Vec<u32>,
|
||||
#[serde(rename = "audioSsrcGroups")]
|
||||
pub audio_ssrc_groups: Vec<SsrcGroup>,
|
||||
#[serde(rename = "dataSsrcs")]
|
||||
pub data_ssrcs: Vec<u32>,
|
||||
#[serde(rename = "dataSsrcGroups")]
|
||||
pub data_ssrc_groups: Vec<SsrcGroup>,
|
||||
#[serde(rename = "videoSsrcs")]
|
||||
pub video_ssrcs: Vec<u32>,
|
||||
#[serde(rename = "videoSsrcGroups")]
|
||||
pub video_ssrc_groups: Vec<SsrcGroup>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct JoinResponse {
|
||||
#[serde(rename = "endpointId")]
|
||||
pub endpoint_id: String,
|
||||
#[serde(rename = "ssrcPrefix")]
|
||||
pub demux_id: u32,
|
||||
pub transport: Transport,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone)]
|
||||
pub struct SfuParticipant {
|
||||
#[serde(rename = "endpointId")]
|
||||
pub endpoint_id: String,
|
||||
#[serde(rename = "ssrcPrefix")]
|
||||
pub demux_id: u32,
|
||||
}
|
||||
|
||||
mod metrics {
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct Client {
|
||||
pub user_id: String,
|
||||
pub demux_id: u32,
|
||||
pub video0_incoming_bps: u64,
|
||||
pub video1_incoming_bps: u64,
|
||||
pub video2_incoming_bps: u64,
|
||||
pub target_send_bps: u64,
|
||||
pub ideal_send_bps: u64,
|
||||
pub allocated_send_bps: u64,
|
||||
pub padding_send_bps: u64,
|
||||
pub video0_incoming_height: u64,
|
||||
pub video1_incoming_height: u64,
|
||||
pub video2_incoming_height: u64,
|
||||
pub max_requested_height: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct Call {
|
||||
pub call_id: String,
|
||||
pub client_count: usize,
|
||||
pub clients: Vec<Client>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct Response {
|
||||
pub call_count: usize,
|
||||
pub client_count: usize,
|
||||
pub calls: Vec<Call>,
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtain a demux_id from the given endpoint_id.
|
||||
///
|
||||
/// The demux_id is the first 112 bits of the SHA-256 hash of the endpoint_id string byte
|
||||
/// representation.
|
||||
///
|
||||
/// ```
|
||||
/// use calling_server::http_server::demux_id_from_endpoint_id;
|
||||
/// use std::convert::TryInto;
|
||||
///
|
||||
/// assert_eq!(demux_id_from_endpoint_id("abcdef-0"), 3487943312.try_into().unwrap());
|
||||
/// assert_eq!(demux_id_from_endpoint_id("abcdef-12345"), 2175944000.try_into().unwrap());
|
||||
/// assert_eq!(demux_id_from_endpoint_id(""), 3820012608.try_into().unwrap());
|
||||
/// ```
|
||||
pub fn demux_id_from_endpoint_id(endpoint_id: &str) -> sfu::DemuxId {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(endpoint_id.as_bytes());
|
||||
|
||||
// Get the 32-bit hash but mask out 4 bits since DemuxIDs must leave
|
||||
// these unset for "SSRC space".
|
||||
(u32::from_be_bytes(hasher.finalize()[0..4].try_into().unwrap()) & 0xfffffff0)
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Authenticate the header and return the (user_id, call_id) tuple or an error.
|
||||
fn authenticate(
|
||||
_config: &'static config::Config,
|
||||
password: &str,
|
||||
) -> Result<(sfu::UserId, sfu::CallId)> {
|
||||
let (user_id_hex, call_id_hex) = match password.split(':').collect::<Vec<_>>()[..] {
|
||||
[user_id_hex, call_id_hex, _timestamp, _mac_hex]
|
||||
if !user_id_hex.is_empty() && !call_id_hex.is_empty() =>
|
||||
{
|
||||
Ok((user_id_hex, call_id_hex))
|
||||
}
|
||||
["2", user_id_hex, call_id_hex, _timestamp, _permission, _mac_hex]
|
||||
if !user_id_hex.is_empty() && !call_id_hex.is_empty() =>
|
||||
{
|
||||
Ok((user_id_hex, call_id_hex))
|
||||
}
|
||||
_ => Err(anyhow!("Password not valid")),
|
||||
}?;
|
||||
|
||||
let user_id = Vec::from_hex(user_id_hex)?.into();
|
||||
let call_id = Vec::from_hex(call_id_hex)?.into();
|
||||
|
||||
// The http_server is used for testing and therefore will not perform
|
||||
// actual GV2 auth, as this is done by the frontend.
|
||||
Ok((user_id, call_id))
|
||||
}
|
||||
|
||||
/// Parses an authorization header using the basic authentication scheme. Returns
|
||||
/// a tuple of the credentials (username, password).
|
||||
fn parse_basic_authorization_header(authorization_header: &str) -> Result<(String, String)> {
|
||||
// Get the credentials from the Basic authorization header.
|
||||
if let ["Basic", base_64_encoded_values] =
|
||||
authorization_header.splitn(2, ' ').collect::<Vec<_>>()[..]
|
||||
{
|
||||
// Decode the credentials to utf-8 format.
|
||||
let decoded_values = base64::decode(base_64_encoded_values)?;
|
||||
let credentials = std::str::from_utf8(&decoded_values)?;
|
||||
|
||||
// Split the credentials into the username and password.
|
||||
if let [username, password] = credentials.splitn(2, ':').collect::<Vec<_>>()[..] {
|
||||
Ok((username.to_string(), password.to_string()))
|
||||
} else {
|
||||
// Malformed token.
|
||||
Err(anyhow!("Authorization header not valid"))
|
||||
}
|
||||
} else {
|
||||
// Malformed header.
|
||||
Err(anyhow!("Could not parse authorization header"))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_and_authenticate(
|
||||
config: &'static config::Config,
|
||||
authorization_header: &str,
|
||||
) -> Result<(sfu::UserId, sfu::CallId)> {
|
||||
let (_, password) = parse_basic_authorization_header(&authorization_header)?;
|
||||
authenticate(config, &password)
|
||||
}
|
||||
|
||||
async fn get_metrics(sfu: Arc<Mutex<Sfu>>) -> Result<warp::reply::Response, warp::Rejection> {
|
||||
trace!("get_metrics():");
|
||||
|
||||
let calls = sfu.lock().get_calls_snapshot(); // SFU lock released here.
|
||||
|
||||
let calls = calls
|
||||
.iter()
|
||||
.map(|call| {
|
||||
// We can take this call lock after closing the SFU lock because we are only reading it
|
||||
// and do not care if it is removed from the list of active calls around the same time.
|
||||
// This is in contrast to if we were updating it with a mut reference and we might revive
|
||||
// the call.
|
||||
let call = call.lock();
|
||||
let clients = call
|
||||
.get_stats()
|
||||
.clients
|
||||
.iter()
|
||||
.map(|client| metrics::Client {
|
||||
demux_id: client.demux_id.into(),
|
||||
user_id: String::from_utf8_lossy(client.user_id.as_slice()).to_string(),
|
||||
target_send_bps: client.target_send_rate.as_bps(),
|
||||
video0_incoming_bps: client.video0_incoming_rate.unwrap_or_default().as_bps(),
|
||||
video1_incoming_bps: client.video1_incoming_rate.unwrap_or_default().as_bps(),
|
||||
video2_incoming_bps: client.video2_incoming_rate.unwrap_or_default().as_bps(),
|
||||
padding_send_bps: client.padding_send_rate.as_bps(),
|
||||
ideal_send_bps: client.ideal_send_rate.as_bps(),
|
||||
allocated_send_bps: client.allocated_send_rate.as_bps(),
|
||||
video0_incoming_height: client
|
||||
.video0_incoming_height
|
||||
.unwrap_or_default()
|
||||
.as_u16() as u64,
|
||||
video1_incoming_height: client
|
||||
.video1_incoming_height
|
||||
.unwrap_or_default()
|
||||
.as_u16() as u64,
|
||||
video2_incoming_height: client
|
||||
.video2_incoming_height
|
||||
.unwrap_or_default()
|
||||
.as_u16() as u64,
|
||||
max_requested_height: client.max_requested_height.unwrap_or_default().as_u16()
|
||||
as u64,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
metrics::Call {
|
||||
call_id: call.loggable_call_id().to_string(),
|
||||
client_count: clients.len(),
|
||||
clients,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let response = metrics::Response {
|
||||
call_count: calls.len(),
|
||||
client_count: calls.iter().map(|c| c.client_count).sum(),
|
||||
calls,
|
||||
};
|
||||
|
||||
Ok(warp::reply::with_status(warp::reply::json(&response), StatusCode::OK).into_response())
|
||||
}
|
||||
|
||||
async fn get_participants(
|
||||
config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
authorization_header: String,
|
||||
) -> Result<warp::reply::Response, warp::Rejection> {
|
||||
trace!("get():");
|
||||
|
||||
let call_id = match parse_and_authenticate(config, &authorization_header) {
|
||||
Ok((_, call_id)) => call_id,
|
||||
Err(err) => {
|
||||
warn!("get(): unauthorized {}", err);
|
||||
return Ok(warp::reply::with_status(
|
||||
warp::reply::json(&err.to_string()),
|
||||
StatusCode::UNAUTHORIZED,
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
let sfu = sfu.lock();
|
||||
|
||||
if let Some(signaling) = sfu.get_call_signaling_info(call_id) {
|
||||
let max_devices = sfu.config.max_clients_per_call;
|
||||
drop(sfu);
|
||||
// Release the SFU lock as early as possible. Before call lock is fine for a imut ref to a
|
||||
// call, as nothing we will do later with the reference can affect SFU decisions around call
|
||||
// dropping.
|
||||
|
||||
let participants = signaling
|
||||
.client_ids
|
||||
.iter()
|
||||
.map(|(demux_id, active_speaker_id)| SfuParticipant {
|
||||
endpoint_id: active_speaker_id.to_owned(),
|
||||
demux_id: u32::from(*demux_id),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let era_id = signaling
|
||||
.created
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("Time went backwards")
|
||||
.as_millis()
|
||||
.to_string();
|
||||
let response = ParticipantsResponse {
|
||||
// TODO: Consider handling the expectation and returning an internal server error for it.
|
||||
era_id,
|
||||
max_devices,
|
||||
participants,
|
||||
creator: signaling.creator_id.as_slice().encode_hex(),
|
||||
};
|
||||
|
||||
Ok(warp::reply::with_status(warp::reply::json(&response), StatusCode::OK).into_response())
|
||||
} else {
|
||||
Ok(StatusCode::NOT_FOUND.into_response())
|
||||
}
|
||||
}
|
||||
|
||||
async fn join(
|
||||
config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
authorization_header: String,
|
||||
join_request: JoinRequest,
|
||||
) -> Result<warp::reply::Response, warp::Rejection> {
|
||||
trace!("join():");
|
||||
|
||||
let (user_id, call_id) = match parse_and_authenticate(config, &authorization_header) {
|
||||
Ok((user_id, call_id)) => (user_id, call_id),
|
||||
Err(err) => {
|
||||
warn!("join(): unauthorized {}", err);
|
||||
return Ok(warp::reply::with_status(
|
||||
warp::reply::json(&err.to_string()),
|
||||
StatusCode::UNAUTHORIZED,
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
// Evaluate the request with basic assertions.
|
||||
if join_request.audio_header_extensions.len() != 3
|
||||
|| join_request.audio_payload_type.id != 102
|
||||
|| !join_request.audio_ssrc_groups.is_empty()
|
||||
|| join_request.audio_ssrcs.len() != 1
|
||||
|| join_request.data_payload_type.id != 101
|
||||
|| join_request.data_ssrcs.len() != 1
|
||||
|| join_request.video_header_extensions.len() != 3
|
||||
|| join_request.video_payload_type.id != 108
|
||||
|| join_request.video_ssrc_groups.len() != 4
|
||||
|| join_request.video_ssrcs.len() != 6
|
||||
|| join_request.transport.fingerprints.len() != 1
|
||||
{
|
||||
return Ok(warp::reply::with_status(
|
||||
warp::reply::json(&"Missing required fields in the request.".to_string()),
|
||||
StatusCode::NOT_ACCEPTABLE,
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
|
||||
let client_dtls_fingerprint = match common::colon_separated_hexstring_to_array(
|
||||
&join_request
|
||||
.transport
|
||||
.fingerprints
|
||||
.get(0)
|
||||
.unwrap() // It has already been established that this fingerprint exists.
|
||||
.fingerprint,
|
||||
) {
|
||||
Ok(v) => v,
|
||||
Err(_) => {
|
||||
return Ok(warp::reply::with_status(
|
||||
warp::reply::json(&"Invalid dtls_fingerprint in the request.".to_string()),
|
||||
StatusCode::NOT_ACCEPTABLE,
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
// Generate ids for the client.
|
||||
let resolution_request_id = rand::thread_rng().gen::<u64>();
|
||||
// The endpoint_id is the term currently used on the client side, it is
|
||||
// equivalent to the active_speaker_id in the Sfu.
|
||||
let endpoint_id = format!(
|
||||
"{}-{}",
|
||||
user_id.as_slice().encode_hex::<String>(),
|
||||
resolution_request_id
|
||||
);
|
||||
let demux_id = demux_id_from_endpoint_id(&endpoint_id);
|
||||
let server_ice_ufrag = ice::random_ufrag();
|
||||
let server_ice_pwd = ice::random_pwd();
|
||||
|
||||
let mut sfu = sfu.lock();
|
||||
sfu.get_or_create_call_and_add_client(
|
||||
call_id,
|
||||
&user_id,
|
||||
resolution_request_id,
|
||||
endpoint_id.clone(),
|
||||
demux_id,
|
||||
server_ice_ufrag.clone(),
|
||||
server_ice_pwd.clone(),
|
||||
join_request.transport.ufrag,
|
||||
client_dtls_fingerprint,
|
||||
)
|
||||
.unwrap();
|
||||
let socket_addr = config::get_server_media_address(config);
|
||||
let candidate = Candidate {
|
||||
port: socket_addr.port(),
|
||||
ip: socket_addr.ip().to_string(),
|
||||
candidate_type: "host".to_string(),
|
||||
};
|
||||
|
||||
let candidates = vec![candidate];
|
||||
|
||||
let fingerprint = Fingerprint {
|
||||
fingerprint: common::bytes_to_colon_separated_hexstring(sfu.server_dtls_fingerprint()),
|
||||
hash: "sha-256".to_string(),
|
||||
};
|
||||
|
||||
let fingerprints = vec![fingerprint];
|
||||
|
||||
let transport = Transport {
|
||||
candidates,
|
||||
fingerprints,
|
||||
ufrag: server_ice_ufrag,
|
||||
pwd: server_ice_pwd,
|
||||
};
|
||||
|
||||
let response = JoinResponse {
|
||||
endpoint_id,
|
||||
demux_id: demux_id.into(),
|
||||
transport,
|
||||
};
|
||||
|
||||
Ok(warp::reply::with_status(warp::reply::json(&response), StatusCode::OK).into_response())
|
||||
}
|
||||
|
||||
async fn leave(
|
||||
endpoint_id: String,
|
||||
config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
authorization_header: String,
|
||||
) -> Result<warp::reply::Response, warp::Rejection> {
|
||||
trace!("leave():");
|
||||
|
||||
let call_id = match parse_and_authenticate(config, &authorization_header) {
|
||||
Ok((_, call_id)) => call_id,
|
||||
Err(err) => {
|
||||
warn!("leave(): unauthorized {}", err);
|
||||
return Ok(warp::reply::with_status(
|
||||
warp::reply::json(&err.to_string()),
|
||||
StatusCode::UNAUTHORIZED,
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
// Calculate the demux_id with some simple validation of the endpoint_id.
|
||||
let demux_id = if endpoint_id.chars().count() > 3 && endpoint_id.contains('-') {
|
||||
demux_id_from_endpoint_id(&endpoint_id)
|
||||
} else {
|
||||
return Ok(warp::reply::with_status(
|
||||
warp::reply::json(&"Invalid endpoint_id format".to_string()),
|
||||
StatusCode::BAD_REQUEST,
|
||||
)
|
||||
.into_response());
|
||||
};
|
||||
|
||||
sfu.lock()
|
||||
.remove_client_from_call(Instant::now(), call_id, demux_id);
|
||||
|
||||
// TODO: When the function above returns a result, handle that.
|
||||
// TODO: Also, if there was no group call, should we return NOT_FOUND?
|
||||
Ok(StatusCode::NO_CONTENT.into_response())
|
||||
}
|
||||
|
||||
/// A warp filter for providing the config for a route.
|
||||
fn with_config(
|
||||
config: &'static config::Config,
|
||||
) -> impl Filter<Extract = (&'static config::Config,), Error = std::convert::Infallible> + Clone {
|
||||
warp::any().map(move || config)
|
||||
}
|
||||
|
||||
/// A warp filter for extracting the Sfu state for a route.
|
||||
fn with_sfu(
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
) -> impl Filter<Extract = (Arc<Mutex<Sfu>>,), Error = std::convert::Infallible> + Clone {
|
||||
warp::any().map(move || sfu.clone())
|
||||
}
|
||||
|
||||
pub async fn start(
|
||||
config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
http_ender_rx: Receiver<()>,
|
||||
is_healthy: Arc<AtomicBool>,
|
||||
) -> Result<()> {
|
||||
// Filter to support: GET /health
|
||||
let health_check_api = warp::path!("about" / "health")
|
||||
.and(warp::get())
|
||||
.map(move || {
|
||||
if is_healthy.load(Ordering::Relaxed) {
|
||||
Ok(StatusCode::OK.into_response())
|
||||
} else {
|
||||
Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response())
|
||||
}
|
||||
});
|
||||
|
||||
// Filter to support: GET /metrics
|
||||
let metrics_api = warp::path!("metrics")
|
||||
.and(warp::get())
|
||||
.and(with_sfu(sfu.clone()))
|
||||
.and_then(get_metrics);
|
||||
|
||||
// Filter to support: GET /v1/conference/participants
|
||||
let get_participants_api = warp::path!("v1" / "conference" / "participants")
|
||||
.and(warp::get())
|
||||
.and(with_config(config))
|
||||
.and(with_sfu(sfu.clone()))
|
||||
.and(warp::header("authorization"))
|
||||
.and_then(get_participants);
|
||||
|
||||
// Filter to support: PUT /v1/conference/participants
|
||||
let join_api = warp::path!("v1" / "conference" / "participants")
|
||||
.and(warp::put())
|
||||
.and(with_config(config))
|
||||
.and(with_sfu(sfu.clone()))
|
||||
.and(warp::header("authorization"))
|
||||
.and(warp::body::json())
|
||||
.and_then(join);
|
||||
|
||||
// Filter to support: DELETE v1/conference/participants/endpoint-id 204 Success
|
||||
let leave_api = warp::path!("v1" / "conference" / "participants" / String)
|
||||
.and(warp::delete())
|
||||
.and(with_config(config))
|
||||
.and(with_sfu(sfu))
|
||||
.and(warp::header("authorization"))
|
||||
.and_then(leave);
|
||||
|
||||
let api = health_check_api
|
||||
.or(metrics_api)
|
||||
.or(get_participants_api)
|
||||
.or(join_api)
|
||||
.or(leave_api);
|
||||
|
||||
// Add other options to form the final routes to be served.
|
||||
// TODO: Disabling the "with(log)" mechanism since it causes the following
|
||||
// error when trying to launch with tokio::spawn():
|
||||
// implementation of `warp::reply::Reply` is not general enough
|
||||
//let routes = api.with(warp::log("calling_service"));
|
||||
|
||||
let (addr, server) = warp::serve(api).bind_with_graceful_shutdown(
|
||||
(IpAddr::from_str(&config.binding_ip)?, config.signaling_port),
|
||||
async {
|
||||
http_ender_rx.await.ok();
|
||||
},
|
||||
);
|
||||
|
||||
info!("http_server ready: {}", addr);
|
||||
server.await;
|
||||
|
||||
info!("http_server shutdown");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod http_server_tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use hex::{FromHex, ToHex};
|
||||
use lazy_static::lazy_static;
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn string_id_vector(n: usize, id_size: usize) -> Vec<String> {
|
||||
let mut vector: Vec<String> = Vec::new();
|
||||
for _ in 0..n {
|
||||
vector.push(common::random_hex_string(id_size));
|
||||
}
|
||||
vector
|
||||
}
|
||||
|
||||
fn random_byte_vector(n: usize) -> Vec<u8> {
|
||||
let mut numbers: Vec<u8> = Vec::new();
|
||||
let mut rng = thread_rng();
|
||||
for _ in 0..n {
|
||||
numbers.push(rng.gen());
|
||||
}
|
||||
numbers
|
||||
}
|
||||
|
||||
fn random_byte_id_vector(n: usize, id_size: usize) -> Vec<Vec<u8>> {
|
||||
let mut vector: Vec<Vec<u8>> = Vec::new();
|
||||
for _ in 0..n {
|
||||
vector.push(random_byte_vector(id_size));
|
||||
}
|
||||
vector
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hex_decode_bench() {
|
||||
// Create 1000 id's to use.
|
||||
let count = 1000;
|
||||
let ids = string_id_vector(count, 64);
|
||||
|
||||
// Pre-allocate the outer vec.
|
||||
let mut ids_bytes: Vec<Vec<u8>> = Vec::with_capacity(count);
|
||||
|
||||
let start = Instant::now();
|
||||
for id in ids {
|
||||
ids_bytes.push(Vec::from_hex(id).unwrap());
|
||||
}
|
||||
let end = Instant::now();
|
||||
|
||||
assert_eq!(count, ids_bytes.len());
|
||||
assert_eq!(32, ids_bytes.get(0).unwrap().len());
|
||||
|
||||
println!(
|
||||
"hex_decode() for {} ids took {}ns",
|
||||
count,
|
||||
end.duration_since(start).as_nanos()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hex_encode_bench() {
|
||||
// Create 1000 id's to use.
|
||||
let count = 1000;
|
||||
let ids = random_byte_id_vector(count, 32);
|
||||
|
||||
// Pre-allocate the outer vec.
|
||||
let mut ids_strings: Vec<String> = Vec::with_capacity(count);
|
||||
|
||||
let start = Instant::now();
|
||||
for id in ids {
|
||||
ids_strings.push(id.encode_hex::<String>());
|
||||
}
|
||||
let end = Instant::now();
|
||||
|
||||
assert_eq!(count, ids_strings.len());
|
||||
assert_eq!(64, ids_strings.get(0).unwrap().len());
|
||||
|
||||
println!(
|
||||
"hex_encode() for {} ids took {}ns",
|
||||
count,
|
||||
end.duration_since(start).as_nanos()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_basic_authorization_header() {
|
||||
let result = parse_basic_authorization_header("");
|
||||
assert!(result.is_err());
|
||||
assert_eq!(
|
||||
result.err().unwrap().to_string(),
|
||||
"Could not parse authorization header"
|
||||
);
|
||||
|
||||
// Error: Could not parse authorization header
|
||||
assert!(parse_basic_authorization_header("B").is_err());
|
||||
assert!(parse_basic_authorization_header("Basic").is_err());
|
||||
assert!(parse_basic_authorization_header("Basic ").is_err());
|
||||
assert!(parse_basic_authorization_header("B X").is_err());
|
||||
assert!(parse_basic_authorization_header("Basi XYZ").is_err());
|
||||
|
||||
// DecodeError: Encoded text cannot have a 6-bit remainder.
|
||||
assert!(parse_basic_authorization_header("Basic X").is_err());
|
||||
|
||||
// DecodeError: Invalid last symbol 90, offset 2.
|
||||
assert!(parse_basic_authorization_header("Basic XYZ").is_err());
|
||||
|
||||
// Utf8Error: invalid utf-8 sequence of 1 bytes from index 0
|
||||
assert!(parse_basic_authorization_header("Basic //3//Q==").is_err());
|
||||
|
||||
// Utf8Error: invalid utf-8 sequence of 1 bytes from index 8
|
||||
assert!(
|
||||
parse_basic_authorization_header("Basic MTIzNDU2Nzj95v3n/ej96f3q/ev97P3t/e797w==")
|
||||
.is_err()
|
||||
);
|
||||
|
||||
let result = parse_basic_authorization_header("Basic VGVzdA==");
|
||||
assert!(result.is_err());
|
||||
assert_eq!(
|
||||
result.err().unwrap().to_string(),
|
||||
"Authorization header not valid"
|
||||
);
|
||||
|
||||
// Error: Authorization header not valid
|
||||
assert!(parse_basic_authorization_header("Basic MSAy").is_err());
|
||||
assert!(parse_basic_authorization_header("Basic MWIgMmI=").is_err());
|
||||
|
||||
// ":"
|
||||
assert_eq!(
|
||||
parse_basic_authorization_header("Basic Og==").unwrap(),
|
||||
("".to_string(), "".to_string())
|
||||
);
|
||||
|
||||
// "username:password"
|
||||
assert_eq!(
|
||||
parse_basic_authorization_header("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap(),
|
||||
("username".to_string(), "password".to_string())
|
||||
);
|
||||
|
||||
// ":password"
|
||||
assert_eq!(
|
||||
parse_basic_authorization_header("Basic OnBhc3N3b3Jk").unwrap(),
|
||||
("".to_string(), "password".to_string())
|
||||
);
|
||||
|
||||
// "username:"
|
||||
assert_eq!(
|
||||
parse_basic_authorization_header("Basic dXNlcm5hbWU6").unwrap(),
|
||||
("username".to_string(), "".to_string())
|
||||
);
|
||||
|
||||
// "::"
|
||||
assert_eq!(
|
||||
parse_basic_authorization_header("Basic Ojo=").unwrap(),
|
||||
("".to_string(), ":".to_string())
|
||||
);
|
||||
|
||||
// ":::::"
|
||||
assert_eq!(
|
||||
parse_basic_authorization_header("Basic Ojo6Ojo=").unwrap(),
|
||||
("".to_string(), "::::".to_string())
|
||||
);
|
||||
|
||||
// "1a2b3c:1a2b3c:1a2b3c:1a2b3c"
|
||||
assert_eq!(
|
||||
parse_basic_authorization_header("Basic MWEyYjNjOjFhMmIzYzoxYTJiM2M6MWEyYjNj").unwrap(),
|
||||
("1a2b3c".to_string(), "1a2b3c:1a2b3c:1a2b3c".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authenticate() {
|
||||
lazy_static! {
|
||||
static ref CONFIG: config::Config = config::default_test_config();
|
||||
}
|
||||
let config = &CONFIG;
|
||||
|
||||
let result = authenticate(config, "1:2");
|
||||
assert!(result.is_err());
|
||||
assert_eq!(result.err().unwrap().to_string(), "Password not valid");
|
||||
|
||||
// Error: Password not valid
|
||||
assert!(authenticate(config, "").is_err());
|
||||
assert!(authenticate(config, ":").is_err());
|
||||
assert!(authenticate(config, "::").is_err());
|
||||
assert!(authenticate(config, ":::").is_err());
|
||||
assert!(authenticate(config, "::::").is_err());
|
||||
assert!(authenticate(config, ":::::").is_err());
|
||||
assert!(authenticate(config, "2:::::").is_err());
|
||||
assert!(authenticate(config, "1:2:3").is_err());
|
||||
assert!(authenticate(config, "1:2:3:4:5").is_err());
|
||||
|
||||
// Error: Odd number of digits
|
||||
assert!(authenticate(config, "1:2b::").is_err());
|
||||
assert!(authenticate(config, "1a:2::").is_err());
|
||||
assert!(authenticate(config, "1a2:2b:1:3c").is_err());
|
||||
assert!(authenticate(config, "2:1:2b:::").is_err());
|
||||
assert!(authenticate(config, "2:1a:2:::").is_err());
|
||||
assert!(authenticate(config, "2:1a2:2b:1:1:3c").is_err());
|
||||
|
||||
// Error: Invalid character 'x' at position 1
|
||||
assert!(authenticate(config, "1x:2b:1:").is_err());
|
||||
assert!(authenticate(config, "1a:2x:1:").is_err());
|
||||
assert!(authenticate(config, "2:1x:2b:1::").is_err());
|
||||
assert!(authenticate(config, "2:1a:2x:1::").is_err());
|
||||
|
||||
// Error: Unknown version
|
||||
assert!(authenticate(config, ":1a:2b:1:2:3").is_err());
|
||||
assert!(authenticate(config, "1:1a:2b:1:2:3").is_err());
|
||||
assert!(authenticate(config, "3:1a:2b:1:2:3").is_err());
|
||||
|
||||
assert_eq!(
|
||||
authenticate(config, "1a:2b:1:").unwrap(),
|
||||
(sfu::UserId::from(vec![26]), sfu::CallId::from(vec![43]))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
authenticate(config, "2:1a:2b:1:2:3").unwrap(),
|
||||
(vec![26].into(), vec![43].into())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_and_authenticate() {
|
||||
lazy_static! {
|
||||
static ref CONFIG: config::Config = config::default_test_config();
|
||||
}
|
||||
let config = &CONFIG;
|
||||
|
||||
// Version 1: "username:1a:2b:1:"
|
||||
let result = parse_and_authenticate(config, "Basic dXNlcm5hbWU6MWE6MmI6MTo=");
|
||||
assert!(!result.is_err());
|
||||
assert_eq!(result.unwrap(), (vec![26].into(), vec![43].into()));
|
||||
|
||||
// Version 2: "username:2:1a:2b:1:2:3"
|
||||
let result = parse_and_authenticate(config, "Basic dXNlcm5hbWU6MjoxYToyYjoxOjI6Mw==");
|
||||
assert!(!result.is_err());
|
||||
assert_eq!(result.unwrap(), (vec![26].into(), vec![43].into()));
|
||||
}
|
||||
}
|
||||
912
src/ice.rs
Normal file
912
src/ice.rs
Normal file
@ -0,0 +1,912 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
//! Implementation of ICE lite. See https://tools.ietf.org/html/rfc5389
|
||||
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
ops::{Deref, Range},
|
||||
};
|
||||
|
||||
use crc::crc32;
|
||||
use hmac::{crypto_mac::MacError, Hmac, Mac, NewMac};
|
||||
use log::*;
|
||||
use sha1::Sha1;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::common::{
|
||||
parse_u16, random_base64_string_of_length_32, random_base64_string_of_length_4,
|
||||
round_up_to_multiple_of, Empty, Writer,
|
||||
};
|
||||
|
||||
const HEADER_LEN: usize = 20;
|
||||
const HMAC_LEN: usize = 20;
|
||||
const FINGERPRINT_LEN: usize = 4;
|
||||
const ATTR_HEADER_LEN: usize = 4;
|
||||
const BINDING_REQUEST_ID: [u8; 2] = [0x00, 0x01];
|
||||
const BINDING_RESPONSE_ID: [u8; 2] = [0x01, 0x01];
|
||||
const MAGIC_COOKIE: [u8; 4] = [0x21, 0x12, 0xA4, 0x42];
|
||||
|
||||
struct AttributeId {}
|
||||
|
||||
impl AttributeId {
|
||||
const USERNAME: u16 = 0x0006;
|
||||
const MESSAGE_INTEGRITY: u16 = 0x0008;
|
||||
const FINGERPRINT: u16 = 0x8028;
|
||||
/// AKA USE-CANDIDATE
|
||||
const NOMINATION: u16 = 0x0025;
|
||||
}
|
||||
|
||||
const FINGERPRINT_XOR_VALUE: u32 = 0x5354554E;
|
||||
|
||||
struct BindingRequestRanges {
|
||||
username: Range<usize>,
|
||||
hmac: Range<usize>,
|
||||
fingerprint: Range<usize>,
|
||||
}
|
||||
|
||||
pub fn join_username(sender_ufrag: &[u8], receiver_ufrag: &[u8]) -> Vec<u8> {
|
||||
[receiver_ufrag, sender_ufrag].join(b":".as_ref())
|
||||
}
|
||||
|
||||
pub struct BindingRequest<'a> {
|
||||
packet: &'a [u8],
|
||||
is_nominated: bool,
|
||||
ranges: BindingRequestRanges,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Eq, PartialEq)]
|
||||
pub enum ParseError {
|
||||
#[error("ICE binding request has no complete header, packet length {0}.")]
|
||||
IncompleteHeader(usize),
|
||||
#[error("ICE binding request has no username.")]
|
||||
MissingUsernameAttribute,
|
||||
#[error("ICE binding request has no hmac.")]
|
||||
MissingHMacAttribute,
|
||||
#[error("ICE binding request has no fingerprint.")]
|
||||
MissingFingerprintAttribute,
|
||||
#[error("ICE binding request message length was declared as {0:#06x} but was {1:#06x}.")]
|
||||
DeclaredMessageLengthMismatch(usize, usize),
|
||||
#[error("ICE binding request hmac length was {0} but expected {1}.")]
|
||||
WrongHMacLength(u16, u16),
|
||||
#[error("ICE binding request fingerprint length was {0} but expected {1}.")]
|
||||
WrongFingerprintLength(u16, u16),
|
||||
#[error("ICE binding request fingerprint seen before hmac")]
|
||||
FingerprintBeforeHMac,
|
||||
#[error("ICE binding request saw attribute {0:#06x} but expected fingerprint.")]
|
||||
ExpectedFingerprint(u16),
|
||||
#[error("ICE binding request saw attribute {0:#06x} after the fingerprint.")]
|
||||
AttributeAfterFingerprint(u16),
|
||||
#[error("ICE binding request attribute {0:#06x} is {1} bytes past packet end.")]
|
||||
AttributeRangePastPacketEnd(u16, usize),
|
||||
}
|
||||
|
||||
impl<'a> BindingRequest<'a> {
|
||||
pub fn looks_like_header(packet: &[u8]) -> bool {
|
||||
packet.len() >= 8 && packet[0..2] == BINDING_REQUEST_ID && packet[4..8] == MAGIC_COOKIE
|
||||
}
|
||||
|
||||
pub fn parse(packet: &'a [u8]) -> Result<BindingRequest<'a>, ParseError> {
|
||||
if packet.len() < HEADER_LEN {
|
||||
return Err(ParseError::IncompleteHeader(packet.len()));
|
||||
}
|
||||
|
||||
let declared_message_length = parse_u16(&packet[2..4]) as usize;
|
||||
let actual_message_length = packet.len() - HEADER_LEN;
|
||||
if declared_message_length != actual_message_length {
|
||||
return Err(ParseError::DeclaredMessageLengthMismatch(
|
||||
declared_message_length,
|
||||
actual_message_length,
|
||||
));
|
||||
}
|
||||
|
||||
/// State machine states for parsing, to help ensure mac and fingerprint are last.
|
||||
enum ParseState {
|
||||
ReadingAttributes,
|
||||
ExpectFingerprint,
|
||||
Done,
|
||||
}
|
||||
|
||||
let mut parse_mode = ParseState::ReadingAttributes;
|
||||
let mut username: Option<Range<usize>> = None;
|
||||
let mut hmac: Option<Range<usize>> = None;
|
||||
let mut fingerprint: Option<Range<usize>> = None;
|
||||
let mut nomination: Option<Range<usize>> = None;
|
||||
|
||||
let mut attr_start = HEADER_LEN;
|
||||
while packet.len() >= attr_start + ATTR_HEADER_LEN {
|
||||
let attr_header = &packet[attr_start..][..ATTR_HEADER_LEN];
|
||||
let attr_id = parse_u16(&attr_header[0..2]);
|
||||
let attr_len = parse_u16(&attr_header[2..4]);
|
||||
let attr_val_start = attr_start + ATTR_HEADER_LEN;
|
||||
let attr_val_end = attr_val_start + attr_len as usize;
|
||||
let attr_range = attr_val_start..attr_val_end;
|
||||
if attr_range.end > packet.len() {
|
||||
return Err(ParseError::AttributeRangePastPacketEnd(
|
||||
attr_id,
|
||||
attr_range.end - packet.len(),
|
||||
));
|
||||
}
|
||||
match parse_mode {
|
||||
ParseState::ReadingAttributes => match attr_id {
|
||||
AttributeId::USERNAME => username = Some(attr_range.clone()),
|
||||
AttributeId::NOMINATION => nomination = Some(attr_range.clone()),
|
||||
AttributeId::MESSAGE_INTEGRITY => {
|
||||
if attr_range.len() != HMAC_LEN {
|
||||
return Err(ParseError::WrongHMacLength(attr_len, HMAC_LEN as u16));
|
||||
}
|
||||
hmac = Some(attr_range.clone());
|
||||
parse_mode = ParseState::ExpectFingerprint;
|
||||
}
|
||||
AttributeId::FINGERPRINT => {
|
||||
return Err(ParseError::FingerprintBeforeHMac);
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
ParseState::ExpectFingerprint => {
|
||||
if attr_id != AttributeId::FINGERPRINT {
|
||||
return Err(ParseError::ExpectedFingerprint(attr_id));
|
||||
}
|
||||
if attr_range.len() != FINGERPRINT_LEN {
|
||||
return Err(ParseError::WrongFingerprintLength(
|
||||
attr_len,
|
||||
FINGERPRINT_LEN as u16,
|
||||
));
|
||||
}
|
||||
fingerprint = Some(attr_range.clone());
|
||||
parse_mode = ParseState::Done;
|
||||
}
|
||||
ParseState::Done => {
|
||||
return Err(ParseError::AttributeAfterFingerprint(attr_id));
|
||||
}
|
||||
}
|
||||
attr_start = round_up_to_multiple_of::<4>(attr_range.end);
|
||||
}
|
||||
|
||||
let username = username.ok_or(ParseError::MissingUsernameAttribute)?;
|
||||
let hmac = hmac.ok_or(ParseError::MissingHMacAttribute)?;
|
||||
let fingerprint = fingerprint.ok_or(ParseError::MissingFingerprintAttribute)?;
|
||||
|
||||
if log_enabled!(Level::Trace) {
|
||||
trace!("ICE binding request:");
|
||||
trace!(" username: {:?}", username.clone().collect::<Vec<_>>());
|
||||
trace!(" hmac: {:?}", hmac.clone().collect::<Vec<_>>());
|
||||
trace!(
|
||||
" fingerprint: {:?}",
|
||||
fingerprint.clone().collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(BindingRequest {
|
||||
packet,
|
||||
is_nominated: nomination.is_some(),
|
||||
ranges: BindingRequestRanges {
|
||||
username,
|
||||
hmac,
|
||||
fingerprint,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn nominated(&self) -> bool {
|
||||
self.is_nominated
|
||||
}
|
||||
|
||||
pub fn hmac(&self) -> &[u8] {
|
||||
&self.packet[self.ranges.hmac.clone()]
|
||||
}
|
||||
|
||||
pub fn fingerprint(&self) -> &[u8] {
|
||||
&self.packet[self.ranges.fingerprint.clone()]
|
||||
}
|
||||
|
||||
pub fn username(&self) -> &[u8] {
|
||||
&self.packet[self.ranges.username.clone()]
|
||||
}
|
||||
|
||||
pub fn verify_hmac(&self, pwd: &[u8]) -> Result<VerifiedBindingRequest, MacError> {
|
||||
Self::calculate_hmac(self.packet, &self.ranges, pwd)
|
||||
.verify(&self.hmac())
|
||||
.map(|_| VerifiedBindingRequest::new(self))
|
||||
}
|
||||
|
||||
fn calculate_hmac(packet: &[u8], ranges: &BindingRequestRanges, pwd: &[u8]) -> Hmac<Sha1> {
|
||||
// ICE HMACs are strange in that they are computed without the HMAC attribute,
|
||||
// but are computed with a length that includes the HMAC attribute.
|
||||
let mut mac = Hmac::<Sha1>::new_from_slice(pwd).expect("All key lengths are valid");
|
||||
mac.update(&packet[0..2]);
|
||||
// The length in the header excludes the header itself, but includes the HMAC attribute.
|
||||
mac.update(&((ranges.hmac.end - HEADER_LEN) as u16).to_be_bytes());
|
||||
mac.update(&packet[4..(ranges.hmac.start - ATTR_HEADER_LEN)]);
|
||||
mac
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VerifiedBindingRequest<'a> {
|
||||
request: &'a BindingRequest<'a>,
|
||||
}
|
||||
|
||||
impl<'a> Deref for VerifiedBindingRequest<'a> {
|
||||
type Target = BindingRequest<'a>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.request
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> VerifiedBindingRequest<'a> {
|
||||
fn new(request: &'a BindingRequest<'a>) -> VerifiedBindingRequest<'a> {
|
||||
VerifiedBindingRequest { request }
|
||||
}
|
||||
|
||||
/// Public constructor for fuzzing only, which allows creation of a verified binding request
|
||||
/// even though it's not possible for the fuzzer to get past the hmac verification.
|
||||
#[cfg(fuzzing)]
|
||||
pub fn new_for_fuzzing(request: &'a BindingRequest<'a>) -> VerifiedBindingRequest<'a> {
|
||||
VerifiedBindingRequest { request }
|
||||
}
|
||||
|
||||
pub fn to_binding_response(&self, username: &[u8], pwd: &[u8]) -> Vec<u8> {
|
||||
let mut packet: Vec<u8> = self.packet.to_vec();
|
||||
packet[0..2].copy_from_slice(&BINDING_RESPONSE_ID);
|
||||
packet[self.ranges.username.clone()].copy_from_slice(username);
|
||||
Self::recalculate_hmac_and_fingerprint_of_packet(&mut packet, &self.ranges, pwd);
|
||||
packet
|
||||
}
|
||||
|
||||
fn recalculate_hmac_and_fingerprint_of_packet(
|
||||
packet: &mut [u8],
|
||||
ranges: &BindingRequestRanges,
|
||||
pwd: &[u8],
|
||||
) {
|
||||
let hmac = BindingRequest::calculate_hmac(packet, ranges, pwd)
|
||||
.finalize()
|
||||
.into_bytes();
|
||||
packet[ranges.hmac.clone()].copy_from_slice(&hmac);
|
||||
let fingerprint = FINGERPRINT_XOR_VALUE
|
||||
^ crc32::checksum_ieee(&packet[..(ranges.fingerprint.start - ATTR_HEADER_LEN)]);
|
||||
packet[ranges.fingerprint.clone()].copy_from_slice(&fingerprint.to_be_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn random_ufrag() -> String {
|
||||
random_base64_string_of_length_4()
|
||||
}
|
||||
|
||||
pub fn random_pwd() -> String {
|
||||
random_base64_string_of_length_32()
|
||||
}
|
||||
|
||||
type TransactionId = [u8; 16];
|
||||
|
||||
pub fn create_binding_request_packet(
|
||||
transaction_id: &TransactionId,
|
||||
username: &[u8],
|
||||
pwd: &[u8],
|
||||
nominated: bool,
|
||||
) -> Vec<u8> {
|
||||
create_packet(BINDING_REQUEST_ID, transaction_id, username, pwd, nominated)
|
||||
}
|
||||
|
||||
// Responses don't need a nomination bit, but for now we include it because it's easier for writing tests.
|
||||
pub fn create_binding_response_packet(
|
||||
transaction_id: &TransactionId,
|
||||
username: &[u8],
|
||||
pwd: &[u8],
|
||||
nominated: bool,
|
||||
) -> Vec<u8> {
|
||||
create_packet(
|
||||
BINDING_RESPONSE_ID,
|
||||
transaction_id,
|
||||
username,
|
||||
pwd,
|
||||
nominated,
|
||||
)
|
||||
}
|
||||
|
||||
fn create_packet(
|
||||
message_type: [u8; 2],
|
||||
transaction_id: &TransactionId,
|
||||
username: &[u8],
|
||||
pwd: &[u8],
|
||||
nominated: bool,
|
||||
) -> Vec<u8> {
|
||||
let username = write_stun_attribute(AttributeId::USERNAME, username);
|
||||
let nomination = if nominated {
|
||||
Some(write_stun_attribute(AttributeId::NOMINATION, Empty {}))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let hmaced_attrs = (username, nomination);
|
||||
let hmaced_attrs_len = hmaced_attrs.written_len();
|
||||
let dummy_hmac = write_stun_attribute(AttributeId::MESSAGE_INTEGRITY, [0u8; HMAC_LEN]);
|
||||
let fingerprinted_attrs = (hmaced_attrs, dummy_hmac);
|
||||
let fingerprinted_attrs_len = fingerprinted_attrs.written_len();
|
||||
let dummy_fingerprint = write_stun_attribute(AttributeId::FINGERPRINT, 0u32);
|
||||
let attrs = (fingerprinted_attrs, dummy_fingerprint);
|
||||
let attrs_len = attrs.written_len();
|
||||
let header = (message_type, attrs_len as u16, transaction_id);
|
||||
let header_len = header.written_len();
|
||||
let mut packet = (header, attrs).to_vec();
|
||||
|
||||
let write_len_in_header = |packet: &mut [u8], len: usize| {
|
||||
packet[2..4].copy_from_slice(&(len as u16).to_be_bytes());
|
||||
};
|
||||
|
||||
let write_value_in_attr =
|
||||
|packet: &mut [u8], attr_start_relative_to_body: usize, value: &[u8]| {
|
||||
packet[header_len + attr_start_relative_to_body + ATTR_HEADER_LEN..][..value.len()]
|
||||
.copy_from_slice(value);
|
||||
};
|
||||
|
||||
// ICE HMACs are strange in that they are computed without the HMAC attribute,
|
||||
// but are computed with a length that includes the HMAC attribute.
|
||||
write_len_in_header(&mut packet, fingerprinted_attrs_len);
|
||||
let hmac_value = {
|
||||
let mut hmac = Hmac::<Sha1>::new_from_slice(pwd).expect("All key lengths are valid");
|
||||
hmac.update(&packet[..header_len + hmaced_attrs_len]);
|
||||
hmac.finalize().into_bytes()
|
||||
};
|
||||
write_value_in_attr(&mut packet, hmaced_attrs_len, &hmac_value);
|
||||
|
||||
write_len_in_header(&mut packet, attrs_len);
|
||||
let fingerprint_value = {
|
||||
FINGERPRINT_XOR_VALUE
|
||||
^ crc32::checksum_ieee(&packet[..header_len + fingerprinted_attrs_len])
|
||||
};
|
||||
write_value_in_attr(
|
||||
&mut packet,
|
||||
fingerprinted_attrs_len,
|
||||
&fingerprint_value.to_be_bytes(),
|
||||
);
|
||||
|
||||
packet
|
||||
}
|
||||
|
||||
fn write_stun_attribute(attribute_id: u16, value: impl Writer) -> impl Writer {
|
||||
let value_len = value.written_len();
|
||||
let padded_len = round_up_to_multiple_of::<4>(value_len);
|
||||
let padding_len = padded_len - value_len;
|
||||
|
||||
let value_len =
|
||||
u16::try_from(value.written_len()).expect("STUN attribute is less than u16::MAX in len.");
|
||||
let padding = vec![0u8; padding_len];
|
||||
(attribute_id, value_len, value, padding)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use hex_literal::hex;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_join_username() {
|
||||
assert_eq!(b"B:A".to_vec(), join_username(b"A", b"B"));
|
||||
assert_eq!(b"DEF:ABC".to_vec(), join_username(b"ABC", b"DEF"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ice_packets() {
|
||||
let transaction_id = [1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
|
||||
let request_username = b"server:client";
|
||||
let response_username = b"client:server";
|
||||
let pwd = b"this should be a long pwd";
|
||||
let nominated = true;
|
||||
let expected_request: &[u8] = &hex!(
|
||||
"
|
||||
/* header */ 0001 0038 0102030405060708090a0b0c0d0e0f10
|
||||
/* username */ 0006 000D 7365727665723a636c69656e74000000
|
||||
/* nomination */ 0025 0000
|
||||
/* hmac */ 0008 0014 73df552e08ec7ceef7f2056411ec82115ba1198e
|
||||
/* fingerprint */ 8028 0004 f2273ea4
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
expected_request,
|
||||
create_binding_request_packet(&transaction_id, request_username, pwd, nominated)
|
||||
);
|
||||
|
||||
let nominated = false;
|
||||
let expected_response: &[u8] = &hex!(
|
||||
"
|
||||
/* header */ 0101 0034 0102030405060708090a0b0c0d0e0f10
|
||||
/* username */ 0006 000D 636c69656e743a736572766572000000
|
||||
/* hmac */ 0008 0014 b7e18a20e28e82c1f0168c223b6e8c2cab599e2e
|
||||
/* fingerprint */ 8028 0004 dedaa207
|
||||
"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
expected_response,
|
||||
create_binding_response_packet(&transaction_id, response_username, pwd, nominated)
|
||||
);
|
||||
}
|
||||
|
||||
mod header_identification_tests {
|
||||
use hex_literal::hex;
|
||||
|
||||
use super::BindingRequest;
|
||||
|
||||
#[test]
|
||||
fn looks_like_binding_request_header() {
|
||||
assert!(BindingRequest::looks_like_header(&hex!(
|
||||
"0001 0000 2112A442"
|
||||
)));
|
||||
assert!(BindingRequest::looks_like_header(&hex!(
|
||||
"0001 FFFF 2112A442"
|
||||
)));
|
||||
assert!(BindingRequest::looks_like_header(&hex!(
|
||||
"0001 0000 2112A442 01"
|
||||
)));
|
||||
assert!(BindingRequest::looks_like_header(&hex!(
|
||||
"0001 FFFF 2112A442 0102"
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn does_not_look_like_binding_request_header() {
|
||||
assert!(
|
||||
!BindingRequest::looks_like_header(&hex!("0001 0000 2112A4")),
|
||||
"Too short"
|
||||
);
|
||||
assert!(
|
||||
!BindingRequest::looks_like_header(&hex!("0101 0000 2112A442")),
|
||||
"Wrong first byte"
|
||||
);
|
||||
assert!(
|
||||
!BindingRequest::looks_like_header(&hex!("0002 0000 2112A442")),
|
||||
"Wrong second byte"
|
||||
);
|
||||
assert!(
|
||||
!BindingRequest::looks_like_header(&hex!("0001 0000 FF12A442")),
|
||||
"Wrong first byte of magic"
|
||||
);
|
||||
assert!(
|
||||
!BindingRequest::looks_like_header(&hex!("0001 0000 2112A4FF")),
|
||||
"Wrong last byte of magic"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
mod parse_binding_requests_failure_tests {
|
||||
use hex_literal::hex;
|
||||
|
||||
use super::{BindingRequest, ParseError};
|
||||
|
||||
#[test]
|
||||
fn prevent_empty_packet() {
|
||||
assert_eq!(
|
||||
Some(ParseError::IncompleteHeader(0)),
|
||||
BindingRequest::parse(&[]).err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prevent_incomplete_header() {
|
||||
let packet: &[u8] = &hex!("0001 004c 2112a4422b6a714565766478326f5a");
|
||||
assert_eq!(
|
||||
Some(ParseError::IncompleteHeader(19)),
|
||||
BindingRequest::parse(packet).err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prevent_header_only() {
|
||||
let packet: &[u8] = &hex!("0001 0000 2112a44271536e422b33695952394469");
|
||||
assert_eq!(
|
||||
Some(ParseError::MissingUsernameAttribute),
|
||||
BindingRequest::parse(packet).err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prevent_missing_username() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 0040 2112a44271536e422b33695952394469
|
||||
/* username */ // 0006 0025 33643462313062303033306363646638353762393063663962373032353939383a416d3356000000
|
||||
c057 0004 00010032
|
||||
802a 0008 f66e672cbb22165d
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
0008 0014749225e1798cdcf19c72a48d36b8de0da89effb6
|
||||
8028 000456d8838f
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
Some(ParseError::MissingUsernameAttribute),
|
||||
BindingRequest::parse(packet).err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prevent_missing_hmac() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 004c 2112a4422b6a714565766478326f5a55
|
||||
0006 0025 33643462313062303033306363646638353762393063663962373032353939383a416d3356000000
|
||||
c057 0004 00010032
|
||||
802a 0008 eef8294dc5f11c9c
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
/* hmac */ // 0008 0014749225e1798cdcf19c72a48d36b8de0da89effb6
|
||||
/* fingerprint */ // 8028 000456d8838f
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
Some(ParseError::MissingHMacAttribute),
|
||||
BindingRequest::parse(packet).err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prevent_missing_fingerprint() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 0064 2112a442535a6370696c496c46696d33
|
||||
0006 0025 33643462313062303033306363646638353762393063663962373032353939383a416d3356000000
|
||||
c057 0004 00010032
|
||||
802a 0008 eef8294dc5f11c9c
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
/* hmac */ 0008 0014 4ac4f93cd6809b35be287203a673b3033b2769da
|
||||
/* fingerprint */ // 8028 000456d8838f
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
Some(ParseError::MissingFingerprintAttribute),
|
||||
BindingRequest::parse(packet).err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prevent_fingerprint_before_hmac() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 006c 2112a44238656d797950694b78506e6e
|
||||
0006 0025 63636431623031363037303065383364616232386435303135636563346362653a31315453000000
|
||||
c057 0004 00010032
|
||||
802a 0008 f66e672cbb22165d
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
/* fingerprint */ 8028 0004 d48fbba0
|
||||
/* hmac */ 0008 0014 5be1331d09c86d8cbfaf48f64687669096d32d3b
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
Some(ParseError::FingerprintBeforeHMac),
|
||||
BindingRequest::parse(packet).err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prevent_wrong_hmac_length() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 0068 2112a442516b77624e657155454a4635
|
||||
0006 0025 63636431623031363037303065383364616232386435303135636563346362653a31315453000000
|
||||
c057 0004 00010032
|
||||
802a 0008 f66e672cbb22165d
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
/* hmac */ 0008 0010 4de1e857695f0804f5f8e9fcf3150977
|
||||
8028 0004 b7b01d0b
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
Some(ParseError::WrongHMacLength(16, 20)),
|
||||
BindingRequest::parse(packet).err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prevent_something_between_hmac_and_fingerprint() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 0070 2112a442656b72774b55515041495476
|
||||
0006 0025 33643462313062303033306363646638353762393063663962373032353939383a416d3356000000
|
||||
c057 0004 00010032
|
||||
802a 0008 eef8294dc5f11c9c
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
/* hmac */ 0008 0014 d385d7f2f2222979333c405cea0b291444592aca
|
||||
abcd 0000
|
||||
/* fingerprint */ 8028 000429560496
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
Some(ParseError::ExpectedFingerprint(0xabcd)),
|
||||
BindingRequest::parse(packet).err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prevent_appending_1_byte_to_packet() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 006c 2112a442665175732f33426771346c7a
|
||||
0006 0025 33643462313062303033306363646638353762393063663962373032353939383a416d3356000000
|
||||
c057 0004 00010032
|
||||
802a 0008 eef8294dc5f11c9c
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
0008 0014 62d3395bf9d117fa6b915cccd60d4dc141d39c92
|
||||
/* fingerprint */ 8028 0004 1698f47f
|
||||
00 // appended 1 bytes past declared message length
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
Some(ParseError::DeclaredMessageLengthMismatch(
|
||||
0x006c,
|
||||
0x006c + 1
|
||||
)),
|
||||
BindingRequest::parse(packet).err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prevent_appending_whole_attribute_to_packet() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 006c 2112a442665175732f33426771346c7a
|
||||
0006 0025 33643462313062303033306363646638353762393063663962373032353939383a416d3356000000
|
||||
c057 0004 00010032
|
||||
802a 0008 eef8294dc5f11c9c
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
0008 0014 62d3395bf9d117fa6b915cccd60d4dc141d39c92
|
||||
/* fingerprint */ 8028 0004 1698f47f
|
||||
abcd 0000 // appended 4 bytes past declared message length
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
Some(ParseError::DeclaredMessageLengthMismatch(
|
||||
0x006c,
|
||||
0x006c + 4
|
||||
)),
|
||||
BindingRequest::parse(packet).err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prevent_attribute_after_fingerprint_within_declared_message_length() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 0070 2112a442665175732f33426771346c7a
|
||||
0006 0025 33643462313062303033306363646638353762393063663962373032353939383a416d3356000000
|
||||
c057 0004 00010032
|
||||
802a 0008 eef8294dc5f11c9c
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
0008 0014 62d3395bf9d117fa6b915cccd60d4dc141d39c92
|
||||
/* fingerprint */ 8028 0004 1698f47f
|
||||
abcd 0000
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
Some(ParseError::AttributeAfterFingerprint(0xabcd)),
|
||||
BindingRequest::parse(packet).err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prevent_attribute_past_end_of_packet() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 006c 2112a442665175732f33426771346c7a
|
||||
/* Too long */ 0006 0069 33643462313062303033306363646638353762393063663962373032353939383a416d3356000000
|
||||
c057 0004 00010032
|
||||
802a 0008 eef8294dc5f11c9c
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
0008 0014 62d3395bf9d117fa6b915cccd60d4dc141d39c92
|
||||
8028 0004 1698f47f
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
Some(ParseError::AttributeRangePastPacketEnd(0x0006, 1)),
|
||||
BindingRequest::parse(packet).err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prevent_wrong_length_fingerprint() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 006b 2112a442665175732f33426771346c7a
|
||||
0006 0025 33643462313062303033306363646638353762393063663962373032353939383a416d3356000000
|
||||
c057 0004 00010032
|
||||
802a 0008 eef8294dc5f11c9c
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
0008 0014 62d3395bf9d117fa6b915cccd60d4dc141d39c92
|
||||
/* fingerprint */ 8028 0003 1698f4
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
Some(ParseError::WrongFingerprintLength(3, 4)),
|
||||
BindingRequest::parse(packet).err()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
mod parse_binding_request_tests {
|
||||
use hex_literal::hex;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_with_nomination() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 006c 2112a44238656d797950694b78506e6e
|
||||
/* username */ 0006 0025 63636431623031363037303065383364616232386435303135636563346362653a31315453000000
|
||||
c057 0004 00010032
|
||||
802a 0008 f66e672cbb22165d
|
||||
/* nominated */ 0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
/* hmac */ 0008 0014 5be1331d09c86d8cbfaf48f64687669096d32d3b
|
||||
/* fingerprint */ 8028 0004 d48fbba0
|
||||
"
|
||||
);
|
||||
let packet = BindingRequest::parse(packet).expect("Parsed");
|
||||
assert!(packet.nominated());
|
||||
assert_eq!(
|
||||
hex!("63636431623031363037303065383364616232386435303135636563346362653a31315453"),
|
||||
packet.username()
|
||||
);
|
||||
assert_eq!(
|
||||
hex!("5be1331d09c86d8cbfaf48f64687669096d32d3b"),
|
||||
packet.hmac()
|
||||
);
|
||||
assert_eq!(hex!("d48fbba0"), packet.fingerprint());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_without_nomination() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 0068 2112a44271536e422b33695952394469
|
||||
/* username */ 0006 0025 63636431623031363037303065383364616232386435303135636563346362653a31315453000000
|
||||
c057 0004 00010032
|
||||
802a 0008 f66e672cbb22165d
|
||||
0024 0004 6e7f1eff
|
||||
/* hmac */ 0008 0014 749225e1798cdcf19c72a48d36b8de0da89effb6
|
||||
/* fingerprint */ 8028 0004 56d8838f
|
||||
"
|
||||
);
|
||||
let packet = BindingRequest::parse(packet).expect("Parsed");
|
||||
assert!(!packet.nominated());
|
||||
assert_eq!(
|
||||
hex!("63636431623031363037303065383364616232386435303135636563346362653a31315453"),
|
||||
packet.username()
|
||||
);
|
||||
assert_eq!(
|
||||
hex!("749225e1798cdcf19c72a48d36b8de0da89effb6"),
|
||||
packet.hmac()
|
||||
);
|
||||
assert_eq!(hex!("56d8838f"), packet.fingerprint());
|
||||
}
|
||||
}
|
||||
|
||||
mod hmac_verification_tests {
|
||||
use hex_literal::hex;
|
||||
|
||||
use super::BindingRequest;
|
||||
|
||||
#[test]
|
||||
fn hmac_verify() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 006c 2112a442716e517877595a6c5853332f
|
||||
0006 0025 63636431623031363037303065383364616232386435303135636563346362653a31315453000000
|
||||
c057 0004 00010032
|
||||
802a 0008 f66e672cbb22165d
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
/* hmac */ 0008 0014 f2929850b442ffc08489031630696a4473534113
|
||||
/* fingerprint */ 8028 0004 4654be07
|
||||
"
|
||||
);
|
||||
|
||||
let ice_packet = BindingRequest::parse(packet).expect("Parsed");
|
||||
|
||||
assert!(ice_packet
|
||||
.verify_hmac(b"000102030405060708090a0b0c0d0e0f")
|
||||
.is_ok());
|
||||
|
||||
assert!(
|
||||
ice_packet
|
||||
.verify_hmac(b"0102030405060708090a0b0c0d0e0f10")
|
||||
.is_err(),
|
||||
"Should not verify with another password"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hmac_does_not_verify_if_packet_manipulated() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 006c 2112a442716e517877595a6c5853332f
|
||||
0006 0025 63636431623031363037303065383364616232386435303135636563346362653a31315453000000
|
||||
c057 0004 00010033
|
||||
802a 0008 f66e672cbb22165d
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
/* hmac */ 0008 0014 f2929850b442ffc08489031630696a4473534113
|
||||
/* fingerprint */ 8028 0004 4654be07
|
||||
"
|
||||
);
|
||||
|
||||
let ice_packet = BindingRequest::parse(packet).expect("Parsed");
|
||||
|
||||
assert!(ice_packet
|
||||
.verify_hmac(b"000102030405060708090a0b0c0d0e0f")
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hmac_does_not_verify_if_hmac_modified_in_packet() {
|
||||
let packet: &[u8] = &hex!(
|
||||
"
|
||||
0001 006c 2112a442716e517877595a6c5853332f
|
||||
0006 0025 63636431623031363037303065383364616232386435303135636563346362653a31315453000000
|
||||
c057 0004 00010032
|
||||
802a 0008 f66e672cbb22165d
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
/* hmac */ 0008 0014 f2929850b442ffc08489031630696a4473534114
|
||||
/* fingerprint */ 8028 0004 4654be07
|
||||
"
|
||||
);
|
||||
|
||||
let ice_packet = BindingRequest::parse(packet).expect("Parsed");
|
||||
|
||||
assert!(ice_packet
|
||||
.verify_hmac(b"000102030405060708090a0b0c0d0e0f")
|
||||
.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_a_verified_binding_response_from_a_binding_request() {
|
||||
let packet = &hex!(
|
||||
"
|
||||
/* header */ 0001 006c 2112a442656b72774b55515041495476
|
||||
/* username */ 0006 0025 33643462313062303033306363646638353762393063663962373032353939383a416d3356000000
|
||||
c057 0004 00010032
|
||||
802a 0008 eef8294dc5f11c9c
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
/* hmac */ 0008 0014 4dc36c18cad0520147769290da6ec1c8996355b0
|
||||
/* fingerprint */ 8028 0004 8eff8489
|
||||
"
|
||||
);
|
||||
|
||||
let response_packet = BindingRequest::parse(packet)
|
||||
.expect("Parsed")
|
||||
.verify_hmac(b"000102030405060708090a0b0c0d0e0f")
|
||||
.expect("Verified")
|
||||
.to_binding_response(
|
||||
&hex!("63636431623031363037303065383364616232386435303135636563346362653a31315453"),
|
||||
b"0102030405060708090a0b0c0d0e0f10",
|
||||
);
|
||||
|
||||
let expected_response: &[u8] = &hex!(
|
||||
"
|
||||
/* header */ 0101 006c 2112a442656b72774b55515041495476
|
||||
/* username */ 0006 0025 63636431623031363037303065383364616232386435303135636563346362653a31315453000000
|
||||
c057 0004 00010032
|
||||
802a 0008 eef8294dc5f11c9c
|
||||
0025 0000
|
||||
0024 0004 6e7f1eff
|
||||
/* hmac */ 0008 0014 b615c21d9e81e3786dbf40a5ad3825d2f39fbb37
|
||||
/* fingerprint */ 8028 0004 137e0acf
|
||||
"
|
||||
);
|
||||
assert_eq!(expected_response, &response_packet);
|
||||
}
|
||||
}
|
||||
25
src/lib.rs
Normal file
25
src/lib.rs
Normal file
@ -0,0 +1,25 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
#[macro_use]
|
||||
pub mod metrics;
|
||||
|
||||
pub mod audio;
|
||||
pub mod call;
|
||||
pub mod common;
|
||||
pub mod config;
|
||||
pub mod connection;
|
||||
pub mod dtls;
|
||||
pub mod googcc;
|
||||
pub mod http_server;
|
||||
pub mod ice;
|
||||
pub mod metrics_server;
|
||||
pub mod protos;
|
||||
pub mod rtp;
|
||||
pub mod sfu;
|
||||
pub mod signaling_server;
|
||||
pub mod transportcc;
|
||||
pub mod udp_server;
|
||||
pub mod vp8;
|
||||
202
src/main.rs
Normal file
202
src/main.rs
Normal file
@ -0,0 +1,202 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
#[macro_use]
|
||||
extern crate lazy_static;
|
||||
#[macro_use]
|
||||
extern crate log;
|
||||
|
||||
use std::sync::{atomic::AtomicBool, Arc};
|
||||
|
||||
use anyhow::Result;
|
||||
use calling_server::{
|
||||
common::{DataRate, Duration, Instant},
|
||||
config, http_server, metrics_server,
|
||||
sfu::Sfu,
|
||||
signaling_server, udp_server,
|
||||
};
|
||||
use env_logger::Env;
|
||||
use parking_lot::Mutex;
|
||||
use rand::Rng;
|
||||
use rcgen::{Certificate, CertificateParams, DnType};
|
||||
use structopt::StructOpt;
|
||||
use tokio::{
|
||||
runtime,
|
||||
signal::unix::{signal, SignalKind},
|
||||
sync::{mpsc, oneshot},
|
||||
};
|
||||
|
||||
lazy_static! {
|
||||
// Load the config and treat it as a read-only static value.
|
||||
static ref CONFIG: config::Config = {
|
||||
let mut config = config::Config::from_args();
|
||||
|
||||
// Generate the server's x.509 certificate and key.
|
||||
let mut params = CertificateParams::new(vec![]);
|
||||
params.distinguished_name.push(DnType::CommonName, "WebRTC");
|
||||
params.alg = &rcgen::PKCS_ECDSA_P256_SHA256;
|
||||
params.serial_number = Some(rand::thread_rng().gen::<u64>());
|
||||
let cert = Certificate::from_params(params).unwrap();
|
||||
|
||||
config.server_certificate_der = cert.serialize_der().expect("certificate should exist");
|
||||
config.server_private_key_der = cert.serialize_private_key_der();
|
||||
|
||||
config
|
||||
};
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
fn print_config(config: &'static config::Config) {
|
||||
info!("config:");
|
||||
info!(" {:38}{}", "binding_ip:", config.binding_ip);
|
||||
info!(" {:38}{:?}", "ice_candidate_ip:", config.ice_candidate_ip);
|
||||
info!(" {:38}{}", "ice_candidate_port:", config.ice_candidate_port);
|
||||
info!(" {:38}{:?}", "signaling_ip:", config.signaling_ip);
|
||||
info!(" {:38}{}", "signaling_port:", config.signaling_port);
|
||||
info!(" {:38}{}", "max_clients_per_call:", config.max_clients_per_call);
|
||||
info!(" {:38}{} ({})", "initial_target_send_rate_kbps:", config.initial_target_send_rate_kbps, DataRate::from_kbps(config.initial_target_send_rate_kbps));
|
||||
info!(" {:38}{}", "tick_interval_ms:", config.tick_interval_ms);
|
||||
info!(" {:38}{:?}", "diagnostics_interval_secs:", config.diagnostics_interval_secs);
|
||||
info!(" {:38}{}", "active_speaker_message_interval_ms:", config.active_speaker_message_interval_ms);
|
||||
info!(" {:38}{}", "inactivity_check_interval_secs:", config.inactivity_check_interval_secs);
|
||||
info!(" {:38}{}", "inactivity_timeout_secs:", config.inactivity_timeout_secs);
|
||||
info!(" {:38}{}", "datadog metrics:",
|
||||
match &config.metrics.datadog {
|
||||
Some(host) => host,
|
||||
None => "Off",
|
||||
});
|
||||
}
|
||||
|
||||
/// Waits for a SIGINT or SIGTERM signal and returns. Can be cancelled
|
||||
/// by sending something to the channel.
|
||||
pub async fn wait_for_signal(mut canceller: mpsc::Receiver<()>) {
|
||||
tokio::select!(
|
||||
_ = async {
|
||||
if let Ok(mut stream) = signal(SignalKind::interrupt()) {
|
||||
stream.recv().await;
|
||||
}
|
||||
} => {
|
||||
// Handle SIGINT for ctrl+c and debug stop command.
|
||||
info!("terminating by signal: SIGINT");
|
||||
},
|
||||
_ = async {
|
||||
if let Ok(mut stream) = signal(SignalKind::terminate()) {
|
||||
stream.recv().await;
|
||||
}
|
||||
} => {
|
||||
// Handle SIGTERM for docker stop command.
|
||||
info!("terminating by signal: SIGTERM");
|
||||
},
|
||||
_ = async { canceller.recv().await } => {},
|
||||
)
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
std::env::set_var("RUST_BACKTRACE", "full");
|
||||
|
||||
// Initialize logging.
|
||||
env_logger::Builder::from_env(
|
||||
Env::default()
|
||||
.default_filter_or("calling_server=info")
|
||||
.default_write_style_or("never"),
|
||||
)
|
||||
.format_timestamp_millis()
|
||||
.init();
|
||||
|
||||
info!("Signal Calling Server starting up...");
|
||||
|
||||
// Log information about the environment we are running in.
|
||||
info!(
|
||||
"calling_server: v{}",
|
||||
option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
|
||||
);
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
{
|
||||
match option_env!("RUSTFLAGS") {
|
||||
None => {
|
||||
warn!("for optimal performance, build with RUSTFLAGS=\"-C target-cpu=native\" or better");
|
||||
}
|
||||
Some(rust_flags) => {
|
||||
info!("built with: RUSTFLAGS=\"{}\"", rust_flags);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the command line arguments.
|
||||
let config = &CONFIG;
|
||||
print_config(config);
|
||||
|
||||
// Create the shared SFU context.
|
||||
let sfu: Arc<Mutex<Sfu>> = Arc::new(Mutex::new(Sfu::new(Instant::now(), config)?));
|
||||
|
||||
// Create a threaded tokio runtime. By default, starts a worker thread
|
||||
// for each core on the system.
|
||||
let threaded_rt = runtime::Runtime::new()?;
|
||||
|
||||
let (signaling_ender_tx, signaling_ender_rx) = oneshot::channel();
|
||||
let (udp_ender_tx, udp_ender_rx) = oneshot::channel();
|
||||
let (metrics_ender_tx, metrics_ender_rx) = oneshot::channel();
|
||||
let (signal_canceller_tx, signal_canceller_rx) = mpsc::channel(1);
|
||||
let is_healthy = Arc::new(AtomicBool::new(true));
|
||||
|
||||
let sfu_clone_for_udp = sfu.clone();
|
||||
let sfu_clone_for_metrics = sfu.clone();
|
||||
let signal_canceller_tx_clone_for_udp = signal_canceller_tx.clone();
|
||||
let signal_canceller_tx_clone_for_metrics = signal_canceller_tx.clone();
|
||||
let is_healthy_clone_for_udp = is_healthy.clone();
|
||||
|
||||
let _ = threaded_rt.block_on(async {
|
||||
// Start the signaling server, either the signaling_server for production
|
||||
// or the http_server for testing.
|
||||
let signaling_server_handle = tokio::spawn(async move {
|
||||
if config.signaling_ip.is_some() {
|
||||
let _ = signaling_server::start(config, sfu, signaling_ender_rx, is_healthy).await;
|
||||
} else {
|
||||
let _ = http_server::start(config, sfu, signaling_ender_rx, is_healthy).await;
|
||||
}
|
||||
let _ = signal_canceller_tx.send(()).await;
|
||||
});
|
||||
|
||||
// Start the udp_server.
|
||||
let udp_server_handle = tokio::spawn(async move {
|
||||
let _ = udp_server::start(
|
||||
config,
|
||||
sfu_clone_for_udp,
|
||||
udp_ender_rx,
|
||||
is_healthy_clone_for_udp,
|
||||
)
|
||||
.await;
|
||||
let _ = signal_canceller_tx_clone_for_udp.send(()).await;
|
||||
});
|
||||
|
||||
// Start the metrics_server.
|
||||
let metrics_server_handle = tokio::spawn(async move {
|
||||
let _ = metrics_server::start(config, sfu_clone_for_metrics, metrics_ender_rx).await;
|
||||
let _ = signal_canceller_tx_clone_for_metrics.send(()).await;
|
||||
});
|
||||
|
||||
// Wait for any signals to be detected, or cancel due to one of the
|
||||
// servers not being able to be started (the channel is buffered).
|
||||
let _ = wait_for_signal(signal_canceller_rx).await;
|
||||
|
||||
// Gracefully exit the servers if needed.
|
||||
let _ = signaling_ender_tx.send(());
|
||||
let _ = udp_ender_tx.send(());
|
||||
let _ = metrics_ender_tx.send(());
|
||||
|
||||
// Wait for the servers to exit.
|
||||
let _ = tokio::join!(
|
||||
signaling_server_handle,
|
||||
udp_server_handle,
|
||||
metrics_server_handle,
|
||||
);
|
||||
});
|
||||
|
||||
info!("shutting down the runtime");
|
||||
threaded_rt.shutdown_timeout(Duration::from_millis(500).into());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
18
src/metrics.rs
Normal file
18
src/metrics.rs
Normal file
@ -0,0 +1,18 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
pub use datadog_statsd::*;
|
||||
pub use histogram::*;
|
||||
pub use macros::*;
|
||||
pub use reporter::*;
|
||||
pub use timing_options::*;
|
||||
|
||||
#[macro_use]
|
||||
mod macros;
|
||||
mod datadog_statsd;
|
||||
mod histogram;
|
||||
mod reporter;
|
||||
mod test_utils;
|
||||
mod timing_options;
|
||||
704
src/metrics/datadog_statsd.rs
Normal file
704
src/metrics/datadog_statsd.rs
Normal file
@ -0,0 +1,704 @@
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
// Based on https://github.com/minato128/rust-dogstatsd
|
||||
// forked from https://github.com/markstory/rust-statsd
|
||||
//
|
||||
extern crate rand;
|
||||
|
||||
use std::{
|
||||
error, fmt,
|
||||
io::Error,
|
||||
mem,
|
||||
net::{AddrParseError, SocketAddr, ToSocketAddrs, UdpSocket},
|
||||
};
|
||||
|
||||
use log::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum StatsdError {
|
||||
IoError(Error),
|
||||
AddrParseError(String),
|
||||
}
|
||||
|
||||
pub trait EventSink {
|
||||
fn send(&mut self, data: String);
|
||||
fn flush(&mut self);
|
||||
}
|
||||
|
||||
impl From<AddrParseError> for StatsdError {
|
||||
fn from(_: AddrParseError) -> StatsdError {
|
||||
StatsdError::AddrParseError("Address parsing error".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Error> for StatsdError {
|
||||
fn from(err: Error) -> StatsdError {
|
||||
StatsdError::IoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for StatsdError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
StatsdError::IoError(ref e) => write!(f, "{}", e),
|
||||
StatsdError::AddrParseError(ref e) => write!(f, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for StatsdError {}
|
||||
|
||||
/// Client socket for statsd servers.
|
||||
///
|
||||
/// After creating a metric you can use `Client`
|
||||
/// to send metrics to the configured statsd server
|
||||
pub struct Client<T: EventSink> {
|
||||
sink: T,
|
||||
prefix: String,
|
||||
constant_tags: Vec<String>,
|
||||
}
|
||||
|
||||
pub struct UdpEventSink {
|
||||
socket: UdpSocket,
|
||||
server_address: SocketAddr,
|
||||
}
|
||||
|
||||
impl UdpEventSink {
|
||||
pub fn new<T: ToSocketAddrs>(host: T) -> Result<UdpEventSink, StatsdError> {
|
||||
let server_address = host
|
||||
.to_socket_addrs()?
|
||||
.next()
|
||||
.ok_or_else(|| StatsdError::AddrParseError("Address parsing error".to_string()))?;
|
||||
|
||||
// Bind to a generic port as we'll only be writing on this
|
||||
// socket.
|
||||
let socket = if server_address.is_ipv4() {
|
||||
UdpSocket::bind("0.0.0.0:0")?
|
||||
} else {
|
||||
UdpSocket::bind("[::]:0")?
|
||||
};
|
||||
Ok(UdpEventSink {
|
||||
socket,
|
||||
server_address,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl EventSink for UdpEventSink {
|
||||
fn send(&mut self, data: String) {
|
||||
let _ = self.socket.send_to(data.as_bytes(), self.server_address);
|
||||
}
|
||||
|
||||
fn flush(&mut self) {
|
||||
// nothing to flush, everything was sent immediately
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PipelineSink<'a, T: EventSink> {
|
||||
sink: &'a mut T,
|
||||
max_udp_size: usize,
|
||||
buffer: String,
|
||||
}
|
||||
|
||||
impl<'a, T: EventSink> PipelineSink<'a, T> {
|
||||
fn new(sink: &'a mut T) -> PipelineSink<'a, T> {
|
||||
const COMMODITY_INTERNET_PACKET_SIZE: usize = 512;
|
||||
|
||||
Self::new_with_size(sink, COMMODITY_INTERNET_PACKET_SIZE)
|
||||
}
|
||||
|
||||
/// See https://github.com/statsd/statsd/blob/master/docs/metric_types.md#multi-metric-packets
|
||||
/// for guidance. 512 is a safe minimum.
|
||||
fn new_with_size(sink: &'a mut T, max_udp_size: usize) -> PipelineSink<'a, T> {
|
||||
Self {
|
||||
sink,
|
||||
max_udp_size,
|
||||
buffer: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: EventSink> Drop for PipelineSink<'_, T> {
|
||||
fn drop(&mut self) {
|
||||
self.flush();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: EventSink> EventSink for PipelineSink<'_, T> {
|
||||
fn send(&mut self, data: String) {
|
||||
if data.len() > self.max_udp_size {
|
||||
warn!(
|
||||
"Not able to send metric packet of length {}, as was over udp size {}",
|
||||
data.len(),
|
||||
self.max_udp_size
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if self.buffer.len() + data.len() >= self.max_udp_size {
|
||||
// cannot buffer, must send this
|
||||
let buffer_contents = mem::replace(&mut self.buffer, data);
|
||||
self.sink.send(buffer_contents);
|
||||
} else {
|
||||
// queue for later
|
||||
if self.buffer.is_empty() {
|
||||
self.buffer = data;
|
||||
} else {
|
||||
self.buffer += "\n";
|
||||
self.buffer += data.as_str();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) {
|
||||
if !self.buffer.is_empty() {
|
||||
let buffer_contents = mem::take(&mut self.buffer);
|
||||
self.sink.send(buffer_contents);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: EventSink> Client<E> {
|
||||
/// Construct a new statsd client given a sink
|
||||
pub fn new(sink: E, prefix: &str, constant_tags: Option<Vec<&str>>) -> Client<E> {
|
||||
Client {
|
||||
sink,
|
||||
prefix: prefix.to_string(),
|
||||
constant_tags: constant_tags
|
||||
.unwrap_or_default()
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Increment a metric by 1
|
||||
///
|
||||
/// This modifies a counter with an effective sampling rate of 1.0.
|
||||
pub fn incr(&mut self, metric: &str, tags: &Option<Vec<&str>>) {
|
||||
self.count(metric, 1.0, tags);
|
||||
}
|
||||
|
||||
/// Decrement a metric by 1
|
||||
///
|
||||
/// This modifies a counter with an effective sampling rate of 1.0.
|
||||
pub fn decr(&mut self, metric: &str, tags: &Option<Vec<&str>>) {
|
||||
self.count(metric, -1.0, tags);
|
||||
}
|
||||
|
||||
/// Modify a counter by `value`.
|
||||
///
|
||||
/// Will increment or decrement a counter by `value` with a sampling rate of 1.0.
|
||||
pub fn count(&mut self, metric: &str, value: f64, tags: &Option<Vec<&str>>) {
|
||||
let data = self.prepare_with_tags(format!("{}:{}|c", metric, value), tags);
|
||||
self.send(data);
|
||||
}
|
||||
|
||||
/// Set a gauge value.
|
||||
pub fn gauge(&mut self, metric: &str, value: f64, tags: &Option<Vec<&str>>) {
|
||||
let data = self.prepare_with_tags(format!("{}:{}|g", metric, value), tags);
|
||||
self.send(data);
|
||||
}
|
||||
|
||||
/// Send a timer value.
|
||||
pub fn timer(&mut self, metric: &str, milliseconds: f64, tags: &Option<Vec<&str>>) {
|
||||
let data = self.prepare_with_tags(format!("{}:{}|ms", metric, milliseconds), tags);
|
||||
self.send(data);
|
||||
}
|
||||
|
||||
/// Send a timer value at a specified sample rate in 0..1 range.
|
||||
pub fn timer_at_rate(
|
||||
&mut self,
|
||||
metric: &str,
|
||||
milliseconds: f64,
|
||||
rate: f64,
|
||||
tags: &Option<Vec<&str>>,
|
||||
) {
|
||||
let data =
|
||||
self.prepare_with_tags(format!("{}:{}|ms|@{}", metric, milliseconds, rate), tags);
|
||||
self.send(data);
|
||||
}
|
||||
|
||||
fn prepare<T: AsRef<str>>(&self, data: T) -> String {
|
||||
if self.prefix.is_empty() {
|
||||
data.as_ref().to_string()
|
||||
} else {
|
||||
format!("{}.{}", self.prefix, data.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_with_tags<T: AsRef<str>>(&self, data: T, tags: &Option<Vec<&str>>) -> String {
|
||||
self.append_tags(self.prepare(data), tags)
|
||||
}
|
||||
|
||||
fn append_tags<T: AsRef<str>>(&self, data: T, tags: &Option<Vec<&str>>) -> String {
|
||||
if self.constant_tags.is_empty() && tags.is_none() {
|
||||
data.as_ref().to_string()
|
||||
} else {
|
||||
let mut all_tags = self.constant_tags.clone();
|
||||
match tags {
|
||||
Some(v) => {
|
||||
for tag in v {
|
||||
all_tags.push(tag.to_string());
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// nothing to do
|
||||
}
|
||||
}
|
||||
format!("{}|#{}", data.as_ref(), all_tags.join(","))
|
||||
}
|
||||
}
|
||||
|
||||
/// Send data along to the sink.
|
||||
fn send(&mut self, data: String) {
|
||||
self.sink.send(data);
|
||||
}
|
||||
|
||||
/// Get a pipeline struct that allows optimizes the number of UDP
|
||||
/// packets used to send multiple metrics.
|
||||
pub fn pipeline(&mut self) -> Client<PipelineSink<E>> {
|
||||
Client {
|
||||
sink: PipelineSink::new(&mut self.sink),
|
||||
prefix: self.prefix.clone(),
|
||||
constant_tags: self.constant_tags.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pipeline_client_of_size(&mut self, max_udp_size: usize) -> Client<PipelineSink<E>> {
|
||||
Client {
|
||||
sink: PipelineSink::new_with_size(&mut self.sink, max_udp_size),
|
||||
prefix: self.prefix.clone(),
|
||||
constant_tags: self.constant_tags.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a histogram value.
|
||||
pub fn histogram(&mut self, metric: &str, value: f64, tags: &Option<Vec<&str>>) {
|
||||
let data = self.prepare_with_tags(format!("{}:{}|h", metric, value), tags);
|
||||
self.send(data);
|
||||
}
|
||||
|
||||
/// Send a histogram value at a specified sample rate in 0..1 range.
|
||||
pub fn histogram_at_rate(
|
||||
&mut self,
|
||||
metric: &str,
|
||||
value: f64,
|
||||
rate: f64,
|
||||
tags: &Option<Vec<&str>>,
|
||||
) {
|
||||
let data = self.prepare_with_tags(format!("{}:{}|h|@{}", metric, value, rate), tags);
|
||||
self.send(data);
|
||||
}
|
||||
|
||||
/// Send a distribution value.
|
||||
pub fn distribution(&mut self, metric: &str, value: f64, tags: &Option<Vec<&str>>) {
|
||||
let data = self.prepare_with_tags(format!("{}.d:{}|d", metric, value), tags);
|
||||
self.send(data);
|
||||
}
|
||||
|
||||
/// Send a distribution value at a specified sample rate in 0..1 range.
|
||||
pub fn distribution_at_rate(
|
||||
&mut self,
|
||||
metric: &str,
|
||||
value: f64,
|
||||
rate: f64,
|
||||
tags: &Option<Vec<&str>>,
|
||||
) {
|
||||
let data = self.prepare_with_tags(format!("{}.d:{}|d|@{}", metric, value, rate), tags);
|
||||
self.send(data);
|
||||
}
|
||||
|
||||
/// Send a event.
|
||||
pub fn event(
|
||||
&mut self,
|
||||
title: &str,
|
||||
text: &str,
|
||||
alert_type: AlertType,
|
||||
tags: &Option<Vec<&str>>,
|
||||
) {
|
||||
let mut d = vec![];
|
||||
d.push(format!("_e{{{},{}}}:{}", title.len(), text.len(), title));
|
||||
d.push(text.to_string());
|
||||
if alert_type != AlertType::Info {
|
||||
d.push(format!("t:{}", alert_type.to_string().to_lowercase()))
|
||||
}
|
||||
let event_with_tags = self.append_tags(d.join("|"), tags);
|
||||
self.send(event_with_tags)
|
||||
}
|
||||
|
||||
/// Send a service check.
|
||||
pub fn service_check(
|
||||
&mut self,
|
||||
service_check_name: &str,
|
||||
status: ServiceCheckStatus,
|
||||
tags: &Option<Vec<&str>>,
|
||||
) {
|
||||
let mut d = vec![];
|
||||
let status_code = (status as u32).to_string();
|
||||
d.push("_sc");
|
||||
d.push(service_check_name);
|
||||
d.push(&status_code);
|
||||
let sc_with_tags = self.append_tags(d.join("|"), tags);
|
||||
self.send(sc_with_tags)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum AlertType {
|
||||
Info,
|
||||
Error,
|
||||
Warning,
|
||||
Success,
|
||||
}
|
||||
|
||||
impl fmt::Display for AlertType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum ServiceCheckStatus {
|
||||
Ok = 0,
|
||||
Warning = 1,
|
||||
Critical = 2,
|
||||
Unknown = 3,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
extern crate rand;
|
||||
|
||||
use std::{cell::RefCell, rc::Rc};
|
||||
|
||||
use super::*;
|
||||
|
||||
struct MockServer {
|
||||
packets: Rc<RefCell<Vec<String>>>,
|
||||
}
|
||||
|
||||
struct MockUdpPort {
|
||||
packets: Rc<RefCell<Vec<String>>>,
|
||||
}
|
||||
|
||||
impl MockServer {
|
||||
fn new() -> MockServer {
|
||||
Self {
|
||||
packets: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn new_port(&self) -> MockUdpPort {
|
||||
MockUdpPort {
|
||||
packets: Rc::clone(&self.packets),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_packet(&mut self) -> String {
|
||||
let mut cell = self.packets.borrow_mut();
|
||||
cell.remove(0)
|
||||
}
|
||||
|
||||
fn expect_no_more_packets(&self) {
|
||||
assert_eq!(0, self.packets.borrow().len());
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> EventSink for MockUdpPort {
|
||||
fn send(&mut self, data: String) {
|
||||
self.packets.borrow_mut().push(data);
|
||||
}
|
||||
|
||||
fn flush(&mut self) {}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_gauge() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
|
||||
client.gauge("metric", 9.1, &None);
|
||||
|
||||
assert_eq!("myapp.metric:9.1|g", server.read_packet());
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_gauge_with_tags() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", Some(vec!["tag1", "tag2:value"]));
|
||||
|
||||
client.gauge("metric", 9.1, &Some(vec!["tag3", "tag4:value"]));
|
||||
|
||||
assert_eq!(
|
||||
"myapp.metric:9.1|g|#tag1,tag2:value,tag3,tag4:value",
|
||||
server.read_packet()
|
||||
);
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_gauge_without_prefix() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "", None);
|
||||
|
||||
client.gauge("metric", 9.1, &None);
|
||||
|
||||
assert_eq!("metric:9.1|g", server.read_packet());
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_incr() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
|
||||
client.incr("metric", &None);
|
||||
|
||||
assert_eq!("myapp.metric:1|c", server.read_packet());
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_decr() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
|
||||
client.decr("metric", &None);
|
||||
|
||||
assert_eq!("myapp.metric:-1|c", server.read_packet());
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_count() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
|
||||
client.count("metric", 12.2, &None);
|
||||
|
||||
assert_eq!("myapp.metric:12.2|c", server.read_packet());
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_count_with_tags() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", Some(vec!["tag1", "tag2:value"]));
|
||||
|
||||
client.count("metric", 12.2, &Some(vec!["tag3", "tag4:value"]));
|
||||
|
||||
assert_eq!(
|
||||
"myapp.metric:12.2|c|#tag1,tag2:value,tag3,tag4:value",
|
||||
server.read_packet()
|
||||
);
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_timer() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
|
||||
client.timer("metric", 21.39, &None);
|
||||
|
||||
assert_eq!("myapp.metric:21.39|ms", server.read_packet());
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_timer_at_rate() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
|
||||
client.timer_at_rate("metric", 21.39, 0.123, &None);
|
||||
|
||||
assert_eq!("myapp.metric:21.39|ms|@0.123", server.read_packet());
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_histogram() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
|
||||
// without tags
|
||||
client.histogram("metric", 9.1, &None);
|
||||
assert_eq!("myapp.metric:9.1|h", server.read_packet());
|
||||
server.expect_no_more_packets();
|
||||
|
||||
// with tags
|
||||
client.histogram_at_rate("metric", 9.1, 0.2, &Some(vec!["tag1", "tag2:test"]));
|
||||
assert_eq!(
|
||||
"myapp.metric:9.1|h|@0.2|#tag1,tag2:test",
|
||||
server.read_packet()
|
||||
);
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_histogram_with_constant_tags() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(
|
||||
server.new_port(),
|
||||
"myapp",
|
||||
Some(vec!["tag1common", "tag2common:test"]),
|
||||
);
|
||||
|
||||
// without tags
|
||||
client.histogram("metric", 9.1, &None);
|
||||
assert_eq!(
|
||||
"myapp.metric:9.1|h|#tag1common,tag2common:test",
|
||||
server.read_packet()
|
||||
);
|
||||
server.expect_no_more_packets();
|
||||
|
||||
// with tags
|
||||
let tags = &Some(vec!["tag1", "tag2:test"]);
|
||||
client.histogram("metric", 9.1, tags);
|
||||
assert_eq!(
|
||||
"myapp.metric:9.1|h|#tag1common,tag2common:test,tag1,tag2:test",
|
||||
server.read_packet()
|
||||
);
|
||||
server.expect_no_more_packets();
|
||||
|
||||
// repeat
|
||||
client.histogram_at_rate("metric", 19.12, 0.2, tags);
|
||||
assert_eq!(
|
||||
"myapp.metric:19.12|h|@0.2|#tag1common,tag2common:test,tag1,tag2:test",
|
||||
server.read_packet()
|
||||
);
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_distribution() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
|
||||
// without tags
|
||||
client.distribution("metric", 9.1, &None);
|
||||
assert_eq!("myapp.metric.d:9.1|d", server.read_packet());
|
||||
server.expect_no_more_packets();
|
||||
|
||||
// with tags
|
||||
client.distribution_at_rate("metric", 9.1, 0.1, &Some(vec!["tag1", "tag2:test"]));
|
||||
assert_eq!(
|
||||
"myapp.metric.d:9.1|d|@0.1|#tag1,tag2:test",
|
||||
server.read_packet()
|
||||
);
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_event_with_tags() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
|
||||
client.event(
|
||||
"Title Test",
|
||||
"Text ABC",
|
||||
AlertType::Error,
|
||||
&Some(vec!["tag1", "tag2:test"]),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
"_e{10,8}:Title Test|Text ABC|t:error|#tag1,tag2:test",
|
||||
server.read_packet()
|
||||
);
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_service_check_with_tags() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
|
||||
client.service_check(
|
||||
"Service.check.name",
|
||||
ServiceCheckStatus::Critical,
|
||||
&Some(vec!["tag1", "tag2:test"]),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
"_sc|Service.check.name|2|#tag1,tag2:test",
|
||||
server.read_packet()
|
||||
);
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pipeline_sending_gauge() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
let mut pipeline = client.pipeline();
|
||||
pipeline.gauge("metric", 9.1, &None);
|
||||
drop(pipeline);
|
||||
|
||||
assert_eq!("myapp.metric:9.1|g", server.read_packet());
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pipeline_sending_histogram() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
let mut pipeline = client.pipeline();
|
||||
pipeline.histogram("metric", 9.1, &None);
|
||||
drop(pipeline);
|
||||
|
||||
assert_eq!("myapp.metric:9.1|h", server.read_packet());
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pipeline_sending_multiple_data() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
let mut pipeline = client.pipeline();
|
||||
pipeline.gauge("metric", 9.1, &None);
|
||||
pipeline.count("metric", 12.2, &None);
|
||||
drop(pipeline);
|
||||
|
||||
assert_eq!(
|
||||
"myapp.metric:9.1|g\nmyapp.metric:12.2|c",
|
||||
server.read_packet()
|
||||
);
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pipeline_set_max_udp_size() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
let mut pipeline = client.pipeline_client_of_size(20);
|
||||
pipeline.gauge("metric", 9.1, &None);
|
||||
pipeline.count("metric", 12.2, &None);
|
||||
drop(pipeline);
|
||||
|
||||
assert_eq!("myapp.metric:9.1|g", server.read_packet());
|
||||
assert_eq!("myapp.metric:12.2|c", server.read_packet());
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pipeline_send_metric_after_pipeline() {
|
||||
let mut server = MockServer::new();
|
||||
let mut client = Client::new(server.new_port(), "myapp", None);
|
||||
let mut pipeline = client.pipeline();
|
||||
|
||||
pipeline.gauge("load", 9.0, &None);
|
||||
pipeline.count("customers", 7.0, &None);
|
||||
drop(pipeline);
|
||||
|
||||
// Should still be able to send metrics
|
||||
// with the client.
|
||||
client.count("customers", 6.0, &None);
|
||||
|
||||
assert_eq!("myapp.load:9|g\nmyapp.customers:7|c", server.read_packet());
|
||||
assert_eq!("myapp.customers:6|c", server.read_packet());
|
||||
server.expect_no_more_packets();
|
||||
}
|
||||
}
|
||||
151
src/metrics/histogram.rs
Normal file
151
src/metrics/histogram.rs
Normal file
@ -0,0 +1,151 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::{collections::HashMap, hash::Hash, iter::FromIterator};
|
||||
|
||||
pub struct Histogram<T> {
|
||||
counts_by_value: HashMap<T, usize>,
|
||||
}
|
||||
|
||||
impl<T> Histogram<T> {
|
||||
pub fn push_n(&mut self, value: T, n: usize)
|
||||
where
|
||||
T: Hash + Eq + Copy,
|
||||
{
|
||||
if n == 0 {
|
||||
return;
|
||||
}
|
||||
let count = self.counts_by_value.entry(value).or_insert(0);
|
||||
*count += n;
|
||||
}
|
||||
|
||||
pub fn push(&mut self, value: T)
|
||||
where
|
||||
T: Hash + Eq + Copy,
|
||||
{
|
||||
self.push_n(value, 1)
|
||||
}
|
||||
|
||||
pub fn push_all(&mut self, values: impl IntoIterator<Item = T>)
|
||||
where
|
||||
T: Hash + Eq + Copy,
|
||||
{
|
||||
for a in values {
|
||||
self.push(a);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.counts_by_value.is_empty()
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = (&T, &usize)> {
|
||||
self.counts_by_value.iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Default for Histogram<T> {
|
||||
fn default() -> Self {
|
||||
Histogram {
|
||||
counts_by_value: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Hash + Eq + Copy> FromIterator<A> for Histogram<A> {
|
||||
fn from_iter<T: IntoIterator<Item = A>>(iter: T) -> Self {
|
||||
let mut histogram = Histogram::default();
|
||||
histogram.push_all(iter);
|
||||
histogram
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::metrics::test_utils::assert_histogram_eq;
|
||||
|
||||
#[test]
|
||||
fn collect_i32_values_into_histogram() {
|
||||
let items = vec![1, 2, 3];
|
||||
|
||||
let histogram = items.into_iter().collect::<Histogram<_>>();
|
||||
|
||||
assert_histogram_eq(&histogram, vec![(1, 1), (2, 1), (3, 1)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_u128_values_into_histogram() {
|
||||
let items = vec![1u128, 100u128, 100u128, 100u128];
|
||||
|
||||
let histogram = items.into_iter().collect::<Histogram<_>>();
|
||||
|
||||
assert_histogram_eq(&histogram, vec![(1u128, 1), (100u128, 3)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_histogram() {
|
||||
let histogram: Histogram<u32> = Histogram::default();
|
||||
|
||||
assert_histogram_eq(&histogram, vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_single_value_to_histogram() {
|
||||
let mut histogram = Histogram::default();
|
||||
|
||||
histogram.push(100);
|
||||
|
||||
assert_histogram_eq(&histogram, vec![(100, 1)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_two_values_to_histogram() {
|
||||
let mut histogram = Histogram::default();
|
||||
|
||||
histogram.push(20);
|
||||
histogram.push(20);
|
||||
|
||||
assert_histogram_eq(&histogram, vec![(20, 2)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_n_values_to_histogram() {
|
||||
let mut histogram = Histogram::default();
|
||||
|
||||
histogram.push_n(20, 5);
|
||||
|
||||
assert_histogram_eq(&histogram, vec![(20, 5)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_0_values_to_histogram() {
|
||||
let mut histogram = Histogram::default();
|
||||
|
||||
histogram.push_n(20, 0);
|
||||
|
||||
assert_histogram_eq(&histogram, vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_all_to_histogram() {
|
||||
let mut histogram: Histogram<i32> = Histogram::default();
|
||||
|
||||
histogram.push_all(vec![100, 100, 200]);
|
||||
|
||||
assert_histogram_eq(&histogram, vec![(100, 2), (200, 1)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_and_push_all_histogram() {
|
||||
let mut histogram = Histogram::default();
|
||||
|
||||
histogram.push(50);
|
||||
histogram.push_all(vec![100, 100, 200, 50]);
|
||||
histogram.push(50);
|
||||
|
||||
assert_histogram_eq(&histogram, vec![(50, 3), (100, 2), (200, 1)]);
|
||||
}
|
||||
}
|
||||
388
src/metrics/macros.rs
Normal file
388
src/metrics/macros.rs
Normal file
@ -0,0 +1,388 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use lazy_static::*;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::metrics::{
|
||||
EventCountReporter, EventReport, HistogramReport, NumericValueReporter, TimingOptions,
|
||||
};
|
||||
|
||||
/// A global structure that contains a map to each of the registered Timing Reporters.
|
||||
///
|
||||
/// The mutex lock is only used once to register a new reporter, and then once by the report
|
||||
/// generation.
|
||||
pub struct Metrics {
|
||||
enabled: AtomicBool,
|
||||
registry: Mutex<Registry>,
|
||||
}
|
||||
|
||||
struct Registry {
|
||||
registered_names: HashSet<&'static str>,
|
||||
numeric_reporters: Vec<Arc<NumericValueReporter>>,
|
||||
event_reporters: Vec<Arc<EventCountReporter>>,
|
||||
}
|
||||
|
||||
impl Default for Registry {
|
||||
fn default() -> Self {
|
||||
Registry {
|
||||
registered_names: Default::default(),
|
||||
numeric_reporters: Default::default(),
|
||||
event_reporters: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Report {
|
||||
pub histograms: Vec<HistogramReport>,
|
||||
pub events: Vec<EventReport>,
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
pub static ref __METRICS: Metrics = Metrics::new_enabled();
|
||||
}
|
||||
|
||||
impl Metrics {
|
||||
fn new_enabled() -> Metrics {
|
||||
Metrics {
|
||||
enabled: AtomicBool::new(true),
|
||||
registry: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn clear(&self) {
|
||||
let mut registry = self.registry.lock();
|
||||
*registry = Default::default();
|
||||
}
|
||||
|
||||
/// Locks the internal structure and adds a new timer.
|
||||
pub fn create_and_register_timer(
|
||||
&self,
|
||||
name: &'static str,
|
||||
options: TimingOptions,
|
||||
) -> Arc<NumericValueReporter> {
|
||||
let numeric_reporter = Arc::new(NumericValueReporter::new(name, options));
|
||||
|
||||
if !self.enabled() {
|
||||
numeric_reporter.disable();
|
||||
}
|
||||
|
||||
let mut registry = self.registry.lock();
|
||||
|
||||
if !registry.registered_names.insert(name) {
|
||||
panic!("The metric name \"{}\" has been used elsewhere.", name);
|
||||
}
|
||||
|
||||
registry
|
||||
.numeric_reporters
|
||||
.push(Arc::clone(&numeric_reporter));
|
||||
numeric_reporter
|
||||
}
|
||||
|
||||
/// Locks the internal structure and adds a new event.
|
||||
pub fn create_and_register_event(&self, name: &'static str) -> Arc<EventCountReporter> {
|
||||
let event_reporter = Arc::new(EventCountReporter::new(name));
|
||||
|
||||
let mut registry = self.registry.lock();
|
||||
|
||||
if !registry.registered_names.insert(name) {
|
||||
panic!("The metric name \"{}\" has been used elsewhere.", name);
|
||||
}
|
||||
|
||||
registry.event_reporters.push(Arc::clone(&event_reporter));
|
||||
event_reporter
|
||||
}
|
||||
|
||||
pub fn enabled(&self) -> bool {
|
||||
self.enabled.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Returns reports and resets timer reporters sorted by name.
|
||||
///
|
||||
/// The lock is open this whole time, but the only other use of the lock is registering new timers.
|
||||
pub fn report(&self) -> Report {
|
||||
let registry = self.registry.lock();
|
||||
|
||||
let mut histograms = registry
|
||||
.numeric_reporters
|
||||
.iter()
|
||||
.map(|reporter| reporter.report())
|
||||
.collect::<Vec<_>>();
|
||||
histograms.sort_unstable_by_key(|report| report.name());
|
||||
|
||||
let mut events = registry
|
||||
.event_reporters
|
||||
.iter()
|
||||
.map(|reporter| reporter.report())
|
||||
.collect::<Vec<_>>();
|
||||
events.sort_unstable_by_key(|report| report.name());
|
||||
|
||||
Report { histograms, events }
|
||||
}
|
||||
|
||||
pub fn disable(&self) {
|
||||
self.enabled.store(false, Ordering::Relaxed);
|
||||
self.registry
|
||||
.lock()
|
||||
.numeric_reporters
|
||||
.iter()
|
||||
.for_each(|reporter| reporter.disable());
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! reporter {
|
||||
($name:expr, $options:expr) => {{
|
||||
lazy_static::lazy_static! {
|
||||
pub static ref __REPORTER: std::sync::Arc<crate::metrics::NumericValueReporter> =
|
||||
crate::metrics::__METRICS.create_and_register_timer($name, $options);
|
||||
};
|
||||
&__REPORTER
|
||||
}};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! event_reporter {
|
||||
($name:expr) => {{
|
||||
lazy_static::lazy_static! {
|
||||
pub static ref __REPORTER: std::sync::Arc<crate::metrics::EventCountReporter> =
|
||||
crate::metrics::__METRICS.create_and_register_event($name);
|
||||
};
|
||||
&__REPORTER
|
||||
}};
|
||||
}
|
||||
|
||||
/// Start a timer, and manually choose when to stop the timer.
|
||||
#[macro_export]
|
||||
macro_rules! start_timer {
|
||||
($name:expr) => {
|
||||
reporter!($name, Default::default()).start_timer()
|
||||
};
|
||||
($name:expr, $options:expr) => {
|
||||
reporter!($name, $options).start_timer()
|
||||
};
|
||||
}
|
||||
|
||||
/// Start a timer that automatically stops when it falls out of scope.
|
||||
#[macro_export]
|
||||
macro_rules! time_scope {
|
||||
($name:expr) => {
|
||||
let _t = reporter!($name, Default::default()).start_timer();
|
||||
};
|
||||
($name:expr, $options:expr) => {
|
||||
let _t = reporter!($name, $options).start_timer();
|
||||
};
|
||||
}
|
||||
|
||||
/// Time the scope in microseconds, 1000 samples per reporting minute
|
||||
#[macro_export]
|
||||
macro_rules! time_scope_us {
|
||||
($name:expr) => {
|
||||
time_scope!(
|
||||
$name,
|
||||
crate::metrics::TimingOptions::microsecond_1000_per_minute()
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
/// Start timer in microseconds, 1000 samples per reporting minute
|
||||
#[macro_export]
|
||||
macro_rules! start_timer_us {
|
||||
($name:expr) => {{
|
||||
start_timer!(
|
||||
$name,
|
||||
crate::metrics::TimingOptions::microsecond_1000_per_minute()
|
||||
)
|
||||
}};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! event {
|
||||
($name:expr) => {
|
||||
event_reporter!($name).count();
|
||||
};
|
||||
($name:expr, $count:expr) => {
|
||||
event_reporter!($name).count_n($count);
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! metrics {
|
||||
() => {{
|
||||
&crate::metrics::__METRICS
|
||||
}};
|
||||
}
|
||||
|
||||
/// Sample the value produced by the supplied function and produce a histogram.
|
||||
#[macro_export]
|
||||
macro_rules! sampling_histogram {
|
||||
($name:expr, $sampler:expr) => {
|
||||
reporter!($name, Default::default()).push($sampler)
|
||||
};
|
||||
($name:expr, $options:expr, $sampler:expr) => {
|
||||
reporter!($name, $options).push($sampler)
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use mock_instant::MockClock;
|
||||
|
||||
use crate::{
|
||||
metrics::{test_utils::assert_histogram_eq, Metrics, Timer},
|
||||
*,
|
||||
};
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "The metric name \"A\" has been used elsewhere.")]
|
||||
fn cant_register_same_timer_twice() {
|
||||
let metrics = Metrics::new_enabled();
|
||||
|
||||
metrics.create_and_register_timer("A", Default::default());
|
||||
metrics.create_and_register_timer("A", Default::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "The metric name \"A\" has been used elsewhere.")]
|
||||
fn cant_register_same_event_twice() {
|
||||
let metrics = Metrics::new_enabled();
|
||||
|
||||
metrics.create_and_register_event("A");
|
||||
metrics.create_and_register_event("A");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "The metric name \"A\" has been used elsewhere.")]
|
||||
fn cant_register_same_name_for_an_event_and_timer() {
|
||||
let metrics = Metrics::new_enabled();
|
||||
|
||||
metrics.create_and_register_timer("A", Default::default());
|
||||
metrics.create_and_register_event("A");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registrations_are_enabled() {
|
||||
let metrics = Metrics::new_enabled();
|
||||
|
||||
let timing_reporter = metrics.create_and_register_timer("A", Default::default());
|
||||
|
||||
timing_reporter.start_timer().stop();
|
||||
timing_reporter.start_timer().stop();
|
||||
|
||||
assert_eq!(2, timing_reporter.report().sample_count());
|
||||
|
||||
timing_reporter.start_timer().stop();
|
||||
|
||||
assert_eq!(1, timing_reporter.report().sample_count());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registrations_after_disabled_are_not_enabled() {
|
||||
let metrics = Metrics::new_enabled();
|
||||
|
||||
metrics.disable();
|
||||
|
||||
let timing_reporter = metrics.create_and_register_timer("A", Default::default());
|
||||
|
||||
timing_reporter.start_timer().stop();
|
||||
timing_reporter.start_timer().stop();
|
||||
|
||||
assert_eq!(0, timing_reporter.report().sample_count());
|
||||
|
||||
timing_reporter.start_timer().stop();
|
||||
|
||||
assert_eq!(0, timing_reporter.report().sample_count());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registrations_before_disabled_are_later_disabled() {
|
||||
let metrics = Metrics::new_enabled();
|
||||
|
||||
let timing_reporter = metrics.create_and_register_timer("A", Default::default());
|
||||
|
||||
metrics.disable();
|
||||
|
||||
timing_reporter.start_timer().stop();
|
||||
timing_reporter.start_timer().stop();
|
||||
|
||||
assert_eq!(0, timing_reporter.report().sample_count());
|
||||
|
||||
timing_reporter.start_timer().stop();
|
||||
|
||||
assert_eq!(0, timing_reporter.report().sample_count());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accurate_timers_using_macros() {
|
||||
// Other tests that trigger reports will cause this test to fail unless we clear it first.
|
||||
metrics!().clear();
|
||||
|
||||
{
|
||||
time_scope!("outer");
|
||||
{
|
||||
time_scope_us!("inner");
|
||||
|
||||
MockClock::advance(Duration::from_secs(2));
|
||||
}
|
||||
|
||||
let timer = start_timer_us!("manual");
|
||||
MockClock::advance(Duration::from_millis(600));
|
||||
timer.stop();
|
||||
MockClock::advance(Duration::from_millis(400));
|
||||
}
|
||||
|
||||
for _ in 0..2 {
|
||||
sampling_histogram!("event1", || 100);
|
||||
}
|
||||
|
||||
sampling_histogram!("event2", || 50);
|
||||
|
||||
event!("event3");
|
||||
event!("event4", 10);
|
||||
|
||||
let reports = metrics!().report();
|
||||
let histograms = reports.histograms;
|
||||
|
||||
let mut iter = histograms.iter();
|
||||
let report1 = iter.next().unwrap();
|
||||
let report2 = iter.next().unwrap();
|
||||
let report3 = iter.next().unwrap();
|
||||
let report4 = iter.next().unwrap();
|
||||
let report5 = iter.next().unwrap();
|
||||
|
||||
assert_eq!("event1", report1.name());
|
||||
assert_eq!("event2", report2.name());
|
||||
assert_eq!("inner", report3.name());
|
||||
assert_eq!("manual", report4.name());
|
||||
assert_eq!("outer", report5.name());
|
||||
|
||||
assert_histogram_eq(&report1.histogram, vec![(100, 2)]);
|
||||
assert_histogram_eq(&report2.histogram, vec![(50, 1)]);
|
||||
assert_histogram_eq(&report3.histogram, vec![(2_000_000, 1)]);
|
||||
assert_histogram_eq(&report4.histogram, vec![(600_000, 1)]);
|
||||
assert_histogram_eq(&report5.histogram, vec![(3_000, 1)]);
|
||||
|
||||
let events = reports.events;
|
||||
let mut iter = events.iter();
|
||||
let event3 = iter.next().unwrap();
|
||||
let event4 = iter.next().unwrap();
|
||||
|
||||
assert_eq!("event3", event3.name());
|
||||
assert_eq!("event4", event4.name());
|
||||
assert_eq!(1, event3.event_count());
|
||||
assert_eq!(10, event4.event_count());
|
||||
}
|
||||
}
|
||||
739
src/metrics/reporter.rs
Normal file
739
src/metrics/reporter.rs
Normal file
@ -0,0 +1,739 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
#[cfg(not(test))]
|
||||
use std::time::Instant;
|
||||
use std::{
|
||||
mem,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
use mock_instant::Instant;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::metrics::{Histogram, Precision, TimingOptions};
|
||||
|
||||
/// Represents a sampler for collecting histograms of values of any unit.
|
||||
/// e.g. they might be times, or packets sizes.
|
||||
pub struct NumericValueReporter {
|
||||
name: &'static str,
|
||||
measurements_since_last_report: Mutex<SinceLastReport>,
|
||||
event_counter: AtomicUsize,
|
||||
/// 1 in every sample_interval will be actually be measured.
|
||||
sample_interval: AtomicUsize,
|
||||
options: TimingOptions,
|
||||
}
|
||||
|
||||
/// The internally mutable component of the performance measuring system.
|
||||
struct SinceLastReport {
|
||||
histogram: Histogram<usize>,
|
||||
initial_event_counter: usize,
|
||||
sample_count: usize,
|
||||
}
|
||||
|
||||
impl NumericValueReporter {
|
||||
pub fn new(name: &'static str, options: TimingOptions) -> NumericValueReporter {
|
||||
NumericValueReporter {
|
||||
name,
|
||||
measurements_since_last_report: Mutex::new(SinceLastReport::new(0)),
|
||||
event_counter: AtomicUsize::new(0),
|
||||
sample_interval: AtomicUsize::new(1),
|
||||
options,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_interval_is_enabled(sample_interval: usize) -> bool {
|
||||
sample_interval != usize::MAX
|
||||
}
|
||||
|
||||
pub fn disable(&self) {
|
||||
self.sample_interval.store(usize::MAX, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[must_use = "If you don't want to assign to a local and drop manually, then use time_scope! macro"]
|
||||
pub fn start_timer(&self) -> impl Timer + '_ {
|
||||
self.sample(|sample_interval| RunningTimer::start(self, Instant::now(), sample_interval))
|
||||
}
|
||||
|
||||
/// This will use ths sampling interval and only invoke the sampler periodically to push an
|
||||
/// arbitrary unit value to the histogram.
|
||||
pub fn push(&self, sampler: impl FnOnce() -> usize) {
|
||||
self.sample(|sample_interval| self.push_sample(sampler(), sample_interval));
|
||||
}
|
||||
|
||||
/// Executes the supplied sampler according to the sample interval.
|
||||
fn sample<T>(&self, sampler: impl FnOnce(usize) -> T) -> Option<T> {
|
||||
let sample_interval = self.sample_interval.load(Ordering::Relaxed);
|
||||
if Self::sample_interval_is_enabled(sample_interval) {
|
||||
let previous_counter = self.event_counter.fetch_add(1, Ordering::AcqRel);
|
||||
if previous_counter % sample_interval == (sample_interval - 1) {
|
||||
return Some(sampler(sample_interval));
|
||||
}
|
||||
};
|
||||
None
|
||||
}
|
||||
|
||||
fn push_time_sample(&self, sample: Duration, sample_interval: usize) {
|
||||
let value = match self.options.sample_precision {
|
||||
Precision::Centisecond => sample.as_millis() as usize / 10,
|
||||
Precision::Millisecond => sample.as_millis() as usize,
|
||||
Precision::Microsecond => sample.as_micros() as usize,
|
||||
Precision::Nanosecond => sample.as_nanos() as usize,
|
||||
};
|
||||
self.push_sample(value, sample_interval);
|
||||
}
|
||||
|
||||
fn push_sample(&self, sample: usize, sample_interval: usize) {
|
||||
self.measurements_since_last_report
|
||||
.lock()
|
||||
.push_sample(sample, sample_interval);
|
||||
}
|
||||
|
||||
/// Creates a report of timings and resets the reporter.
|
||||
pub fn report(&self) -> HistogramReport {
|
||||
let event_count = self.event_counter.load(Ordering::Relaxed);
|
||||
let last_sample_interval = self.sample_interval.load(Ordering::Relaxed);
|
||||
|
||||
let since_last_report = {
|
||||
let mut times_since_last_report = self.measurements_since_last_report.lock();
|
||||
|
||||
mem::replace(
|
||||
&mut *times_since_last_report,
|
||||
SinceLastReport::new(event_count),
|
||||
)
|
||||
};
|
||||
|
||||
let events_since_last_report = event_count - since_last_report.initial_event_counter;
|
||||
|
||||
if Self::sample_interval_is_enabled(last_sample_interval) {
|
||||
self.sample_interval.store(
|
||||
Self::calculate_sample_rate(
|
||||
events_since_last_report,
|
||||
self.options.target_sample_rate,
|
||||
),
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
}
|
||||
|
||||
HistogramReport {
|
||||
name: self.name,
|
||||
sample_interval: last_sample_interval,
|
||||
histogram: since_last_report.histogram,
|
||||
event_count: events_since_last_report,
|
||||
sample_count: since_last_report.sample_count,
|
||||
sample_precision: self.options.sample_precision,
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_sample_rate(actual_count: usize, target_rate: usize) -> usize {
|
||||
(actual_count / target_rate).max(1)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EventCountReporter {
|
||||
name: &'static str,
|
||||
event_counter: AtomicUsize,
|
||||
}
|
||||
|
||||
impl EventCountReporter {
|
||||
pub fn new(name: &'static str) -> EventCountReporter {
|
||||
EventCountReporter {
|
||||
name,
|
||||
event_counter: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// This will count n events.
|
||||
pub fn count_n(&self, n: usize) {
|
||||
self.event_counter.fetch_add(n, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// This will count an event.
|
||||
pub fn count(&self) {
|
||||
self.count_n(1);
|
||||
}
|
||||
|
||||
/// Grab the event count and reset to zero.
|
||||
pub fn report(&self) -> EventReport {
|
||||
EventReport {
|
||||
name: self.name,
|
||||
event_count: self.event_counter.swap(0, Ordering::Relaxed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RunningTimer<'a> {
|
||||
reporter: &'a NumericValueReporter,
|
||||
start_time: Instant,
|
||||
sample_interval: usize,
|
||||
}
|
||||
|
||||
pub trait Timer {
|
||||
fn stop(self);
|
||||
}
|
||||
|
||||
impl<'a> RunningTimer<'a> {
|
||||
fn start(
|
||||
reporter: &'a NumericValueReporter,
|
||||
start_time: Instant,
|
||||
sample_interval: usize,
|
||||
) -> RunningTimer<'a> {
|
||||
RunningTimer {
|
||||
reporter,
|
||||
start_time,
|
||||
sample_interval,
|
||||
}
|
||||
}
|
||||
|
||||
fn stop(&mut self) {
|
||||
self.reporter
|
||||
.push_time_sample(self.start_time.elapsed(), self.sample_interval);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for RunningTimer<'a> {
|
||||
fn drop(&mut self) {
|
||||
self.stop();
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Timer for RunningTimer<'a> {
|
||||
fn stop(self) {}
|
||||
}
|
||||
|
||||
impl<T: Timer> Timer for Option<T> {
|
||||
fn stop(self) {
|
||||
if let Some(stoppable) = self {
|
||||
stoppable.stop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct HistogramReport {
|
||||
name: &'static str,
|
||||
sample_interval: usize,
|
||||
event_count: usize,
|
||||
sample_count: usize,
|
||||
pub histogram: Histogram<usize>,
|
||||
sample_precision: Precision,
|
||||
}
|
||||
|
||||
impl HistogramReport {
|
||||
pub fn name(&self) -> &'static str {
|
||||
self.name
|
||||
}
|
||||
|
||||
/// Actual number of times this event started in this reporting period.
|
||||
pub fn event_count(&self) -> usize {
|
||||
self.event_count
|
||||
}
|
||||
|
||||
/// Number of times this event was sampled in the time period.
|
||||
pub fn sample_count(&self) -> usize {
|
||||
self.sample_count
|
||||
}
|
||||
|
||||
/// 1 in sample_intervals were actually recorded.
|
||||
pub fn sample_interval(&self) -> usize {
|
||||
self.sample_interval
|
||||
}
|
||||
|
||||
pub fn sample_precision(&self) -> Precision {
|
||||
self.sample_precision
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EventReport {
|
||||
name: &'static str,
|
||||
event_count: usize,
|
||||
}
|
||||
|
||||
impl EventReport {
|
||||
pub fn name(&self) -> &'static str {
|
||||
self.name
|
||||
}
|
||||
|
||||
pub fn event_count(&self) -> usize {
|
||||
self.event_count
|
||||
}
|
||||
}
|
||||
|
||||
impl SinceLastReport {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `event_counter` - The reporter's event counter at the time of creation.
|
||||
fn new(event_counter: usize) -> SinceLastReport {
|
||||
SinceLastReport {
|
||||
histogram: Histogram::default(),
|
||||
initial_event_counter: event_counter,
|
||||
sample_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn push_sample(&mut self, sample: usize, n: usize) {
|
||||
self.histogram.push_n(sample, n);
|
||||
self.sample_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use mock_instant::MockClock;
|
||||
|
||||
use super::*;
|
||||
use crate::metrics::{test_utils::assert_histogram_eq, Precision};
|
||||
|
||||
#[test]
|
||||
fn push_a_value_sample() {
|
||||
let name = "test";
|
||||
let reporter = NumericValueReporter::new(name, Default::default());
|
||||
|
||||
reporter.push(|| 100);
|
||||
|
||||
let report = reporter.report();
|
||||
assert_eq!(1, report.event_count());
|
||||
|
||||
assert_histogram_eq(&report.histogram, vec![(100, 1)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn once_report_is_taken_a_new_report_starts() {
|
||||
let name = "test";
|
||||
let reporter = NumericValueReporter::new(name, Default::default());
|
||||
|
||||
reporter.start_timer().stop();
|
||||
|
||||
let report = reporter.report();
|
||||
assert_eq!(1, report.event_count());
|
||||
|
||||
let report = reporter.report();
|
||||
assert_eq!(0, report.event_count());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_contains_name_of_perf() {
|
||||
let name = "test";
|
||||
let reporter = NumericValueReporter::new(name, Default::default());
|
||||
let report = reporter.report();
|
||||
assert_eq!("test", report.name());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn timer_added_to_report_when_dropped() {
|
||||
let name = "test";
|
||||
let reporter = NumericValueReporter::new(name, Default::default());
|
||||
|
||||
let timer = reporter.start_timer();
|
||||
|
||||
let report = reporter.report();
|
||||
assert_eq!(1, report.event_count(), "Expect to be counted when start");
|
||||
assert_eq!(
|
||||
0,
|
||||
report.sample_count(),
|
||||
"Expect to be absent from report before drop"
|
||||
);
|
||||
|
||||
drop(timer);
|
||||
|
||||
let report = reporter.report();
|
||||
assert_eq!(
|
||||
0,
|
||||
report.event_count(),
|
||||
"Event didn't start in this reporting period"
|
||||
);
|
||||
assert_eq!(
|
||||
1,
|
||||
report.sample_count(),
|
||||
"Expect to be present in report after drop"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accurate_timers_at_millisecond_precision() {
|
||||
let reporter = NumericValueReporter::new(
|
||||
"Mocked Time",
|
||||
TimingOptions {
|
||||
sample_precision: Precision::Millisecond,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
{
|
||||
let _timer = reporter.start_timer();
|
||||
|
||||
for _ in 0..2 {
|
||||
let _timer = reporter.start_timer();
|
||||
|
||||
MockClock::advance(Duration::from_secs(2));
|
||||
}
|
||||
|
||||
MockClock::advance(Duration::from_secs(1));
|
||||
}
|
||||
|
||||
let report = reporter.report();
|
||||
|
||||
assert_eq!("Mocked Time", report.name);
|
||||
assert_histogram_eq(&&report.histogram, vec![(2000, 2), (5000, 1)]);
|
||||
assert_eq!(Precision::Millisecond, report.sample_precision());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accurate_timers_at_centisecond_precision() {
|
||||
let reporter = NumericValueReporter::new(
|
||||
"Mocked Time",
|
||||
TimingOptions {
|
||||
sample_precision: Precision::Centisecond,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
{
|
||||
let _timer = reporter.start_timer();
|
||||
MockClock::advance(Duration::from_secs(5));
|
||||
}
|
||||
|
||||
let report = reporter.report();
|
||||
|
||||
assert_eq!("Mocked Time", report.name);
|
||||
assert_histogram_eq(&&report.histogram, vec![(500, 1)]);
|
||||
assert_eq!(Precision::Centisecond, report.sample_precision());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accurate_timers_at_microsecond_precision() {
|
||||
let reporter = NumericValueReporter::new(
|
||||
"Mocked Time",
|
||||
TimingOptions {
|
||||
sample_precision: Precision::Microsecond,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
{
|
||||
let _timer = reporter.start_timer();
|
||||
MockClock::advance(Duration::from_micros(3621));
|
||||
}
|
||||
|
||||
let report = reporter.report();
|
||||
|
||||
assert_eq!("Mocked Time", report.name);
|
||||
assert_histogram_eq(&&report.histogram, vec![(3621, 1)]);
|
||||
assert_eq!(Precision::Microsecond, report.sample_precision());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accurate_timers_at_nanosecond_precision() {
|
||||
let reporter = NumericValueReporter::new(
|
||||
"Mocked Time",
|
||||
TimingOptions {
|
||||
sample_precision: Precision::Nanosecond,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
{
|
||||
let _timer = reporter.start_timer();
|
||||
MockClock::advance(Duration::from_nanos(123));
|
||||
}
|
||||
|
||||
let report = reporter.report();
|
||||
|
||||
assert_eq!("Mocked Time", report.name);
|
||||
assert_histogram_eq(&&report.histogram, vec![(123, 1)]);
|
||||
assert_eq!(Precision::Nanosecond, report.sample_precision());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn early_stop() {
|
||||
let name = "Mocked Time";
|
||||
let reporter = NumericValueReporter::new(name, Default::default());
|
||||
|
||||
{
|
||||
let timer = reporter.start_timer();
|
||||
|
||||
MockClock::advance(Duration::from_secs(1));
|
||||
|
||||
timer.stop();
|
||||
|
||||
MockClock::advance(Duration::from_secs(2));
|
||||
}
|
||||
|
||||
let report = reporter.report();
|
||||
|
||||
assert_histogram_eq(&&report.histogram, vec![(1000, 1)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn early_stop_twice_doesnt_count_twice() {
|
||||
let name = "Mocked Time";
|
||||
let reporter = NumericValueReporter::new(name, Default::default());
|
||||
|
||||
{
|
||||
let timer = reporter.start_timer();
|
||||
|
||||
MockClock::advance(Duration::from_secs(1));
|
||||
|
||||
timer.stop();
|
||||
|
||||
MockClock::advance(Duration::from_secs(2));
|
||||
}
|
||||
|
||||
let report = reporter.report();
|
||||
|
||||
assert_histogram_eq(&&report.histogram, vec![(1000, 1)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_sampling_with_timers() {
|
||||
let reporter = NumericValueReporter::new(
|
||||
"Mocked Time",
|
||||
TimingOptions {
|
||||
target_sample_rate: 1_000,
|
||||
sample_precision: Precision::Millisecond,
|
||||
},
|
||||
);
|
||||
assert_eq!(0, reporter.event_counter.fetch_add(0, Ordering::Acquire));
|
||||
|
||||
for _ in 0..10_000 {
|
||||
let timer = reporter.start_timer();
|
||||
|
||||
MockClock::advance(Duration::from_secs(1));
|
||||
|
||||
timer.stop();
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
10_000,
|
||||
reporter.event_counter.fetch_add(0, Ordering::Acquire)
|
||||
);
|
||||
|
||||
let report = reporter.report();
|
||||
assert_eq!(1, report.sample_interval);
|
||||
assert_eq!(10_000, report.sample_count);
|
||||
assert_eq!(10_000, report.event_count);
|
||||
assert_histogram_eq(&report.histogram, vec![(1000, 10_000)]);
|
||||
|
||||
// because 10K were sampled in the report, the actual sample rate should be adjusted to 10
|
||||
|
||||
for _ in 0..10_000 {
|
||||
let timer = reporter.start_timer();
|
||||
|
||||
MockClock::advance(Duration::from_secs(1));
|
||||
|
||||
timer.stop();
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
20_000,
|
||||
reporter.event_counter.fetch_add(0, Ordering::Acquire)
|
||||
);
|
||||
|
||||
let report = reporter.report();
|
||||
assert_eq!(10, report.sample_interval);
|
||||
assert_eq!(1_000, report.sample_count);
|
||||
assert_eq!(10_000, report.event_count);
|
||||
assert_histogram_eq(&report.histogram, vec![(1000, 10_000)]);
|
||||
|
||||
// because we got exactly as many as expected, no further change should happen
|
||||
for _ in 0..10_000 {
|
||||
let timer = reporter.start_timer();
|
||||
|
||||
MockClock::advance(Duration::from_secs(1));
|
||||
|
||||
timer.stop();
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
30_000,
|
||||
reporter.event_counter.fetch_add(0, Ordering::Acquire)
|
||||
);
|
||||
|
||||
let report = reporter.report();
|
||||
assert_eq!(10, report.sample_interval);
|
||||
assert_eq!(1_000, report.sample_count);
|
||||
assert_eq!(10_000, report.event_count);
|
||||
assert_histogram_eq(&report.histogram, vec![(1000, 10_000)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_sampling_using_push() {
|
||||
let invocations = AtomicUsize::new(0);
|
||||
let count_invocations_return_1000 = || {
|
||||
invocations.fetch_add(1, Ordering::Relaxed);
|
||||
1000
|
||||
};
|
||||
let reporter = NumericValueReporter::new(
|
||||
"Mocked",
|
||||
TimingOptions {
|
||||
target_sample_rate: 1_000,
|
||||
sample_precision: Precision::Millisecond,
|
||||
},
|
||||
);
|
||||
assert_eq!(0, reporter.event_counter.fetch_add(0, Ordering::Acquire));
|
||||
|
||||
for _ in 0..10_000 {
|
||||
reporter.push(count_invocations_return_1000);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
10_000,
|
||||
reporter.event_counter.fetch_add(0, Ordering::Acquire)
|
||||
);
|
||||
|
||||
let report = reporter.report();
|
||||
assert_eq!(1, report.sample_interval);
|
||||
assert_eq!(10_000, report.sample_count);
|
||||
assert_eq!(report.sample_count, invocations.load(Ordering::Relaxed));
|
||||
assert_eq!(10_000, report.event_count);
|
||||
assert_histogram_eq(&report.histogram, vec![(1000, 10_000)]);
|
||||
invocations.store(0, Ordering::Relaxed);
|
||||
|
||||
// because 10K were sampled in the report, the actual sample rate should be adjusted to 10
|
||||
|
||||
for _ in 0..10_000 {
|
||||
reporter.push(count_invocations_return_1000);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
20_000,
|
||||
reporter.event_counter.fetch_add(0, Ordering::Acquire)
|
||||
);
|
||||
|
||||
let report = reporter.report();
|
||||
assert_eq!(10, report.sample_interval);
|
||||
assert_eq!(1_000, report.sample_count);
|
||||
assert_eq!(report.sample_count, invocations.load(Ordering::Relaxed));
|
||||
assert_eq!(10_000, report.event_count);
|
||||
assert_histogram_eq(&report.histogram, vec![(1000, 10_000)]);
|
||||
invocations.store(0, Ordering::Relaxed);
|
||||
|
||||
// because we got exactly as many as expected, no further change should happen
|
||||
for _ in 0..10_000 {
|
||||
reporter.push(count_invocations_return_1000);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
30_000,
|
||||
reporter.event_counter.fetch_add(0, Ordering::Acquire)
|
||||
);
|
||||
|
||||
let report = reporter.report();
|
||||
assert_eq!(10, report.sample_interval);
|
||||
assert_eq!(1_000, report.sample_count);
|
||||
assert_eq!(report.sample_count, invocations.load(Ordering::Relaxed));
|
||||
assert_eq!(10_000, report.event_count);
|
||||
assert_histogram_eq(&report.histogram, vec![(1000, 10_000)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn only_the_nth_sample_is_taken() {
|
||||
let reporter = NumericValueReporter::new(
|
||||
"Mocked Time",
|
||||
TimingOptions {
|
||||
target_sample_rate: 1_000,
|
||||
sample_precision: Precision::Millisecond,
|
||||
},
|
||||
);
|
||||
assert_eq!(0, reporter.event_counter.fetch_add(0, Ordering::Acquire));
|
||||
|
||||
for _ in 0..10_000 {
|
||||
let timer = reporter.start_timer();
|
||||
|
||||
MockClock::advance(Duration::from_secs(1));
|
||||
|
||||
timer.stop();
|
||||
}
|
||||
|
||||
let report = reporter.report();
|
||||
assert_eq!(1, report.sample_interval);
|
||||
|
||||
// because 10K were sampled in the report, the actual sample rate should be adjusted to 10
|
||||
for _ in 0..10_000 {
|
||||
let timer = reporter.start_timer();
|
||||
|
||||
MockClock::advance(Duration::from_secs(1));
|
||||
|
||||
timer.stop();
|
||||
}
|
||||
|
||||
let report = reporter.report();
|
||||
assert_eq!(10, report.sample_interval);
|
||||
|
||||
for _ in 0..9 {
|
||||
let timer = reporter.start_timer();
|
||||
|
||||
MockClock::advance(Duration::from_secs(1));
|
||||
|
||||
timer.stop();
|
||||
}
|
||||
|
||||
let timer = reporter.start_timer();
|
||||
|
||||
MockClock::advance(Duration::from_secs(3));
|
||||
|
||||
timer.stop();
|
||||
|
||||
let report = reporter.report();
|
||||
assert_eq!(1, report.sample_count);
|
||||
|
||||
// Because there were 9x1 second, followed by 1x3second, this shows the 10th value exactly is being sampled.
|
||||
assert_histogram_eq(&report.histogram, vec![(3000, 10)])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_sample_rate_calculations() {
|
||||
assert_eq!(
|
||||
10,
|
||||
NumericValueReporter::calculate_sample_rate(10_000, 1_000)
|
||||
);
|
||||
assert_eq!(
|
||||
20,
|
||||
NumericValueReporter::calculate_sample_rate(20_000, 1_000)
|
||||
);
|
||||
assert_eq!(1, NumericValueReporter::calculate_sample_rate(0, 1_000));
|
||||
assert_eq!(
|
||||
123,
|
||||
NumericValueReporter::calculate_sample_rate(12_314, 100)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disable_timer() {
|
||||
let timing_reporter = NumericValueReporter::new("c", Default::default());
|
||||
|
||||
timing_reporter.disable();
|
||||
|
||||
timing_reporter.start_timer().stop();
|
||||
timing_reporter.start_timer().stop();
|
||||
|
||||
assert_eq!(0, timing_reporter.report().sample_count);
|
||||
|
||||
timing_reporter.start_timer().stop();
|
||||
|
||||
assert_eq!(0, timing_reporter.report().sample_count);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_counting() {
|
||||
let event_reporter = EventCountReporter::new("event");
|
||||
|
||||
event_reporter.count();
|
||||
|
||||
let report = event_reporter.report();
|
||||
assert_eq!("event", report.name());
|
||||
assert_eq!(1, report.event_count());
|
||||
|
||||
event_reporter.count();
|
||||
event_reporter.count_n(3);
|
||||
|
||||
let report = event_reporter.report();
|
||||
assert_eq!("event", report.name());
|
||||
assert_eq!(4, report.event_count());
|
||||
}
|
||||
}
|
||||
29
src/metrics/test_utils.rs
Normal file
29
src/metrics/test_utils.rs
Normal file
@ -0,0 +1,29 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
#![cfg(test)]
|
||||
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::metrics::Histogram;
|
||||
|
||||
/// Compares the contents of a hash map with an expected vector of Key-Value pairs.
|
||||
pub fn assert_map_eq<K, V>(actual: impl Iterator<Item = (K, V)>, mut expected: Vec<(K, V)>)
|
||||
where
|
||||
K: Ord + PartialEq + Debug + Copy,
|
||||
V: PartialEq + Debug + Copy,
|
||||
{
|
||||
let mut actual: Vec<(K, V)> = actual.into_iter().collect::<Vec<_>>();
|
||||
actual.sort_unstable_by(|(k1, _), (k2, _)| k1.cmp(k2));
|
||||
expected.sort_unstable_by(|(k1, _), (k2, _)| k1.cmp(k2));
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
pub fn assert_histogram_eq<K>(histogram: &Histogram<K>, expected: Vec<(K, usize)>)
|
||||
where
|
||||
K: Ord + PartialEq + Debug + Copy,
|
||||
{
|
||||
assert_map_eq(histogram.iter().map(|(k, v)| (*k, *v)), expected)
|
||||
}
|
||||
50
src/metrics/timing_options.rs
Normal file
50
src/metrics/timing_options.rs
Normal file
@ -0,0 +1,50 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TimingOptions {
|
||||
/// How many samples you want to take in the reporting period.
|
||||
///
|
||||
/// The timer will adjust automatically to take this many from the second reporting period onwards.
|
||||
pub target_sample_rate: usize,
|
||||
|
||||
/// The precision the durations will be recorded at on the histogram.
|
||||
pub sample_precision: Precision,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub enum Precision {
|
||||
Centisecond,
|
||||
Millisecond,
|
||||
Microsecond,
|
||||
Nanosecond,
|
||||
}
|
||||
|
||||
impl TimingOptions {
|
||||
const DEFAULT: TimingOptions = TimingOptions {
|
||||
target_sample_rate: 1_000,
|
||||
sample_precision: Precision::Millisecond,
|
||||
};
|
||||
|
||||
pub fn microsecond_1000_per_minute() -> TimingOptions {
|
||||
TimingOptions {
|
||||
target_sample_rate: 1_000,
|
||||
sample_precision: Precision::Microsecond,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn nanosecond_1000_per_minute() -> TimingOptions {
|
||||
TimingOptions {
|
||||
target_sample_rate: 1_000,
|
||||
sample_precision: Precision::Nanosecond,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TimingOptions {
|
||||
fn default() -> Self {
|
||||
TimingOptions::DEFAULT
|
||||
}
|
||||
}
|
||||
236
src/metrics_server.rs
Normal file
236
src/metrics_server.rs
Normal file
@ -0,0 +1,236 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::{
|
||||
ops::{Deref, DerefMut},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use log::*;
|
||||
use parking_lot::Mutex;
|
||||
use psutil::process::Process;
|
||||
use tokio::sync::oneshot::Receiver;
|
||||
|
||||
use crate::{
|
||||
config,
|
||||
config::Config,
|
||||
metrics::{
|
||||
Client as DatadogClient, Histogram, HistogramReport, PipelineSink, Precision, UdpEventSink,
|
||||
},
|
||||
sfu::Sfu,
|
||||
};
|
||||
|
||||
pub async fn start(
|
||||
config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
shutdown_signal_rx: Receiver<()>,
|
||||
) -> Result<()> {
|
||||
match Datadog::new(&config) {
|
||||
None => {
|
||||
metrics!().disable();
|
||||
info!("metrics server not started because not configured, metrics disabled");
|
||||
|
||||
let _ = tokio::select!(
|
||||
_ = shutdown_signal_rx => {},
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Some(mut datadog) => {
|
||||
let tick_handle = tokio::spawn(async move {
|
||||
let mut tick_interval = tokio::time::interval(Duration::from_secs(60));
|
||||
|
||||
loop {
|
||||
tick_interval.tick().await;
|
||||
let mut datadog = datadog.open_pipeline();
|
||||
|
||||
for (metric_name, value) in get_value_metrics() {
|
||||
datadog.gauge(metric_name, value as f64, &None);
|
||||
}
|
||||
|
||||
let report = metrics!().report();
|
||||
for report in report.histograms {
|
||||
datadog.send_timer_histogram(&report, &None);
|
||||
}
|
||||
for report in report.events {
|
||||
datadog.count(report.name(), report.event_count() as f64, &None);
|
||||
}
|
||||
|
||||
let stats = sfu.lock().get_stats();
|
||||
for (name, histogram) in stats.histograms {
|
||||
datadog.send_count_histogram(name, &histogram, &None);
|
||||
}
|
||||
for (name, value) in stats.values {
|
||||
datadog.gauge(name, value as f64, &None);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let _ = tokio::select!(
|
||||
_ = tick_handle => {},
|
||||
_ = shutdown_signal_rx => {},
|
||||
);
|
||||
|
||||
info!("metrics server shutdown");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Datadog {
|
||||
client: DatadogClient<UdpEventSink>,
|
||||
}
|
||||
|
||||
struct DatadogPipeline<'a>(DatadogClient<PipelineSink<'a, UdpEventSink>>);
|
||||
|
||||
impl Datadog {
|
||||
fn new(config: &'static Config) -> Option<Self> {
|
||||
let host = config.metrics.datadog.as_ref()?;
|
||||
|
||||
let sink = UdpEventSink::new(host).unwrap();
|
||||
let source = config
|
||||
.ice_candidate_ip
|
||||
.as_deref()
|
||||
.unwrap_or("unspecified")
|
||||
.to_string();
|
||||
|
||||
let mut point_tags = vec![
|
||||
("region", config.metrics.region.clone()),
|
||||
("source", source),
|
||||
];
|
||||
|
||||
if let Some(version) = &config.metrics.version {
|
||||
point_tags.push(("version", version.to_string()));
|
||||
}
|
||||
|
||||
let constant_tags: Vec<_> = point_tags
|
||||
.iter()
|
||||
.map(|(a, b)| format!("{}:{}", a, b))
|
||||
.collect();
|
||||
|
||||
let constant_tags: Vec<_> = constant_tags.iter().map(|a| a.as_ref()).collect();
|
||||
|
||||
let client = DatadogClient::new(sink, "", Some(constant_tags));
|
||||
|
||||
Some(Self { client })
|
||||
}
|
||||
|
||||
fn open_pipeline(&mut self) -> DatadogPipeline<'_> {
|
||||
DatadogPipeline(self.client.pipeline())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Deref for DatadogPipeline<'a> {
|
||||
type Target = DatadogClient<PipelineSink<'a, UdpEventSink>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> DerefMut for DatadogPipeline<'a> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> DatadogPipeline<'a> {
|
||||
fn send_timer_histogram(
|
||||
&mut self,
|
||||
histogram_report: &HistogramReport,
|
||||
tags: &Option<Vec<&str>>,
|
||||
) {
|
||||
let name = histogram_report.name();
|
||||
|
||||
let precision = histogram_report.sample_precision();
|
||||
|
||||
let factor = match precision {
|
||||
Precision::Centisecond => 10f64,
|
||||
Precision::Millisecond => 1f64,
|
||||
Precision::Microsecond => 0.001f64,
|
||||
Precision::Nanosecond => 0.000_001f64,
|
||||
};
|
||||
|
||||
for (value, frequency) in histogram_report.histogram.iter() {
|
||||
self.timer_at_rate(
|
||||
name,
|
||||
*value as f64 * factor,
|
||||
1f64 / (*frequency as f64),
|
||||
tags,
|
||||
);
|
||||
self.distribution_at_rate(
|
||||
name,
|
||||
*value as f64 * factor,
|
||||
1f64 / (*frequency as f64),
|
||||
tags,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn send_count_histogram(
|
||||
&mut self,
|
||||
metric_name: &str,
|
||||
histogram: &Histogram<usize>,
|
||||
tags: &Option<Vec<&str>>,
|
||||
) {
|
||||
for (value, frequency) in histogram.iter() {
|
||||
self.histogram_at_rate(metric_name, *value as f64, 1f64 / (*frequency as f64), tags);
|
||||
self.distribution_at_rate(metric_name, *value as f64, 1f64 / (*frequency as f64), tags);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets a vector of (metric_names, values)
|
||||
fn get_value_metrics() -> Vec<(&'static str, f32)> {
|
||||
let mut value_metrics = Vec::new();
|
||||
|
||||
value_metrics.extend(get_process_metrics());
|
||||
|
||||
value_metrics
|
||||
}
|
||||
|
||||
/// Gets a vector of (metric_names, values) for current process metrics
|
||||
fn get_process_metrics() -> Vec<(&'static str, f32)> {
|
||||
let mut value_metrics = Vec::new();
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref CURRENT_PROCESS: Mutex<Process> = Mutex::new(Process::current().expect("Can't get current process"));
|
||||
}
|
||||
|
||||
let mut current_process = CURRENT_PROCESS.lock();
|
||||
|
||||
match current_process.memory_percent() {
|
||||
Ok(memory_percentage) => {
|
||||
value_metrics.push(("calling.system.memory.pc", memory_percentage));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error getting memory percentage {:?}", e)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")] // open_files is not yet implemented for macos
|
||||
match current_process.open_files() {
|
||||
Ok(open_files) => {
|
||||
value_metrics.push(("calling.system.memory.fd.count", open_files.len() as f32));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error getting fd count {:?}", e)
|
||||
}
|
||||
}
|
||||
|
||||
match current_process.cpu_percent() {
|
||||
Ok(cpu_percentage) => {
|
||||
value_metrics.push(("calling.system.cpu.pc", cpu_percentage));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error getting cpu percentage {:?}", e)
|
||||
}
|
||||
}
|
||||
|
||||
value_metrics
|
||||
}
|
||||
180
src/protos.rs
Normal file
180
src/protos.rs
Normal file
@ -0,0 +1,180 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct DeviceToDevice {
|
||||
#[prost(bytes, optional, tag = "1")]
|
||||
pub group_id: ::std::option::Option<std::vec::Vec<u8>>,
|
||||
#[prost(message, optional, tag = "2")]
|
||||
pub media_key: ::std::option::Option<device_to_device::MediaKey>,
|
||||
#[prost(message, optional, tag = "3")]
|
||||
pub heartbeat: ::std::option::Option<device_to_device::Heartbeat>,
|
||||
#[prost(message, optional, tag = "4")]
|
||||
pub leaving: ::std::option::Option<device_to_device::Leaving>,
|
||||
}
|
||||
pub mod device_to_device {
|
||||
/// Sent over signaling
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct MediaKey {
|
||||
#[prost(uint32, optional, tag = "1")]
|
||||
pub ratchet_counter: ::std::option::Option<u32>,
|
||||
#[prost(bytes, optional, tag = "2")]
|
||||
pub secret: ::std::option::Option<std::vec::Vec<u8>>,
|
||||
#[prost(uint32, optional, tag = "3")]
|
||||
pub demux_id: ::std::option::Option<u32>,
|
||||
}
|
||||
/// Sent over RTP data
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Heartbeat {
|
||||
#[prost(bool, optional, tag = "1")]
|
||||
pub audio_muted: ::std::option::Option<bool>,
|
||||
#[prost(bool, optional, tag = "2")]
|
||||
pub video_muted: ::std::option::Option<bool>,
|
||||
}
|
||||
/// Sent over RTP data channel *and* signaling
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Leaving {
|
||||
/// When sent over signaling, you must indicate which device is leaving.
|
||||
#[prost(uint32, optional, tag = "1")]
|
||||
pub demux_id: ::std::option::Option<u32>,
|
||||
}
|
||||
}
|
||||
/// Called RtpDataChannelMessage in the SFU's RtpDataChannelMessages.proto
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct DeviceToSfu {
|
||||
/// Called resolutionRequest in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(message, optional, tag = "1")]
|
||||
pub video_request: ::std::option::Option<device_to_sfu::VideoRequestMessage>,
|
||||
/// Called endpointMessage in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(message, optional, tag = "7")]
|
||||
pub send_to_devices: ::std::option::Option<device_to_sfu::SendToDevices>,
|
||||
}
|
||||
pub mod device_to_sfu {
|
||||
/// Called ResolutionRequestMessage in the SFU's RtpDataChannelMessages.proto
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct VideoRequestMessage {
|
||||
#[prost(message, repeated, tag = "1")]
|
||||
pub requests: ::std::vec::Vec<video_request_message::VideoRequest>,
|
||||
/// Don't send more than this many, even if they are in the list above
|
||||
/// or if they aren't in the list above.
|
||||
/// Called lastN in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(uint32, optional, tag = "2")]
|
||||
pub max: ::std::option::Option<u32>,
|
||||
}
|
||||
pub mod video_request_message {
|
||||
/// Called Constraint in the SFU's RtpDataChannelMessages.proto
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct VideoRequest {
|
||||
/// Functionally the same as a DemuxId, but oddly different.
|
||||
/// Called endpointSuffix in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(uint64, optional, tag = "1")]
|
||||
pub id: ::std::option::Option<u64>,
|
||||
/// Called idealHeight in the SFU's RtpDataChannelMessages.proto
|
||||
/// This does not allocate bits eagerly.
|
||||
#[prost(uint32, optional, tag = "2")]
|
||||
pub height: ::std::option::Option<u32>,
|
||||
}
|
||||
}
|
||||
/// Called EndpointToEndpointMessage in the SFU's RtpDataChannelMessages.proto
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct SendToDevices {
|
||||
/// Functionally the same as a DemuxId, but oddly different.
|
||||
/// If not set, it will broadcast
|
||||
#[prost(string, optional, tag = "1")]
|
||||
pub long_device_id: ::std::option::Option<std::string::String>,
|
||||
#[prost(bytes, optional, tag = "3")]
|
||||
pub payload: ::std::option::Option<std::vec::Vec<u8>>,
|
||||
}
|
||||
}
|
||||
/// Called RtpDataChannelMessage in the SFU's RtpDataChannelMessages.proto
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct SfuToDevice {
|
||||
/// Called senderVideoConstraint in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(message, optional, tag = "2")]
|
||||
pub video_request: ::std::option::Option<sfu_to_device::VideoRequest>,
|
||||
/// Called endpointConnectionStatus in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(message, optional, tag = "3")]
|
||||
pub device_connection_status: ::std::option::Option<sfu_to_device::DeviceConnectionStatus>,
|
||||
/// Called dominantSpeaker in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(message, optional, tag = "4")]
|
||||
pub speaker: ::std::option::Option<sfu_to_device::Speaker>,
|
||||
/// Called forwardedEndpoints in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(message, optional, tag = "5")]
|
||||
pub forwarding: ::std::option::Option<sfu_to_device::Forwarding>,
|
||||
/// Called endpointChanged in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(message, optional, tag = "6")]
|
||||
pub device_joined_or_left: ::std::option::Option<sfu_to_device::DeviceJoinedOrLeft>,
|
||||
/// Called forwardedEndpoints in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(message, optional, tag = "7")]
|
||||
pub received_from_device: ::std::option::Option<sfu_to_device::ReceivedFromDevice>,
|
||||
}
|
||||
pub mod sfu_to_device {
|
||||
/// Called EndpointChangedMessage in the SFU's RtpDataChannelMessages.proto
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct DeviceJoinedOrLeft {
|
||||
/// Functionally the same as a DemuxId, but oddly different.
|
||||
/// Called endpoint in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(string, optional, tag = "1")]
|
||||
pub long_device_id: ::std::option::Option<std::string::String>,
|
||||
#[prost(bool, optional, tag = "2")]
|
||||
pub joined: ::std::option::Option<bool>,
|
||||
}
|
||||
/// The current primary/active speaker as calculated by rather complex logic by the SFU.
|
||||
/// Called DominantSpeakerMessage in the SFU's RtpDataChannelMessages.proto
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Speaker {
|
||||
/// Functionally the same as a DemuxId, but oddly different.
|
||||
/// Called endpoint in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(string, optional, tag = "1")]
|
||||
pub id: ::std::option::Option<std::string::String>,
|
||||
}
|
||||
/// The resolution the SFU wants you to send to it to satisfy the requests
|
||||
/// of all of the other devices.
|
||||
/// Called SenderVideoConstraintMessage in the SFU's RtpDataChannelMessages.proto
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct VideoRequest {
|
||||
/// Called idealHeight in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(uint32, optional, tag = "1")]
|
||||
pub height: ::std::option::Option<u32>,
|
||||
}
|
||||
/// Called EndpointConnectionStatusMessage in the SFU's RtpDataChannelMessages.proto
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct DeviceConnectionStatus {
|
||||
/// Functionally the same as a DemuxId, but oddly different.
|
||||
#[prost(string, optional, tag = "1")]
|
||||
pub long_device_id: ::std::option::Option<std::string::String>,
|
||||
#[prost(bool, optional, tag = "2")]
|
||||
pub active: ::std::option::Option<bool>,
|
||||
}
|
||||
/// Called EndpointToEndpointMessage in the SFU's RtpDataChannelMessages.proto
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct ReceivedFromDevice {
|
||||
/// Functionally the same as a DemuxId, but oddly different.
|
||||
#[prost(string, optional, tag = "1")]
|
||||
pub long_device_id: ::std::option::Option<std::string::String>,
|
||||
#[prost(bytes, optional, tag = "3")]
|
||||
pub payload: ::std::option::Option<std::vec::Vec<u8>>,
|
||||
}
|
||||
/// Called ForwardedEndpointsMessage in the SFU's RtpDataChannelMessages.proto
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Forwarding {
|
||||
/// The remote devices from which video is being forwarded.
|
||||
/// Functionally the same as a DemuxId, but oddly different.
|
||||
/// Called suffixEndpointsBeingForwarded in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(uint64, repeated, packed = "false", tag = "1")]
|
||||
pub video_forwarded_short_device_ids: ::std::vec::Vec<u64>,
|
||||
/// The remote devices from which video is being forwarded, but from which
|
||||
/// video was not being forwarded in the last Forwarding message.
|
||||
/// Functionally the same as a DemuxId, but oddly different.
|
||||
/// Called suffixEndpointsEnteringLastN in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(uint64, repeated, packed = "false", tag = "2")]
|
||||
pub newly_forwarded_short_device_ids: ::std::vec::Vec<u64>,
|
||||
/// All the of devices. The same as the "Devices" messages.
|
||||
/// Functionally the same as a DemuxId, but oddly different.
|
||||
/// Called suffixEndpointsInConference in the SFU's RtpDataChannelMessages.proto
|
||||
#[prost(uint64, repeated, packed = "false", tag = "3")]
|
||||
pub all_devices_short_device_ids: ::std::vec::Vec<u64>,
|
||||
}
|
||||
}
|
||||
2398
src/rtp.rs
Normal file
2398
src/rtp.rs
Normal file
File diff suppressed because it is too large
Load Diff
1274
src/sfu.rs
Normal file
1274
src/sfu.rs
Normal file
File diff suppressed because it is too large
Load Diff
861
src/signaling_server.rs
Normal file
861
src/signaling_server.rs
Normal file
@ -0,0 +1,861 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
//! Implementation of the SFU signaling server. This version is based on warp.
|
||||
//! Supported REST APIs:
|
||||
//! GET /about/health
|
||||
//! GET /v1/info
|
||||
//! GET /v1/call/$call_id/clients
|
||||
//! POST /v1/call/$call_id/client/$demux_id (join)
|
||||
//! DELETE /v1/call/$call_id/client/$demux_id (leave)
|
||||
|
||||
use std::{
|
||||
convert::{Infallible, TryInto},
|
||||
error::Error,
|
||||
net::IpAddr,
|
||||
str::{self, FromStr},
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use hex::FromHex;
|
||||
use log::*;
|
||||
use parking_lot::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::oneshot::Receiver;
|
||||
use warp::{http::StatusCode, Filter, Reply};
|
||||
|
||||
use crate::{call, common, common::Instant, config, ice, sfu, sfu::Sfu};
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InfoResponse {
|
||||
pub direct_access_ip: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientsResponse {
|
||||
pub endpoint_ids: Vec<String>, // Aka active_speaker_ids, a concatenation of user_id + '-' + resolution_request_id.
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct JoinRequest {
|
||||
pub endpoint_id: String, // Aka active_speaker_id, a concatenation of user_id + '-' + resolution_request_id.
|
||||
pub client_ice_ufrag: String,
|
||||
pub client_dtls_fingerprint: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct JoinResponse {
|
||||
pub server_ip: String,
|
||||
pub server_port: u16,
|
||||
pub server_ice_ufrag: String,
|
||||
pub server_ice_pwd: String,
|
||||
pub server_dtls_fingerprint: String,
|
||||
}
|
||||
|
||||
/// Struct to support warp rejection (errors) for invalid argument values.
|
||||
#[derive(Debug)]
|
||||
struct InvalidArgument {
|
||||
reason: String,
|
||||
}
|
||||
impl warp::reject::Reject for InvalidArgument {}
|
||||
|
||||
/// Struct to support warp rejection (errors) for internal errors.
|
||||
#[derive(Debug)]
|
||||
struct InternalError {
|
||||
reason: String,
|
||||
}
|
||||
impl warp::reject::Reject for InternalError {}
|
||||
|
||||
/// Get a call_id (Vec<u8>) from a string hex value.
|
||||
fn call_id_from_hex(call_id: &str) -> Result<sfu::CallId> {
|
||||
if call_id.is_empty() {
|
||||
return Err(anyhow!("call_id is empty"));
|
||||
}
|
||||
|
||||
Ok(Vec::from_hex(call_id)
|
||||
.map_err(|_| anyhow!("call_id is invalid"))?
|
||||
.into())
|
||||
}
|
||||
|
||||
/// Parse a user_id hash and resolution_request_id from the provided endpoint. The endpoint
|
||||
/// string has the following format: $hex(sha256(user_id))-$resolution_request_id
|
||||
///
|
||||
/// ```
|
||||
/// use calling_server::signaling_server::parse_user_id_and_resolution_request_id_from_endpoint_id;
|
||||
///
|
||||
/// assert_eq!(parse_user_id_and_resolution_request_id_from_endpoint_id("abcdef-0").unwrap(), (vec![171, 205, 239].into(), 0));
|
||||
/// assert_eq!(parse_user_id_and_resolution_request_id_from_endpoint_id("abcdef-12345").unwrap(), (vec![171, 205, 239].into(), 12345));
|
||||
/// assert_eq!(parse_user_id_and_resolution_request_id_from_endpoint_id("").is_err(), true);
|
||||
/// assert_eq!(parse_user_id_and_resolution_request_id_from_endpoint_id("abcdef-").is_err(), true);
|
||||
/// assert_eq!(parse_user_id_and_resolution_request_id_from_endpoint_id("abcdef-a").is_err(), true);
|
||||
/// assert_eq!(parse_user_id_and_resolution_request_id_from_endpoint_id("abcdef-1-").is_err(), true);
|
||||
/// ```
|
||||
pub fn parse_user_id_and_resolution_request_id_from_endpoint_id(
|
||||
endpoint_id: &str,
|
||||
) -> Result<(sfu::UserId, u64)> {
|
||||
if let [user_id_hex, suffix] = endpoint_id.splitn(2, '-').collect::<Vec<_>>()[..] {
|
||||
let resolution_request_id = u64::from_str(suffix)?;
|
||||
let user_id = Vec::from_hex(&user_id_hex)?;
|
||||
|
||||
Ok((user_id.into(), resolution_request_id))
|
||||
} else {
|
||||
Err(anyhow!("malformed endpoint_id"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtain information about the server.
|
||||
async fn get_info(
|
||||
config: &'static config::Config,
|
||||
_sfu: Arc<Mutex<Sfu>>,
|
||||
) -> Result<warp::reply::Response, warp::Rejection> {
|
||||
trace!("get_info():");
|
||||
|
||||
if let Some(private_ip) = &config.signaling_ip {
|
||||
let response = InfoResponse {
|
||||
direct_access_ip: private_ip.to_string(),
|
||||
};
|
||||
|
||||
Ok(warp::reply::json(&response).into_response())
|
||||
} else {
|
||||
Err(warp::reject::custom(InternalError {
|
||||
reason: "private_ip not set".to_string(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the list of clients for a given call. Returns "Not Found" if
|
||||
/// the call does not exist or an empty list if there are no clients
|
||||
/// currently in the call.
|
||||
async fn get_clients(
|
||||
call_id: String,
|
||||
_config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
) -> Result<warp::reply::Response, warp::Rejection> {
|
||||
trace!("get_clients(): {}", call_id);
|
||||
|
||||
let call_id = call_id_from_hex(&call_id).map_err(|err| {
|
||||
warp::reject::custom(InvalidArgument {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
})?;
|
||||
|
||||
let sfu = sfu.lock();
|
||||
if let Some(signaling) = sfu.get_call_signaling_info(call_id) {
|
||||
let response = ClientsResponse {
|
||||
endpoint_ids: signaling
|
||||
.client_ids
|
||||
.into_iter()
|
||||
.map(|(_demux_id, active_speaker_id)| active_speaker_id)
|
||||
.collect(),
|
||||
};
|
||||
|
||||
Ok(warp::reply::json(&response).into_response())
|
||||
} else {
|
||||
Ok(StatusCode::NOT_FOUND.into_response())
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles a request for a client to join a call.
|
||||
async fn join(
|
||||
call_id: String,
|
||||
demux_id: u32,
|
||||
config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
request: JoinRequest,
|
||||
) -> Result<warp::reply::Response, warp::Rejection> {
|
||||
trace!("join(): {} {}", call_id, demux_id);
|
||||
|
||||
let call_id = call_id_from_hex(&call_id).map_err(|err| {
|
||||
warp::reject::custom(InvalidArgument {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
})?;
|
||||
let demux_id = demux_id.try_into().map_err(|err: call::Error| {
|
||||
warp::reject::custom(InvalidArgument {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
})?;
|
||||
|
||||
let (user_id, resolution_request_id) =
|
||||
parse_user_id_and_resolution_request_id_from_endpoint_id(&request.endpoint_id).map_err(
|
||||
|err| {
|
||||
warp::reject::custom(InvalidArgument {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
},
|
||||
)?;
|
||||
|
||||
let client_dtls_fingerprint = common::colon_separated_hexstring_to_array(
|
||||
&request.client_dtls_fingerprint,
|
||||
)
|
||||
.map_err(|err| {
|
||||
warp::reject::custom(InvalidArgument {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
})?;
|
||||
|
||||
let server_ice_ufrag = ice::random_ufrag();
|
||||
let server_ice_pwd = ice::random_pwd();
|
||||
|
||||
let mut sfu = sfu.lock();
|
||||
match sfu.get_or_create_call_and_add_client(
|
||||
call_id,
|
||||
&user_id,
|
||||
resolution_request_id,
|
||||
request.endpoint_id,
|
||||
demux_id,
|
||||
server_ice_ufrag.to_string(),
|
||||
server_ice_pwd.to_string(),
|
||||
request.client_ice_ufrag,
|
||||
client_dtls_fingerprint,
|
||||
) {
|
||||
Ok(_) => {
|
||||
let media_addr = config::get_server_media_address(config);
|
||||
let response = JoinResponse {
|
||||
server_ip: media_addr.ip().to_string(),
|
||||
server_port: media_addr.port(),
|
||||
server_ice_ufrag,
|
||||
server_ice_pwd,
|
||||
server_dtls_fingerprint: common::bytes_to_colon_separated_hexstring(
|
||||
sfu.server_dtls_fingerprint(),
|
||||
),
|
||||
};
|
||||
|
||||
Ok(warp::reply::json(&response).into_response())
|
||||
}
|
||||
Err(err) => {
|
||||
error!("client failed to join call {}", err.to_string());
|
||||
if err == sfu::SfuError::DuplicateDemuxIdDetected {
|
||||
// Invalid argument because the demux_id is a duplicate.
|
||||
Err(warp::reject::custom(InvalidArgument {
|
||||
reason: err.to_string(),
|
||||
}))
|
||||
} else {
|
||||
Err(warp::reject::custom(InternalError {
|
||||
reason: format!("failed to add client to call {}", err.to_string()),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles a request for a client to leave a call.
|
||||
async fn leave(
|
||||
call_id: String,
|
||||
demux_id: u32,
|
||||
_config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
) -> Result<warp::reply::Response, warp::Rejection> {
|
||||
trace!("leave(): {} {}", call_id, demux_id);
|
||||
|
||||
let call_id = call_id_from_hex(&call_id).map_err(|err| {
|
||||
warp::reject::custom(InvalidArgument {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
})?;
|
||||
let demux_id = demux_id.try_into().map_err(|err: call::Error| {
|
||||
warp::reject::custom(InvalidArgument {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
})?;
|
||||
|
||||
sfu.lock()
|
||||
.remove_client_from_call(Instant::now(), call_id, demux_id);
|
||||
|
||||
Ok(StatusCode::NO_CONTENT.into_response())
|
||||
}
|
||||
|
||||
/// Map rejections to a format that should be presented in the response.
|
||||
async fn rejection_handler(rejection: warp::Rejection) -> Result<impl Reply, Infallible> {
|
||||
let code;
|
||||
let message;
|
||||
|
||||
if rejection.is_not_found() {
|
||||
code = StatusCode::NOT_FOUND;
|
||||
message = "".to_string();
|
||||
} else if let Some(r) = rejection.find::<InvalidArgument>() {
|
||||
// Our detection of invalid request arguments.
|
||||
code = StatusCode::BAD_REQUEST;
|
||||
message = r.reason.to_string();
|
||||
} else if let Some(r) = rejection.find::<InternalError>() {
|
||||
// Our internal errors.
|
||||
code = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
message = r.reason.to_string();
|
||||
} else if let Some(e) = rejection.find::<warp::filters::body::BodyDeserializeError>() {
|
||||
// Warp's detection of invalid requests (when deserializing json).
|
||||
code = StatusCode::BAD_REQUEST;
|
||||
message = match e.source() {
|
||||
Some(cause) => cause.to_string(),
|
||||
None => "".to_string(),
|
||||
};
|
||||
} else {
|
||||
code = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
message = "unknown".to_string();
|
||||
}
|
||||
|
||||
Ok(warp::reply::with_status(message, code))
|
||||
}
|
||||
|
||||
/// A warp filter for providing the config for a route.
|
||||
fn with_config(
|
||||
config: &'static config::Config,
|
||||
) -> impl Filter<Extract = (&'static config::Config,), Error = std::convert::Infallible> + Clone {
|
||||
warp::any().map(move || config)
|
||||
}
|
||||
|
||||
/// A warp filter for extracting the Sfu state for a route.
|
||||
fn with_sfu(
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
) -> impl Filter<Extract = (Arc<Mutex<Sfu>>,), Error = std::convert::Infallible> + Clone {
|
||||
warp::any().map(move || sfu.clone())
|
||||
}
|
||||
|
||||
/// Filter to support the "GET /about/health" API for the server and testing.
|
||||
fn get_health_api(
|
||||
is_healthy: Arc<AtomicBool>,
|
||||
) -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
warp::path!("about" / "health")
|
||||
.and(warp::get())
|
||||
.map(move || {
|
||||
if is_healthy.load(Ordering::Relaxed) {
|
||||
Ok(StatusCode::OK.into_response())
|
||||
} else {
|
||||
Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Filter to support the "GET /v1/info" API for the server and testing.
|
||||
fn get_info_api(
|
||||
config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
) -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
warp::path!("v1" / "info")
|
||||
.and(warp::get())
|
||||
.and(with_config(config))
|
||||
.and(with_sfu(sfu))
|
||||
.and_then(get_info)
|
||||
}
|
||||
|
||||
/// Filter to support the "GET /v1/call/$call_id/clients" API for the server and testing.
|
||||
fn get_clients_api(
|
||||
config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
) -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
warp::path!("v1" / "call" / String / "clients")
|
||||
.and(warp::get())
|
||||
.and(with_config(config))
|
||||
.and(with_sfu(sfu))
|
||||
.and_then(get_clients)
|
||||
}
|
||||
|
||||
/// Filter to support the "POST /v1/call/$call_id/client/$demux_id" API for the server and testing.
|
||||
fn join_api(
|
||||
config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
) -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
warp::path!("v1" / "call" / String / "client" / u32)
|
||||
.and(warp::post())
|
||||
.and(with_config(config))
|
||||
.and(with_sfu(sfu))
|
||||
.and(warp::body::json())
|
||||
.and_then(join)
|
||||
}
|
||||
|
||||
/// Filter to support the "DELETE /v1/call/$call_id/client/$demux_id" API for the server and testing.
|
||||
fn leave_api(
|
||||
config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
) -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
warp::path!("v1" / "call" / String / "client" / u32)
|
||||
.and(warp::delete())
|
||||
.and(with_config(config))
|
||||
.and(with_sfu(sfu))
|
||||
.and_then(leave)
|
||||
}
|
||||
|
||||
/// The overall signaling api combined as a single filter for the server and testing.
|
||||
pub fn signaling_api(
|
||||
config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
is_healthy: Arc<AtomicBool>,
|
||||
) -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
||||
get_health_api(is_healthy)
|
||||
.or(get_info_api(config, sfu.clone()))
|
||||
.or(get_clients_api(config, sfu.clone()))
|
||||
.or(join_api(config, sfu.clone()))
|
||||
.or(leave_api(config, sfu))
|
||||
}
|
||||
|
||||
pub async fn start(
|
||||
config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
ender_rx: Receiver<()>,
|
||||
is_healthy: Arc<AtomicBool>,
|
||||
) -> Result<()> {
|
||||
let api = signaling_api(config, sfu, is_healthy)
|
||||
.with(warp::log("calling_service"))
|
||||
.recover(rejection_handler);
|
||||
|
||||
let (addr, server) = warp::serve(api).bind_with_graceful_shutdown(
|
||||
(IpAddr::from_str(&config.binding_ip)?, config.signaling_port),
|
||||
async {
|
||||
let _ = ender_rx.await;
|
||||
},
|
||||
);
|
||||
|
||||
info!("signaling_server ready: {}", addr);
|
||||
server.await;
|
||||
|
||||
info!("signaling_server shutdown");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod signaling_server_tests {
|
||||
use std::convert::TryInto;
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use tokio::sync::oneshot;
|
||||
use warp::test::request;
|
||||
|
||||
use super::*;
|
||||
use crate::sfu::DemuxId;
|
||||
|
||||
const CALL_ID: &str = "fe076d76bffb54b1";
|
||||
const DTLS_FINGERPRINT: &str = "6F:A3:AE:74:FD:FA:2C:85:1B:19:52:55:D1:AE:4E:08:84:42:25:B9:7D:03:C1:62:C6:49:2B:C7:DC:0E:5E:09";
|
||||
const ENDPOINT_ID_1: &str =
|
||||
"7ab9bbf0b71f81598ae1b592aaf82f9b20b638142a9610c3e37965bec7519112-5287417572362992825";
|
||||
const ENDPOINT_ID_2: &str =
|
||||
"b25387a93fd65599bacae4a8f8726e9e818ecf0bec3360593fe542cdb8e611a3-7715148009648537058";
|
||||
const UFRAG: &str = "Ouub";
|
||||
|
||||
lazy_static! {
|
||||
static ref DEFAULT_CONFIG: config::Config = config::default_test_config();
|
||||
|
||||
// Load a config with no signaling_ip set.
|
||||
static ref BAD_IP_CONFIG: config::Config = {
|
||||
let mut config = config::default_test_config();
|
||||
config.signaling_ip = None;
|
||||
config
|
||||
};
|
||||
}
|
||||
|
||||
fn new_sfu(now: Instant, config: &'static config::Config) -> Arc<Mutex<Sfu>> {
|
||||
Arc::new(Mutex::new(
|
||||
Sfu::new(now, config).expect("Sfu::new should work"),
|
||||
))
|
||||
}
|
||||
|
||||
fn add_client_to_sfu(
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
call_id: &str,
|
||||
endpoint_id: &str,
|
||||
demux_id: DemuxId,
|
||||
client_ice_ufrag: &str,
|
||||
client_dtls_fingerprint: &str,
|
||||
) {
|
||||
let call_id = call_id_from_hex(call_id).unwrap();
|
||||
let (user_id, resolution_request_id) =
|
||||
parse_user_id_and_resolution_request_id_from_endpoint_id(endpoint_id).unwrap();
|
||||
let client_dtls_fingerprint =
|
||||
common::colon_separated_hexstring_to_array(client_dtls_fingerprint).unwrap();
|
||||
|
||||
sfu.lock()
|
||||
.get_or_create_call_and_add_client(
|
||||
call_id,
|
||||
&user_id,
|
||||
resolution_request_id,
|
||||
endpoint_id.to_string(),
|
||||
demux_id,
|
||||
ice::random_ufrag(),
|
||||
ice::random_pwd(),
|
||||
client_ice_ufrag.to_string(),
|
||||
client_dtls_fingerprint,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn remove_client_from_sfu(sfu: Arc<Mutex<Sfu>>, call_id: &str, demux_id: DemuxId) {
|
||||
let call_id = call_id_from_hex(call_id).unwrap();
|
||||
|
||||
sfu.lock()
|
||||
.remove_client_from_call(Instant::now(), call_id, demux_id);
|
||||
}
|
||||
|
||||
fn check_call_exists_in_sfu(sfu: Arc<Mutex<Sfu>>, call_id: &str) -> bool {
|
||||
sfu.lock()
|
||||
.get_call_signaling_info(call_id_from_hex(call_id).unwrap())
|
||||
.is_some()
|
||||
}
|
||||
|
||||
fn get_client_count_in_call_from_sfu(sfu: Arc<Mutex<Sfu>>, call_id: &str) -> usize {
|
||||
if let Some(signaling) = sfu
|
||||
.lock()
|
||||
.get_call_signaling_info(call_id_from_hex(call_id).unwrap())
|
||||
{
|
||||
signaling.size
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_start() {
|
||||
let config = &DEFAULT_CONFIG;
|
||||
let sfu = new_sfu(Instant::now(), config);
|
||||
let is_healthy = Arc::new(AtomicBool::new(true));
|
||||
|
||||
let (signaling_ender_tx, signaling_ender_rx) = oneshot::channel();
|
||||
|
||||
let server_handle =
|
||||
tokio::spawn(async move { start(config, sfu, signaling_ender_rx, is_healthy).await });
|
||||
|
||||
let closer_handle = tokio::spawn(async move { signaling_ender_tx.send(()) });
|
||||
|
||||
let (server_result, _) = tokio::join!(server_handle, closer_handle,);
|
||||
|
||||
assert!(server_result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_health() {
|
||||
let config = &DEFAULT_CONFIG;
|
||||
let sfu = new_sfu(Instant::now(), config);
|
||||
let is_healthy = Arc::new(AtomicBool::new(true));
|
||||
|
||||
let api = signaling_api(config, sfu, is_healthy.clone()).recover(rejection_handler);
|
||||
|
||||
let response = request()
|
||||
.method("GET")
|
||||
.path("/about/health")
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
is_healthy.store(false, Ordering::Relaxed);
|
||||
|
||||
let response = request()
|
||||
.method("GET")
|
||||
.path("/about/health")
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_info() {
|
||||
let config = &DEFAULT_CONFIG;
|
||||
let sfu = new_sfu(Instant::now(), config);
|
||||
let is_healthy = Arc::new(AtomicBool::new(true));
|
||||
|
||||
let api = signaling_api(config, sfu, is_healthy).recover(rejection_handler);
|
||||
|
||||
let response = request().method("GET").path("/v1/info").reply(&api).await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
response.headers().get("content-type").unwrap(),
|
||||
"application/json"
|
||||
);
|
||||
assert_eq!(response.body(), r#"{"directAccessIp":"127.0.0.1"}"#);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_info_bad_ip() {
|
||||
let config = &BAD_IP_CONFIG;
|
||||
let sfu = new_sfu(Instant::now(), config);
|
||||
let is_healthy = Arc::new(AtomicBool::new(true));
|
||||
|
||||
let api = signaling_api(config, sfu, is_healthy).recover(rejection_handler);
|
||||
|
||||
let response = request().method("GET").path("/v1/info").reply(&api).await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_clients() {
|
||||
let config = &DEFAULT_CONFIG;
|
||||
let sfu = new_sfu(Instant::now(), config);
|
||||
let is_healthy = Arc::new(AtomicBool::new(true));
|
||||
|
||||
let api = signaling_api(config, sfu.clone(), is_healthy).recover(rejection_handler);
|
||||
|
||||
let response = request()
|
||||
.method("GET")
|
||||
.path(&format!("/v1/call/{}/clients", CALL_ID))
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
// No clients were added.
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
// Join with client 16.
|
||||
add_client_to_sfu(
|
||||
sfu.clone(),
|
||||
CALL_ID,
|
||||
ENDPOINT_ID_1,
|
||||
16u32.try_into().unwrap(),
|
||||
UFRAG,
|
||||
DTLS_FINGERPRINT,
|
||||
);
|
||||
|
||||
let response = request()
|
||||
.method("GET")
|
||||
.path(&format!("/v1/call/{}/clients", CALL_ID))
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
response.headers().get("content-type").unwrap(),
|
||||
"application/json"
|
||||
);
|
||||
assert_eq!(
|
||||
response.body(),
|
||||
&format!("{{\"endpointIds\":[\"{}\"]}}", ENDPOINT_ID_1)
|
||||
);
|
||||
|
||||
// Join with client 32.
|
||||
add_client_to_sfu(
|
||||
sfu.clone(),
|
||||
CALL_ID,
|
||||
ENDPOINT_ID_2,
|
||||
32u32.try_into().unwrap(),
|
||||
UFRAG,
|
||||
DTLS_FINGERPRINT,
|
||||
);
|
||||
|
||||
let response = request()
|
||||
.method("GET")
|
||||
.path(&format!("/v1/call/{}/clients", CALL_ID))
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
response.body(),
|
||||
&format!(
|
||||
"{{\"endpointIds\":[\"{}\",\"{}\"]}}",
|
||||
ENDPOINT_ID_1, ENDPOINT_ID_2
|
||||
)
|
||||
);
|
||||
|
||||
remove_client_from_sfu(sfu.clone(), CALL_ID, 16u32.try_into().unwrap());
|
||||
|
||||
let response = request()
|
||||
.method("GET")
|
||||
.path(&format!("/v1/call/{}/clients", CALL_ID))
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
response.body(),
|
||||
&format!("{{\"endpointIds\":[\"{}\"]}}", ENDPOINT_ID_2)
|
||||
);
|
||||
|
||||
remove_client_from_sfu(sfu.clone(), CALL_ID, 32u32.try_into().unwrap());
|
||||
|
||||
let response = request()
|
||||
.method("GET")
|
||||
.path(&format!("/v1/call/{}/clients", CALL_ID))
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert_eq!(response.body(), "{\"endpointIds\":[]}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_join() {
|
||||
let config = &DEFAULT_CONFIG;
|
||||
let sfu = new_sfu(Instant::now(), config);
|
||||
let is_healthy = Arc::new(AtomicBool::new(true));
|
||||
|
||||
let api = signaling_api(config, sfu.clone(), is_healthy).recover(rejection_handler);
|
||||
|
||||
// Join with a invalid DemuxId.
|
||||
let response = request()
|
||||
.method("POST")
|
||||
.path(&format!("/v1/call/{}/client/{}", CALL_ID, 1))
|
||||
.json(&JoinRequest {
|
||||
endpoint_id: ENDPOINT_ID_1.to_string(),
|
||||
client_ice_ufrag: UFRAG.to_string(),
|
||||
client_dtls_fingerprint: DTLS_FINGERPRINT.to_string(),
|
||||
})
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Join with an invalid CallId.
|
||||
let response = request()
|
||||
.method("POST")
|
||||
.path(&format!("/v1/call/{}/client/{}", "INVALIDNOTHEX", 16))
|
||||
.json(&JoinRequest {
|
||||
endpoint_id: ENDPOINT_ID_1.to_string(),
|
||||
client_ice_ufrag: UFRAG.to_string(),
|
||||
client_dtls_fingerprint: DTLS_FINGERPRINT.to_string(),
|
||||
})
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Join with an invalid endpoint_id.
|
||||
let response = request()
|
||||
.method("POST")
|
||||
.path(&format!("/v1/call/{}/client/{}", CALL_ID, 16))
|
||||
.json(&JoinRequest {
|
||||
endpoint_id: "MALFORMEDNOHYPHEN".to_string(),
|
||||
client_ice_ufrag: UFRAG.to_string(),
|
||||
client_dtls_fingerprint: DTLS_FINGERPRINT.to_string(),
|
||||
})
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Join with an invalid endpoint_id.
|
||||
let response = request()
|
||||
.method("POST")
|
||||
.path(&format!("/v1/call/{}/client/{}", CALL_ID, 16))
|
||||
.json(&JoinRequest {
|
||||
endpoint_id: "MALFORMEDNOHYPHEN".to_string(),
|
||||
client_ice_ufrag: UFRAG.to_string(),
|
||||
client_dtls_fingerprint: DTLS_FINGERPRINT.to_string(),
|
||||
})
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Join with an invalid dtls_fingerprint.
|
||||
let response = request()
|
||||
.method("POST")
|
||||
.path(&format!("/v1/call/{}/client/{}", CALL_ID, 16))
|
||||
.json(&JoinRequest {
|
||||
endpoint_id: ENDPOINT_ID_1.to_string(),
|
||||
client_ice_ufrag: UFRAG.to_string(),
|
||||
client_dtls_fingerprint: "INVALIDNOCOLONS".to_string(),
|
||||
})
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Join with good parameters.
|
||||
let response = request()
|
||||
.method("POST")
|
||||
.path(&format!("/v1/call/{}/client/{}", CALL_ID, 16))
|
||||
.json(&JoinRequest {
|
||||
endpoint_id: ENDPOINT_ID_1.to_string(),
|
||||
client_ice_ufrag: UFRAG.to_string(),
|
||||
client_dtls_fingerprint: DTLS_FINGERPRINT.to_string(),
|
||||
})
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
response.headers().get("content-type").unwrap(),
|
||||
"application/json"
|
||||
);
|
||||
|
||||
let response: JoinResponse = serde_json::from_slice(response.body()).unwrap();
|
||||
assert_eq!(response.server_ip, "127.0.0.1");
|
||||
assert_eq!(response.server_port, 10000);
|
||||
assert_eq!(response.server_dtls_fingerprint.chars().count(), 95);
|
||||
|
||||
assert!(
|
||||
check_call_exists_in_sfu(sfu.clone(), CALL_ID),
|
||||
"Call doesn't exist"
|
||||
);
|
||||
assert_eq!(get_client_count_in_call_from_sfu(sfu.clone(), CALL_ID), 1);
|
||||
|
||||
// Attempt to join again using the same demux_id (should be a bad request).
|
||||
let response = request()
|
||||
.method("POST")
|
||||
.path(&format!("/v1/call/{}/client/{}", CALL_ID, 16))
|
||||
.json(&JoinRequest {
|
||||
endpoint_id: ENDPOINT_ID_1.to_string(),
|
||||
client_ice_ufrag: UFRAG.to_string(),
|
||||
client_dtls_fingerprint: DTLS_FINGERPRINT.to_string(),
|
||||
})
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
assert_eq!(get_client_count_in_call_from_sfu(sfu.clone(), CALL_ID), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_leave() {
|
||||
let config = &DEFAULT_CONFIG;
|
||||
let sfu = new_sfu(Instant::now(), config);
|
||||
let is_healthy = Arc::new(AtomicBool::new(true));
|
||||
|
||||
let api = signaling_api(config, sfu.clone(), is_healthy).recover(rejection_handler);
|
||||
|
||||
// Join with client 16 and verify.
|
||||
add_client_to_sfu(
|
||||
sfu.clone(),
|
||||
CALL_ID,
|
||||
ENDPOINT_ID_1,
|
||||
16u32.try_into().unwrap(),
|
||||
UFRAG,
|
||||
DTLS_FINGERPRINT,
|
||||
);
|
||||
assert!(
|
||||
check_call_exists_in_sfu(sfu.clone(), CALL_ID),
|
||||
"Call doesn't exist"
|
||||
);
|
||||
assert_eq!(get_client_count_in_call_from_sfu(sfu.clone(), CALL_ID), 1);
|
||||
|
||||
let response = request()
|
||||
.method("DELETE")
|
||||
.path(&format!("/v1/call/{}/client/{}", CALL_ID, 16))
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::NO_CONTENT);
|
||||
|
||||
assert!(
|
||||
check_call_exists_in_sfu(sfu.clone(), CALL_ID),
|
||||
"Call doesn't exist"
|
||||
);
|
||||
assert_eq!(get_client_count_in_call_from_sfu(sfu.clone(), CALL_ID), 0);
|
||||
|
||||
// Attempt to leave again (response is indifferent since the client has already left).
|
||||
let response = request()
|
||||
.method("DELETE")
|
||||
.path(&format!("/v1/call/{}/client/{}", CALL_ID, 16))
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::NO_CONTENT);
|
||||
|
||||
// Attempt to leave again from an unknown call.
|
||||
let response = request()
|
||||
.method("DELETE")
|
||||
.path(&format!("/v1/call/1234/client/{}", 16))
|
||||
.reply(&api)
|
||||
.await;
|
||||
|
||||
assert_eq!(response.status(), StatusCode::NO_CONTENT);
|
||||
}
|
||||
}
|
||||
1181
src/transportcc.rs
Normal file
1181
src/transportcc.rs
Normal file
File diff suppressed because it is too large
Load Diff
129
src/udp_server.rs
Normal file
129
src/udp_server.rs
Normal file
@ -0,0 +1,129 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
//! Implementation of the udp server.
|
||||
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use log::*;
|
||||
use parking_lot::Mutex;
|
||||
use tokio::sync::oneshot::Receiver;
|
||||
|
||||
#[cfg(all(feature = "epoll", target_os = "linux"))]
|
||||
mod epoll;
|
||||
#[cfg(all(feature = "epoll", target_os = "linux"))]
|
||||
use epoll::*;
|
||||
#[cfg(not(all(feature = "epoll", target_os = "linux")))]
|
||||
mod generic;
|
||||
#[cfg(not(all(feature = "epoll", target_os = "linux")))]
|
||||
use generic::*;
|
||||
|
||||
use crate::{
|
||||
common::{Duration, Instant},
|
||||
config,
|
||||
sfu::{Sfu, SfuError},
|
||||
};
|
||||
|
||||
pub async fn start(
|
||||
config: &'static config::Config,
|
||||
sfu: Arc<Mutex<Sfu>>,
|
||||
udp_ender_rx: Receiver<()>,
|
||||
is_healthy: Arc<AtomicBool>,
|
||||
) -> Result<()> {
|
||||
let num_udp_threads = config.udp_threads.unwrap_or_else(|| {
|
||||
// Default to N - 1 CPUs, keeping one clear for the HTTP server.
|
||||
// But clamp to 15 so we don't run out of memory or another contended resource.
|
||||
num_cpus::get().clamp(2, 16) - 1
|
||||
});
|
||||
|
||||
let tick_interval = Duration::from_millis(config.tick_interval_ms);
|
||||
|
||||
let local_addr = SocketAddr::new(config.binding_ip.parse()?, config.ice_candidate_port);
|
||||
|
||||
let udp_handler_state = UdpServerState::new(local_addr, num_udp_threads, tick_interval)?;
|
||||
let udp_handler_state_for_tick = udp_handler_state.clone();
|
||||
|
||||
let sfu_for_tick = sfu.clone();
|
||||
|
||||
info!(
|
||||
"udp_server ready: {:?}; starting {} threads",
|
||||
local_addr, num_udp_threads
|
||||
);
|
||||
|
||||
// Spawn (blocking) threads for the UDP server.
|
||||
let udp_packet_handles = udp_handler_state.start_threads(move |sender_addr, data| {
|
||||
time_scope_us!("calling.udp_server.handle_packet");
|
||||
|
||||
trace!(
|
||||
"received packet of {} bytes from {}",
|
||||
data.len(),
|
||||
sender_addr
|
||||
);
|
||||
|
||||
sampling_histogram!("calling.udp_server.incoming_packet.size_bytes", || data
|
||||
.len());
|
||||
|
||||
Sfu::handle_packet(&sfu, sender_addr, data).unwrap_or_else(|err| {
|
||||
// Check for certain errors that can arise in normal conditions
|
||||
// (say, because UDP packets arrive out of order).
|
||||
// Note that we still use ".sfu" prefixes for these error events.
|
||||
match &err {
|
||||
SfuError::UnknownPacketType(_) => {
|
||||
event!("calling.sfu.error.expected.unhandled");
|
||||
trace!("handle_packet() failed: {}", err);
|
||||
}
|
||||
SfuError::IceBindingRequestUnknownUsername(_) => {
|
||||
event!("calling.sfu.error.expected.ice_binding_request_unknown_username");
|
||||
trace!("handle_packet() failed: {}", err);
|
||||
}
|
||||
_ => {
|
||||
event!("calling.sfu.error.unexpected");
|
||||
debug!("handle_packet() failed: {}", err);
|
||||
}
|
||||
}
|
||||
Vec::new()
|
||||
})
|
||||
});
|
||||
|
||||
// Spawn a normal (cooperative) task to run some regular maintenance on an interval.
|
||||
let tick_handle = tokio::spawn(async move {
|
||||
let mut tick_state = Default::default();
|
||||
loop {
|
||||
time_scope_us!("calling.udp_server.tick");
|
||||
// Use sleep() instead of interval() so that we never wait *less* than one interval
|
||||
// to do the next tick.
|
||||
tokio::time::sleep(tick_interval.into()).await;
|
||||
time_scope_us!("calling.udp_server.tick.processing");
|
||||
|
||||
let tick_output = { sfu_for_tick.lock().tick(Instant::now()) };
|
||||
|
||||
// Process outside the scope of the lock on the sfu.
|
||||
match udp_handler_state_for_tick.tick(tick_output, &mut tick_state) {
|
||||
Ok(()) => {}
|
||||
Err(err) => {
|
||||
error!("{}", err);
|
||||
is_healthy.store(false, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for any task to complete and cancel the rest.
|
||||
let _ = tokio::select!(
|
||||
_ = udp_packet_handles => {},
|
||||
_ = tick_handle => {},
|
||||
_ = udp_ender_rx => {},
|
||||
);
|
||||
|
||||
info!("udp_server shutdown");
|
||||
Ok(())
|
||||
}
|
||||
725
src/udp_server/epoll.rs
Normal file
725
src/udp_server/epoll.rs
Normal file
@ -0,0 +1,725 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::{
|
||||
collections::{hash_map, HashMap},
|
||||
future::Future,
|
||||
io,
|
||||
net::{SocketAddr, UdpSocket},
|
||||
os::unix::io::{AsRawFd, FromRawFd, RawFd},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use log::*;
|
||||
use nix::sys::{epoll::*, socket::*};
|
||||
use parking_lot::RwLock;
|
||||
|
||||
use crate::{
|
||||
common::{try_scoped, Duration},
|
||||
metrics::TimingOptions,
|
||||
sfu,
|
||||
};
|
||||
|
||||
/// Controls number of sockets a particular thread will handle without going back to epoll.
|
||||
///
|
||||
/// A higher number saves calls into the kernel, but claims more events for a single thread to
|
||||
/// process.
|
||||
const MAX_EPOLL_EVENTS: usize = 16;
|
||||
|
||||
/// The shared state for an epoll-based UDP server.
|
||||
///
|
||||
/// This server is implemented with a "new client" socket that receives new connections, plus a map
|
||||
/// of dedicated sockets for each connected client. Processing these sockets is handled by [epoll],
|
||||
/// with each thread of the UDP server getting its own epoll descriptor to block on. This allows
|
||||
/// events to be level-triggered (as in, threads will be repeatedly woken up if a socket with data
|
||||
/// is not immediately read from) while still only waking one thread for a particular event.
|
||||
///
|
||||
/// The implementation uses two-phase cleanup for clients that have left the call (either gracefully
|
||||
/// or through timeout). This avoids opening a new connection immediately after the old one was
|
||||
/// closed.
|
||||
///
|
||||
/// [epoll]: https://man7.org/linux/man-pages/man7/epoll.7.html
|
||||
pub(super) struct UdpServerState {
|
||||
local_addr: SocketAddr,
|
||||
new_client_socket: UdpSocket,
|
||||
all_epoll_fds: Vec<RawFd>,
|
||||
all_connections: RwLock<ConnectionMap>,
|
||||
tick_interval: Duration,
|
||||
}
|
||||
|
||||
/// The persistent state from tick to tick for the epoll-based server.
|
||||
///
|
||||
/// This tracks clients that have left the call on the previous tick.
|
||||
#[derive(Default)]
|
||||
pub(super) struct TickState {
|
||||
clients_that_left_previously: Vec<SocketAddr>,
|
||||
}
|
||||
|
||||
impl UdpServerState {
|
||||
/// Sets up the server state by binding an initial socket to `local_addr`.
|
||||
///
|
||||
/// Also creates a separate epoll file descriptor for each thread we plan to use.
|
||||
pub fn new(
|
||||
local_addr: SocketAddr,
|
||||
num_threads: usize,
|
||||
tick_interval: Duration,
|
||||
) -> Result<Arc<Self>> {
|
||||
let new_client_socket = Self::open_socket_with_reusable_port(&local_addr)?;
|
||||
let all_epoll_fds = (0..num_threads)
|
||||
.map(|_| epoll_create1(EpollCreateFlags::empty()))
|
||||
.collect::<nix::Result<_>>()?;
|
||||
let result = Self {
|
||||
local_addr,
|
||||
new_client_socket,
|
||||
all_epoll_fds,
|
||||
all_connections: RwLock::new(ConnectionMap::new()),
|
||||
tick_interval,
|
||||
};
|
||||
result.add_socket_to_poll_for_reads(&result.new_client_socket)?;
|
||||
Ok(Arc::new(result))
|
||||
}
|
||||
|
||||
/// Opens a socket and binds it to `local_addr` after setting the `SO_REUSEPORT` sockopt.
|
||||
///
|
||||
/// This allows multiple sockets to bind to the same address.
|
||||
fn open_socket_with_reusable_port(local_addr: &SocketAddr) -> Result<UdpSocket> {
|
||||
// Open an IPv4 UDP socket in blocking mode.
|
||||
let socket_fd = socket(
|
||||
AddressFamily::Inet,
|
||||
SockType::Datagram,
|
||||
SockFlag::empty(),
|
||||
SockProtocol::Udp,
|
||||
)?;
|
||||
// Allow later sockets to handle connections.
|
||||
setsockopt(socket_fd, sockopt::ReusePort, &true)?;
|
||||
// Bind the socket to the given local address.
|
||||
bind(
|
||||
socket_fd,
|
||||
&SockAddr::new_inet(InetAddr::from_std(local_addr)),
|
||||
)?;
|
||||
// Pass ownership into Rust.
|
||||
// std::net::UdpSocket can only be created and bound in one step, which
|
||||
// doesn't allow us to set SO_REUSEPORT.
|
||||
// Safety: we have just created this socket FD, so we know it's valid.
|
||||
let result = unsafe { UdpSocket::from_raw_fd(socket_fd) };
|
||||
// Set a read timeout for a "pseudo-nonblocking" interface.
|
||||
// Why? Because epoll might wake up more than one thread to read from a single socket.
|
||||
result.set_read_timeout(Some(Duration::from_millis(10).into()))?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Adds `socket` to be polled by each of the descriptors in `self.all_epoll_fds`.
|
||||
///
|
||||
/// Specifically, this is only polling for read events with "exclusive" wakeups. That is,
|
||||
/// "out-of-band" data will be ignored, and only one epoll FD will receive an event for any
|
||||
/// particular socket being ready.
|
||||
fn add_socket_to_poll_for_reads(&self, socket: &UdpSocket) -> Result<()> {
|
||||
let socket_fd = socket.as_raw_fd();
|
||||
let mut event_read_only = EpollEvent::new(
|
||||
EpollFlags::EPOLLIN | EpollFlags::EPOLLEXCLUSIVE,
|
||||
socket_fd as u64,
|
||||
);
|
||||
for &epoll_fd in &self.all_epoll_fds {
|
||||
epoll_ctl(
|
||||
epoll_fd,
|
||||
EpollOp::EpollCtlAdd,
|
||||
socket_fd,
|
||||
&mut event_read_only,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Launches the configured number of threads for the server using Tokio's blocking thread pool
|
||||
/// ([`tokio::task::spawn_blocking`]).
|
||||
///
|
||||
/// `handle_packet` should take a single incoming packet's source address and data and produce a
|
||||
/// (possibly empty) set of outgoing packets.
|
||||
///
|
||||
/// This should only be called once.
|
||||
pub fn start_threads(
|
||||
self: Arc<Self>,
|
||||
handle_packet: impl FnMut(SocketAddr, &mut [u8]) -> Vec<(Vec<u8>, SocketAddr)>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
) -> impl Future {
|
||||
let all_handles = self.all_epoll_fds.iter().map(|&epoll_fd| {
|
||||
let self_for_thread = self.clone();
|
||||
let handle_packet_for_thread = handle_packet.clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
self_for_thread.run(epoll_fd, handle_packet_for_thread)
|
||||
})
|
||||
});
|
||||
futures::future::select_all(all_handles)
|
||||
}
|
||||
|
||||
/// Runs a single listener on the current thread, polling `epoll_fd`.
|
||||
///
|
||||
/// See [`UdpServerState::start_threads`].
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
epoll_fd: RawFd,
|
||||
mut handle_packet: impl FnMut(SocketAddr, &mut [u8]) -> Vec<(Vec<u8>, SocketAddr)>,
|
||||
) {
|
||||
let new_client_socket_fd = self.new_client_socket.as_raw_fd();
|
||||
let mut buf = [0u8; 1500];
|
||||
|
||||
loop {
|
||||
let mut current_events = [EpollEvent::empty(); MAX_EPOLL_EVENTS];
|
||||
let num_events = epoll_wait(epoll_fd, &mut current_events, -1).unwrap_or_else(|err| {
|
||||
warn!("epoll_wait() failed: {}", err);
|
||||
0
|
||||
});
|
||||
for event in ¤t_events[..num_events] {
|
||||
let socket_fd = event.data() as i32;
|
||||
let connections_lock = self.all_connections.read();
|
||||
let socket = if socket_fd == new_client_socket_fd {
|
||||
&self.new_client_socket
|
||||
} else {
|
||||
match connections_lock.get_by_fd(socket_fd) {
|
||||
Some(socket) => socket,
|
||||
None => {
|
||||
// By the time we got to this event the socket was closed.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if event.events().contains(EpollFlags::EPOLLERR) {
|
||||
match socket.take_error() {
|
||||
Err(err) => {
|
||||
warn!("take_error() failed: {}", err);
|
||||
event!("calling.udp.epoll.take_error_failure");
|
||||
// Hopefully this is a transient failure. Just skip this socket for now.
|
||||
continue;
|
||||
}
|
||||
Ok(None) => {
|
||||
// Assume another thread got here first.
|
||||
continue;
|
||||
}
|
||||
Ok(Some(err)) => {
|
||||
if err.kind() == io::ErrorKind::ConnectionRefused {
|
||||
// This can happen when someone leaves a call
|
||||
// because e.g. their router stops forwarding packets.
|
||||
// This is normal with UDP; technically this error happened
|
||||
// with the *previous* packet and we're just finding out now.
|
||||
trace!("socket error: {}", err);
|
||||
|
||||
match socket.peer_addr() {
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"peer_addr() failed while handling an error: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
Ok(addr) => {
|
||||
// Drop the read lock...
|
||||
drop(connections_lock);
|
||||
// ...and connect with a write lock...
|
||||
let mut write_lock = self.all_connections.write();
|
||||
// ...and mark the connection as closed.
|
||||
// If we changed state (such as already going to Closed)
|
||||
// in between the locks, mark_closed is still safe to call:
|
||||
// - If the connection is still open, we want to close it.
|
||||
// - If the connection is closed, closing it again doesn't hurt.
|
||||
// - If the connection has been removed entirely, closing it does nothing.
|
||||
// - If the connection has been removed and the address gets reused,
|
||||
// we'll close a connection that doesn't belong here anymore.
|
||||
// That's very unlikely because it means we've had at least two ticks,
|
||||
// and it'll (hopefully) heal itself in another two.
|
||||
write_lock.mark_closed(&addr);
|
||||
// No need to read more from this socket.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
event!("calling.udp.epoll.socket_error");
|
||||
warn!("socket error: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We ignore all other events but EPOLLIN; hangups will be handled by tick()
|
||||
// expiring the connection.
|
||||
if !event.events().contains(EpollFlags::EPOLLIN) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// We only read one packet for each socket that's ready. This isn't as efficient
|
||||
// as it could be; if one socket has many packets ready, we have to go back into
|
||||
// the epoll loop to find that out. On the other hand, this does ensure that we
|
||||
// don't get stuck reading from one socket and ignore all others.
|
||||
//
|
||||
// Note that this relies on using epoll in level-triggered mode rather than
|
||||
// edge-triggered.
|
||||
let (size, sender_addr) = match socket.recv_from(&mut buf) {
|
||||
Err(err) => {
|
||||
match err.kind() {
|
||||
io::ErrorKind::TimedOut
|
||||
| io::ErrorKind::WouldBlock
|
||||
| io::ErrorKind::Interrupted => {}
|
||||
io::ErrorKind::ConnectionRefused => {
|
||||
// This can happen when someone leaves a call
|
||||
// because e.g. their router stops forwarding packets.
|
||||
// This is normal with UDP; technically this error happened
|
||||
// with the previous *sent* packet and we're just finding out now.
|
||||
trace!("recv_from() failed: {}", err);
|
||||
}
|
||||
_ => {
|
||||
warn!("recv_from() failed: {}", err);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Ok((size, sender_addr)) => (size, sender_addr),
|
||||
};
|
||||
drop(connections_lock);
|
||||
|
||||
let packets_to_send = handle_packet(sender_addr, &mut buf[..size]);
|
||||
|
||||
for (buf, addr) in packets_to_send {
|
||||
trace!("sending packet of {} bytes to {}", buf.len(), addr);
|
||||
time_scope!(
|
||||
"calling.udp.epoll.send_packet",
|
||||
TimingOptions::nanosecond_1000_per_minute()
|
||||
);
|
||||
sampling_histogram!("calling.epoll.send_packet.size_bytes", || buf.len());
|
||||
|
||||
let connections_lock = self.all_connections.read();
|
||||
match connections_lock.get_by_addr(&addr) {
|
||||
ConnectionState::Connected(socket) => {
|
||||
if let Err(err) = socket.send(&buf) {
|
||||
if err.kind() == io::ErrorKind::ConnectionRefused {
|
||||
// This can happen when someone leaves a call
|
||||
// because e.g. their router stops forwarding packets.
|
||||
// This is normal with UDP; technically this error happened
|
||||
// with the *previous* packet and we're just finding out now.
|
||||
trace!("send() failed: {}", err);
|
||||
|
||||
// Drop the read lock...
|
||||
drop(connections_lock);
|
||||
// ...and connect with a write lock...
|
||||
let mut write_lock = self.all_connections.write();
|
||||
// ...and mark the connection as closed.
|
||||
// If we changed state (such as already going to Closed)
|
||||
// in between the locks, mark_closed is still safe to call:
|
||||
// - If the connection is still open, we want to close it.
|
||||
// - If the connection is closed, closing it again doesn't hurt.
|
||||
// - If the connection has been removed entirely, closing it does nothing.
|
||||
// - If the connection has been removed and the address gets reused,
|
||||
// we'll close a connection that doesn't belong here anymore.
|
||||
// That's very unlikely because it means we've had at least two ticks,
|
||||
// and it'll (hopefully) heal itself in another two.
|
||||
write_lock.mark_closed(&addr);
|
||||
} else {
|
||||
warn!("send() failed: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
ConnectionState::Closed => {
|
||||
trace!("dropping packet (connection already closed)")
|
||||
}
|
||||
ConnectionState::NotYetConnected => {
|
||||
// Drop the read lock...
|
||||
drop(connections_lock);
|
||||
// ...and connect with a write lock...
|
||||
let mut write_lock = self.all_connections.write();
|
||||
|
||||
// ...and check if another thread beat us to it.
|
||||
match write_lock.get_by_addr(&addr) {
|
||||
ConnectionState::Connected(socket) => {
|
||||
if let Err(err) = socket.send(&buf) {
|
||||
if err.kind() == io::ErrorKind::ConnectionRefused {
|
||||
// This can happen when someone leaves a call
|
||||
// because e.g. their router stops forwarding packets.
|
||||
// This is normal with UDP; technically this error happened
|
||||
// with the *previous* packet and we're just finding out now.
|
||||
trace!("send() failed: {}", err);
|
||||
|
||||
// ...and mark the connection as closed.
|
||||
write_lock.mark_closed(&addr);
|
||||
} else {
|
||||
warn!("send() failed: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
ConnectionState::Closed => {
|
||||
trace!("dropping packet (connection already closed)")
|
||||
}
|
||||
ConnectionState::NotYetConnected => {
|
||||
trace!("connecting to {:?}", addr);
|
||||
match try_scoped(|| {
|
||||
let client_socket =
|
||||
Self::open_socket_with_reusable_port(&self.local_addr)?;
|
||||
client_socket.connect(addr)?;
|
||||
self.add_socket_to_poll_for_reads(&client_socket)?;
|
||||
Ok(client_socket)
|
||||
}) {
|
||||
Ok(client_socket) => {
|
||||
let client_socket = write_lock
|
||||
.get_or_insert_connected(client_socket, addr);
|
||||
if let Err(err) = client_socket.send(&buf) {
|
||||
warn!("send() failed: {}", err);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("failed to connect to peer: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Process the results of [`sfu::SfuServer::tick`].
|
||||
///
|
||||
/// This includes cleaning up connections for clients that have left.
|
||||
pub fn tick(
|
||||
&self,
|
||||
mut tick_update: sfu::TickOutput,
|
||||
persistent_tick_state: &mut TickState,
|
||||
) -> Result<()> {
|
||||
for (buf, addr) in tick_update.packets_to_send {
|
||||
trace!("sending tick packet of {} bytes to {}", buf.len(), addr);
|
||||
|
||||
let connections_lock = self.all_connections.read();
|
||||
match connections_lock.get_by_addr(&addr) {
|
||||
ConnectionState::Connected(socket) => {
|
||||
if let Err(err) = socket.send(&buf) {
|
||||
if err.kind() == io::ErrorKind::ConnectionRefused {
|
||||
// This can happen when someone leaves a call
|
||||
// because e.g. their router stops forwarding packets.
|
||||
// This is normal with UDP; technically this error happened
|
||||
// with the *previous* packet and we're just finding out now.
|
||||
trace!("send() failed: {}", err);
|
||||
|
||||
// This will call mark_closed below
|
||||
tick_update.expired_client_addrs.push(addr)
|
||||
} else {
|
||||
warn!("send() failed: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
ConnectionState::Closed => {
|
||||
trace!("dropping packet (connection already closed)")
|
||||
}
|
||||
ConnectionState::NotYetConnected => {
|
||||
trace!("dropping packet (not yet connected)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up any clients that have already left.
|
||||
if !tick_update.expired_client_addrs.is_empty()
|
||||
|| !persistent_tick_state
|
||||
.clients_that_left_previously
|
||||
.is_empty()
|
||||
{
|
||||
match self
|
||||
.all_connections
|
||||
.try_write_for(self.tick_interval.into())
|
||||
{
|
||||
None => {
|
||||
anyhow::bail!(
|
||||
"could not acquire connection lock after {:?}; one of the epoll handler threads is likely deadlocked",
|
||||
self.tick_interval
|
||||
);
|
||||
}
|
||||
Some(mut socket_lock) => {
|
||||
// Clean up clients from the last tick.
|
||||
// This two-phase cleanup makes the following scenario unlikely...
|
||||
// 1. UDP handler produces packets for socket X.
|
||||
// 2. The UDP handler is pre-empted.
|
||||
// 3. The tick handler runs and removes socket X.
|
||||
// 4. The UDP handler resumes and tries to send to socket X.
|
||||
// 5. The UDP handler thinks it needs to make a new connection.
|
||||
// ...but not impossible, since the tick handler could run *twice* in step 3.
|
||||
// In that case the new connection would be leaked.
|
||||
for addr in persistent_tick_state.clients_that_left_previously.iter() {
|
||||
socket_lock.remove_closed(addr);
|
||||
}
|
||||
|
||||
// Mark clients to be cleaned up next tick.
|
||||
for addr in tick_update.expired_client_addrs.iter() {
|
||||
socket_lock.mark_closed(&addr);
|
||||
}
|
||||
persistent_tick_state.clients_that_left_previously =
|
||||
tick_update.expired_client_addrs;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A doubly-keyed map that allows looking up a socket by raw file descriptor (for epoll) or by peer
|
||||
/// address.
|
||||
///
|
||||
/// The map owns the socket, so removal from the map will close the socket as well. However, when a
|
||||
/// socket is removed, the peer address stays in the map to distinguish "recently closed" from "not
|
||||
/// yet connected". [`ConnectionMap::mark_closed`] and [`ConnectionMap::remove_closed`] implement
|
||||
/// the two parts of this two-phase cleanup.
|
||||
///
|
||||
/// The map is generic to support unit testing, but isn't intended for storing anything else.
|
||||
struct ConnectionMap<T = UdpSocket> {
|
||||
/// The primary map from file descriptors to sockets.
|
||||
///
|
||||
/// The use of file descriptors is largely arbitrary; it's a value *already* uniquely associated
|
||||
/// with a socket.
|
||||
by_fd: HashMap<RawFd, T>,
|
||||
|
||||
/// The secondary map from peer addresses to file descriptors.
|
||||
///
|
||||
/// A value may be [`ConnectionMap::TOMBSTONE_FD`], in which case it represents a recently-closed connection.
|
||||
by_peer_addr: HashMap<SocketAddr, RawFd>,
|
||||
}
|
||||
|
||||
/// Represents the state of a connection in a [ConnectionMap].
|
||||
#[derive(Debug)]
|
||||
enum ConnectionState<T> {
|
||||
/// The peer address was not found, so there must be no existing connection.
|
||||
NotYetConnected,
|
||||
/// The given socket is connected to the peer in question.
|
||||
Connected(T),
|
||||
/// There was a connection to this peer but that connection has been closed.
|
||||
Closed,
|
||||
}
|
||||
|
||||
impl<T: AsRawFd> ConnectionMap<T> {
|
||||
/// A placeholder for `self.by_peer_addr` to represent a closed connection.
|
||||
const TOMBSTONE_FD: RawFd = -1;
|
||||
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
by_fd: HashMap::new(),
|
||||
by_peer_addr: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the socket for `peer_addr` or inserts `socket` if there isn't one.
|
||||
///
|
||||
/// If there is already a socket for `peer_addr`, the argument `socket` will be dropped (and the
|
||||
/// underlying socket closed).
|
||||
fn get_or_insert_connected(&mut self, socket: T, peer_addr: SocketAddr) -> &T {
|
||||
let fd = socket.as_raw_fd();
|
||||
match self.by_peer_addr.entry(peer_addr) {
|
||||
hash_map::Entry::Occupied(mut entry) => {
|
||||
if *entry.get() != Self::TOMBSTONE_FD {
|
||||
// This address is already connected to a different socket.
|
||||
return &self.by_fd[entry.get()];
|
||||
}
|
||||
entry.insert(fd);
|
||||
}
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
entry.insert(fd);
|
||||
}
|
||||
}
|
||||
let inserted_socket = match self.by_fd.entry(fd) {
|
||||
hash_map::Entry::Occupied(_) => {
|
||||
unreachable!("file descriptor reused before socket closed");
|
||||
}
|
||||
hash_map::Entry::Vacant(entry) => entry.insert(socket),
|
||||
};
|
||||
inserted_socket
|
||||
}
|
||||
|
||||
/// Gets the connection for `peer_addr`, which can be in any of the states represented by
|
||||
/// [ConnectionState].
|
||||
fn get_by_addr(&self, peer_addr: &SocketAddr) -> ConnectionState<&T> {
|
||||
match self.by_peer_addr.get(peer_addr) {
|
||||
None => ConnectionState::NotYetConnected,
|
||||
Some(&Self::TOMBSTONE_FD) => ConnectionState::Closed,
|
||||
Some(fd) => ConnectionState::Connected(&self.by_fd[fd]),
|
||||
}
|
||||
}
|
||||
|
||||
/// Looks up a socket by file descriptor.
|
||||
fn get_by_fd(&self, fd: RawFd) -> Option<&T> {
|
||||
self.by_fd.get(&fd)
|
||||
}
|
||||
|
||||
/// Marks the connection for `peer_addr` as closed.
|
||||
///
|
||||
/// The socket associated with that connection will be removed from the map. If there was no
|
||||
/// connection for the given peer, or if it was already closed, returns `None`.
|
||||
fn mark_closed(&mut self, peer_addr: &SocketAddr) -> Option<T> {
|
||||
let entry = self.by_peer_addr.get_mut(peer_addr)?;
|
||||
// Not stricly necessary, but a small perf optimization for when
|
||||
// mark_closed is called more than once.
|
||||
if *entry == Self::TOMBSTONE_FD {
|
||||
return None;
|
||||
}
|
||||
let fd = std::mem::replace(entry, Self::TOMBSTONE_FD);
|
||||
self.by_fd.remove(&fd)
|
||||
}
|
||||
|
||||
/// Removes the entry for `peer_addr` from the map, which must have previously been marked
|
||||
/// closed.
|
||||
///
|
||||
/// This allows a peer address to be reused (perhaps reconnecting to the server). It also keeps
|
||||
/// the peer map from growing indefinitely.
|
||||
///
|
||||
/// See [`ConnectionMap::mark_closed`].
|
||||
fn remove_closed(&mut self, peer_addr: &SocketAddr) {
|
||||
match self.by_peer_addr.remove_entry(peer_addr) {
|
||||
None => {
|
||||
warn!("no connection record to remove for this address");
|
||||
}
|
||||
Some((_, Self::TOMBSTONE_FD)) => {}
|
||||
Some((addr, fd)) => {
|
||||
// There's already a new connection to this address. Put the entry back.
|
||||
self.by_peer_addr.insert(addr, fd);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FakeSocket {
|
||||
fd: RawFd,
|
||||
id: i32,
|
||||
}
|
||||
impl AsRawFd for FakeSocket {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.fd
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_map_absent() {
|
||||
let mut map: ConnectionMap<FakeSocket> = ConnectionMap::new();
|
||||
let addr = "127.0.0.1:80".parse().expect("valid SocketAddr");
|
||||
|
||||
assert!(map.get_by_fd(0).is_none());
|
||||
assert!(matches!(
|
||||
map.get_by_addr(&addr),
|
||||
ConnectionState::NotYetConnected
|
||||
));
|
||||
assert!(map.mark_closed(&addr).is_none());
|
||||
map.remove_closed(&addr); // just don't panic
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_map_lifecycle() {
|
||||
let mut map: ConnectionMap<FakeSocket> = ConnectionMap::new();
|
||||
let addr: SocketAddr = "127.0.0.1:80".parse().expect("valid SocketAddr");
|
||||
|
||||
// Insert
|
||||
let fd = 5;
|
||||
let id = 55;
|
||||
let socket = FakeSocket { fd, id };
|
||||
let socket_ref = map.get_or_insert_connected(socket, addr);
|
||||
assert_eq!(socket_ref.id, id);
|
||||
|
||||
assert_eq!(map.get_by_fd(fd).expect("present").id, id);
|
||||
match map.get_by_addr(&addr) {
|
||||
ConnectionState::Connected(socket) => {
|
||||
assert_eq!(socket.id, id);
|
||||
}
|
||||
state => {
|
||||
panic!("unexpected state: {:?}", state)
|
||||
}
|
||||
}
|
||||
|
||||
// Mark closed.
|
||||
let socket = map.mark_closed(&addr).expect("present");
|
||||
assert_eq!(socket.id, id);
|
||||
|
||||
assert!(map.get_by_fd(fd).is_none());
|
||||
assert!(matches!(map.get_by_addr(&addr), ConnectionState::Closed));
|
||||
|
||||
// Remove closed.
|
||||
map.remove_closed(&addr);
|
||||
|
||||
assert!(map.get_by_fd(fd).is_none());
|
||||
assert!(matches!(
|
||||
map.get_by_addr(&addr),
|
||||
ConnectionState::NotYetConnected
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_map_first_insert_wins() {
|
||||
let mut map: ConnectionMap<FakeSocket> = ConnectionMap::new();
|
||||
let addr: SocketAddr = "127.0.0.1:80".parse().expect("valid SocketAddr");
|
||||
|
||||
let fd = 5;
|
||||
let id = 55;
|
||||
let socket = FakeSocket { fd, id };
|
||||
let socket_ref = map.get_or_insert_connected(socket, addr);
|
||||
assert_eq!(socket_ref.id, id);
|
||||
|
||||
// Check that we don't replace an existing connection.
|
||||
let new_socket = FakeSocket { fd, id: id + 1 };
|
||||
let socket_ref = map.get_or_insert_connected(new_socket, addr);
|
||||
assert_eq!(socket_ref.id, id);
|
||||
|
||||
assert_eq!(map.get_by_fd(fd).expect("present").id, id);
|
||||
match map.get_by_addr(&addr) {
|
||||
ConnectionState::Connected(socket) => {
|
||||
assert_eq!(socket.id, id);
|
||||
}
|
||||
state => {
|
||||
panic!("unexpected state: {:?}", state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_map_can_insert_over_closed() {
|
||||
let mut map: ConnectionMap<FakeSocket> = ConnectionMap::new();
|
||||
let addr: SocketAddr = "127.0.0.1:80".parse().expect("valid SocketAddr");
|
||||
|
||||
let fd = 5;
|
||||
let id = 55;
|
||||
let socket = FakeSocket { fd, id };
|
||||
let socket_ref = map.get_or_insert_connected(socket, addr);
|
||||
assert_eq!(socket_ref.id, id);
|
||||
|
||||
map.mark_closed(&addr);
|
||||
// But don't remove it!
|
||||
|
||||
let new_socket = FakeSocket { fd, id: id + 1 };
|
||||
let socket_ref = map.get_or_insert_connected(new_socket, addr);
|
||||
assert_eq!(socket_ref.id, id + 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_map_remove_open() {
|
||||
let mut map: ConnectionMap<FakeSocket> = ConnectionMap::new();
|
||||
let addr: SocketAddr = "127.0.0.1:80".parse().expect("valid SocketAddr");
|
||||
|
||||
// Insert
|
||||
let fd = 5;
|
||||
let id = 55;
|
||||
let socket = FakeSocket { fd, id };
|
||||
let socket_ref = map.get_or_insert_connected(socket, addr);
|
||||
assert_eq!(socket_ref.id, id);
|
||||
|
||||
// Try to remove.
|
||||
map.remove_closed(&addr);
|
||||
|
||||
assert_eq!(map.get_by_fd(fd).expect("present").id, id);
|
||||
match map.get_by_addr(&addr) {
|
||||
ConnectionState::Connected(socket) => {
|
||||
assert_eq!(socket.id, id);
|
||||
}
|
||||
state => {
|
||||
panic!("unexpected state: {:?}", state)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
119
src/udp_server/generic.rs
Normal file
119
src/udp_server/generic.rs
Normal file
@ -0,0 +1,119 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use std::{
|
||||
future::Future,
|
||||
net::{SocketAddr, UdpSocket},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use log::*;
|
||||
|
||||
use crate::{common::Duration, metrics::TimingOptions, sfu};
|
||||
|
||||
/// The shared state for a generic UDP server.
|
||||
///
|
||||
/// This server is implemented with a single socket for all sends and receives. Multiple threads can
|
||||
/// use the socket, but this only helps if packet processing takes a long time. Otherwise they'll
|
||||
/// just block in the kernel trying to send.
|
||||
pub(super) struct UdpServerState {
|
||||
socket: UdpSocket,
|
||||
num_threads: usize,
|
||||
}
|
||||
|
||||
/// The persistent state from tick to tick for the generic server.
|
||||
///
|
||||
/// The generic server does not need to carry any state from tick to tick.
|
||||
#[derive(Default)]
|
||||
pub(super) struct TickState;
|
||||
|
||||
impl UdpServerState {
|
||||
/// Sets up the server state by binding a socket to `local_addr`.
|
||||
pub fn new(
|
||||
local_addr: SocketAddr,
|
||||
num_threads: usize,
|
||||
_tick_interval: Duration,
|
||||
) -> Result<Arc<Self>> {
|
||||
Ok(Arc::new(Self {
|
||||
socket: UdpSocket::bind(local_addr)?,
|
||||
num_threads,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Launches the configured number of threads for the server using Tokio's blocking thread pool
|
||||
/// ([`tokio::task::spawn_blocking`]).
|
||||
///
|
||||
/// `handle_packet` should take a single incoming packet's source address and data and produce a
|
||||
/// (possibly empty) set of outgoing packets.
|
||||
///
|
||||
/// This should only be called once.
|
||||
pub fn start_threads(
|
||||
self: Arc<Self>,
|
||||
handle_packet: impl FnMut(SocketAddr, &mut [u8]) -> Vec<(Vec<u8>, SocketAddr)>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
) -> impl Future {
|
||||
let all_handles = (0..self.num_threads).map(|_| {
|
||||
let self_for_thread = self.clone();
|
||||
let handle_packet_for_thread = handle_packet.clone();
|
||||
tokio::task::spawn_blocking(move || self_for_thread.run(handle_packet_for_thread))
|
||||
});
|
||||
futures::future::select_all(all_handles)
|
||||
}
|
||||
|
||||
/// Runs a single listener on the current thread.
|
||||
///
|
||||
/// See [`UdpServerState::start_threads`].
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
mut handle_packet: impl FnMut(SocketAddr, &mut [u8]) -> Vec<(Vec<u8>, SocketAddr)>,
|
||||
) {
|
||||
let mut buf = [0u8; 1500];
|
||||
|
||||
loop {
|
||||
let received_packet = match self.socket.recv_from(&mut buf) {
|
||||
Err(err) => {
|
||||
warn!("recv_from() failed: {}", err);
|
||||
None
|
||||
}
|
||||
Ok((size, sender_addr)) => Some((size, sender_addr)),
|
||||
};
|
||||
|
||||
if let Some((size, sender_addr)) = received_packet {
|
||||
let packets_to_send = handle_packet(sender_addr, &mut buf[..size]);
|
||||
for (buf, addr) in packets_to_send {
|
||||
trace!("sending packet of {} bytes to {}", buf.len(), addr);
|
||||
time_scope!(
|
||||
"calling.udp.generic.send_packet",
|
||||
TimingOptions::nanosecond_1000_per_minute()
|
||||
);
|
||||
sampling_histogram!("calling.generic.send_packet.size_bytes", || buf.len());
|
||||
|
||||
if let Err(err) = self.socket.send_to(&buf, addr) {
|
||||
warn!("send_to failed: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Process the results of [`sfu::Sfu::tick`].
|
||||
pub fn tick(
|
||||
&self,
|
||||
tick_update: sfu::TickOutput,
|
||||
_persistent_tick_state: &mut TickState,
|
||||
) -> Result<()> {
|
||||
for (buf, addr) in tick_update.packets_to_send {
|
||||
trace!("sending packet of {} bytes to {}", buf.len(), addr);
|
||||
|
||||
if let Err(err) = self.socket.send_to(&buf, addr) {
|
||||
warn!("send_to failed: {}", err);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
537
src/vp8.rs
Normal file
537
src/vp8.rs
Normal file
@ -0,0 +1,537 @@
|
||||
//
|
||||
// Copyright 2021 Signal Messenger, LLC
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
use anyhow::Result;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::common::{expand_truncated_counter, Bits, BytesReader, PixelSize, ReadResult};
|
||||
|
||||
pub type TruncatedPictureId = u16;
|
||||
pub type FullPictureId = u64;
|
||||
pub type TruncatedTl0PicIdx = u8;
|
||||
pub type FullTl0PicIdx = u64;
|
||||
|
||||
/// See https://tools.ietf.org/html/rfc7741 for the format.
|
||||
#[derive(Debug, Default, Clone, Eq, PartialEq)]
|
||||
pub struct ParsedHeader {
|
||||
/// Incremented with each video frame. Really a u15.
|
||||
/// Used to provide indicate frame order and gaps.
|
||||
/// Must be rewritten or cleared when forwarding simulcast.
|
||||
pub picture_id: Option<TruncatedPictureId>,
|
||||
|
||||
// /// If false, the frame can be discarded without disrupting future frames.
|
||||
// /// There doesn't seem to be any use for this field
|
||||
// /// because we don't support dropping frames in the SFU.
|
||||
// referenced: bool,
|
||||
//
|
||||
/// Incremented with each frame with TemporalLayerId == 0.
|
||||
/// Used to indicate temporal layer dependencies.
|
||||
/// Frames with TemporalLayerId > 0 refer to frames with TemporalLayerId == 0
|
||||
/// either directly or through a frame with one less TemoralLayerId.
|
||||
/// Must be rewritten or cleared when forwarding simulcast.
|
||||
pub tl0_pic_idx: Option<TruncatedTl0PicIdx>,
|
||||
|
||||
// /// 0 = temporal base layer. Really a u4.
|
||||
// /// There doesn't seem to be any use for this field
|
||||
// /// because we don't support dropping frames in the SFU.
|
||||
// temporal_layer_id: Option<u8>,
|
||||
|
||||
// /// AKA "layer sync". If true, this frame references temporal layer 0
|
||||
// /// even if this frame's temporal_layer_id > 1. If false, this frame
|
||||
// /// references a frame with temporal_layer_id-1.
|
||||
// /// But there doesn't seem to be any use for this field
|
||||
// /// because we don't support dropping frames in the SFU.
|
||||
// references_temporal_layer0_directly: Option<bool>,
|
||||
//
|
||||
/// Incremented with each key frame. Really a u5.
|
||||
/// There doesn't seem to be any use for this field.
|
||||
/// key_frame_index: Option<u8>,
|
||||
pub is_key_frame: bool,
|
||||
|
||||
/// (width, height). Only included in the header if is_key_frame.
|
||||
/// Subsequent frames must have the same resolution.
|
||||
/// Really u14s.
|
||||
pub resolution: Option<PixelSize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
struct Byte0 {
|
||||
has_extensions: bool,
|
||||
starts_partition: bool,
|
||||
zero_partition_idx: bool,
|
||||
}
|
||||
|
||||
impl Byte0 {
|
||||
fn parse(byte0: u8) -> Self {
|
||||
Self {
|
||||
has_extensions: byte0.ms_bit(0), // X bit
|
||||
//_reserved1: byte0.ms_bit(1), // R bit
|
||||
//_non_ref_frame: byte0.ms_bit(2), // N bit
|
||||
starts_partition: byte0.ms_bit(3), // S bit,
|
||||
//_reserved2: byte0.ms_bit(4), // R bit
|
||||
zero_partition_idx: byte0 & 0b111 == 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Note that the payload header is present only in packets that have the S bit equal to one
|
||||
/// and the PID equal to zero in the payload descriptor.
|
||||
///
|
||||
/// https://datatracker.ietf.org/doc/html/rfc7741#section-4.3
|
||||
fn has_payload_header(&self) -> bool {
|
||||
self.starts_partition && self.zero_partition_idx
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
struct XByte {
|
||||
has_picture_id: bool,
|
||||
has_tl0_pic_idx: bool,
|
||||
has_tid: bool,
|
||||
has_key_idx: bool,
|
||||
}
|
||||
|
||||
impl XByte {
|
||||
fn parse(x_byte: u8) -> Self {
|
||||
Self {
|
||||
has_picture_id: x_byte.ms_bit(0), // I bit
|
||||
has_tl0_pic_idx: x_byte.ms_bit(1), // L bit
|
||||
has_tid: x_byte.ms_bit(2), // T bit
|
||||
has_key_idx: x_byte.ms_bit(3), // K bit
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
struct PayloadHeader {
|
||||
key_frame: bool,
|
||||
}
|
||||
|
||||
impl PayloadHeader {
|
||||
fn parse(byte: u8) -> Self {
|
||||
Self {
|
||||
key_frame: !byte.ms_bit(7), // P bit: Inverse key frame flag.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Eq, PartialEq, Debug, Copy, Clone)]
|
||||
pub enum Vp8Error {
|
||||
#[error("Got a 7-bit VP8 picture ID. Expecting only 15-bit picture IDs.")]
|
||||
SevenBitPictureId,
|
||||
}
|
||||
|
||||
impl ParsedHeader {
|
||||
/// This reads both the "descriptor" and the "header"
|
||||
/// See https://datatracker.ietf.org/doc/html/rfc7741#section-4.2
|
||||
pub fn read(payload: &[u8]) -> Result<Self> {
|
||||
let mut payload = BytesReader::from_slice(payload);
|
||||
let mut header = Self::default();
|
||||
|
||||
let byte0 = Byte0::parse(payload.read_u8()?);
|
||||
|
||||
if byte0.has_extensions {
|
||||
let x_byte = XByte::parse(payload.read_u8()?);
|
||||
|
||||
if x_byte.has_picture_id {
|
||||
let mut peek = BytesReader::clone(&payload);
|
||||
if !peek.read_u8()?.ms_bit(0) {
|
||||
// The spec says it could be 7-bit, but WebRTC only sends 15-bit
|
||||
return Err(Vp8Error::SevenBitPictureId.into());
|
||||
}
|
||||
let picture_id_with_leading_bit = payload.read_u16_be()?;
|
||||
header.picture_id = Some(picture_id_with_leading_bit & 0b0111_1111_1111_1111);
|
||||
}
|
||||
|
||||
if x_byte.has_tl0_pic_idx {
|
||||
let tl0_pic_idx = payload.read_u8()?;
|
||||
header.tl0_pic_idx = Some(tl0_pic_idx);
|
||||
};
|
||||
|
||||
if x_byte.has_tid || x_byte.has_key_idx {
|
||||
let _tk_byte = payload.read_u8()?;
|
||||
// If in the future we want the TID or key frame index, here is how to get it:
|
||||
// if has_tid {
|
||||
// header.temporal_layer_id = Some(tk_byte >> 6);
|
||||
// header.references_temporal_layer0_directly = Some(tk_byte.ms_bit(2));
|
||||
// }
|
||||
// if has_key_idx {
|
||||
// header.key_frame_index = Some(tk_byte & 0b0001_1111);
|
||||
// }
|
||||
};
|
||||
}
|
||||
|
||||
if byte0.has_payload_header() {
|
||||
// The codec bitstream format specifies two different variants of the uncompressed data
|
||||
// chunk: a 3-octet version for interframes and a 10-octet version for key frames.
|
||||
// The first 3 octets are common to both variants.
|
||||
let mut common_header = payload.read(3)?;
|
||||
let payload0 = PayloadHeader::parse(common_header.read_u8()?);
|
||||
header.is_key_frame = payload0.key_frame;
|
||||
if header.is_key_frame {
|
||||
// In the case of a key frame, the remaining 7 octets are considered to be part
|
||||
// of the remaining payload in this RTP format.
|
||||
let mut additional_key_frame_header = payload.read(7)?;
|
||||
header.resolution = Some(ParsedHeader::size_from_additional_key_frame_header(
|
||||
&mut additional_key_frame_header,
|
||||
)?);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(header)
|
||||
}
|
||||
|
||||
/// see https://datatracker.ietf.org/doc/html/rfc6386#section-9.1
|
||||
fn size_from_additional_key_frame_header(
|
||||
additional_key_frame_header: &mut BytesReader,
|
||||
) -> ReadResult<PixelSize> {
|
||||
let _skipped = additional_key_frame_header.read_bytes(3)?;
|
||||
let width_with_scale = additional_key_frame_header.read_u16_le()?;
|
||||
let height_with_scale = additional_key_frame_header.read_u16_le()?;
|
||||
let width = width_with_scale & 0b11_1111_1111_1111;
|
||||
let height = height_with_scale & 0b11_1111_1111_1111;
|
||||
Ok(PixelSize { width, height })
|
||||
}
|
||||
}
|
||||
|
||||
// This assumes that the picture ID and TL0 PIC IDX are present in the packet
|
||||
// and that the picture ID is of the 15-bit variety.
|
||||
// If they aren't, the payload will be corrupted
|
||||
pub fn modify_header(
|
||||
rtp_payload: &mut [u8],
|
||||
picture_id: TruncatedPictureId,
|
||||
tl0_pic_idx: TruncatedTl0PicIdx,
|
||||
) {
|
||||
rtp_payload[2..4].copy_from_slice(&((picture_id | 0b1000_0000_0000_0000).to_be_bytes()));
|
||||
rtp_payload[4] = tl0_pic_idx;
|
||||
}
|
||||
|
||||
pub fn expand_picture_id(truncated: TruncatedPictureId, max: &mut FullPictureId) -> FullPictureId {
|
||||
expand_truncated_counter(truncated, max, 15)
|
||||
}
|
||||
|
||||
pub fn expand_tl0_pic_idx(truncated: TruncatedTl0PicIdx, max: &mut FullTl0PicIdx) -> FullTl0PicIdx {
|
||||
expand_truncated_counter(truncated, max, 8)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod byte_0_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn zero() {
|
||||
assert_eq!(
|
||||
Byte0::parse(0b00000000),
|
||||
Byte0 {
|
||||
has_extensions: false,
|
||||
starts_partition: false,
|
||||
zero_partition_idx: true
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_ones() {
|
||||
assert_eq!(
|
||||
Byte0::parse(0b11111111),
|
||||
Byte0 {
|
||||
has_extensions: true,
|
||||
starts_partition: true,
|
||||
zero_partition_idx: false
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reserved_ignored() {
|
||||
assert_eq!(
|
||||
Byte0::parse(0b01001000),
|
||||
Byte0 {
|
||||
has_extensions: false,
|
||||
starts_partition: false,
|
||||
zero_partition_idx: true
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn has_extensions() {
|
||||
assert_eq!(
|
||||
Byte0::parse(0b10000000),
|
||||
Byte0 {
|
||||
has_extensions: true,
|
||||
starts_partition: false,
|
||||
zero_partition_idx: true
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn begins_partition() {
|
||||
assert_eq!(
|
||||
Byte0::parse(0b00010000),
|
||||
Byte0 {
|
||||
has_extensions: false,
|
||||
starts_partition: true,
|
||||
zero_partition_idx: true
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_zero_partitions() {
|
||||
assert_eq!(
|
||||
Byte0::parse(0b00000001),
|
||||
Byte0 {
|
||||
has_extensions: false,
|
||||
starts_partition: false,
|
||||
zero_partition_idx: false
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
Byte0::parse(0b00000010),
|
||||
Byte0 {
|
||||
has_extensions: false,
|
||||
starts_partition: false,
|
||||
zero_partition_idx: false
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
Byte0::parse(0b00000100),
|
||||
Byte0 {
|
||||
has_extensions: false,
|
||||
starts_partition: false,
|
||||
zero_partition_idx: false
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod x_byte_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn zero() {
|
||||
assert_eq!(
|
||||
XByte::parse(0b00000000),
|
||||
XByte {
|
||||
has_picture_id: false,
|
||||
has_tl0_pic_idx: false,
|
||||
has_tid: false,
|
||||
has_key_idx: false
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_ones() {
|
||||
assert_eq!(
|
||||
XByte::parse(0b11111111),
|
||||
XByte {
|
||||
has_picture_id: true,
|
||||
has_tl0_pic_idx: true,
|
||||
has_tid: true,
|
||||
has_key_idx: true
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reserved_ignored() {
|
||||
assert_eq!(
|
||||
XByte::parse(0b00001111),
|
||||
XByte {
|
||||
has_picture_id: false,
|
||||
has_tl0_pic_idx: false,
|
||||
has_tid: false,
|
||||
has_key_idx: false
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn has_picture_id() {
|
||||
assert_eq!(
|
||||
XByte::parse(0b10000000),
|
||||
XByte {
|
||||
has_picture_id: true,
|
||||
has_tl0_pic_idx: false,
|
||||
has_tid: false,
|
||||
has_key_idx: false
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn has_t10_pic_index() {
|
||||
assert_eq!(
|
||||
XByte::parse(0b01000000),
|
||||
XByte {
|
||||
has_picture_id: false,
|
||||
has_tl0_pic_idx: true,
|
||||
has_tid: false,
|
||||
has_key_idx: false
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn has_tid() {
|
||||
assert_eq!(
|
||||
XByte::parse(0b00100000),
|
||||
XByte {
|
||||
has_picture_id: false,
|
||||
has_tl0_pic_idx: false,
|
||||
has_tid: true,
|
||||
has_key_idx: false
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn has_key_index() {
|
||||
assert_eq!(
|
||||
XByte::parse(0b00010000),
|
||||
XByte {
|
||||
has_picture_id: false,
|
||||
has_tl0_pic_idx: false,
|
||||
has_tid: false,
|
||||
has_key_idx: true
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod payload_header_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn non_key_frame() {
|
||||
assert_eq!(
|
||||
PayloadHeader::parse(0b00000001),
|
||||
PayloadHeader { key_frame: false }
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_frame() {
|
||||
assert_eq!(
|
||||
PayloadHeader::parse(0b00000000),
|
||||
PayloadHeader { key_frame: true }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod read_header_tests {
|
||||
use hex_literal::hex;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn read_header() {
|
||||
let data = &hex!(
|
||||
"
|
||||
/* byte0 */ 90
|
||||
/* xbyte */ c0
|
||||
/* picture_id */ 9267 // (with leading bit)
|
||||
/* tl0_pic_idx */ dc
|
||||
/* payload0 */ 00
|
||||
/* skipped */ 0000000000
|
||||
/* width and scale */ 8002
|
||||
/* height and scale */ 6801
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
ParsedHeader::read(data).unwrap(),
|
||||
ParsedHeader {
|
||||
picture_id: Some(4711),
|
||||
tl0_pic_idx: Some(220),
|
||||
is_key_frame: true,
|
||||
resolution: Some(PixelSize {
|
||||
width: 640,
|
||||
height: 360
|
||||
})
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_header_alternative_values() {
|
||||
let data = &hex!(
|
||||
"
|
||||
/* byte0 */ 90
|
||||
/* xbyte */ c0
|
||||
/* picture_id */ 81d4 // (with leading bit)
|
||||
/* tl0_pic_idx */ d4
|
||||
/* payload0 */ 00
|
||||
/* skipped */ 0000000000
|
||||
/* width and scale */ 8007
|
||||
/* height and scale */ 38C4
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
ParsedHeader::read(data).unwrap(),
|
||||
ParsedHeader {
|
||||
picture_id: Some(468),
|
||||
tl0_pic_idx: Some(212),
|
||||
is_key_frame: true,
|
||||
resolution: Some(PixelSize {
|
||||
width: 1920,
|
||||
height: 1080
|
||||
})
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_extensions() {
|
||||
let data = &hex!(
|
||||
"
|
||||
/* byte0 */ 10
|
||||
/* payload0 */ 00
|
||||
/* skipped */ 0000000000
|
||||
/* width and scale */ 8002
|
||||
/* height and scale */ 6801
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
ParsedHeader::read(data).unwrap(),
|
||||
ParsedHeader {
|
||||
picture_id: None,
|
||||
tl0_pic_idx: None,
|
||||
is_key_frame: true,
|
||||
resolution: Some(PixelSize {
|
||||
width: 640,
|
||||
height: 360
|
||||
})
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn seven_bit_picture_id() {
|
||||
let data = &hex!(
|
||||
"
|
||||
/* byte0 */ 90
|
||||
/* xbyte */ c0
|
||||
/* picture_id */ 12 // seven bits due to no leading bit set
|
||||
/* tl0_pic_idx */ dc
|
||||
/* payload0 */ 00
|
||||
/* skipped */ 0000000000
|
||||
/* width and scale */ 8002
|
||||
/* height and scale */ 6801
|
||||
"
|
||||
);
|
||||
assert_eq!(
|
||||
ParsedHeader::read(data)
|
||||
.unwrap_err()
|
||||
.downcast::<Vp8Error>()
|
||||
.unwrap(),
|
||||
Vp8Error::SevenBitPictureId
|
||||
);
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user